Feature: 迁移至结构化并发框架 AnyIO (#3053)

This commit is contained in:
Ju4tCode
2024-10-26 15:36:01 +08:00
committed by GitHub
parent bd9befbb55
commit ff21ceb946
39 changed files with 5422 additions and 4080 deletions

View File

@ -1,6 +1,11 @@
from collections.abc import Awaitable
from types import TracebackType
from typing_extensions import TypeAlias
from typing import Any, Union, Callable, cast
from collections.abc import Iterable, Awaitable
from typing import Any, Union, Callable, Optional, cast
import anyio
from anyio.abc import TaskGroup
from exceptiongroup import suppress
from nonebot.utils import run_sync, is_coroutine_callable
@ -11,10 +16,24 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
class Lifespan:
def __init__(self) -> None:
self._task_group: Optional[TaskGroup] = None
self._startup_funcs: list[LIFESPAN_FUNC] = []
self._ready_funcs: list[LIFESPAN_FUNC] = []
self._shutdown_funcs: list[LIFESPAN_FUNC] = []
@property
def task_group(self) -> TaskGroup:
if self._task_group is None:
raise RuntimeError("Lifespan not started")
return self._task_group
@task_group.setter
def task_group(self, task_group: TaskGroup) -> None:
if self._task_group is not None:
raise RuntimeError("Lifespan already started")
self._task_group = task_group
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._startup_funcs.append(func)
return func
@ -29,7 +48,7 @@ class Lifespan:
@staticmethod
async def _run_lifespan_func(
funcs: list[LIFESPAN_FUNC],
funcs: Iterable[LIFESPAN_FUNC],
) -> None:
for func in funcs:
if is_coroutine_callable(func):
@ -38,18 +57,44 @@ class Lifespan:
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
async def startup(self) -> None:
# create background task group
self.task_group = anyio.create_task_group()
await self.task_group.__aenter__()
# run startup funcs
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)
# run ready funcs
if self._ready_funcs:
await self._run_lifespan_func(self._ready_funcs)
async def shutdown(self) -> None:
async def shutdown(
self,
*,
exc_type: Optional[type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)
# reverse shutdown funcs to ensure stack order
await self._run_lifespan_func(reversed(self._shutdown_funcs))
# shutdown background task group
self.task_group.cancel_scope.cancel()
with suppress(Exception):
await self.task_group.__aexit__(exc_type, exc_val, exc_tb)
self._task_group = None
async def __aenter__(self) -> None:
await self.startup()
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.shutdown()
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.shutdown(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)