Feature: 迁移至结构化并发框架 AnyIO (#3053)

This commit is contained in:
Ju4tCode
2024-10-26 15:36:01 +08:00
committed by GitHub
parent bd9befbb55
commit ff21ceb946
39 changed files with 5422 additions and 4080 deletions

View File

@ -8,17 +8,20 @@ FrontMatter:
"""
import abc
import asyncio
import inspect
from functools import partial
from dataclasses import field, dataclass
from collections.abc import Iterable, Awaitable
from typing import Any, Generic, TypeVar, Callable, Optional, cast
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.typing import _DependentCallable
from nonebot.exception import SkippedException
from nonebot.utils import run_sync, is_coroutine_callable
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
from nonebot.utils import run_sync, is_coroutine_callable, flatten_exception_group
from .utils import check_field_type, get_typed_signature
@ -84,7 +87,16 @@ class Dependent(Generic[R]):
)
async def __call__(self, **kwargs: Any) -> R:
try:
exception: Optional[BaseExceptionGroup[SkippedException]] = None
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
nonlocal exception
exception = exc_group
# raise one of the exceptions instead
excs = list(flatten_exception_group(exc_group))
logger.trace(f"{self} skipped due to {excs}")
with catch({SkippedException: _handle_skipped}):
# do pre-check
await self.check(**kwargs)
@ -96,9 +108,8 @@ class Dependent(Generic[R]):
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
else:
return await run_sync(cast(Callable[..., R], self.call))(**values)
except SkippedException as e:
logger.trace(f"{self} skipped due to {e}")
raise
raise exception
@staticmethod
def parse_params(
@ -166,10 +177,13 @@ class Dependent(Generic[R]):
return cls(call, params, parameterless_params)
async def check(self, **params: Any) -> None:
await asyncio.gather(*(param._check(**params) for param in self.parameterless))
await asyncio.gather(
*(cast(Param, param.field_info)._check(**params) for param in self.params)
)
async with anyio.create_task_group() as tg:
for param in self.parameterless:
tg.start_soon(partial(param._check, **params))
async with anyio.create_task_group() as tg:
for param in self.params:
tg.start_soon(partial(cast(Param, param.field_info)._check, **params))
async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any:
param = cast(Param, field.field_info)
@ -185,10 +199,17 @@ class Dependent(Generic[R]):
await param._solve(**params)
# solve param values
values = await asyncio.gather(
*(self._solve_field(field, params) for field in self.params)
)
return {field.name: value for field, value in zip(self.params, values)}
result: dict[str, Any] = {}
async def _solve_field(field: ModelField, params: dict[str, Any]) -> None:
value = await self._solve_field(field, params)
result[field.name] = value
async with anyio.create_task_group() as tg:
for field in self.params:
tg.start_soon(_solve_field, field, params)
return result
__autodoc__ = {"CustomConfig": False}

View File

@ -12,14 +12,18 @@ FrontMatter:
"""
import signal
import asyncio
import threading
from typing import Optional
from typing_extensions import override
import anyio
from anyio.abc import TaskGroup
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.consts import WINDOWS
from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver
from nonebot.utils import flatten_exception_group
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
@ -35,8 +39,8 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False
self.should_exit: anyio.Event = anyio.Event()
self.force_exit: anyio.Event = anyio.Event()
@property
@override
@ -54,85 +58,98 @@ class Driver(BaseDriver):
def run(self, *args, **kwargs):
"""启动 none driver"""
super().run(*args, **kwargs)
loop = asyncio.get_event_loop()
loop.run_until_complete(self._serve())
anyio.run(self._serve)
async def _serve(self):
self._install_signal_handlers()
await self._startup()
if self.should_exit.is_set():
return
await self._main_loop()
await self._shutdown()
async with anyio.create_task_group() as driver_tg:
driver_tg.start_soon(self._handle_signals)
driver_tg.start_soon(self._listen_force_exit, driver_tg)
driver_tg.start_soon(self._handle_lifespan, driver_tg)
async def _startup(self):
async def _handle_signals(self):
try:
await self._lifespan.startup()
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Application startup failed. "
"Exiting.</bg #f8bbd0></r>"
)
self.should_exit.set()
return
logger.info("Application startup completed.")
async def _main_loop(self):
await self.should_exit.wait()
async def _shutdown(self):
logger.info("Shutting down")
logger.info("Waiting for application shutdown.")
try:
await self._lifespan.shutdown()
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>"
)
for task in asyncio.all_tasks():
if task is not asyncio.current_task() and not task.done():
task.cancel()
await asyncio.sleep(0.1)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
if tasks and not self.force_exit:
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
while tasks and not self.force_exit:
await asyncio.sleep(0.1)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
logger.info("Application shutdown complete.")
loop = asyncio.get_event_loop()
loop.stop()
def _install_signal_handlers(self) -> None:
if threading.current_thread() is not threading.main_thread():
# Signals can only be listened to from the main thread.
return
loop = asyncio.get_event_loop()
try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self._handle_exit, sig, None)
with anyio.open_signal_receiver(*HANDLED_SIGNALS) as signal_receiver:
async for sig in signal_receiver:
self.exit(force=self.should_exit.is_set())
except NotImplementedError:
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self._handle_exit)
signal.signal(sig, self._handle_legacy_signal)
def _handle_exit(self, sig, frame):
# backport for Windows signal handling
def _handle_legacy_signal(self, sig, frame):
self.exit(force=self.should_exit.is_set())
async def _handle_lifespan(self, tg: TaskGroup):
try:
await self._startup()
if self.should_exit.is_set():
return
await self._listen_exit()
await self._shutdown()
finally:
tg.cancel_scope.cancel()
async def _startup(self):
def handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
self.should_exit.set()
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error occurred while running startup hook."
"</bg #f8bbd0></r>"
)
logger.error(
"<r><bg #f8bbd0>Application startup failed. "
"Exiting.</bg #f8bbd0></r>"
)
with catch({Exception: handle_exception}):
await self._lifespan.startup()
if not self.should_exit.is_set():
logger.info("Application startup completed.")
async def _listen_exit(self, tg: Optional[TaskGroup] = None):
await self.should_exit.wait()
if tg is not None:
tg.cancel_scope.cancel()
async def _shutdown(self):
logger.info("Shutting down")
logger.info("Waiting for application shutdown. (CTRL+C to force quit)")
error_occurred: bool = False
def handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
nonlocal error_occurred
error_occurred = True
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error occurred while running shutdown hook."
"</bg #f8bbd0></r>"
)
logger.error(
"<r><bg #f8bbd0>Application shutdown failed. "
"Exiting.</bg #f8bbd0></r>"
)
with catch({Exception: handle_exception}):
await self._lifespan.shutdown()
if not error_occurred:
logger.info("Application shutdown complete.")
async def _listen_force_exit(self, tg: TaskGroup):
await self.force_exit.wait()
tg.cancel_scope.cancel()
def exit(self, force: bool = False):
"""退出 none driver
@ -142,4 +159,4 @@ class Driver(BaseDriver):
if not self.should_exit.is_set():
self.should_exit.set()
if force:
self.force_exit = True
self.force_exit.set()

View File

@ -1,11 +1,14 @@
import abc
import asyncio
from functools import partial
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional, Protocol
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.config import Config
from nonebot.exception import MockApiException
from nonebot.utils import flatten_exception_group
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
if TYPE_CHECKING:
@ -76,47 +79,98 @@ class Bot(abc.ABC):
skip_calling_api: bool = False
exception: Optional[Exception] = None
if coros := [hook(self, api, data) for hook in self._calling_api_hook]:
try:
logger.debug("Running CallingAPI hooks...")
await asyncio.gather(*coros)
except MockApiException as e:
if self._calling_api_hook:
logger.debug("Running CallingAPI hooks...")
def _handle_mock_api_exception(
exc_group: BaseExceptionGroup[MockApiException],
) -> None:
nonlocal skip_calling_api, result
excs = [
exc
for exc in flatten_exception_group(exc_group)
if isinstance(exc, MockApiException)
]
if not excs:
return
elif len(excs) > 1:
logger.warning(
"Multiple hooks want to mock API result. Use the first one."
)
skip_calling_api = True
result = e.result
result = excs[0].result
logger.debug(
f"Calling API {api} is cancelled. Return {result} instead."
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
f"Calling API {api} is cancelled. Return {result!r} instead."
)
def _handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
with catch(
{
MockApiException: _handle_mock_api_exception,
Exception: _handle_exception,
}
):
async with anyio.create_task_group() as tg:
for hook in self._calling_api_hook:
tg.start_soon(hook, self, api, data)
if not skip_calling_api:
try:
result = await self.adapter._call_api(self, api, **data)
except Exception as e:
exception = e
if coros := [
hook(self, exception, api, data, result) for hook in self._called_api_hook
]:
try:
logger.debug("Running CalledAPI hooks...")
await asyncio.gather(*coros)
except MockApiException as e:
# mock api result
result = e.result
# ignore exception
if self._called_api_hook:
logger.debug("Running CalledAPI hooks...")
def _handle_mock_api_exception(
exc_group: BaseExceptionGroup[MockApiException],
) -> None:
nonlocal result, exception
excs = [
exc
for exc in flatten_exception_group(exc_group)
if isinstance(exc, MockApiException)
]
if not excs:
return
elif len(excs) > 1:
logger.warning(
"Multiple hooks want to mock API result. Use the first one."
)
result = excs[0].result
exception = None
logger.debug(
f"Calling API {api} result is mocked. Return {result} instead."
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
def _handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
with catch(
{
MockApiException: _handle_mock_api_exception,
Exception: _handle_exception,
}
):
async with anyio.create_task_group() as tg:
for hook in self._called_api_hook:
tg.start_soon(hook, self, exception, api, data, result)
if exception:
raise exception

View File

@ -1,6 +1,11 @@
from collections.abc import Awaitable
from types import TracebackType
from typing_extensions import TypeAlias
from typing import Any, Union, Callable, cast
from collections.abc import Iterable, Awaitable
from typing import Any, Union, Callable, Optional, cast
import anyio
from anyio.abc import TaskGroup
from exceptiongroup import suppress
from nonebot.utils import run_sync, is_coroutine_callable
@ -11,10 +16,24 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
class Lifespan:
def __init__(self) -> None:
self._task_group: Optional[TaskGroup] = None
self._startup_funcs: list[LIFESPAN_FUNC] = []
self._ready_funcs: list[LIFESPAN_FUNC] = []
self._shutdown_funcs: list[LIFESPAN_FUNC] = []
@property
def task_group(self) -> TaskGroup:
if self._task_group is None:
raise RuntimeError("Lifespan not started")
return self._task_group
@task_group.setter
def task_group(self, task_group: TaskGroup) -> None:
if self._task_group is not None:
raise RuntimeError("Lifespan already started")
self._task_group = task_group
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._startup_funcs.append(func)
return func
@ -29,7 +48,7 @@ class Lifespan:
@staticmethod
async def _run_lifespan_func(
funcs: list[LIFESPAN_FUNC],
funcs: Iterable[LIFESPAN_FUNC],
) -> None:
for func in funcs:
if is_coroutine_callable(func):
@ -38,18 +57,44 @@ class Lifespan:
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
async def startup(self) -> None:
# create background task group
self.task_group = anyio.create_task_group()
await self.task_group.__aenter__()
# run startup funcs
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)
# run ready funcs
if self._ready_funcs:
await self._run_lifespan_func(self._ready_funcs)
async def shutdown(self) -> None:
async def shutdown(
self,
*,
exc_type: Optional[type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)
# reverse shutdown funcs to ensure stack order
await self._run_lifespan_func(reversed(self._shutdown_funcs))
# shutdown background task group
self.task_group.cancel_scope.cancel()
with suppress(Exception):
await self.task_group.__aexit__(exc_type, exc_val, exc_tb)
self._task_group = None
async def __aenter__(self) -> None:
await self.startup()
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.shutdown()
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.shutdown(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)

View File

@ -1,17 +1,20 @@
import abc
import asyncio
from types import TracebackType
from collections.abc import AsyncGenerator
from typing_extensions import Self, TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional
from anyio.abc import TaskGroup
from anyio import CancelScope, create_task_group
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.config import Env, Config
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.utils import escape_tag, run_coro_with_catch
from nonebot.internal.params import BotParam, DependParam, DefaultParam
from nonebot.utils import escape_tag, run_coro_with_catch, flatten_exception_group
from nonebot.typing import (
T_DependencyCache,
T_BotConnectionHook,
@ -61,7 +64,6 @@ class Driver(abc.ABC):
self.config: Config = config
"""全局配置对象"""
self._bots: dict[str, "Bot"] = {}
self._bot_tasks: set[asyncio.Task] = set()
self._lifespan = Lifespan()
def __repr__(self) -> str:
@ -75,6 +77,10 @@ class Driver(abc.ABC):
"""获取当前所有已连接的 Bot"""
return self._bots
@property
def task_group(self) -> TaskGroup:
return self._lifespan.task_group
def register_adapter(self, adapter: type["Adapter"], **kwargs) -> None:
"""注册一个协议适配器
@ -112,8 +118,6 @@ class Driver(abc.ABC):
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
)
self.on_shutdown(self._cleanup)
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)
@ -154,66 +158,57 @@ class Driver(abc.ABC):
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
self._bots[bot.self_id] = bot
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack:
if coros := [
run_coro_with_catch(
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
(SkippedException,),
)
for hook in self._bot_connection_hook
]:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketConnection hook. "
"Running cancelled!"
"</bg #f8bbd0></r>"
with CancelScope(shield=True), catch({Exception: handle_exception}):
async with AsyncExitStack() as stack, create_task_group() as tg:
for hook in self._bot_connection_hook:
tg.start_soon(
run_coro_with_catch,
hook(
bot=bot, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
)
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
self.task_group.start_soon(_run_hook, bot)
def _bot_disconnect(self, bot: "Bot") -> None:
"""在连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._bots:
del self._bots[bot.self_id]
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketDisConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack:
if coros := [
run_coro_with_catch(
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
(SkippedException,),
)
for hook in self._bot_disconnection_hook
]:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketDisConnection hook. "
"Running cancelled!"
"</bg #f8bbd0></r>"
# shield cancellation to ensure bot disconnect hooks are always run
with CancelScope(shield=True), catch({Exception: handle_exception}):
async with create_task_group() as tg, AsyncExitStack() as stack:
for hook in self._bot_disconnection_hook:
tg.start_soon(
run_coro_with_catch,
hook(
bot=bot, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
)
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
async def _cleanup(self) -> None:
"""清理驱动器资源"""
if self._bot_tasks:
logger.opt(colors=True).debug(
"<y>Waiting for running bot connection hooks...</y>"
)
await asyncio.gather(*self._bot_tasks, return_exceptions=True)
self.task_group.start_soon(_run_hook, bot)
class Mixin(abc.ABC):

View File

@ -22,11 +22,13 @@ from typing import ( # noqa: UP035
overload,
)
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.internal.rule import Rule
from nonebot.utils import classproperty
from nonebot.dependencies import Param, Dependent
from nonebot.internal.permission import User, Permission
from nonebot.utils import classproperty, flatten_exception_group
from nonebot.internal.adapter import (
Bot,
Event,
@ -812,28 +814,34 @@ class Matcher(metaclass=MatcherMeta):
f"bot={bot}, event={event!r}, state={state!r}"
)
def _handle_stop_propagation(exc_group: BaseExceptionGroup[StopPropagation]):
self.block = True
with self.ensure_context(bot, event):
try:
# Refresh preprocess state
self.state.update(state)
with catch({StopPropagation: _handle_stop_propagation}):
# Refresh preprocess state
self.state.update(state)
while self.remain_handlers:
handler = self.remain_handlers.pop(0)
current_handler.set(handler)
logger.debug(f"Running handler {handler}")
try:
await handler(
matcher=self,
bot=bot,
event=event,
state=self.state,
stack=stack,
dependency_cache=dependency_cache,
)
except SkippedException:
logger.debug(f"Handler {handler} skipped")
except StopPropagation:
self.block = True
while self.remain_handlers:
handler = self.remain_handlers.pop(0)
current_handler.set(handler)
logger.debug(f"Running handler {handler}")
def _handle_skipped(
exc_group: BaseExceptionGroup[SkippedException],
):
logger.debug(f"Handler {handler} skipped")
with catch({SkippedException: _handle_skipped}):
await handler(
matcher=self,
bot=bot,
event=event,
state=self.state,
stack=stack,
dependency_cache=dependency_cache,
)
finally:
logger.info(f"{self} running complete")
@ -846,10 +854,54 @@ class Matcher(metaclass=MatcherMeta):
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
):
try:
exc: Optional[Union[FinishedException, RejectedException, PausedException]] = (
None
)
def _handle_special_exception(
exc_group: BaseExceptionGroup[
Union[FinishedException, RejectedException, PausedException]
]
):
nonlocal exc
excs = list(flatten_exception_group(exc_group))
if len(excs) > 1:
logger.warning(
"Multiple session control exceptions occurred. "
"NoneBot will choose the proper one."
)
finished_exc = next(
(e for e in excs if isinstance(e, FinishedException)),
None,
)
rejected_exc = next(
(e for e in excs if isinstance(e, RejectedException)),
None,
)
paused_exc = next(
(e for e in excs if isinstance(e, PausedException)),
None,
)
exc = finished_exc or rejected_exc or paused_exc
elif isinstance(
excs[0], (FinishedException, RejectedException, PausedException)
):
exc = excs[0]
with catch(
{
(
FinishedException,
RejectedException,
PausedException,
): _handle_special_exception
}
):
await self.simple_run(bot, event, state, stack, dependency_cache)
except RejectedException:
if isinstance(exc, FinishedException):
pass
elif isinstance(exc, RejectedException):
await self.resolve_reject()
type_ = await self.update_type(bot, event, stack, dependency_cache)
permission = await self.update_permission(
@ -870,7 +922,7 @@ class Matcher(metaclass=MatcherMeta):
default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__._default_permission_updater,
)
except PausedException:
elif isinstance(exc, PausedException):
type_ = await self.update_type(bot, event, stack, dependency_cache)
permission = await self.update_permission(
bot, event, stack, dependency_cache
@ -890,5 +942,3 @@ class Matcher(metaclass=MatcherMeta):
default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__._default_permission_updater,
)
except FinishedException:
pass

View File

@ -1,5 +1,5 @@
import asyncio
import inspect
from enum import Enum
from typing_extensions import Self, get_args, override, get_origin
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from typing import (
@ -13,8 +13,11 @@ from typing import (
cast,
)
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from pydantic.fields import FieldInfo as PydanticFieldInfo
from nonebot.exception import SkippedException
from nonebot.dependencies import Param, Dependent
from nonebot.dependencies.utils import check_field_type
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
@ -93,6 +96,75 @@ def Depends(
return DependsInner(dependency, use_cache=use_cache, validate=validate)
class CacheState(str, Enum):
"""子依赖缓存状态"""
PENDING = "PENDING"
FINISHED = "FINISHED"
class DependencyCache:
"""子依赖结果缓存。
用于缓存子依赖的结果,以避免重复计算。
"""
def __init__(self):
self._state = CacheState.PENDING
self._result: Any = None
self._exception: Optional[BaseException] = None
self._waiter = anyio.Event()
def result(self) -> Any:
"""获取子依赖结果"""
if self._state != CacheState.FINISHED:
raise RuntimeError("Result is not ready")
if self._exception is not None:
raise self._exception
return self._result
def exception(self) -> Optional[BaseException]:
"""获取子依赖异常"""
if self._state != CacheState.FINISHED:
raise RuntimeError("Result is not ready")
return self._exception
def set_result(self, result: Any) -> None:
"""设置子依赖结果"""
if self._state != CacheState.PENDING:
raise RuntimeError(f"Cache state invalid: {self._state}")
self._result = result
self._state = CacheState.FINISHED
self._waiter.set()
def set_exception(self, exception: BaseException) -> None:
"""设置子依赖异常"""
if self._state != CacheState.PENDING:
raise RuntimeError(f"Cache state invalid: {self._state}")
self._exception = exception
self._state = CacheState.FINISHED
self._waiter.set()
async def wait(self):
"""等待子依赖结果"""
await self._waiter.wait()
if self._state != CacheState.FINISHED:
raise RuntimeError("Invalid cache state")
if self._exception is not None:
raise self._exception
return self._result
class DependParam(Param):
"""子依赖注入参数。
@ -194,17 +266,27 @@ class DependParam(Param):
call = cast(Callable[..., Any], sub_dependent.call)
# solve sub dependency with current cache
sub_values = await sub_dependent.solve(
stack=stack,
dependency_cache=dependency_cache,
**kwargs,
)
exc: Optional[BaseExceptionGroup[SkippedException]] = None
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
nonlocal exc
exc = exc_group
with catch({SkippedException: _handle_skipped}):
sub_values = await sub_dependent.solve(
stack=stack,
dependency_cache=dependency_cache,
**kwargs,
)
if exc is not None:
raise exc
# run dependency function
task: asyncio.Task[Any]
if use_cache and call in dependency_cache:
return await dependency_cache[call]
elif is_gen_callable(call) or is_async_gen_callable(call):
return await dependency_cache[call].wait()
if is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
stack, AsyncExitStack
), "Generator dependency should be called in context"
@ -212,17 +294,21 @@ class DependParam(Param):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
cm = asynccontextmanager(call)(**sub_values)
task = asyncio.create_task(stack.enter_async_context(cm))
dependency_cache[call] = task
return await task
target = stack.enter_async_context(cm)
elif is_coroutine_callable(call):
task = asyncio.create_task(call(**sub_values))
dependency_cache[call] = task
return await task
target = call(**sub_values)
else:
task = asyncio.create_task(run_sync(call)(**sub_values))
dependency_cache[call] = task
return await task
target = run_sync(call)(**sub_values)
dependency_cache[call] = cache = DependencyCache()
try:
result = await target
cache.set_result(result)
return result
except BaseException as e:
cache.set_exception(e)
raise
@override
async def _check(self, **kwargs: Any) -> None:

