mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-09-06 20:16:47 +00:00
✨ Feature: 迁移至结构化并发框架 AnyIO (#3053)
This commit is contained in:
@ -8,17 +8,20 @@ FrontMatter:
|
||||
"""
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import inspect
|
||||
from functools import partial
|
||||
from dataclasses import field, dataclass
|
||||
from collections.abc import Iterable, Awaitable
|
||||
from typing import Any, Generic, TypeVar, Callable, Optional, cast
|
||||
|
||||
import anyio
|
||||
from exceptiongroup import BaseExceptionGroup, catch
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import _DependentCallable
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.utils import run_sync, is_coroutine_callable
|
||||
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
|
||||
from nonebot.utils import run_sync, is_coroutine_callable, flatten_exception_group
|
||||
|
||||
from .utils import check_field_type, get_typed_signature
|
||||
|
||||
@ -84,7 +87,16 @@ class Dependent(Generic[R]):
|
||||
)
|
||||
|
||||
async def __call__(self, **kwargs: Any) -> R:
|
||||
try:
|
||||
exception: Optional[BaseExceptionGroup[SkippedException]] = None
|
||||
|
||||
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
|
||||
nonlocal exception
|
||||
exception = exc_group
|
||||
# raise one of the exceptions instead
|
||||
excs = list(flatten_exception_group(exc_group))
|
||||
logger.trace(f"{self} skipped due to {excs}")
|
||||
|
||||
with catch({SkippedException: _handle_skipped}):
|
||||
# do pre-check
|
||||
await self.check(**kwargs)
|
||||
|
||||
@ -96,9 +108,8 @@ class Dependent(Generic[R]):
|
||||
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
|
||||
else:
|
||||
return await run_sync(cast(Callable[..., R], self.call))(**values)
|
||||
except SkippedException as e:
|
||||
logger.trace(f"{self} skipped due to {e}")
|
||||
raise
|
||||
|
||||
raise exception
|
||||
|
||||
@staticmethod
|
||||
def parse_params(
|
||||
@ -166,10 +177,13 @@ class Dependent(Generic[R]):
|
||||
return cls(call, params, parameterless_params)
|
||||
|
||||
async def check(self, **params: Any) -> None:
|
||||
await asyncio.gather(*(param._check(**params) for param in self.parameterless))
|
||||
await asyncio.gather(
|
||||
*(cast(Param, param.field_info)._check(**params) for param in self.params)
|
||||
)
|
||||
async with anyio.create_task_group() as tg:
|
||||
for param in self.parameterless:
|
||||
tg.start_soon(partial(param._check, **params))
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for param in self.params:
|
||||
tg.start_soon(partial(cast(Param, param.field_info)._check, **params))
|
||||
|
||||
async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any:
|
||||
param = cast(Param, field.field_info)
|
||||
@ -185,10 +199,17 @@ class Dependent(Generic[R]):
|
||||
await param._solve(**params)
|
||||
|
||||
# solve param values
|
||||
values = await asyncio.gather(
|
||||
*(self._solve_field(field, params) for field in self.params)
|
||||
)
|
||||
return {field.name: value for field, value in zip(self.params, values)}
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
async def _solve_field(field: ModelField, params: dict[str, Any]) -> None:
|
||||
value = await self._solve_field(field, params)
|
||||
result[field.name] = value
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for field in self.params:
|
||||
tg.start_soon(_solve_field, field, params)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
__autodoc__ = {"CustomConfig": False}
|
||||
|
@ -12,14 +12,18 @@ FrontMatter:
|
||||
"""
|
||||
|
||||
import signal
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
from exceptiongroup import BaseExceptionGroup, catch
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.consts import WINDOWS
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.drivers import Driver as BaseDriver
|
||||
from nonebot.utils import flatten_exception_group
|
||||
|
||||
HANDLED_SIGNALS = (
|
||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||
@ -35,8 +39,8 @@ class Driver(BaseDriver):
|
||||
def __init__(self, env: Env, config: Config):
|
||||
super().__init__(env, config)
|
||||
|
||||
self.should_exit: asyncio.Event = asyncio.Event()
|
||||
self.force_exit: bool = False
|
||||
self.should_exit: anyio.Event = anyio.Event()
|
||||
self.force_exit: anyio.Event = anyio.Event()
|
||||
|
||||
@property
|
||||
@override
|
||||
@ -54,85 +58,98 @@ class Driver(BaseDriver):
|
||||
def run(self, *args, **kwargs):
|
||||
"""启动 none driver"""
|
||||
super().run(*args, **kwargs)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self._serve())
|
||||
anyio.run(self._serve)
|
||||
|
||||
async def _serve(self):
|
||||
self._install_signal_handlers()
|
||||
await self._startup()
|
||||
if self.should_exit.is_set():
|
||||
return
|
||||
await self._main_loop()
|
||||
await self._shutdown()
|
||||
async with anyio.create_task_group() as driver_tg:
|
||||
driver_tg.start_soon(self._handle_signals)
|
||||
driver_tg.start_soon(self._listen_force_exit, driver_tg)
|
||||
driver_tg.start_soon(self._handle_lifespan, driver_tg)
|
||||
|
||||
async def _startup(self):
|
||||
async def _handle_signals(self):
|
||||
try:
|
||||
await self._lifespan.startup()
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Application startup failed. "
|
||||
"Exiting.</bg #f8bbd0></r>"
|
||||
)
|
||||
self.should_exit.set()
|
||||
return
|
||||
|
||||
logger.info("Application startup completed.")
|
||||
|
||||
async def _main_loop(self):
|
||||
await self.should_exit.wait()
|
||||
|
||||
async def _shutdown(self):
|
||||
logger.info("Shutting down")
|
||||
|
||||
logger.info("Waiting for application shutdown.")
|
||||
|
||||
try:
|
||||
await self._lifespan.shutdown()
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running shutdown function. "
|
||||
"Ignored!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
for task in asyncio.all_tasks():
|
||||
if task is not asyncio.current_task() and not task.done():
|
||||
task.cancel()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
if tasks and not self.force_exit:
|
||||
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
|
||||
while tasks and not self.force_exit:
|
||||
await asyncio.sleep(0.1)
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
logger.info("Application shutdown complete.")
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.stop()
|
||||
|
||||
def _install_signal_handlers(self) -> None:
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
# Signals can only be listened to from the main thread.
|
||||
return
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
for sig in HANDLED_SIGNALS:
|
||||
loop.add_signal_handler(sig, self._handle_exit, sig, None)
|
||||
with anyio.open_signal_receiver(*HANDLED_SIGNALS) as signal_receiver:
|
||||
async for sig in signal_receiver:
|
||||
self.exit(force=self.should_exit.is_set())
|
||||
except NotImplementedError:
|
||||
# Windows
|
||||
for sig in HANDLED_SIGNALS:
|
||||
signal.signal(sig, self._handle_exit)
|
||||
signal.signal(sig, self._handle_legacy_signal)
|
||||
|
||||
def _handle_exit(self, sig, frame):
|
||||
# backport for Windows signal handling
|
||||
def _handle_legacy_signal(self, sig, frame):
|
||||
self.exit(force=self.should_exit.is_set())
|
||||
|
||||
async def _handle_lifespan(self, tg: TaskGroup):
|
||||
try:
|
||||
await self._startup()
|
||||
|
||||
if self.should_exit.is_set():
|
||||
return
|
||||
|
||||
await self._listen_exit()
|
||||
|
||||
await self._shutdown()
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
async def _startup(self):
|
||||
def handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
|
||||
self.should_exit.set()
|
||||
|
||||
for exc in flatten_exception_group(exc_group):
|
||||
logger.opt(colors=True, exception=exc).error(
|
||||
"<r><bg #f8bbd0>Error occurred while running startup hook."
|
||||
"</bg #f8bbd0></r>"
|
||||
)
|
||||
logger.error(
|
||||
"<r><bg #f8bbd0>Application startup failed. "
|
||||
"Exiting.</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
with catch({Exception: handle_exception}):
|
||||
await self._lifespan.startup()
|
||||
|
||||
if not self.should_exit.is_set():
|
||||
logger.info("Application startup completed.")
|
||||
|
||||
async def _listen_exit(self, tg: Optional[TaskGroup] = None):
|
||||
await self.should_exit.wait()
|
||||
|
||||
if tg is not None:
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
async def _shutdown(self):
|
||||
logger.info("Shutting down")
|
||||
logger.info("Waiting for application shutdown. (CTRL+C to force quit)")
|
||||
|
||||
error_occurred: bool = False
|
||||
|
||||
def handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
|
||||
nonlocal error_occurred
|
||||
|
||||
error_occurred = True
|
||||
|
||||
for exc in flatten_exception_group(exc_group):
|
||||
logger.opt(colors=True, exception=exc).error(
|
||||
"<r><bg #f8bbd0>Error occurred while running shutdown hook."
|
||||
"</bg #f8bbd0></r>"
|
||||
)
|
||||
logger.error(
|
||||
"<r><bg #f8bbd0>Application shutdown failed. "
|
||||
"Exiting.</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
with catch({Exception: handle_exception}):
|
||||
await self._lifespan.shutdown()
|
||||
|
||||
if not error_occurred:
|
||||
logger.info("Application shutdown complete.")
|
||||
|
||||
async def _listen_force_exit(self, tg: TaskGroup):
|
||||
await self.force_exit.wait()
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
def exit(self, force: bool = False):
|
||||
"""退出 none driver
|
||||
|
||||
@ -142,4 +159,4 @@ class Driver(BaseDriver):
|
||||
if not self.should_exit.is_set():
|
||||
self.should_exit.set()
|
||||
if force:
|
||||
self.force_exit = True
|
||||
self.force_exit.set()
|
||||
|
@ -1,11 +1,14 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional, Protocol
|
||||
|
||||
import anyio
|
||||
from exceptiongroup import BaseExceptionGroup, catch
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Config
|
||||
from nonebot.exception import MockApiException
|
||||
from nonebot.utils import flatten_exception_group
|
||||
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -76,47 +79,98 @@ class Bot(abc.ABC):
|
||||
skip_calling_api: bool = False
|
||||
exception: Optional[Exception] = None
|
||||
|
||||
if coros := [hook(self, api, data) for hook in self._calling_api_hook]:
|
||||
try:
|
||||
logger.debug("Running CallingAPI hooks...")
|
||||
await asyncio.gather(*coros)
|
||||
except MockApiException as e:
|
||||
if self._calling_api_hook:
|
||||
logger.debug("Running CallingAPI hooks...")
|
||||
|
||||
def _handle_mock_api_exception(
|
||||
exc_group: BaseExceptionGroup[MockApiException],
|
||||
) -> None:
|
||||
nonlocal skip_calling_api, result
|
||||
|
||||
excs = [
|
||||
exc
|
||||
for exc in flatten_exception_group(exc_group)
|
||||
if isinstance(exc, MockApiException)
|
||||
]
|
||||
if not excs:
|
||||
return
|
||||
elif len(excs) > 1:
|
||||
logger.warning(
|
||||
"Multiple hooks want to mock API result. Use the first one."
|
||||
)
|
||||
|
||||
skip_calling_api = True
|
||||
result = e.result
|
||||
result = excs[0].result
|
||||
|
||||
logger.debug(
|
||||
f"Calling API {api} is cancelled. Return {result} instead."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
|
||||
"Running cancelled!</bg #f8bbd0></r>"
|
||||
f"Calling API {api} is cancelled. Return {result!r} instead."
|
||||
)
|
||||
|
||||
def _handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
|
||||
for exc in flatten_exception_group(exc_group):
|
||||
logger.opt(colors=True, exception=exc).error(
|
||||
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
|
||||
"Running cancelled!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
with catch(
|
||||
{
|
||||
MockApiException: _handle_mock_api_exception,
|
||||
Exception: _handle_exception,
|
||||
}
|
||||
):
|
||||
async with anyio.create_task_group() as tg:
|
||||
for hook in self._calling_api_hook:
|
||||
tg.start_soon(hook, self, api, data)
|
||||
|
||||
if not skip_calling_api:
|
||||
try:
|
||||
result = await self.adapter._call_api(self, api, **data)
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
if coros := [
|
||||
hook(self, exception, api, data, result) for hook in self._called_api_hook
|
||||
]:
|
||||
try:
|
||||
logger.debug("Running CalledAPI hooks...")
|
||||
await asyncio.gather(*coros)
|
||||
except MockApiException as e:
|
||||
# mock api result
|
||||
result = e.result
|
||||
# ignore exception
|
||||
if self._called_api_hook:
|
||||
logger.debug("Running CalledAPI hooks...")
|
||||
|
||||
def _handle_mock_api_exception(
|
||||
exc_group: BaseExceptionGroup[MockApiException],
|
||||
) -> None:
|
||||
nonlocal result, exception
|
||||
|
||||
excs = [
|
||||
exc
|
||||
for exc in flatten_exception_group(exc_group)
|
||||
if isinstance(exc, MockApiException)
|
||||
]
|
||||
if not excs:
|
||||
return
|
||||
elif len(excs) > 1:
|
||||
logger.warning(
|
||||
"Multiple hooks want to mock API result. Use the first one."
|
||||
)
|
||||
|
||||
result = excs[0].result
|
||||
exception = None
|
||||
logger.debug(
|
||||
f"Calling API {api} result is mocked. Return {result} instead."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
|
||||
"Running cancelled!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
def _handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
|
||||
for exc in flatten_exception_group(exc_group):
|
||||
logger.opt(colors=True, exception=exc).error(
|
||||
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
|
||||
"Running cancelled!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
with catch(
|
||||
{
|
||||
MockApiException: _handle_mock_api_exception,
|
||||
Exception: _handle_exception,
|
||||
}
|
||||
):
|
||||
async with anyio.create_task_group() as tg:
|
||||
for hook in self._called_api_hook:
|
||||
tg.start_soon(hook, self, exception, api, data, result)
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
|
@ -1,6 +1,11 @@
|
||||
from collections.abc import Awaitable
|
||||
from types import TracebackType
|
||||
from typing_extensions import TypeAlias
|
||||
from typing import Any, Union, Callable, cast
|
||||
from collections.abc import Iterable, Awaitable
|
||||
from typing import Any, Union, Callable, Optional, cast
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
from exceptiongroup import suppress
|
||||
|
||||
from nonebot.utils import run_sync, is_coroutine_callable
|
||||
|
||||
@ -11,10 +16,24 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
|
||||
|
||||
class Lifespan:
|
||||
def __init__(self) -> None:
|
||||
self._task_group: Optional[TaskGroup] = None
|
||||
|
||||
self._startup_funcs: list[LIFESPAN_FUNC] = []
|
||||
self._ready_funcs: list[LIFESPAN_FUNC] = []
|
||||
self._shutdown_funcs: list[LIFESPAN_FUNC] = []
|
||||
|
||||
@property
|
||||
def task_group(self) -> TaskGroup:
|
||||
if self._task_group is None:
|
||||
raise RuntimeError("Lifespan not started")
|
||||
return self._task_group
|
||||
|
||||
@task_group.setter
|
||||
def task_group(self, task_group: TaskGroup) -> None:
|
||||
if self._task_group is not None:
|
||||
raise RuntimeError("Lifespan already started")
|
||||
self._task_group = task_group
|
||||
|
||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
self._startup_funcs.append(func)
|
||||
return func
|
||||
@ -29,7 +48,7 @@ class Lifespan:
|
||||
|
||||
@staticmethod
|
||||
async def _run_lifespan_func(
|
||||
funcs: list[LIFESPAN_FUNC],
|
||||
funcs: Iterable[LIFESPAN_FUNC],
|
||||
) -> None:
|
||||
for func in funcs:
|
||||
if is_coroutine_callable(func):
|
||||
@ -38,18 +57,44 @@ class Lifespan:
|
||||
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
|
||||
|
||||
async def startup(self) -> None:
|
||||
# create background task group
|
||||
self.task_group = anyio.create_task_group()
|
||||
await self.task_group.__aenter__()
|
||||
|
||||
# run startup funcs
|
||||
if self._startup_funcs:
|
||||
await self._run_lifespan_func(self._startup_funcs)
|
||||
|
||||
# run ready funcs
|
||||
if self._ready_funcs:
|
||||
await self._run_lifespan_func(self._ready_funcs)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
async def shutdown(
|
||||
self,
|
||||
*,
|
||||
exc_type: Optional[type[BaseException]] = None,
|
||||
exc_val: Optional[BaseException] = None,
|
||||
exc_tb: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
if self._shutdown_funcs:
|
||||
await self._run_lifespan_func(self._shutdown_funcs)
|
||||
# reverse shutdown funcs to ensure stack order
|
||||
await self._run_lifespan_func(reversed(self._shutdown_funcs))
|
||||
|
||||
# shutdown background task group
|
||||
self.task_group.cancel_scope.cancel()
|
||||
|
||||
with suppress(Exception):
|
||||
await self.task_group.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
self._task_group = None
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self.startup()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
await self.shutdown()
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
await self.shutdown(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)
|
||||
|
@ -1,17 +1,20 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from types import TracebackType
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional
|
||||
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio import CancelScope, create_task_group
|
||||
from exceptiongroup import BaseExceptionGroup, catch
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.utils import escape_tag, run_coro_with_catch
|
||||
from nonebot.internal.params import BotParam, DependParam, DefaultParam
|
||||
from nonebot.utils import escape_tag, run_coro_with_catch, flatten_exception_group
|
||||
from nonebot.typing import (
|
||||
T_DependencyCache,
|
||||
T_BotConnectionHook,
|
||||
@ -61,7 +64,6 @@ class Driver(abc.ABC):
|
||||
self.config: Config = config
|
||||
"""全局配置对象"""
|
||||
self._bots: dict[str, "Bot"] = {}
|
||||
self._bot_tasks: set[asyncio.Task] = set()
|
||||
self._lifespan = Lifespan()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -75,6 +77,10 @@ class Driver(abc.ABC):
|
||||
"""获取当前所有已连接的 Bot"""
|
||||
return self._bots
|
||||
|
||||
@property
|
||||
def task_group(self) -> TaskGroup:
|
||||
return self._lifespan.task_group
|
||||
|
||||
def register_adapter(self, adapter: type["Adapter"], **kwargs) -> None:
|
||||
"""注册一个协议适配器
|
||||
|
||||
@ -112,8 +118,6 @@ class Driver(abc.ABC):
|
||||
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
|
||||
)
|
||||
|
||||
self.on_shutdown(self._cleanup)
|
||||
|
||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
"""注册一个启动时执行的函数"""
|
||||
return self._lifespan.on_startup(func)
|
||||
@ -154,66 +158,57 @@ class Driver(abc.ABC):
|
||||
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
|
||||
self._bots[bot.self_id] = bot
|
||||
|
||||
def handle_exception(exc_group: BaseExceptionGroup) -> None:
|
||||
for exc in flatten_exception_group(exc_group):
|
||||
logger.opt(colors=True, exception=exc).error(
|
||||
"<r><bg #f8bbd0>"
|
||||
"Error when running WebSocketConnection hook:"
|
||||
"</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
async def _run_hook(bot: "Bot") -> None:
|
||||
dependency_cache: T_DependencyCache = {}
|
||||
async with AsyncExitStack() as stack:
|
||||
if coros := [
|
||||
run_coro_with_catch(
|
||||
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
|
||||
(SkippedException,),
|
||||
)
|
||||
for hook in self._bot_connection_hook
|
||||
]:
|
||||
try:
|
||||
await asyncio.gather(*coros)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>"
|
||||
"Error when running WebSocketConnection hook. "
|
||||
"Running cancelled!"
|
||||
"</bg #f8bbd0></r>"
|
||||
with CancelScope(shield=True), catch({Exception: handle_exception}):
|
||||
async with AsyncExitStack() as stack, create_task_group() as tg:
|
||||
for hook in self._bot_connection_hook:
|
||||
tg.start_soon(
|
||||
run_coro_with_catch,
|
||||
hook(
|
||||
bot=bot, stack=stack, dependency_cache=dependency_cache
|
||||
),
|
||||
(SkippedException,),
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_run_hook(bot))
|
||||
task.add_done_callback(self._bot_tasks.discard)
|
||||
self._bot_tasks.add(task)
|
||||
self.task_group.start_soon(_run_hook, bot)
|
||||
|
||||
def _bot_disconnect(self, bot: "Bot") -> None:
|
||||
"""在连接断开后,调用该函数来注销 bot 对象"""
|
||||
if bot.self_id in self._bots:
|
||||
del self._bots[bot.self_id]
|
||||
|
||||
def handle_exception(exc_group: BaseExceptionGroup) -> None:
|
||||
for exc in flatten_exception_group(exc_group):
|
||||
logger.opt(colors=True, exception=exc).error(
|
||||
"<r><bg #f8bbd0>"
|
||||
"Error when running WebSocketDisConnection hook:"
|
||||
"</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
async def _run_hook(bot: "Bot") -> None:
|
||||
dependency_cache: T_DependencyCache = {}
|
||||
async with AsyncExitStack() as stack:
|
||||
if coros := [
|
||||
run_coro_with_catch(
|
||||
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
|
||||
(SkippedException,),
|
||||
)
|
||||
for hook in self._bot_disconnection_hook
|
||||
]:
|
||||
try:
|
||||
await asyncio.gather(*coros)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>"
|
||||
"Error when running WebSocketDisConnection hook. "
|
||||
"Running cancelled!"
|
||||
"</bg #f8bbd0></r>"
|
||||
# shield cancellation to ensure bot disconnect hooks are always run
|
||||
with CancelScope(shield=True), catch({Exception: handle_exception}):
|
||||
async with create_task_group() as tg, AsyncExitStack() as stack:
|
||||
for hook in self._bot_disconnection_hook:
|
||||
tg.start_soon(
|
||||
run_coro_with_catch,
|
||||
hook(
|
||||
bot=bot, stack=stack, dependency_cache=dependency_cache
|
||||
),
|
||||
(SkippedException,),
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_run_hook(bot))
|
||||
task.add_done_callback(self._bot_tasks.discard)
|
||||
self._bot_tasks.add(task)
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""清理驱动器资源"""
|
||||
if self._bot_tasks:
|
||||
logger.opt(colors=True).debug(
|
||||
"<y>Waiting for running bot connection hooks...</y>"
|
||||
)
|
||||
await asyncio.gather(*self._bot_tasks, return_exceptions=True)
|
||||
self.task_group.start_soon(_run_hook, bot)
|
||||
|
||||
|
||||
class Mixin(abc.ABC):
|
||||
|
@ -22,11 +22,13 @@ from typing import ( # noqa: UP035
|
||||
overload,
|
||||
)
|
||||
|
||||
from exceptiongroup import BaseExceptionGroup, catch
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.internal.rule import Rule
|
||||
from nonebot.utils import classproperty
|
||||
from nonebot.dependencies import Param, Dependent
|
||||
from nonebot.internal.permission import User, Permission
|
||||
from nonebot.utils import classproperty, flatten_exception_group
|
||||
from nonebot.internal.adapter import (
|
||||
Bot,
|
||||
Event,
|
||||
@ -812,28 +814,34 @@ class Matcher(metaclass=MatcherMeta):
|
||||
f"bot={bot}, event={event!r}, state={state!r}"
|
||||
)
|
||||
|
||||
def _handle_stop_propagation(exc_group: BaseExceptionGroup[StopPropagation]):
|
||||
self.block = True
|
||||
|
||||
with self.ensure_context(bot, event):
|
||||
try:
|
||||
# Refresh preprocess state
|
||||
self.state.update(state)
|
||||
with catch({StopPropagation: _handle_stop_propagation}):
|
||||
# Refresh preprocess state
|
||||
self.state.update(state)
|
||||
|
||||
while self.remain_handlers:
|
||||
handler = self.remain_handlers.pop(0)
|
||||
current_handler.set(handler)
|
||||
logger.debug(f"Running handler {handler}")
|
||||
try:
|
||||
await handler(
|
||||
matcher=self,
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=self.state,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
except SkippedException:
|
||||
logger.debug(f"Handler {handler} skipped")
|
||||
except StopPropagation:
|
||||
self.block = True
|
||||
while self.remain_handlers:
|
||||
handler = self.remain_handlers.pop(0)
|
||||
current_handler.set(handler)
|
||||
logger.debug(f"Running handler {handler}")
|
||||
|
||||
def _handle_skipped(
|
||||
exc_group: BaseExceptionGroup[SkippedException],
|
||||
):
|
||||
logger.debug(f"Handler {handler} skipped")
|
||||
|
||||
with catch({SkippedException: _handle_skipped}):
|
||||
await handler(
|
||||
matcher=self,
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=self.state,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
finally:
|
||||
logger.info(f"{self} running complete")
|
||||
|
||||
@ -846,10 +854,54 @@ class Matcher(metaclass=MatcherMeta):
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None,
|
||||
):
|
||||
try:
|
||||
exc: Optional[Union[FinishedException, RejectedException, PausedException]] = (
|
||||
None
|
||||
)
|
||||
|
||||
def _handle_special_exception(
|
||||
exc_group: BaseExceptionGroup[
|
||||
Union[FinishedException, RejectedException, PausedException]
|
||||
]
|
||||
):
|
||||
nonlocal exc
|
||||
excs = list(flatten_exception_group(exc_group))
|
||||
if len(excs) > 1:
|
||||
logger.warning(
|
||||
"Multiple session control exceptions occurred. "
|
||||
"NoneBot will choose the proper one."
|
||||
)
|
||||
finished_exc = next(
|
||||
(e for e in excs if isinstance(e, FinishedException)),
|
||||
None,
|
||||
)
|
||||
rejected_exc = next(
|
||||
(e for e in excs if isinstance(e, RejectedException)),
|
||||
None,
|
||||
)
|
||||
paused_exc = next(
|
||||
(e for e in excs if isinstance(e, PausedException)),
|
||||
None,
|
||||
)
|
||||
exc = finished_exc or rejected_exc or paused_exc
|
||||
elif isinstance(
|
||||
excs[0], (FinishedException, RejectedException, PausedException)
|
||||
):
|
||||
exc = excs[0]
|
||||
|
||||
with catch(
|
||||
{
|
||||
(
|
||||
FinishedException,
|
||||
RejectedException,
|
||||
PausedException,
|
||||
): _handle_special_exception
|
||||
}
|
||||
):
|
||||
await self.simple_run(bot, event, state, stack, dependency_cache)
|
||||
|
||||
except RejectedException:
|
||||
if isinstance(exc, FinishedException):
|
||||
pass
|
||||
elif isinstance(exc, RejectedException):
|
||||
await self.resolve_reject()
|
||||
type_ = await self.update_type(bot, event, stack, dependency_cache)
|
||||
permission = await self.update_permission(
|
||||
@ -870,7 +922,7 @@ class Matcher(metaclass=MatcherMeta):
|
||||
default_type_updater=self.__class__._default_type_updater,
|
||||
default_permission_updater=self.__class__._default_permission_updater,
|
||||
)
|
||||
except PausedException:
|
||||
elif isinstance(exc, PausedException):
|
||||
type_ = await self.update_type(bot, event, stack, dependency_cache)
|
||||
permission = await self.update_permission(
|
||||
bot, event, stack, dependency_cache
|
||||
@ -890,5 +942,3 @@ class Matcher(metaclass=MatcherMeta):
|
||||
default_type_updater=self.__class__._default_type_updater,
|
||||
default_permission_updater=self.__class__._default_permission_updater,
|
||||
)
|
||||
except FinishedException:
|
||||
pass
|
||||
|
@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from typing_extensions import Self, get_args, override, get_origin
|
||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||
from typing import (
|
||||
@ -13,8 +13,11 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
import anyio
|
||||
from exceptiongroup import BaseExceptionGroup, catch
|
||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
||||
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.dependencies import Param, Dependent
|
||||
from nonebot.dependencies.utils import check_field_type
|
||||
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
|
||||
@ -93,6 +96,75 @@ def Depends(
|
||||
return DependsInner(dependency, use_cache=use_cache, validate=validate)
|
||||
|
||||
|
||||
class CacheState(str, Enum):
|
||||
"""子依赖缓存状态"""
|
||||
|
||||
PENDING = "PENDING"
|
||||
FINISHED = "FINISHED"
|
||||
|
||||
|
||||
class DependencyCache:
|
||||
"""子依赖结果缓存。
|
||||
|
||||
用于缓存子依赖的结果,以避免重复计算。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._state = CacheState.PENDING
|
||||
self._result: Any = None
|
||||
self._exception: Optional[BaseException] = None
|
||||
self._waiter = anyio.Event()
|
||||
|
||||
def result(self) -> Any:
|
||||
"""获取子依赖结果"""
|
||||
|
||||
if self._state != CacheState.FINISHED:
|
||||
raise RuntimeError("Result is not ready")
|
||||
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
return self._result
|
||||
|
||||
def exception(self) -> Optional[BaseException]:
|
||||
"""获取子依赖异常"""
|
||||
|
||||
if self._state != CacheState.FINISHED:
|
||||
raise RuntimeError("Result is not ready")
|
||||
|
||||
return self._exception
|
||||
|
||||
def set_result(self, result: Any) -> None:
|
||||
"""设置子依赖结果"""
|
||||
|
||||
if self._state != CacheState.PENDING:
|
||||
raise RuntimeError(f"Cache state invalid: {self._state}")
|
||||
|
||||
self._result = result
|
||||
self._state = CacheState.FINISHED
|
||||
self._waiter.set()
|
||||
|
||||
def set_exception(self, exception: BaseException) -> None:
|
||||
"""设置子依赖异常"""
|
||||
|
||||
if self._state != CacheState.PENDING:
|
||||
raise RuntimeError(f"Cache state invalid: {self._state}")
|
||||
|
||||
self._exception = exception
|
||||
self._state = CacheState.FINISHED
|
||||
self._waiter.set()
|
||||
|
||||
async def wait(self):
|
||||
"""等待子依赖结果"""
|
||||
await self._waiter.wait()
|
||||
if self._state != CacheState.FINISHED:
|
||||
raise RuntimeError("Invalid cache state")
|
||||
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
return self._result
|
||||
|
||||
|
||||
class DependParam(Param):
|
||||
"""子依赖注入参数。
|
||||
|
||||
@ -194,17 +266,27 @@ class DependParam(Param):
|
||||
call = cast(Callable[..., Any], sub_dependent.call)
|
||||
|
||||
# solve sub dependency with current cache
|
||||
sub_values = await sub_dependent.solve(
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
**kwargs,
|
||||
)
|
||||
exc: Optional[BaseExceptionGroup[SkippedException]] = None
|
||||
|
||||
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
|
||||
nonlocal exc
|
||||
exc = exc_group
|
||||
|
||||
with catch({SkippedException: _handle_skipped}):
|
||||
sub_values = await sub_dependent.solve(
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if exc is not None:
|
||||
raise exc
|
||||
|
||||
# run dependency function
|
||||
task: asyncio.Task[Any]
|
||||
if use_cache and call in dependency_cache:
|
||||
return await dependency_cache[call]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
return await dependency_cache[call].wait()
|
||||
|
||||
if is_gen_callable(call) or is_async_gen_callable(call):
|
||||
assert isinstance(
|
||||
stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
@ -212,17 +294,21 @@ class DependParam(Param):
|
||||
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
|
||||
else:
|
||||
cm = asynccontextmanager(call)(**sub_values)
|
||||
task = asyncio.create_task(stack.enter_async_context(cm))
|
||||
dependency_cache[call] = task
|
||||
return await task
|
||||
|
||||
target = stack.enter_async_context(cm)
|
||||
elif is_coroutine_callable(call):
|
||||
task = asyncio.create_task(call(**sub_values))
|
||||
dependency_cache[call] = task
|
||||
return await task
|
||||
target = call(**sub_values)
|
||||
else:
|
||||
task = asyncio.create_task(run_sync(call)(**sub_values))
|
||||
dependency_cache[call] = task
|
||||
return await task
|
||||
target = run_sync(call)(**sub_values)
|
||||
|
||||
dependency_cache[call] = cache = DependencyCache()
|
||||
try:
|
||||
result = await target
|
||||
cache.set_result(result)
|
||||
return result
|
||||
except BaseException as e:
|
||||
cache.set_exception(e)
|
||||
raise
|
||||
|
||||
@override
|
||||
async def _check(self, **kwargs: Any) -> None:
|
||||
|
@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
from typing_extensions import Self
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Union, ClassVar, NoReturn, Optional
|
||||
|
||||
import anyio
|
||||
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.utils import run_coro_with_catch
|
||||
from nonebot.exception import SkippedException
|
||||
@ -70,22 +71,26 @@ class Permission:
|
||||
"""
|
||||
if not self.checkers:
|
||||
return True
|
||||
results = await asyncio.gather(
|
||||
*(
|
||||
run_coro_with_catch(
|
||||
checker(
|
||||
bot=bot,
|
||||
event=event,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
),
|
||||
(SkippedException,),
|
||||
False,
|
||||
)
|
||||
for checker in self.checkers
|
||||
),
|
||||
)
|
||||
return any(results)
|
||||
|
||||
result = False
|
||||
|
||||
async def _run_checker(checker: Dependent[bool]) -> None:
|
||||
nonlocal result
|
||||
# calculate the result first to avoid data racing
|
||||
is_passed = await run_coro_with_catch(
|
||||
checker(
|
||||
bot=bot, event=event, stack=stack, dependency_cache=dependency_cache
|
||||
),
|
||||
(SkippedException,),
|
||||
False,
|
||||
)
|
||||
result |= is_passed
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for checker in self.checkers:
|
||||
tg.start_soon(_run_checker, checker)
|
||||
|
||||
return result
|
||||
|
||||
def __and__(self, other: object) -> NoReturn:
|
||||
raise RuntimeError("And operation between Permissions is not allowed.")
|
||||
|
@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Union, ClassVar, NoReturn, Optional
|
||||
|
||||
import anyio
|
||||
from exceptiongroup import BaseExceptionGroup, catch
|
||||
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache
|
||||
@ -71,22 +73,33 @@ class Rule:
|
||||
"""
|
||||
if not self.checkers:
|
||||
return True
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
*(
|
||||
checker(
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
for checker in self.checkers
|
||||
)
|
||||
|
||||
result = True
|
||||
|
||||
def _handle_skipped_exception(
|
||||
exc_group: BaseExceptionGroup[SkippedException],
|
||||
) -> None:
|
||||
nonlocal result
|
||||
result = False
|
||||
|
||||
async def _run_checker(checker: Dependent[bool]) -> None:
|
||||
nonlocal result
|
||||
# calculate the result first to avoid data racing
|
||||
is_passed = await checker(
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
except SkippedException:
|
||||
return False
|
||||
return all(results)
|
||||
result &= is_passed
|
||||
|
||||
with catch({SkippedException: _handle_skipped_exception}):
|
||||
async with anyio.create_task_group() as tg:
|
||||
for checker in self.checkers:
|
||||
tg.start_soon(_run_checker, checker)
|
||||
|
||||
return result
|
||||
|
||||
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
|
||||
if other is None:
|
||||
|
@ -9,23 +9,30 @@ FrontMatter:
|
||||
description: nonebot.message 模块
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from datetime import datetime
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
import anyio
|
||||
from exceptiongroup import BaseExceptionGroup, catch
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.rule import TrieRule
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.matcher import Matcher, matchers
|
||||
from nonebot.utils import escape_tag, run_coro_with_catch
|
||||
from nonebot.exception import (
|
||||
NoLogException,
|
||||
StopPropagation,
|
||||
IgnoredException,
|
||||
SkippedException,
|
||||
)
|
||||
from nonebot.utils import (
|
||||
escape_tag,
|
||||
run_coro_with_catch,
|
||||
run_coro_with_shield,
|
||||
flatten_exception_group,
|
||||
)
|
||||
from nonebot.typing import (
|
||||
T_State,
|
||||
T_DependencyCache,
|
||||
@ -125,6 +132,21 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
|
||||
return func
|
||||
|
||||
|
||||
def _handle_ignored_exception(msg: str) -> Callable[[BaseExceptionGroup], None]:
|
||||
def _handle(exc_group: BaseExceptionGroup[IgnoredException]) -> None:
|
||||
logger.opt(colors=True).info(msg)
|
||||
|
||||
return _handle
|
||||
|
||||
|
||||
def _handle_exception(msg: str) -> Callable[[BaseExceptionGroup], None]:
|
||||
def _handle(exc_group: BaseExceptionGroup[Exception]) -> None:
|
||||
for exc in flatten_exception_group(exc_group):
|
||||
logger.opt(colors=True, exception=exc).error(msg)
|
||||
|
||||
return _handle
|
||||
|
||||
|
||||
async def _apply_event_preprocessors(
|
||||
bot: "Bot",
|
||||
event: "Event",
|
||||
@ -152,10 +174,21 @@ async def _apply_event_preprocessors(
|
||||
if show_log:
|
||||
logger.debug("Running PreProcessors...")
|
||||
|
||||
try:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
run_coro_with_catch(
|
||||
with catch(
|
||||
{
|
||||
IgnoredException: _handle_ignored_exception(
|
||||
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
|
||||
),
|
||||
Exception: _handle_exception(
|
||||
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
|
||||
"Event ignored!</bg #f8bbd0></r>"
|
||||
),
|
||||
}
|
||||
):
|
||||
async with anyio.create_task_group() as tg:
|
||||
for proc in _event_preprocessors:
|
||||
tg.start_soon(
|
||||
run_coro_with_catch,
|
||||
proc(
|
||||
bot=bot,
|
||||
event=event,
|
||||
@ -165,22 +198,10 @@ async def _apply_event_preprocessors(
|
||||
),
|
||||
(SkippedException,),
|
||||
)
|
||||
for proc in _event_preprocessors
|
||||
)
|
||||
)
|
||||
except IgnoredException:
|
||||
logger.opt(colors=True).info(
|
||||
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
|
||||
"Event ignored!</bg #f8bbd0></r>"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _apply_event_postprocessors(
|
||||
@ -207,10 +228,17 @@ async def _apply_event_postprocessors(
|
||||
if show_log:
|
||||
logger.debug("Running PostProcessors...")
|
||||
|
||||
try:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
run_coro_with_catch(
|
||||
with catch(
|
||||
{
|
||||
Exception: _handle_exception(
|
||||
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
|
||||
)
|
||||
}
|
||||
):
|
||||
async with anyio.create_task_group() as tg:
|
||||
for proc in _event_postprocessors:
|
||||
tg.start_soon(
|
||||
run_coro_with_catch,
|
||||
proc(
|
||||
bot=bot,
|
||||
event=event,
|
||||
@ -220,13 +248,6 @@ async def _apply_event_postprocessors(
|
||||
),
|
||||
(SkippedException,),
|
||||
)
|
||||
for proc in _event_postprocessors
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
|
||||
async def _apply_run_preprocessors(
|
||||
@ -254,35 +275,38 @@ async def _apply_run_preprocessors(
|
||||
return True
|
||||
|
||||
# ensure matcher function can be correctly called
|
||||
with matcher.ensure_context(bot, event):
|
||||
try:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
run_coro_with_catch(
|
||||
proc(
|
||||
matcher=matcher,
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
),
|
||||
(SkippedException,),
|
||||
)
|
||||
for proc in _run_preprocessors
|
||||
with (
|
||||
matcher.ensure_context(bot, event),
|
||||
catch(
|
||||
{
|
||||
IgnoredException: _handle_ignored_exception(
|
||||
f"{matcher} running is <b>cancelled</b>"
|
||||
),
|
||||
Exception: _handle_exception(
|
||||
"<r><bg #f8bbd0>Error when running RunPreProcessors. "
|
||||
"Running cancelled!</bg #f8bbd0></r>"
|
||||
),
|
||||
}
|
||||
),
|
||||
):
|
||||
async with anyio.create_task_group() as tg:
|
||||
for proc in _run_preprocessors:
|
||||
tg.start_soon(
|
||||
run_coro_with_catch,
|
||||
proc(
|
||||
matcher=matcher,
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
),
|
||||
(SkippedException,),
|
||||
)
|
||||
)
|
||||
except IgnoredException:
|
||||
logger.opt(colors=True).info(f"{matcher} running is <b>cancelled</b>")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running RunPreProcessors. "
|
||||
"Running cancelled!</bg #f8bbd0></r>"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _apply_run_postprocessors(
|
||||
@ -306,29 +330,32 @@ async def _apply_run_postprocessors(
|
||||
if not _run_postprocessors:
|
||||
return
|
||||
|
||||
with matcher.ensure_context(bot, event):
|
||||
try:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
run_coro_with_catch(
|
||||
proc(
|
||||
matcher=matcher,
|
||||
exception=exception,
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=matcher.state,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
),
|
||||
(SkippedException,),
|
||||
)
|
||||
for proc in _run_postprocessors
|
||||
with (
|
||||
matcher.ensure_context(bot, event),
|
||||
catch(
|
||||
{
|
||||
Exception: _handle_exception(
|
||||
"<r><bg #f8bbd0>Error when running RunPostProcessors"
|
||||
"</bg #f8bbd0></r>"
|
||||
)
|
||||
}
|
||||
),
|
||||
):
|
||||
async with anyio.create_task_group() as tg:
|
||||
for proc in _run_postprocessors:
|
||||
tg.start_soon(
|
||||
run_coro_with_catch,
|
||||
proc(
|
||||
matcher=matcher,
|
||||
exception=exception,
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=matcher.state,
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
),
|
||||
(SkippedException,),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running RunPostProcessors</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
|
||||
async def _check_matcher(
|
||||
@ -425,8 +452,9 @@ async def _run_matcher(
|
||||
|
||||
exception = None
|
||||
|
||||
logger.debug(f"Running {matcher}")
|
||||
|
||||
try:
|
||||
logger.debug(f"Running {matcher}")
|
||||
await matcher.run(bot, event, state, stack, dependency_cache)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
@ -494,8 +522,7 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
|
||||
|
||||
用法:
|
||||
```python
|
||||
import asyncio
|
||||
asyncio.create_task(handle_event(bot, event))
|
||||
driver.task_group.start_soon(handle_event, bot, event)
|
||||
```
|
||||
"""
|
||||
show_log = True
|
||||
@ -530,6 +557,13 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
|
||||
)
|
||||
|
||||
break_flag = False
|
||||
|
||||
def _handle_stop_propagation(exc_group: BaseExceptionGroup) -> None:
|
||||
nonlocal break_flag
|
||||
|
||||
break_flag = True
|
||||
logger.debug("Stop event propagation")
|
||||
|
||||
# iterate through all priority until stop propagation
|
||||
for priority in sorted(matchers.keys()):
|
||||
if break_flag:
|
||||
@ -538,23 +572,30 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
|
||||
if show_log:
|
||||
logger.debug(f"Checking for matchers in priority {priority}...")
|
||||
|
||||
pending_tasks = [
|
||||
check_and_run_matcher(
|
||||
matcher, bot, event, state.copy(), stack, dependency_cache
|
||||
)
|
||||
for matcher in matchers[priority]
|
||||
]
|
||||
results = await asyncio.gather(*pending_tasks, return_exceptions=True)
|
||||
for result in results:
|
||||
if not isinstance(result, Exception):
|
||||
continue
|
||||
if isinstance(result, StopPropagation):
|
||||
break_flag = True
|
||||
logger.debug("Stop event propagation")
|
||||
else:
|
||||
logger.opt(colors=True, exception=result).error(
|
||||
if not (priority_matchers := matchers[priority]):
|
||||
continue
|
||||
|
||||
with catch(
|
||||
{
|
||||
StopPropagation: _handle_stop_propagation,
|
||||
Exception: _handle_exception(
|
||||
"<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
|
||||
)
|
||||
),
|
||||
}
|
||||
):
|
||||
async with anyio.create_task_group() as tg:
|
||||
for matcher in priority_matchers:
|
||||
tg.start_soon(
|
||||
run_coro_with_shield,
|
||||
check_and_run_matcher(
|
||||
matcher,
|
||||
bot,
|
||||
event,
|
||||
state.copy(),
|
||||
stack,
|
||||
dependency_cache,
|
||||
),
|
||||
)
|
||||
|
||||
if show_log:
|
||||
logger.debug("Checking for matchers completed")
|
||||
|
@ -22,7 +22,7 @@ from . import _managers, get_plugin, _module_name_to_plugin_id
|
||||
try: # pragma: py-gte-311
|
||||
import tomllib # pyright: ignore[reportMissingImports]
|
||||
except ModuleNotFoundError: # pragma: py-lt-311
|
||||
import tomli as tomllib
|
||||
import tomli as tomllib # pyright: ignore[reportMissingImports]
|
||||
|
||||
|
||||
def load_plugin(module_path: Union[str, Path]) -> Optional[Plugin]:
|
||||
|
@ -21,10 +21,9 @@ from typing import TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import ParamSpec, TypeAlias, get_args, override, get_origin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from asyncio import Task
|
||||
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.permission import Permission
|
||||
from nonebot.internal.params import DependencyCache
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
@ -258,5 +257,5 @@ T_PermissionUpdater: TypeAlias = _DependentCallable["Permission"]
|
||||
- MatcherParam: Matcher 对象
|
||||
- DefaultParam: 带有默认值的参数
|
||||
"""
|
||||
T_DependencyCache: TypeAlias = dict[_DependentCallable[t.Any], "Task[t.Any]"]
|
||||
T_DependencyCache: TypeAlias = dict[_DependentCallable[t.Any], "DependencyCache"]
|
||||
"""依赖缓存, 用于存储依赖函数的返回值"""
|
||||
|
@ -9,21 +9,22 @@ FrontMatter:
|
||||
|
||||
import re
|
||||
import json
|
||||
import asyncio
|
||||
import inspect
|
||||
import importlib
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
from contextvars import copy_context
|
||||
from functools import wraps, partial
|
||||
from contextlib import AbstractContextManager, asynccontextmanager
|
||||
from typing_extensions import ParamSpec, get_args, override, get_origin
|
||||
from collections.abc import Mapping, Sequence, Coroutine, AsyncGenerator
|
||||
from typing import Any, Union, Generic, TypeVar, Callable, Optional, overload
|
||||
from collections.abc import Mapping, Sequence, Coroutine, Generator, AsyncGenerator
|
||||
|
||||
import anyio
|
||||
import anyio.to_thread
|
||||
from pydantic import BaseModel
|
||||
from exceptiongroup import BaseExceptionGroup, catch
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import (
|
||||
@ -39,6 +40,7 @@ R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
E = TypeVar("E", bound=BaseException)
|
||||
|
||||
|
||||
def escape_tag(s: str) -> str:
|
||||
@ -178,11 +180,9 @@ def run_sync(call: Callable[P, R]) -> Callable[P, Coroutine[None, None, R]]:
|
||||
|
||||
@wraps(call)
|
||||
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
loop = asyncio.get_running_loop()
|
||||
pfunc = partial(call, *args, **kwargs)
|
||||
context = copy_context()
|
||||
result = await loop.run_in_executor(None, partial(context.run, pfunc))
|
||||
return result
|
||||
return await anyio.to_thread.run_sync(
|
||||
partial(call, *args, **kwargs), abandon_on_cancel=True
|
||||
)
|
||||
|
||||
return _wrapper
|
||||
|
||||
@ -234,10 +234,34 @@ async def run_coro_with_catch(
|
||||
协程的返回值或发生异常时的指定值
|
||||
"""
|
||||
|
||||
try:
|
||||
with catch({exc: lambda exc_group: None}):
|
||||
return await coro
|
||||
except exc:
|
||||
return return_on_err
|
||||
|
||||
return return_on_err
|
||||
|
||||
|
||||
async def run_coro_with_shield(coro: Coroutine[Any, Any, T]) -> T:
|
||||
"""运行协程并在取消时屏蔽取消异常。
|
||||
|
||||
参数:
|
||||
coro: 要运行的协程
|
||||
|
||||
返回:
|
||||
协程的返回值
|
||||
"""
|
||||
|
||||
with anyio.CancelScope(shield=True):
|
||||
return await coro
|
||||
|
||||
|
||||
def flatten_exception_group(
|
||||
exc_group: BaseExceptionGroup[E],
|
||||
) -> Generator[E, None, None]:
|
||||
for exc in exc_group.exceptions:
|
||||
if isinstance(exc, BaseExceptionGroup):
|
||||
yield from flatten_exception_group(exc)
|
||||
else:
|
||||
yield exc
|
||||
|
||||
|
||||
def get_name(obj: Any) -> str:
|
||||
|
Reference in New Issue
Block a user