Feature: 迁移至结构化并发框架 AnyIO (#3053)

This commit is contained in:
Ju4tCode
2024-10-26 15:36:01 +08:00
committed by GitHub
parent bd9befbb55
commit ff21ceb946
39 changed files with 5422 additions and 4080 deletions

View File

@ -1,6 +1,11 @@
from collections.abc import Awaitable
from types import TracebackType
from typing_extensions import TypeAlias
from typing import Any, Union, Callable, cast
from collections.abc import Iterable, Awaitable
from typing import Any, Union, Callable, Optional, cast
import anyio
from anyio.abc import TaskGroup
from exceptiongroup import suppress
from nonebot.utils import run_sync, is_coroutine_callable
@ -11,10 +16,24 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
class Lifespan:
def __init__(self) -> None:
self._task_group: Optional[TaskGroup] = None
self._startup_funcs: list[LIFESPAN_FUNC] = []
self._ready_funcs: list[LIFESPAN_FUNC] = []
self._shutdown_funcs: list[LIFESPAN_FUNC] = []
@property
def task_group(self) -> TaskGroup:
if self._task_group is None:
raise RuntimeError("Lifespan not started")
return self._task_group
@task_group.setter
def task_group(self, task_group: TaskGroup) -> None:
if self._task_group is not None:
raise RuntimeError("Lifespan already started")
self._task_group = task_group
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._startup_funcs.append(func)
return func
@ -29,7 +48,7 @@ class Lifespan:
@staticmethod
async def _run_lifespan_func(
funcs: list[LIFESPAN_FUNC],
funcs: Iterable[LIFESPAN_FUNC],
) -> None:
for func in funcs:
if is_coroutine_callable(func):
@ -38,18 +57,44 @@ class Lifespan:
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
async def startup(self) -> None:
# create background task group
self.task_group = anyio.create_task_group()
await self.task_group.__aenter__()
# run startup funcs
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)
# run ready funcs
if self._ready_funcs:
await self._run_lifespan_func(self._ready_funcs)
async def shutdown(self) -> None:
async def shutdown(
self,
*,
exc_type: Optional[type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)
# reverse shutdown funcs to ensure stack order
await self._run_lifespan_func(reversed(self._shutdown_funcs))
# shutdown background task group
self.task_group.cancel_scope.cancel()
with suppress(Exception):
await self.task_group.__aexit__(exc_type, exc_val, exc_tb)
self._task_group = None
async def __aenter__(self) -> None:
await self.startup()
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.shutdown()
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.shutdown(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)

View File

@ -1,17 +1,20 @@
import abc
import asyncio
from types import TracebackType
from collections.abc import AsyncGenerator
from typing_extensions import Self, TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional
from anyio.abc import TaskGroup
from anyio import CancelScope, create_task_group
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.config import Env, Config
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.utils import escape_tag, run_coro_with_catch
from nonebot.internal.params import BotParam, DependParam, DefaultParam
from nonebot.utils import escape_tag, run_coro_with_catch, flatten_exception_group
from nonebot.typing import (
T_DependencyCache,
T_BotConnectionHook,
@ -61,7 +64,6 @@ class Driver(abc.ABC):
self.config: Config = config
"""全局配置对象"""
self._bots: dict[str, "Bot"] = {}
self._bot_tasks: set[asyncio.Task] = set()
self._lifespan = Lifespan()
def __repr__(self) -> str:
@ -75,6 +77,10 @@ class Driver(abc.ABC):
"""获取当前所有已连接的 Bot"""
return self._bots
@property
def task_group(self) -> TaskGroup:
return self._lifespan.task_group
def register_adapter(self, adapter: type["Adapter"], **kwargs) -> None:
"""注册一个协议适配器
@ -112,8 +118,6 @@ class Driver(abc.ABC):
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
)
self.on_shutdown(self._cleanup)
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)
@ -154,66 +158,57 @@ class Driver(abc.ABC):
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
self._bots[bot.self_id] = bot
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack:
if coros := [
run_coro_with_catch(
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
(SkippedException,),
)
for hook in self._bot_connection_hook
]:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketConnection hook. "
"Running cancelled!"
"</bg #f8bbd0></r>"
with CancelScope(shield=True), catch({Exception: handle_exception}):
async with AsyncExitStack() as stack, create_task_group() as tg:
for hook in self._bot_connection_hook:
tg.start_soon(
run_coro_with_catch,
hook(
bot=bot, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
)
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
self.task_group.start_soon(_run_hook, bot)
def _bot_disconnect(self, bot: "Bot") -> None:
"""在连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._bots:
del self._bots[bot.self_id]
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketDisConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack:
if coros := [
run_coro_with_catch(
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
(SkippedException,),
)
for hook in self._bot_disconnection_hook
]:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketDisConnection hook. "
"Running cancelled!"
"</bg #f8bbd0></r>"
# shield cancellation to ensure bot disconnect hooks are always run
with CancelScope(shield=True), catch({Exception: handle_exception}):
async with create_task_group() as tg, AsyncExitStack() as stack:
for hook in self._bot_disconnection_hook:
tg.start_soon(
run_coro_with_catch,
hook(
bot=bot, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
)
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
async def _cleanup(self) -> None:
"""清理驱动器资源"""
if self._bot_tasks:
logger.opt(colors=True).debug(
"<y>Waiting for running bot connection hooks...</y>"
)
await asyncio.gather(*self._bot_tasks, return_exceptions=True)
self.task_group.start_soon(_run_hook, bot)
class Mixin(abc.ABC):