View File

@ -1,8 +1,9 @@
import asyncio
from typing_extensions import Self
from contextlib import AsyncExitStack
from typing import Union, ClassVar, NoReturn, Optional
import anyio
from nonebot.dependencies import Dependent
from nonebot.utils import run_coro_with_catch
from nonebot.exception import SkippedException
@ -70,22 +71,26 @@ class Permission:
"""
if not self.checkers:
return True
results = await asyncio.gather(
*(
run_coro_with_catch(
checker(
bot=bot,
event=event,
stack=stack,
dependency_cache=dependency_cache,
),
(SkippedException,),
False,
)
for checker in self.checkers
),
)
return any(results)
result = False
async def _run_checker(checker: Dependent[bool]) -> None:
nonlocal result
# calculate the result first to avoid data racing
is_passed = await run_coro_with_catch(
checker(
bot=bot, event=event, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
False,
)
result |= is_passed
async with anyio.create_task_group() as tg:
for checker in self.checkers:
tg.start_soon(_run_checker, checker)
return result
def __and__(self, other: object) -> NoReturn:
raise RuntimeError("And operation between Permissions is not allowed.")

View File

@ -1,7 +1,9 @@
import asyncio
from contextlib import AsyncExitStack
from typing import Union, ClassVar, NoReturn, Optional
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache
@ -71,22 +73,33 @@ class Rule:
"""
if not self.checkers:
return True
try:
results = await asyncio.gather(
*(
checker(
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
)
for checker in self.checkers
)
result = True
def _handle_skipped_exception(
exc_group: BaseExceptionGroup[SkippedException],
) -> None:
nonlocal result
result = False
async def _run_checker(checker: Dependent[bool]) -> None:
nonlocal result
# calculate the result first to avoid data racing
is_passed = await checker(
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
)
except SkippedException:
return False
return all(results)
result &= is_passed
with catch({SkippedException: _handle_skipped_exception}):
async with anyio.create_task_group() as tg:
for checker in self.checkers:
tg.start_soon(_run_checker, checker)
return result
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
if other is None:

View File

@ -9,23 +9,30 @@ FrontMatter:
description: nonebot.message 模块
"""
import asyncio
import contextlib
from datetime import datetime
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.rule import TrieRule
from nonebot.dependencies import Dependent
from nonebot.matcher import Matcher, matchers
from nonebot.utils import escape_tag, run_coro_with_catch
from nonebot.exception import (
NoLogException,
StopPropagation,
IgnoredException,
SkippedException,
)
from nonebot.utils import (
escape_tag,
run_coro_with_catch,
run_coro_with_shield,
flatten_exception_group,
)
from nonebot.typing import (
T_State,
T_DependencyCache,
@ -125,6 +132,21 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
return func
def _handle_ignored_exception(msg: str) -> Callable[[BaseExceptionGroup], None]:
def _handle(exc_group: BaseExceptionGroup[IgnoredException]) -> None:
logger.opt(colors=True).info(msg)
return _handle
def _handle_exception(msg: str) -> Callable[[BaseExceptionGroup], None]:
def _handle(exc_group: BaseExceptionGroup[Exception]) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(msg)
return _handle
async def _apply_event_preprocessors(
bot: "Bot",
event: "Event",
@ -152,10 +174,21 @@ async def _apply_event_preprocessors(
if show_log:
logger.debug("Running PreProcessors...")
try:
await asyncio.gather(
*(
run_coro_with_catch(
with catch(
{
IgnoredException: _handle_ignored_exception(
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
),
Exception: _handle_exception(
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
"Event ignored!</bg #f8bbd0></r>"
),
}
):
async with anyio.create_task_group() as tg:
for proc in _event_preprocessors:
tg.start_soon(
run_coro_with_catch,
proc(
bot=bot,
event=event,
@ -165,22 +198,10 @@ async def _apply_event_preprocessors(
),
(SkippedException,),
)
for proc in _event_preprocessors
)
)
except IgnoredException:
logger.opt(colors=True).info(
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
)
return False
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
"Event ignored!</bg #f8bbd0></r>"
)
return False
return True
return True
return False
async def _apply_event_postprocessors(
@ -207,10 +228,17 @@ async def _apply_event_postprocessors(
if show_log:
logger.debug("Running PostProcessors...")
try:
await asyncio.gather(
*(
run_coro_with_catch(
with catch(
{
Exception: _handle_exception(
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
)
}
):
async with anyio.create_task_group() as tg:
for proc in _event_postprocessors:
tg.start_soon(
run_coro_with_catch,
proc(
bot=bot,
event=event,
@ -220,13 +248,6 @@ async def _apply_event_postprocessors(
),
(SkippedException,),
)
for proc in _event_postprocessors
)
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
)
async def _apply_run_preprocessors(
@ -254,35 +275,38 @@ async def _apply_run_preprocessors(
return True
# ensure matcher function can be correctly called
with matcher.ensure_context(bot, event):
try:
await asyncio.gather(
*(
run_coro_with_catch(
proc(
matcher=matcher,
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
),
(SkippedException,),
)
for proc in _run_preprocessors
with (
matcher.ensure_context(bot, event),
catch(
{
IgnoredException: _handle_ignored_exception(
f"{matcher} running is <b>cancelled</b>"
),
Exception: _handle_exception(
"<r><bg #f8bbd0>Error when running RunPreProcessors. "
"Running cancelled!</bg #f8bbd0></r>"
),
}
),
):
async with anyio.create_task_group() as tg:
for proc in _run_preprocessors:
tg.start_soon(
run_coro_with_catch,
proc(
matcher=matcher,
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
),
(SkippedException,),
)
)
except IgnoredException:
logger.opt(colors=True).info(f"{matcher} running is <b>cancelled</b>")
return False
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running RunPreProcessors. "
"Running cancelled!</bg #f8bbd0></r>"
)
return False
return True
return True
return False
async def _apply_run_postprocessors(
@ -306,29 +330,32 @@ async def _apply_run_postprocessors(
if not _run_postprocessors:
return
with matcher.ensure_context(bot, event):
try:
await asyncio.gather(
*(
run_coro_with_catch(
proc(
matcher=matcher,
exception=exception,
bot=bot,
event=event,
state=matcher.state,
stack=stack,
dependency_cache=dependency_cache,
),
(SkippedException,),
)
for proc in _run_postprocessors
with (
matcher.ensure_context(bot, event),
catch(
{
Exception: _handle_exception(
"<r><bg #f8bbd0>Error when running RunPostProcessors"
"</bg #f8bbd0></r>"
)
}
),
):
async with anyio.create_task_group() as tg:
for proc in _run_postprocessors:
tg.start_soon(
run_coro_with_catch,
proc(
matcher=matcher,
exception=exception,
bot=bot,
event=event,
state=matcher.state,
stack=stack,
dependency_cache=dependency_cache,
),
(SkippedException,),
)
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running RunPostProcessors</bg #f8bbd0></r>"
)
async def _check_matcher(
@ -425,8 +452,9 @@ async def _run_matcher(
exception = None
logger.debug(f"Running {matcher}")
try:
logger.debug(f"Running {matcher}")
await matcher.run(bot, event, state, stack, dependency_cache)
except Exception as e:
logger.opt(colors=True, exception=e).error(
@ -494,8 +522,7 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
用法:
```python
import asyncio
asyncio.create_task(handle_event(bot, event))
driver.task_group.start_soon(handle_event, bot, event)
```
"""
show_log = True
@ -530,6 +557,13 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
)
break_flag = False
def _handle_stop_propagation(exc_group: BaseExceptionGroup) -> None:
nonlocal break_flag
break_flag = True
logger.debug("Stop event propagation")
# iterate through all priority until stop propagation
for priority in sorted(matchers.keys()):
if break_flag:
@ -538,23 +572,30 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
if show_log:
logger.debug(f"Checking for matchers in priority {priority}...")
pending_tasks = [
check_and_run_matcher(
matcher, bot, event, state.copy(), stack, dependency_cache
)
for matcher in matchers[priority]
]
results = await asyncio.gather(*pending_tasks, return_exceptions=True)
for result in results:
if not isinstance(result, Exception):
continue
if isinstance(result, StopPropagation):
break_flag = True
logger.debug("Stop event propagation")
else:
logger.opt(colors=True, exception=result).error(
if not (priority_matchers := matchers[priority]):
continue
with catch(
{
StopPropagation: _handle_stop_propagation,
Exception: _handle_exception(
"<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
)
),
}
):
async with anyio.create_task_group() as tg:
for matcher in priority_matchers:
tg.start_soon(
run_coro_with_shield,
check_and_run_matcher(
matcher,
bot,
event,
state.copy(),
stack,
dependency_cache,
),
)
if show_log:
logger.debug("Checking for matchers completed")

