mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-17 19:40:44 +00:00
🐛 Fix: Bot Hook 没有捕获跳过异常 (#905)
This commit is contained in:
@ -56,7 +56,7 @@ class CustomEnvSettings(EnvSettingsSource):
|
||||
if env_path.is_file():
|
||||
env_file_vars = read_env_file(
|
||||
env_path,
|
||||
encoding=env_file_encoding,
|
||||
encoding=env_file_encoding, # type: ignore
|
||||
case_sensitive=settings.__config__.case_sensitive,
|
||||
)
|
||||
env_vars = {**env_file_vars, **env_vars}
|
||||
|
@ -38,7 +38,7 @@ try:
|
||||
from quart.datastructures import FileStorage
|
||||
from quart import Websocket as QuartWebSocket
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Please install Quart by using `pip install nonebot2[quart]`"
|
||||
) from None
|
||||
|
||||
|
@ -7,11 +7,11 @@ NoneBotException
|
||||
├── ParserExit
|
||||
├── ProcessException
|
||||
| ├── IgnoredException
|
||||
| ├── SkippedException
|
||||
| | └── TypeMisMatch
|
||||
| ├── MockApiException
|
||||
| └── StopPropagation
|
||||
├── MatcherException
|
||||
| ├── SkippedException
|
||||
| | └── TypeMisMatch
|
||||
| ├── PausedException
|
||||
| ├── RejectedException
|
||||
| └── FinishedException
|
||||
@ -75,6 +75,37 @@ class IgnoredException(ProcessException):
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
class SkippedException(ProcessException):
|
||||
"""指示 NoneBot 立即结束当前 `Dependent` 的运行。
|
||||
|
||||
例如,可以在 `Handler` 中通过 {ref}`nonebot.matcher.Matcher.skip` 抛出。
|
||||
|
||||
用法:
|
||||
```python
|
||||
def always_skip():
|
||||
Matcher.skip()
|
||||
|
||||
@matcher.handle()
|
||||
async def handler(dependency = Depends(always_skip)):
|
||||
# never run
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class TypeMisMatch(SkippedException):
|
||||
"""当前 `Handler` 的参数类型不匹配。"""
|
||||
|
||||
def __init__(self, param: ModelField, value: Any):
|
||||
self.param: ModelField = param
|
||||
self.value: Any = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TypeMisMatch, param={self.param}, value={self.value}>"
|
||||
|
||||
def __str__(self):
|
||||
self.__repr__()
|
||||
|
||||
|
||||
class MockApiException(ProcessException):
|
||||
"""指示 NoneBot 阻止本次 API 调用或修改本次调用返回值,并返回自定义内容。可由 api hook 抛出。
|
||||
|
||||
@ -114,37 +145,6 @@ class MatcherException(NoneBotException):
|
||||
"""所有 Matcher 发生的异常基类。"""
|
||||
|
||||
|
||||
class SkippedException(MatcherException):
|
||||
"""指示 NoneBot 立即结束当前 `Handler` 的处理,继续处理下一个 `Handler`。
|
||||
|
||||
可以在 `Handler` 中通过 {ref}`nonebot.matcher.Matcher.skip` 抛出。
|
||||
|
||||
用法:
|
||||
```python
|
||||
def always_skip():
|
||||
Matcher.skip()
|
||||
|
||||
@matcher.handle()
|
||||
async def handler(dependency = Depends(always_skip)):
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class TypeMisMatch(SkippedException):
|
||||
"""当前 `Handler` 的参数类型不匹配。"""
|
||||
|
||||
def __init__(self, param: ModelField, value: Any):
|
||||
self.param: ModelField = param
|
||||
self.value: Any = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TypeMisMatch, param={self.param}, value={self.value}>"
|
||||
|
||||
def __str__(self):
|
||||
self.__repr__()
|
||||
|
||||
|
||||
class PausedException(MatcherException):
|
||||
"""指示 NoneBot 结束当前 `Handler` 并等待下一条消息后继续下一个 `Handler`。可用于用户输入新信息。
|
||||
|
||||
|
@ -4,9 +4,10 @@ from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.utils import escape_tag, run_coro_with_catch
|
||||
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
|
||||
from nonebot.internal.params import BotParam, DependParam, DefaultParam
|
||||
|
||||
@ -128,7 +129,12 @@ class Driver(abc.ABC):
|
||||
self._clients[bot.self_id] = bot
|
||||
|
||||
async def _run_hook(bot: "Bot") -> None:
|
||||
coros = list(map(lambda x: x(bot=bot), self._bot_connection_hook))
|
||||
coros = list(
|
||||
map(
|
||||
lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)),
|
||||
self._bot_connection_hook,
|
||||
)
|
||||
)
|
||||
if coros:
|
||||
try:
|
||||
await asyncio.gather(*coros)
|
||||
@ -146,7 +152,12 @@ class Driver(abc.ABC):
|
||||
del self._clients[bot.self_id]
|
||||
|
||||
async def _run_hook(bot: "Bot") -> None:
|
||||
coros = list(map(lambda x: x(bot=bot), self._bot_disconnection_hook))
|
||||
coros = list(
|
||||
map(
|
||||
lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)),
|
||||
self._bot_disconnection_hook,
|
||||
)
|
||||
)
|
||||
if coros:
|
||||
try:
|
||||
await asyncio.gather(*coros)
|
||||
|
@ -248,8 +248,8 @@ class Cookies(MutableMapping):
|
||||
self,
|
||||
name: str,
|
||||
default: Optional[str] = None,
|
||||
domain: str = None,
|
||||
path: str = None,
|
||||
domain: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
value: Optional[str] = None
|
||||
for cookie in self.jar:
|
||||
|
@ -3,6 +3,7 @@ from contextlib import AsyncExitStack
|
||||
from typing import Any, Set, Tuple, Union, NoReturn, Optional, Coroutine
|
||||
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.utils import run_coro_with_catch
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.typing import T_DependencyCache, T_PermissionChecker
|
||||
|
||||
@ -10,13 +11,6 @@ from .adapter import Bot, Event
|
||||
from .params import BotParam, EventParam, DependParam, DefaultParam
|
||||
|
||||
|
||||
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
|
||||
try:
|
||||
return await coro
|
||||
except SkippedException:
|
||||
return False
|
||||
|
||||
|
||||
class Permission:
|
||||
"""{ref}`nonebot.matcher.Matcher` 权限类。
|
||||
|
||||
@ -72,13 +66,15 @@ class Permission:
|
||||
return True
|
||||
results = await asyncio.gather(
|
||||
*(
|
||||
_run_coro_with_catch(
|
||||
run_coro_with_catch(
|
||||
checker(
|
||||
bot=bot,
|
||||
event=event,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
),
|
||||
(SkippedException,),
|
||||
False,
|
||||
)
|
||||
for checker in self.checkers
|
||||
),
|
||||
|
@ -14,9 +14,9 @@ from typing import TYPE_CHECKING, Any, Set, Dict, Type, Optional, Coroutine
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.rule import TrieRule
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.matcher import Matcher, matchers
|
||||
from nonebot.utils import escape_tag, run_coro_with_catch
|
||||
from nonebot.exception import (
|
||||
NoLogException,
|
||||
StopPropagation,
|
||||
@ -110,13 +110,6 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
|
||||
return func
|
||||
|
||||
|
||||
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]) -> Any:
|
||||
try:
|
||||
return await coro
|
||||
except SkippedException:
|
||||
pass
|
||||
|
||||
|
||||
async def _check_matcher(
|
||||
priority: int,
|
||||
Matcher: Type[Matcher],
|
||||
@ -167,7 +160,7 @@ async def _run_matcher(
|
||||
|
||||
coros = list(
|
||||
map(
|
||||
lambda x: _run_coro_with_catch(
|
||||
lambda x: run_coro_with_catch(
|
||||
x(
|
||||
matcher=matcher,
|
||||
bot=bot,
|
||||
@ -175,7 +168,8 @@ async def _run_matcher(
|
||||
state=state,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
),
|
||||
(SkippedException,),
|
||||
),
|
||||
_run_preprocessors,
|
||||
)
|
||||
@ -208,7 +202,7 @@ async def _run_matcher(
|
||||
|
||||
coros = list(
|
||||
map(
|
||||
lambda x: _run_coro_with_catch(
|
||||
lambda x: run_coro_with_catch(
|
||||
x(
|
||||
matcher=matcher,
|
||||
exception=exception,
|
||||
@ -217,7 +211,8 @@ async def _run_matcher(
|
||||
state=state,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
),
|
||||
(SkippedException,),
|
||||
),
|
||||
_run_postprocessors,
|
||||
)
|
||||
@ -263,14 +258,15 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
|
||||
async with AsyncExitStack() as stack:
|
||||
coros = list(
|
||||
map(
|
||||
lambda x: _run_coro_with_catch(
|
||||
lambda x: run_coro_with_catch(
|
||||
x(
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
),
|
||||
(SkippedException,),
|
||||
),
|
||||
_event_preprocessors,
|
||||
)
|
||||
@ -330,14 +326,15 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
|
||||
|
||||
coros = list(
|
||||
map(
|
||||
lambda x: _run_coro_with_catch(
|
||||
lambda x: run_coro_with_catch(
|
||||
x(
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
),
|
||||
(SkippedException,),
|
||||
),
|
||||
_event_postprocessors,
|
||||
)
|
||||
|
@ -45,9 +45,23 @@ T_State = Dict[Any, Any]
|
||||
"""事件处理状态 State 类型"""
|
||||
|
||||
T_BotConnectionHook = Callable[..., Awaitable[Any]]
|
||||
"""Bot 连接建立时钩子函数"""
|
||||
"""Bot 连接建立时钩子函数
|
||||
|
||||
依赖参数:
|
||||
|
||||
- DependParam: 子依赖参数
|
||||
- BotParam: Bot 对象
|
||||
- DefaultParam: 带有默认值的参数
|
||||
"""
|
||||
T_BotDisconnectionHook = Callable[..., Awaitable[Any]]
|
||||
"""Bot 连接断开时钩子函数"""
|
||||
"""Bot 连接断开时钩子函数
|
||||
|
||||
依赖参数:
|
||||
|
||||
- DependParam: 子依赖参数
|
||||
- BotParam: Bot 对象
|
||||
- DefaultParam: 带有默认值的参数
|
||||
"""
|
||||
T_CallingAPIHook = Callable[["Bot", str, Dict[str, Any]], Awaitable[Any]]
|
||||
"""`bot.call_api` 钩子函数"""
|
||||
T_CalledAPIHook = Callable[
|
||||
|
@ -21,7 +21,6 @@ from typing import (
|
||||
TypeVar,
|
||||
Callable,
|
||||
Optional,
|
||||
Awaitable,
|
||||
Coroutine,
|
||||
AsyncGenerator,
|
||||
ContextManager,
|
||||
@ -132,6 +131,17 @@ async def run_sync_ctx_manager(
|
||||
await run_sync(cm.__exit__)(None, None, None)
|
||||
|
||||
|
||||
async def run_coro_with_catch(
|
||||
coro: Coroutine[Any, Any, T],
|
||||
exc: Tuple[Type[Exception], ...],
|
||||
return_on_err: R = None,
|
||||
) -> Union[T, R]:
|
||||
try:
|
||||
return await coro
|
||||
except exc:
|
||||
return return_on_err
|
||||
|
||||
|
||||
def get_name(obj: Any) -> str:
|
||||
"""获取对象的名称"""
|
||||
if inspect.isfunction(obj) or inspect.isclass(obj):
|
||||
|
Reference in New Issue
Block a user