mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-09-07 12:36:47 +00:00
✨ Feature: 迁移至结构化并发框架 AnyIO (#3053)
This commit is contained in:
@ -8,17 +8,20 @@ FrontMatter:
|
||||
"""
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import inspect
|
||||
from functools import partial
|
||||
from dataclasses import field, dataclass
|
||||
from collections.abc import Iterable, Awaitable
|
||||
from typing import Any, Generic, TypeVar, Callable, Optional, cast
|
||||
|
||||
import anyio
|
||||
from exceptiongroup import BaseExceptionGroup, catch
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import _DependentCallable
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.utils import run_sync, is_coroutine_callable
|
||||
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
|
||||
from nonebot.utils import run_sync, is_coroutine_callable, flatten_exception_group
|
||||
|
||||
from .utils import check_field_type, get_typed_signature
|
||||
|
||||
@ -84,7 +87,16 @@ class Dependent(Generic[R]):
|
||||
)
|
||||
|
||||
async def __call__(self, **kwargs: Any) -> R:
|
||||
try:
|
||||
exception: Optional[BaseExceptionGroup[SkippedException]] = None
|
||||
|
||||
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
|
||||
nonlocal exception
|
||||
exception = exc_group
|
||||
# raise one of the exceptions instead
|
||||
excs = list(flatten_exception_group(exc_group))
|
||||
logger.trace(f"{self} skipped due to {excs}")
|
||||
|
||||
with catch({SkippedException: _handle_skipped}):
|
||||
# do pre-check
|
||||
await self.check(**kwargs)
|
||||
|
||||
@ -96,9 +108,8 @@ class Dependent(Generic[R]):
|
||||
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
|
||||
else:
|
||||
return await run_sync(cast(Callable[..., R], self.call))(**values)
|
||||
except SkippedException as e:
|
||||
logger.trace(f"{self} skipped due to {e}")
|
||||
raise
|
||||
|
||||
raise exception
|
||||
|
||||
@staticmethod
|
||||
def parse_params(
|
||||
@ -166,10 +177,13 @@ class Dependent(Generic[R]):
|
||||
return cls(call, params, parameterless_params)
|
||||
|
||||
async def check(self, **params: Any) -> None:
|
||||
await asyncio.gather(*(param._check(**params) for param in self.parameterless))
|
||||
await asyncio.gather(
|
||||
*(cast(Param, param.field_info)._check(**params) for param in self.params)
|
||||
)
|
||||
async with anyio.create_task_group() as tg:
|
||||
for param in self.parameterless:
|
||||
tg.start_soon(partial(param._check, **params))
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for param in self.params:
|
||||
tg.start_soon(partial(cast(Param, param.field_info)._check, **params))
|
||||
|
||||
async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any:
|
||||
param = cast(Param, field.field_info)
|
||||
@ -185,10 +199,17 @@ class Dependent(Generic[R]):
|
||||
await param._solve(**params)
|
||||
|
||||
# solve param values
|
||||
values = await asyncio.gather(
|
||||
*(self._solve_field(field, params) for field in self.params)
|
||||
)
|
||||
return {field.name: value for field, value in zip(self.params, values)}
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
async def _solve_field(field: ModelField, params: dict[str, Any]) -> None:
|
||||
value = await self._solve_field(field, params)
|
||||
result[field.name] = value
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for field in self.params:
|
||||
tg.start_soon(_solve_field, field, params)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
__autodoc__ = {"CustomConfig": False}
|
||||
|
Reference in New Issue
Block a user