View File

@ -22,7 +22,7 @@ from . import _managers, get_plugin, _module_name_to_plugin_id
try: # pragma: py-gte-311
import tomllib # pyright: ignore[reportMissingImports]
except ModuleNotFoundError: # pragma: py-lt-311
import tomli as tomllib
import tomli as tomllib # pyright: ignore[reportMissingImports]
def load_plugin(module_path: Union[str, Path]) -> Optional[Plugin]:

View File

@ -21,10 +21,9 @@ from typing import TYPE_CHECKING, TypeVar
from typing_extensions import ParamSpec, TypeAlias, get_args, override, get_origin
if TYPE_CHECKING:
from asyncio import Task
from nonebot.adapters import Bot
from nonebot.permission import Permission
from nonebot.internal.params import DependencyCache
T = TypeVar("T")
P = ParamSpec("P")
@ -258,5 +257,5 @@ T_PermissionUpdater: TypeAlias = _DependentCallable["Permission"]
- MatcherParam: Matcher 对象
- DefaultParam: 带有默认值的参数
"""
T_DependencyCache: TypeAlias = dict[_DependentCallable[t.Any], "Task[t.Any]"]
T_DependencyCache: TypeAlias = dict[_DependentCallable[t.Any], "DependencyCache"]
"""依赖缓存, 用于存储依赖函数的返回值"""

View File

@ -9,21 +9,22 @@ FrontMatter:
import re
import json
import asyncio
import inspect
import importlib
import contextlib
import dataclasses
from pathlib import Path
from collections import deque
from contextvars import copy_context
from functools import wraps, partial
from contextlib import AbstractContextManager, asynccontextmanager
from typing_extensions import ParamSpec, get_args, override, get_origin
from collections.abc import Mapping, Sequence, Coroutine, AsyncGenerator
from typing import Any, Union, Generic, TypeVar, Callable, Optional, overload
from collections.abc import Mapping, Sequence, Coroutine, Generator, AsyncGenerator
import anyio
import anyio.to_thread
from pydantic import BaseModel
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.typing import (
@ -39,6 +40,7 @@ R = TypeVar("R")
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
E = TypeVar("E", bound=BaseException)
def escape_tag(s: str) -> str:
@ -178,11 +180,9 @@ def run_sync(call: Callable[P, R]) -> Callable[P, Coroutine[None, None, R]]:
@wraps(call)
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
loop = asyncio.get_running_loop()
pfunc = partial(call, *args, **kwargs)
context = copy_context()
result = await loop.run_in_executor(None, partial(context.run, pfunc))
return result
return await anyio.to_thread.run_sync(
partial(call, *args, **kwargs), abandon_on_cancel=True
)
return _wrapper
@ -234,10 +234,34 @@ async def run_coro_with_catch(
协程的返回值或发生异常时的指定值
"""
try:
with catch({exc: lambda exc_group: None}):
return await coro
except exc:
return return_on_err
return return_on_err
async def run_coro_with_shield(coro: Coroutine[Any, Any, T]) -> T:
"""运行协程并在取消时屏蔽取消异常。
参数:
coro: 要运行的协程
返回:
协程的返回值
"""
with anyio.CancelScope(shield=True):
return await coro
def flatten_exception_group(
exc_group: BaseExceptionGroup[E],
) -> Generator[E, None, None]:
for exc in exc_group.exceptions:
if isinstance(exc, BaseExceptionGroup):
yield from flatten_exception_group(exc)
else:
yield exc
def get_name(obj: Any) -> str: