mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-27 16:21:28 +00:00
✨ Feature: 迁移至结构化并发框架 AnyIO (#3053)
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from typing_extensions import Self, get_args, override, get_origin
|
||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||
from typing import (
|
||||
@ -13,8 +13,11 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
import anyio
|
||||
from exceptiongroup import BaseExceptionGroup, catch
|
||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
||||
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.dependencies import Param, Dependent
|
||||
from nonebot.dependencies.utils import check_field_type
|
||||
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
|
||||
@ -93,6 +96,75 @@ def Depends(
|
||||
return DependsInner(dependency, use_cache=use_cache, validate=validate)
|
||||
|
||||
|
||||
class CacheState(str, Enum):
|
||||
"""子依赖缓存状态"""
|
||||
|
||||
PENDING = "PENDING"
|
||||
FINISHED = "FINISHED"
|
||||
|
||||
|
||||
class DependencyCache:
|
||||
"""子依赖结果缓存。
|
||||
|
||||
用于缓存子依赖的结果,以避免重复计算。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._state = CacheState.PENDING
|
||||
self._result: Any = None
|
||||
self._exception: Optional[BaseException] = None
|
||||
self._waiter = anyio.Event()
|
||||
|
||||
def result(self) -> Any:
|
||||
"""获取子依赖结果"""
|
||||
|
||||
if self._state != CacheState.FINISHED:
|
||||
raise RuntimeError("Result is not ready")
|
||||
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
return self._result
|
||||
|
||||
def exception(self) -> Optional[BaseException]:
|
||||
"""获取子依赖异常"""
|
||||
|
||||
if self._state != CacheState.FINISHED:
|
||||
raise RuntimeError("Result is not ready")
|
||||
|
||||
return self._exception
|
||||
|
||||
def set_result(self, result: Any) -> None:
|
||||
"""设置子依赖结果"""
|
||||
|
||||
if self._state != CacheState.PENDING:
|
||||
raise RuntimeError(f"Cache state invalid: {self._state}")
|
||||
|
||||
self._result = result
|
||||
self._state = CacheState.FINISHED
|
||||
self._waiter.set()
|
||||
|
||||
def set_exception(self, exception: BaseException) -> None:
|
||||
"""设置子依赖异常"""
|
||||
|
||||
if self._state != CacheState.PENDING:
|
||||
raise RuntimeError(f"Cache state invalid: {self._state}")
|
||||
|
||||
self._exception = exception
|
||||
self._state = CacheState.FINISHED
|
||||
self._waiter.set()
|
||||
|
||||
async def wait(self):
|
||||
"""等待子依赖结果"""
|
||||
await self._waiter.wait()
|
||||
if self._state != CacheState.FINISHED:
|
||||
raise RuntimeError("Invalid cache state")
|
||||
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
return self._result
|
||||
|
||||
|
||||
class DependParam(Param):
|
||||
"""子依赖注入参数。
|
||||
|
||||
@ -194,17 +266,27 @@ class DependParam(Param):
|
||||
call = cast(Callable[..., Any], sub_dependent.call)
|
||||
|
||||
# solve sub dependency with current cache
|
||||
sub_values = await sub_dependent.solve(
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
**kwargs,
|
||||
)
|
||||
exc: Optional[BaseExceptionGroup[SkippedException]] = None
|
||||
|
||||
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
|
||||
nonlocal exc
|
||||
exc = exc_group
|
||||
|
||||
with catch({SkippedException: _handle_skipped}):
|
||||
sub_values = await sub_dependent.solve(
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if exc is not None:
|
||||
raise exc
|
||||
|
||||
# run dependency function
|
||||
task: asyncio.Task[Any]
|
||||
if use_cache and call in dependency_cache:
|
||||
return await dependency_cache[call]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
return await dependency_cache[call].wait()
|
||||
|
||||
if is_gen_callable(call) or is_async_gen_callable(call):
|
||||
assert isinstance(
|
||||
stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
@ -212,17 +294,21 @@ class DependParam(Param):
|
||||
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
|
||||
else:
|
||||
cm = asynccontextmanager(call)(**sub_values)
|
||||
task = asyncio.create_task(stack.enter_async_context(cm))
|
||||
dependency_cache[call] = task
|
||||
return await task
|
||||
|
||||
target = stack.enter_async_context(cm)
|
||||
elif is_coroutine_callable(call):
|
||||
task = asyncio.create_task(call(**sub_values))
|
||||
dependency_cache[call] = task
|
||||
return await task
|
||||
target = call(**sub_values)
|
||||
else:
|
||||
task = asyncio.create_task(run_sync(call)(**sub_values))
|
||||
dependency_cache[call] = task
|
||||
return await task
|
||||
target = run_sync(call)(**sub_values)
|
||||
|
||||
dependency_cache[call] = cache = DependencyCache()
|
||||
try:
|
||||
result = await target
|
||||
cache.set_result(result)
|
||||
return result
|
||||
except BaseException as e:
|
||||
cache.set_exception(e)
|
||||
raise
|
||||
|
||||
@override
|
||||
async def _check(self, **kwargs: Any) -> None:
|
||||
|
Reference in New Issue
Block a user