Feature: 迁移至结构化并发框架 AnyIO (#3053)

This commit is contained in:
Ju4tCode
2024-10-26 15:36:01 +08:00
committed by GitHub
parent bd9befbb55
commit ff21ceb946
39 changed files with 5422 additions and 4080 deletions

View File

@ -8,17 +8,20 @@ FrontMatter:
"""
import abc
import asyncio
import inspect
from functools import partial
from dataclasses import field, dataclass
from collections.abc import Iterable, Awaitable
from typing import Any, Generic, TypeVar, Callable, Optional, cast
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.typing import _DependentCallable
from nonebot.exception import SkippedException
from nonebot.utils import run_sync, is_coroutine_callable
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
from nonebot.utils import run_sync, is_coroutine_callable, flatten_exception_group
from .utils import check_field_type, get_typed_signature
@ -84,7 +87,16 @@ class Dependent(Generic[R]):
)
async def __call__(self, **kwargs: Any) -> R:
try:
exception: Optional[BaseExceptionGroup[SkippedException]] = None
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
nonlocal exception
exception = exc_group
# raise one of the exceptions instead
excs = list(flatten_exception_group(exc_group))
logger.trace(f"{self} skipped due to {excs}")
with catch({SkippedException: _handle_skipped}):
# do pre-check
await self.check(**kwargs)
@ -96,9 +108,8 @@ class Dependent(Generic[R]):
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
else:
return await run_sync(cast(Callable[..., R], self.call))(**values)
except SkippedException as e:
logger.trace(f"{self} skipped due to {e}")
raise
raise exception
@staticmethod
def parse_params(
@ -166,10 +177,13 @@ class Dependent(Generic[R]):
return cls(call, params, parameterless_params)
async def check(self, **params: Any) -> None:
await asyncio.gather(*(param._check(**params) for param in self.parameterless))
await asyncio.gather(
*(cast(Param, param.field_info)._check(**params) for param in self.params)
)
async with anyio.create_task_group() as tg:
for param in self.parameterless:
tg.start_soon(partial(param._check, **params))
async with anyio.create_task_group() as tg:
for param in self.params:
tg.start_soon(partial(cast(Param, param.field_info)._check, **params))
async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any:
param = cast(Param, field.field_info)
@ -185,10 +199,17 @@ class Dependent(Generic[R]):
await param._solve(**params)
# solve param values
values = await asyncio.gather(
*(self._solve_field(field, params) for field in self.params)
)
return {field.name: value for field, value in zip(self.params, values)}
result: dict[str, Any] = {}
async def _solve_field(field: ModelField, params: dict[str, Any]) -> None:
value = await self._solve_field(field, params)
result[field.name] = value
async with anyio.create_task_group() as tg:
for field in self.params:
tg.start_soon(_solve_field, field, params)
return result
__autodoc__ = {"CustomConfig": False}