Feature: 迁移至结构化并发框架 AnyIO (#3053)

This commit is contained in:
Ju4tCode
2024-10-26 15:36:01 +08:00
committed by GitHub
parent bd9befbb55
commit ff21ceb946
39 changed files with 5422 additions and 4080 deletions

View File

@ -1,17 +1,20 @@
import abc
import asyncio
from types import TracebackType
from collections.abc import AsyncGenerator
from typing_extensions import Self, TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional
from anyio.abc import TaskGroup
from anyio import CancelScope, create_task_group
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.config import Env, Config
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.utils import escape_tag, run_coro_with_catch
from nonebot.internal.params import BotParam, DependParam, DefaultParam
from nonebot.utils import escape_tag, run_coro_with_catch, flatten_exception_group
from nonebot.typing import (
T_DependencyCache,
T_BotConnectionHook,
@ -61,7 +64,6 @@ class Driver(abc.ABC):
self.config: Config = config
"""全局配置对象"""
self._bots: dict[str, "Bot"] = {}
self._bot_tasks: set[asyncio.Task] = set()
self._lifespan = Lifespan()
def __repr__(self) -> str:
@ -75,6 +77,10 @@ class Driver(abc.ABC):
"""获取当前所有已连接的 Bot"""
return self._bots
@property
def task_group(self) -> TaskGroup:
return self._lifespan.task_group
def register_adapter(self, adapter: type["Adapter"], **kwargs) -> None:
"""注册一个协议适配器
@ -112,8 +118,6 @@ class Driver(abc.ABC):
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
)
self.on_shutdown(self._cleanup)
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)
@ -154,66 +158,57 @@ class Driver(abc.ABC):
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
self._bots[bot.self_id] = bot
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack:
if coros := [
run_coro_with_catch(
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
(SkippedException,),
)
for hook in self._bot_connection_hook
]:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketConnection hook. "
"Running cancelled!"
"</bg #f8bbd0></r>"
with CancelScope(shield=True), catch({Exception: handle_exception}):
async with AsyncExitStack() as stack, create_task_group() as tg:
for hook in self._bot_connection_hook:
tg.start_soon(
run_coro_with_catch,
hook(
bot=bot, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
)
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
self.task_group.start_soon(_run_hook, bot)
def _bot_disconnect(self, bot: "Bot") -> None:
"""在连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._bots:
del self._bots[bot.self_id]
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketDisConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack:
if coros := [
run_coro_with_catch(
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
(SkippedException,),
)
for hook in self._bot_disconnection_hook
]:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketDisConnection hook. "
"Running cancelled!"
"</bg #f8bbd0></r>"
# shield cancellation to ensure bot disconnect hooks are always run
with CancelScope(shield=True), catch({Exception: handle_exception}):
async with create_task_group() as tg, AsyncExitStack() as stack:
for hook in self._bot_disconnection_hook:
tg.start_soon(
run_coro_with_catch,
hook(
bot=bot, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
)
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
async def _cleanup(self) -> None:
"""清理驱动器资源"""
if self._bot_tasks:
logger.opt(colors=True).debug(
"<y>Waiting for running bot connection hooks...</y>"
)
await asyncio.gather(*self._bot_tasks, return_exceptions=True)
self.task_group.start_soon(_run_hook, bot)
class Mixin(abc.ABC):