Feature: 迁移至结构化并发框架 AnyIO (#3053)

This commit is contained in:
Ju4tCode
2024-10-26 15:36:01 +08:00
committed by GitHub
parent bd9befbb55
commit ff21ceb946
39 changed files with 5422 additions and 4080 deletions

View File

@ -1,11 +1,14 @@
import abc
import asyncio
from functools import partial
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional, Protocol
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.config import Config
from nonebot.exception import MockApiException
from nonebot.utils import flatten_exception_group
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
if TYPE_CHECKING:
@ -76,47 +79,98 @@ class Bot(abc.ABC):
skip_calling_api: bool = False
exception: Optional[Exception] = None
if coros := [hook(self, api, data) for hook in self._calling_api_hook]:
try:
logger.debug("Running CallingAPI hooks...")
await asyncio.gather(*coros)
except MockApiException as e:
if self._calling_api_hook:
logger.debug("Running CallingAPI hooks...")
def _handle_mock_api_exception(
exc_group: BaseExceptionGroup[MockApiException],
) -> None:
nonlocal skip_calling_api, result
excs = [
exc
for exc in flatten_exception_group(exc_group)
if isinstance(exc, MockApiException)
]
if not excs:
return
elif len(excs) > 1:
logger.warning(
"Multiple hooks want to mock API result. Use the first one."
)
skip_calling_api = True
result = e.result
result = excs[0].result
logger.debug(
f"Calling API {api} is cancelled. Return {result} instead."
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
f"Calling API {api} is cancelled. Return {result!r} instead."
)
def _handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
with catch(
{
MockApiException: _handle_mock_api_exception,
Exception: _handle_exception,
}
):
async with anyio.create_task_group() as tg:
for hook in self._calling_api_hook:
tg.start_soon(hook, self, api, data)
if not skip_calling_api:
try:
result = await self.adapter._call_api(self, api, **data)
except Exception as e:
exception = e
if coros := [
hook(self, exception, api, data, result) for hook in self._called_api_hook
]:
try:
logger.debug("Running CalledAPI hooks...")
await asyncio.gather(*coros)
except MockApiException as e:
# mock api result
result = e.result
# ignore exception
if self._called_api_hook:
logger.debug("Running CalledAPI hooks...")
def _handle_mock_api_exception(
exc_group: BaseExceptionGroup[MockApiException],
) -> None:
nonlocal result, exception
excs = [
exc
for exc in flatten_exception_group(exc_group)
if isinstance(exc, MockApiException)
]
if not excs:
return
elif len(excs) > 1:
logger.warning(
"Multiple hooks want to mock API result. Use the first one."
)
result = excs[0].result
exception = None
logger.debug(
f"Calling API {api} result is mocked. Return {result} instead."
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
def _handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
with catch(
{
MockApiException: _handle_mock_api_exception,
Exception: _handle_exception,
}
):
async with anyio.create_task_group() as tg:
for hook in self._called_api_hook:
tg.start_soon(hook, self, exception, api, data, result)
if exception:
raise exception

View File

@ -1,6 +1,11 @@
from collections.abc import Awaitable
from types import TracebackType
from typing_extensions import TypeAlias
from typing import Any, Union, Callable, cast
from collections.abc import Iterable, Awaitable
from typing import Any, Union, Callable, Optional, cast
import anyio
from anyio.abc import TaskGroup
from exceptiongroup import suppress
from nonebot.utils import run_sync, is_coroutine_callable
@ -11,10 +16,24 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
class Lifespan:
def __init__(self) -> None:
self._task_group: Optional[TaskGroup] = None
self._startup_funcs: list[LIFESPAN_FUNC] = []
self._ready_funcs: list[LIFESPAN_FUNC] = []
self._shutdown_funcs: list[LIFESPAN_FUNC] = []
@property
def task_group(self) -> TaskGroup:
if self._task_group is None:
raise RuntimeError("Lifespan not started")
return self._task_group
@task_group.setter
def task_group(self, task_group: TaskGroup) -> None:
if self._task_group is not None:
raise RuntimeError("Lifespan already started")
self._task_group = task_group
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._startup_funcs.append(func)
return func
@ -29,7 +48,7 @@ class Lifespan:
@staticmethod
async def _run_lifespan_func(
funcs: list[LIFESPAN_FUNC],
funcs: Iterable[LIFESPAN_FUNC],
) -> None:
for func in funcs:
if is_coroutine_callable(func):
@ -38,18 +57,44 @@ class Lifespan:
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
async def startup(self) -> None:
# create background task group
self.task_group = anyio.create_task_group()
await self.task_group.__aenter__()
# run startup funcs
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)
# run ready funcs
if self._ready_funcs:
await self._run_lifespan_func(self._ready_funcs)
async def shutdown(self) -> None:
async def shutdown(
self,
*,
exc_type: Optional[type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)
# reverse shutdown funcs to ensure stack order
await self._run_lifespan_func(reversed(self._shutdown_funcs))
# shutdown background task group
self.task_group.cancel_scope.cancel()
with suppress(Exception):
await self.task_group.__aexit__(exc_type, exc_val, exc_tb)
self._task_group = None
async def __aenter__(self) -> None:
await self.startup()
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.shutdown()
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.shutdown(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)

View File

@ -1,17 +1,20 @@
import abc
import asyncio
from types import TracebackType
from collections.abc import AsyncGenerator
from typing_extensions import Self, TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional
from anyio.abc import TaskGroup
from anyio import CancelScope, create_task_group
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.config import Env, Config
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.utils import escape_tag, run_coro_with_catch
from nonebot.internal.params import BotParam, DependParam, DefaultParam
from nonebot.utils import escape_tag, run_coro_with_catch, flatten_exception_group
from nonebot.typing import (
T_DependencyCache,
T_BotConnectionHook,
@ -61,7 +64,6 @@ class Driver(abc.ABC):
self.config: Config = config
"""全局配置对象"""
self._bots: dict[str, "Bot"] = {}
self._bot_tasks: set[asyncio.Task] = set()
self._lifespan = Lifespan()
def __repr__(self) -> str:
@ -75,6 +77,10 @@ class Driver(abc.ABC):
"""获取当前所有已连接的 Bot"""
return self._bots
@property
def task_group(self) -> TaskGroup:
return self._lifespan.task_group
def register_adapter(self, adapter: type["Adapter"], **kwargs) -> None:
"""注册一个协议适配器
@ -112,8 +118,6 @@ class Driver(abc.ABC):
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
)
self.on_shutdown(self._cleanup)
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)
@ -154,66 +158,57 @@ class Driver(abc.ABC):
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
self._bots[bot.self_id] = bot
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack:
if coros := [
run_coro_with_catch(
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
(SkippedException,),
)
for hook in self._bot_connection_hook
]:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketConnection hook. "
"Running cancelled!"
"</bg #f8bbd0></r>"
with CancelScope(shield=True), catch({Exception: handle_exception}):
async with AsyncExitStack() as stack, create_task_group() as tg:
for hook in self._bot_connection_hook:
tg.start_soon(
run_coro_with_catch,
hook(
bot=bot, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
)
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
self.task_group.start_soon(_run_hook, bot)
def _bot_disconnect(self, bot: "Bot") -> None:
"""在连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._bots:
del self._bots[bot.self_id]
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketDisConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack:
if coros := [
run_coro_with_catch(
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
(SkippedException,),
)
for hook in self._bot_disconnection_hook
]:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketDisConnection hook. "
"Running cancelled!"
"</bg #f8bbd0></r>"
# shield cancellation to ensure bot disconnect hooks are always run
with CancelScope(shield=True), catch({Exception: handle_exception}):
async with create_task_group() as tg, AsyncExitStack() as stack:
for hook in self._bot_disconnection_hook:
tg.start_soon(
run_coro_with_catch,
hook(
bot=bot, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
)
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
async def _cleanup(self) -> None:
"""清理驱动器资源"""
if self._bot_tasks:
logger.opt(colors=True).debug(
"<y>Waiting for running bot connection hooks...</y>"
)
await asyncio.gather(*self._bot_tasks, return_exceptions=True)
self.task_group.start_soon(_run_hook, bot)
class Mixin(abc.ABC):

View File

@ -22,11 +22,13 @@ from typing import ( # noqa: UP035
overload,
)
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.internal.rule import Rule
from nonebot.utils import classproperty
from nonebot.dependencies import Param, Dependent
from nonebot.internal.permission import User, Permission
from nonebot.utils import classproperty, flatten_exception_group
from nonebot.internal.adapter import (
Bot,
Event,
@ -812,28 +814,34 @@ class Matcher(metaclass=MatcherMeta):
f"bot={bot}, event={event!r}, state={state!r}"
)
def _handle_stop_propagation(exc_group: BaseExceptionGroup[StopPropagation]):
self.block = True
with self.ensure_context(bot, event):
try:
# Refresh preprocess state
self.state.update(state)
with catch({StopPropagation: _handle_stop_propagation}):
# Refresh preprocess state
self.state.update(state)
while self.remain_handlers:
handler = self.remain_handlers.pop(0)
current_handler.set(handler)
logger.debug(f"Running handler {handler}")
try:
await handler(
matcher=self,
bot=bot,
event=event,
state=self.state,
stack=stack,
dependency_cache=dependency_cache,
)
except SkippedException:
logger.debug(f"Handler {handler} skipped")
except StopPropagation:
self.block = True
while self.remain_handlers:
handler = self.remain_handlers.pop(0)
current_handler.set(handler)
logger.debug(f"Running handler {handler}")
def _handle_skipped(
exc_group: BaseExceptionGroup[SkippedException],
):
logger.debug(f"Handler {handler} skipped")
with catch({SkippedException: _handle_skipped}):
await handler(
matcher=self,
bot=bot,
event=event,
state=self.state,
stack=stack,
dependency_cache=dependency_cache,
)
finally:
logger.info(f"{self} running complete")
@ -846,10 +854,54 @@ class Matcher(metaclass=MatcherMeta):
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
):
try:
exc: Optional[Union[FinishedException, RejectedException, PausedException]] = (
None
)
def _handle_special_exception(
exc_group: BaseExceptionGroup[
Union[FinishedException, RejectedException, PausedException]
]
):
nonlocal exc
excs = list(flatten_exception_group(exc_group))
if len(excs) > 1:
logger.warning(
"Multiple session control exceptions occurred. "
"NoneBot will choose the proper one."
)
finished_exc = next(
(e for e in excs if isinstance(e, FinishedException)),
None,
)
rejected_exc = next(
(e for e in excs if isinstance(e, RejectedException)),
None,
)
paused_exc = next(
(e for e in excs if isinstance(e, PausedException)),
None,
)
exc = finished_exc or rejected_exc or paused_exc
elif isinstance(
excs[0], (FinishedException, RejectedException, PausedException)
):
exc = excs[0]
with catch(
{
(
FinishedException,
RejectedException,
PausedException,
): _handle_special_exception
}
):
await self.simple_run(bot, event, state, stack, dependency_cache)
except RejectedException:
if isinstance(exc, FinishedException):
pass
elif isinstance(exc, RejectedException):
await self.resolve_reject()
type_ = await self.update_type(bot, event, stack, dependency_cache)
permission = await self.update_permission(
@ -870,7 +922,7 @@ class Matcher(metaclass=MatcherMeta):
default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__._default_permission_updater,
)
except PausedException:
elif isinstance(exc, PausedException):
type_ = await self.update_type(bot, event, stack, dependency_cache)
permission = await self.update_permission(
bot, event, stack, dependency_cache
@ -890,5 +942,3 @@ class Matcher(metaclass=MatcherMeta):
default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__._default_permission_updater,
)
except FinishedException:
pass

View File

@ -1,5 +1,5 @@
import asyncio
import inspect
from enum import Enum
from typing_extensions import Self, get_args, override, get_origin
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from typing import (
@ -13,8 +13,11 @@ from typing import (
cast,
)
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from pydantic.fields import FieldInfo as PydanticFieldInfo
from nonebot.exception import SkippedException
from nonebot.dependencies import Param, Dependent
from nonebot.dependencies.utils import check_field_type
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
@ -93,6 +96,75 @@ def Depends(
return DependsInner(dependency, use_cache=use_cache, validate=validate)
class CacheState(str, Enum):
"""子依赖缓存状态"""
PENDING = "PENDING"
FINISHED = "FINISHED"
class DependencyCache:
"""子依赖结果缓存。
用于缓存子依赖的结果,以避免重复计算。
"""
def __init__(self):
self._state = CacheState.PENDING
self._result: Any = None
self._exception: Optional[BaseException] = None
self._waiter = anyio.Event()
def result(self) -> Any:
"""获取子依赖结果"""
if self._state != CacheState.FINISHED:
raise RuntimeError("Result is not ready")
if self._exception is not None:
raise self._exception
return self._result
def exception(self) -> Optional[BaseException]:
"""获取子依赖异常"""
if self._state != CacheState.FINISHED:
raise RuntimeError("Result is not ready")
return self._exception
def set_result(self, result: Any) -> None:
"""设置子依赖结果"""
if self._state != CacheState.PENDING:
raise RuntimeError(f"Cache state invalid: {self._state}")
self._result = result
self._state = CacheState.FINISHED
self._waiter.set()
def set_exception(self, exception: BaseException) -> None:
"""设置子依赖异常"""
if self._state != CacheState.PENDING:
raise RuntimeError(f"Cache state invalid: {self._state}")
self._exception = exception
self._state = CacheState.FINISHED
self._waiter.set()
async def wait(self):
"""等待子依赖结果"""
await self._waiter.wait()
if self._state != CacheState.FINISHED:
raise RuntimeError("Invalid cache state")
if self._exception is not None:
raise self._exception
return self._result
class DependParam(Param):
"""子依赖注入参数。
@ -194,17 +266,27 @@ class DependParam(Param):
call = cast(Callable[..., Any], sub_dependent.call)
# solve sub dependency with current cache
sub_values = await sub_dependent.solve(
stack=stack,
dependency_cache=dependency_cache,
**kwargs,
)
exc: Optional[BaseExceptionGroup[SkippedException]] = None
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
nonlocal exc
exc = exc_group
with catch({SkippedException: _handle_skipped}):
sub_values = await sub_dependent.solve(
stack=stack,
dependency_cache=dependency_cache,
**kwargs,
)
if exc is not None:
raise exc
# run dependency function
task: asyncio.Task[Any]
if use_cache and call in dependency_cache:
return await dependency_cache[call]
elif is_gen_callable(call) or is_async_gen_callable(call):
return await dependency_cache[call].wait()
if is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
stack, AsyncExitStack
), "Generator dependency should be called in context"
@ -212,17 +294,21 @@ class DependParam(Param):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
cm = asynccontextmanager(call)(**sub_values)
task = asyncio.create_task(stack.enter_async_context(cm))
dependency_cache[call] = task
return await task
target = stack.enter_async_context(cm)
elif is_coroutine_callable(call):
task = asyncio.create_task(call(**sub_values))
dependency_cache[call] = task
return await task
target = call(**sub_values)
else:
task = asyncio.create_task(run_sync(call)(**sub_values))
dependency_cache[call] = task
return await task
target = run_sync(call)(**sub_values)
dependency_cache[call] = cache = DependencyCache()
try:
result = await target
cache.set_result(result)
return result
except BaseException as e:
cache.set_exception(e)
raise
@override
async def _check(self, **kwargs: Any) -> None:

View File

@ -1,8 +1,9 @@
import asyncio
from typing_extensions import Self
from contextlib import AsyncExitStack
from typing import Union, ClassVar, NoReturn, Optional
import anyio
from nonebot.dependencies import Dependent
from nonebot.utils import run_coro_with_catch
from nonebot.exception import SkippedException
@ -70,22 +71,26 @@ class Permission:
"""
if not self.checkers:
return True
results = await asyncio.gather(
*(
run_coro_with_catch(
checker(
bot=bot,
event=event,
stack=stack,
dependency_cache=dependency_cache,
),
(SkippedException,),
False,
)
for checker in self.checkers
),
)
return any(results)
result = False
async def _run_checker(checker: Dependent[bool]) -> None:
nonlocal result
# calculate the result first to avoid data racing
is_passed = await run_coro_with_catch(
checker(
bot=bot, event=event, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
False,
)
result |= is_passed
async with anyio.create_task_group() as tg:
for checker in self.checkers:
tg.start_soon(_run_checker, checker)
return result
def __and__(self, other: object) -> NoReturn:
raise RuntimeError("And operation between Permissions is not allowed.")

View File

@ -1,7 +1,9 @@
import asyncio
from contextlib import AsyncExitStack
from typing import Union, ClassVar, NoReturn, Optional
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache
@ -71,22 +73,33 @@ class Rule:
"""
if not self.checkers:
return True
try:
results = await asyncio.gather(
*(
checker(
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
)
for checker in self.checkers
)
result = True
def _handle_skipped_exception(
exc_group: BaseExceptionGroup[SkippedException],
) -> None:
nonlocal result
result = False
async def _run_checker(checker: Dependent[bool]) -> None:
nonlocal result
# calculate the result first to avoid data racing
is_passed = await checker(
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
)
except SkippedException:
return False
return all(results)
result &= is_passed
with catch({SkippedException: _handle_skipped_exception}):
async with anyio.create_task_group() as tg:
for checker in self.checkers:
tg.start_soon(_run_checker, checker)
return result
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
if other is None: