mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-09-06 20:16:47 +00:00
✨ Feature: 迁移至结构化并发框架 AnyIO (#3053)
This commit is contained in:
@ -1,6 +1,11 @@
|
||||
from collections.abc import Awaitable
|
||||
from types import TracebackType
|
||||
from typing_extensions import TypeAlias
|
||||
from typing import Any, Union, Callable, cast
|
||||
from collections.abc import Iterable, Awaitable
|
||||
from typing import Any, Union, Callable, Optional, cast
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
from exceptiongroup import suppress
|
||||
|
||||
from nonebot.utils import run_sync, is_coroutine_callable
|
||||
|
||||
@ -11,10 +16,24 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
|
||||
|
||||
class Lifespan:
|
||||
def __init__(self) -> None:
|
||||
self._task_group: Optional[TaskGroup] = None
|
||||
|
||||
self._startup_funcs: list[LIFESPAN_FUNC] = []
|
||||
self._ready_funcs: list[LIFESPAN_FUNC] = []
|
||||
self._shutdown_funcs: list[LIFESPAN_FUNC] = []
|
||||
|
||||
@property
|
||||
def task_group(self) -> TaskGroup:
|
||||
if self._task_group is None:
|
||||
raise RuntimeError("Lifespan not started")
|
||||
return self._task_group
|
||||
|
||||
@task_group.setter
|
||||
def task_group(self, task_group: TaskGroup) -> None:
|
||||
if self._task_group is not None:
|
||||
raise RuntimeError("Lifespan already started")
|
||||
self._task_group = task_group
|
||||
|
||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
self._startup_funcs.append(func)
|
||||
return func
|
||||
@ -29,7 +48,7 @@ class Lifespan:
|
||||
|
||||
@staticmethod
|
||||
async def _run_lifespan_func(
|
||||
funcs: list[LIFESPAN_FUNC],
|
||||
funcs: Iterable[LIFESPAN_FUNC],
|
||||
) -> None:
|
||||
for func in funcs:
|
||||
if is_coroutine_callable(func):
|
||||
@ -38,18 +57,44 @@ class Lifespan:
|
||||
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
|
||||
|
||||
async def startup(self) -> None:
|
||||
# create background task group
|
||||
self.task_group = anyio.create_task_group()
|
||||
await self.task_group.__aenter__()
|
||||
|
||||
# run startup funcs
|
||||
if self._startup_funcs:
|
||||
await self._run_lifespan_func(self._startup_funcs)
|
||||
|
||||
# run ready funcs
|
||||
if self._ready_funcs:
|
||||
await self._run_lifespan_func(self._ready_funcs)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
async def shutdown(
|
||||
self,
|
||||
*,
|
||||
exc_type: Optional[type[BaseException]] = None,
|
||||
exc_val: Optional[BaseException] = None,
|
||||
exc_tb: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
if self._shutdown_funcs:
|
||||
await self._run_lifespan_func(self._shutdown_funcs)
|
||||
# reverse shutdown funcs to ensure stack order
|
||||
await self._run_lifespan_func(reversed(self._shutdown_funcs))
|
||||
|
||||
# shutdown background task group
|
||||
self.task_group.cancel_scope.cancel()
|
||||
|
||||
with suppress(Exception):
|
||||
await self.task_group.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
self._task_group = None
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self.startup()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
await self.shutdown()
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
await self.shutdown(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)
|
||||
|
@ -1,17 +1,20 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from types import TracebackType
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional
|
||||
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio import CancelScope, create_task_group
|
||||
from exceptiongroup import BaseExceptionGroup, catch
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.utils import escape_tag, run_coro_with_catch
|
||||
from nonebot.internal.params import BotParam, DependParam, DefaultParam
|
||||
from nonebot.utils import escape_tag, run_coro_with_catch, flatten_exception_group
|
||||
from nonebot.typing import (
|
||||
T_DependencyCache,
|
||||
T_BotConnectionHook,
|
||||
@ -61,7 +64,6 @@ class Driver(abc.ABC):
|
||||
self.config: Config = config
|
||||
"""全局配置对象"""
|
||||
self._bots: dict[str, "Bot"] = {}
|
||||
self._bot_tasks: set[asyncio.Task] = set()
|
||||
self._lifespan = Lifespan()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -75,6 +77,10 @@ class Driver(abc.ABC):
|
||||
"""获取当前所有已连接的 Bot"""
|
||||
return self._bots
|
||||
|
||||
@property
|
||||
def task_group(self) -> TaskGroup:
|
||||
return self._lifespan.task_group
|
||||
|
||||
def register_adapter(self, adapter: type["Adapter"], **kwargs) -> None:
|
||||
"""注册一个协议适配器
|
||||
|
||||
@ -112,8 +118,6 @@ class Driver(abc.ABC):
|
||||
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
|
||||
)
|
||||
|
||||
self.on_shutdown(self._cleanup)
|
||||
|
||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
"""注册一个启动时执行的函数"""
|
||||
return self._lifespan.on_startup(func)
|
||||
@ -154,66 +158,57 @@ class Driver(abc.ABC):
|
||||
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
|
||||
self._bots[bot.self_id] = bot
|
||||
|
||||
def handle_exception(exc_group: BaseExceptionGroup) -> None:
|
||||
for exc in flatten_exception_group(exc_group):
|
||||
logger.opt(colors=True, exception=exc).error(
|
||||
"<r><bg #f8bbd0>"
|
||||
"Error when running WebSocketConnection hook:"
|
||||
"</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
async def _run_hook(bot: "Bot") -> None:
|
||||
dependency_cache: T_DependencyCache = {}
|
||||
async with AsyncExitStack() as stack:
|
||||
if coros := [
|
||||
run_coro_with_catch(
|
||||
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
|
||||
(SkippedException,),
|
||||
)
|
||||
for hook in self._bot_connection_hook
|
||||
]:
|
||||
try:
|
||||
await asyncio.gather(*coros)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>"
|
||||
"Error when running WebSocketConnection hook. "
|
||||
"Running cancelled!"
|
||||
"</bg #f8bbd0></r>"
|
||||
with CancelScope(shield=True), catch({Exception: handle_exception}):
|
||||
async with AsyncExitStack() as stack, create_task_group() as tg:
|
||||
for hook in self._bot_connection_hook:
|
||||
tg.start_soon(
|
||||
run_coro_with_catch,
|
||||
hook(
|
||||
bot=bot, stack=stack, dependency_cache=dependency_cache
|
||||
),
|
||||
(SkippedException,),
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_run_hook(bot))
|
||||
task.add_done_callback(self._bot_tasks.discard)
|
||||
self._bot_tasks.add(task)
|
||||
self.task_group.start_soon(_run_hook, bot)
|
||||
|
||||
def _bot_disconnect(self, bot: "Bot") -> None:
|
||||
"""在连接断开后,调用该函数来注销 bot 对象"""
|
||||
if bot.self_id in self._bots:
|
||||
del self._bots[bot.self_id]
|
||||
|
||||
def handle_exception(exc_group: BaseExceptionGroup) -> None:
|
||||
for exc in flatten_exception_group(exc_group):
|
||||
logger.opt(colors=True, exception=exc).error(
|
||||
"<r><bg #f8bbd0>"
|
||||
"Error when running WebSocketDisConnection hook:"
|
||||
"</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
async def _run_hook(bot: "Bot") -> None:
|
||||
dependency_cache: T_DependencyCache = {}
|
||||
async with AsyncExitStack() as stack:
|
||||
if coros := [
|
||||
run_coro_with_catch(
|
||||
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
|
||||
(SkippedException,),
|
||||
)
|
||||
for hook in self._bot_disconnection_hook
|
||||
]:
|
||||
try:
|
||||
await asyncio.gather(*coros)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>"
|
||||
"Error when running WebSocketDisConnection hook. "
|
||||
"Running cancelled!"
|
||||
"</bg #f8bbd0></r>"
|
||||
# shield cancellation to ensure bot disconnect hooks are always run
|
||||
with CancelScope(shield=True), catch({Exception: handle_exception}):
|
||||
async with create_task_group() as tg, AsyncExitStack() as stack:
|
||||
for hook in self._bot_disconnection_hook:
|
||||
tg.start_soon(
|
||||
run_coro_with_catch,
|
||||
hook(
|
||||
bot=bot, stack=stack, dependency_cache=dependency_cache
|
||||
),
|
||||
(SkippedException,),
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_run_hook(bot))
|
||||
task.add_done_callback(self._bot_tasks.discard)
|
||||
self._bot_tasks.add(task)
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""清理驱动器资源"""
|
||||
if self._bot_tasks:
|
||||
logger.opt(colors=True).debug(
|
||||
"<y>Waiting for running bot connection hooks...</y>"
|
||||
)
|
||||
await asyncio.gather(*self._bot_tasks, return_exceptions=True)
|
||||
self.task_group.start_soon(_run_hook, bot)
|
||||
|
||||
|
||||
class Mixin(abc.ABC):
|
||||
|
Reference in New Issue
Block a user