🐛 Fix: Bot Hook 没有捕获跳过异常 (#905)

This commit is contained in:
Ju4tCode
2022-04-04 10:35:14 +08:00
committed by GitHub
parent 494b9c625d
commit 2f3324ce0c
12 changed files with 333 additions and 303 deletions

View File

@ -56,7 +56,7 @@ class CustomEnvSettings(EnvSettingsSource):
if env_path.is_file():
env_file_vars = read_env_file(
env_path,
encoding=env_file_encoding,
encoding=env_file_encoding, # type: ignore
case_sensitive=settings.__config__.case_sensitive,
)
env_vars = {**env_file_vars, **env_vars}

View File

@ -38,7 +38,7 @@ try:
from quart.datastructures import FileStorage
from quart import Websocket as QuartWebSocket
except ImportError:
raise ValueError(
raise ImportError(
"Please install Quart by using `pip install nonebot2[quart]`"
) from None

View File

@ -7,11 +7,11 @@ NoneBotException
├── ParserExit
├── ProcessException
| ├── IgnoredException
| ├── SkippedException
| | └── TypeMisMatch
| ├── MockApiException
| └── StopPropagation
├── MatcherException
| ├── SkippedException
| | └── TypeMisMatch
| ├── PausedException
| ├── RejectedException
| └── FinishedException
@ -75,6 +75,37 @@ class IgnoredException(ProcessException):
return self.__repr__()
class SkippedException(ProcessException):
"""指示 NoneBot 立即结束当前 `Dependent` 的运行。
例如,可以在 `Handler` 中通过 {ref}`nonebot.matcher.Matcher.skip` 抛出。
用法:
```python
def always_skip():
Matcher.skip()
@matcher.handle()
async def handler(dependency = Depends(always_skip)):
# never run
```
"""
class TypeMisMatch(SkippedException):
"""当前 `Handler` 的参数类型不匹配。"""
def __init__(self, param: ModelField, value: Any):
self.param: ModelField = param
self.value: Any = value
def __repr__(self):
return f"<TypeMisMatch, param={self.param}, value={self.value}>"
def __str__(self):
self.__repr__()
class MockApiException(ProcessException):
"""指示 NoneBot 阻止本次 API 调用或修改本次调用返回值,并返回自定义内容。可由 api hook 抛出。
@ -114,37 +145,6 @@ class MatcherException(NoneBotException):
"""所有 Matcher 发生的异常基类。"""
class SkippedException(MatcherException):
"""指示 NoneBot 立即结束当前 `Handler` 的处理,继续处理下一个 `Handler`。
可以在 `Handler` 中通过 {ref}`nonebot.matcher.Matcher.skip` 抛出。
用法:
```python
def always_skip():
Matcher.skip()
@matcher.handle()
async def handler(dependency = Depends(always_skip)):
...
```
"""
class TypeMisMatch(SkippedException):
"""当前 `Handler` 的参数类型不匹配。"""
def __init__(self, param: ModelField, value: Any):
self.param: ModelField = param
self.value: Any = value
def __repr__(self):
return f"<TypeMisMatch, param={self.param}, value={self.value}>"
def __str__(self):
self.__repr__()
class PausedException(MatcherException):
"""指示 NoneBot 结束当前 `Handler` 并等待下一条消息后继续下一个 `Handler`。可用于用户输入新信息。

View File

@ -4,9 +4,10 @@ from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
from nonebot.log import logger
from nonebot.utils import escape_tag
from nonebot.config import Env, Config
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.utils import escape_tag, run_coro_with_catch
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
from nonebot.internal.params import BotParam, DependParam, DefaultParam
@ -128,7 +129,12 @@ class Driver(abc.ABC):
self._clients[bot.self_id] = bot
async def _run_hook(bot: "Bot") -> None:
coros = list(map(lambda x: x(bot=bot), self._bot_connection_hook))
coros = list(
map(
lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)),
self._bot_connection_hook,
)
)
if coros:
try:
await asyncio.gather(*coros)
@ -146,7 +152,12 @@ class Driver(abc.ABC):
del self._clients[bot.self_id]
async def _run_hook(bot: "Bot") -> None:
coros = list(map(lambda x: x(bot=bot), self._bot_disconnection_hook))
coros = list(
map(
lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)),
self._bot_disconnection_hook,
)
)
if coros:
try:
await asyncio.gather(*coros)

View File

@ -248,8 +248,8 @@ class Cookies(MutableMapping):
self,
name: str,
default: Optional[str] = None,
domain: str = None,
path: str = None,
domain: Optional[str] = None,
path: Optional[str] = None,
) -> Optional[str]:
value: Optional[str] = None
for cookie in self.jar:

View File

@ -3,6 +3,7 @@ from contextlib import AsyncExitStack
from typing import Any, Set, Tuple, Union, NoReturn, Optional, Coroutine
from nonebot.dependencies import Dependent
from nonebot.utils import run_coro_with_catch
from nonebot.exception import SkippedException
from nonebot.typing import T_DependencyCache, T_PermissionChecker
@ -10,13 +11,6 @@ from .adapter import Bot, Event
from .params import BotParam, EventParam, DependParam, DefaultParam
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
try:
return await coro
except SkippedException:
return False
class Permission:
"""{ref}`nonebot.matcher.Matcher` 权限类。
@ -72,13 +66,15 @@ class Permission:
return True
results = await asyncio.gather(
*(
_run_coro_with_catch(
run_coro_with_catch(
checker(
bot=bot,
event=event,
stack=stack,
dependency_cache=dependency_cache,
)
),
(SkippedException,),
False,
)
for checker in self.checkers
),

View File

@ -14,9 +14,9 @@ from typing import TYPE_CHECKING, Any, Set, Dict, Type, Optional, Coroutine
from nonebot.log import logger
from nonebot.rule import TrieRule
from nonebot.utils import escape_tag
from nonebot.dependencies import Dependent
from nonebot.matcher import Matcher, matchers
from nonebot.utils import escape_tag, run_coro_with_catch
from nonebot.exception import (
NoLogException,
StopPropagation,
@ -110,13 +110,6 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
return func
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]) -> Any:
try:
return await coro
except SkippedException:
pass
async def _check_matcher(
priority: int,
Matcher: Type[Matcher],
@ -167,7 +160,7 @@ async def _run_matcher(
coros = list(
map(
lambda x: _run_coro_with_catch(
lambda x: run_coro_with_catch(
x(
matcher=matcher,
bot=bot,
@ -175,7 +168,8 @@ async def _run_matcher(
state=state,
stack=stack,
dependency_cache=dependency_cache,
)
),
(SkippedException,),
),
_run_preprocessors,
)
@ -208,7 +202,7 @@ async def _run_matcher(
coros = list(
map(
lambda x: _run_coro_with_catch(
lambda x: run_coro_with_catch(
x(
matcher=matcher,
exception=exception,
@ -217,7 +211,8 @@ async def _run_matcher(
state=state,
stack=stack,
dependency_cache=dependency_cache,
)
),
(SkippedException,),
),
_run_postprocessors,
)
@ -263,14 +258,15 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
async with AsyncExitStack() as stack:
coros = list(
map(
lambda x: _run_coro_with_catch(
lambda x: run_coro_with_catch(
x(
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
)
),
(SkippedException,),
),
_event_preprocessors,
)
@ -330,14 +326,15 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
coros = list(
map(
lambda x: _run_coro_with_catch(
lambda x: run_coro_with_catch(
x(
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
)
),
(SkippedException,),
),
_event_postprocessors,
)

View File

@ -45,9 +45,23 @@ T_State = Dict[Any, Any]
"""事件处理状态 State 类型"""
T_BotConnectionHook = Callable[..., Awaitable[Any]]
"""Bot 连接建立时钩子函数"""
"""Bot 连接建立时钩子函数
依赖参数:
- DependParam: 子依赖参数
- BotParam: Bot 对象
- DefaultParam: 带有默认值的参数
"""
T_BotDisconnectionHook = Callable[..., Awaitable[Any]]
"""Bot 连接断开时钩子函数"""
"""Bot 连接断开时钩子函数
依赖参数:
- DependParam: 子依赖参数
- BotParam: Bot 对象
- DefaultParam: 带有默认值的参数
"""
T_CallingAPIHook = Callable[["Bot", str, Dict[str, Any]], Awaitable[Any]]
"""`bot.call_api` 钩子函数"""
T_CalledAPIHook = Callable[

View File

@ -21,7 +21,6 @@ from typing import (
TypeVar,
Callable,
Optional,
Awaitable,
Coroutine,
AsyncGenerator,
ContextManager,
@ -132,6 +131,17 @@ async def run_sync_ctx_manager(
await run_sync(cm.__exit__)(None, None, None)
async def run_coro_with_catch(
coro: Coroutine[Any, Any, T],
exc: Tuple[Type[Exception], ...],
return_on_err: R = None,
) -> Union[T, R]:
try:
return await coro
except exc:
return return_on_err
def get_name(obj: Any) -> str:
"""获取对象的名称"""
if inspect.isfunction(obj) or inspect.isclass(obj):