mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-28 00:31:14 +00:00
♻️ rewrite dependency injection system
This commit is contained in:
@ -5,227 +5,170 @@
|
||||
该模块实现了依赖注入的定义与处理。
|
||||
"""
|
||||
|
||||
import abc
|
||||
import inspect
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast
|
||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||
from typing import Any, Dict, List, Type, Generic, TypeVar, Callable, Optional
|
||||
|
||||
from pydantic import BaseConfig
|
||||
from pydantic.schema import get_annotation_from_field_info
|
||||
from pydantic.fields import Required, Undefined, ModelField
|
||||
from pydantic.fields import Required, FieldInfo, Undefined, ModelField
|
||||
|
||||
from nonebot.log import logger
|
||||
from .models import Param as Param
|
||||
from .utils import get_typed_signature
|
||||
from .models import Dependent as Dependent
|
||||
from nonebot.exception import SkippedException
|
||||
from .models import DependsWrapper as DependsWrapper
|
||||
from nonebot.typing import T_Handler, T_DependencyCache
|
||||
from nonebot.utils import (
|
||||
CacheLock,
|
||||
run_sync,
|
||||
is_gen_callable,
|
||||
run_sync_ctx_manager,
|
||||
is_async_gen_callable,
|
||||
is_coroutine_callable,
|
||||
)
|
||||
from nonebot.utils import run_sync, is_coroutine_callable
|
||||
|
||||
cache_lock = CacheLock()
|
||||
T = TypeVar("T", bound="Dependent")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class Param(abc.ABC, FieldInfo):
|
||||
@classmethod
|
||||
def _check_param(
|
||||
cls, dependent: "Dependent", name: str, param: inspect.Parameter
|
||||
) -> Optional["Param"]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _check_parameterless(
|
||||
cls, dependent: "Dependent", value: Any
|
||||
) -> Optional["Param"]:
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _solve(self, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CustomConfig(BaseConfig):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
def get_param_sub_dependent(
|
||||
*, param: inspect.Parameter, allow_types: Optional[List[Type[Param]]] = None
|
||||
) -> Dependent:
|
||||
depends: DependsWrapper = param.default
|
||||
if depends.dependency:
|
||||
dependency = depends.dependency
|
||||
else:
|
||||
dependency = param.annotation
|
||||
return get_sub_dependant(
|
||||
depends=depends, dependency=dependency, name=param.name, allow_types=allow_types
|
||||
)
|
||||
class Dependent(Generic[R]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
call: Callable[..., Any],
|
||||
params: Optional[List[ModelField]] = None,
|
||||
parameterless: Optional[List[Param]] = None,
|
||||
allow_types: Optional[List[Type[Param]]] = None,
|
||||
) -> None:
|
||||
self.call = call
|
||||
self.params = params or []
|
||||
self.parameterless = parameterless or []
|
||||
self.allow_types = allow_types or []
|
||||
|
||||
async def __call__(self, **kwargs: Any) -> R:
|
||||
values = await self.solve(**kwargs)
|
||||
|
||||
def get_parameterless_sub_dependant(
|
||||
*, depends: DependsWrapper, allow_types: Optional[List[Type[Param]]] = None
|
||||
) -> Dependent:
|
||||
assert callable(
|
||||
depends.dependency
|
||||
), "A parameter-less dependency must have a callable dependency"
|
||||
return get_sub_dependant(
|
||||
depends=depends, dependency=depends.dependency, allow_types=allow_types
|
||||
)
|
||||
|
||||
|
||||
def get_sub_dependant(
|
||||
*,
|
||||
depends: DependsWrapper,
|
||||
dependency: T_Handler,
|
||||
name: Optional[str] = None,
|
||||
allow_types: Optional[List[Type[Param]]] = None,
|
||||
) -> Dependent:
|
||||
sub_dependant = get_dependent(
|
||||
call=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types
|
||||
)
|
||||
return sub_dependant
|
||||
|
||||
|
||||
def get_dependent(
|
||||
*,
|
||||
call: T_Handler,
|
||||
name: Optional[str] = None,
|
||||
use_cache: bool = True,
|
||||
allow_types: Optional[List[Type[Param]]] = None,
|
||||
) -> Dependent:
|
||||
signature = get_typed_signature(call)
|
||||
params = signature.parameters
|
||||
dependent = Dependent(
|
||||
call=call, name=name, allow_types=allow_types, use_cache=use_cache
|
||||
)
|
||||
for param_name, param in params.items():
|
||||
if isinstance(param.default, DependsWrapper):
|
||||
sub_dependent = get_param_sub_dependent(
|
||||
param=param, allow_types=allow_types
|
||||
)
|
||||
dependent.dependencies.append(sub_dependent)
|
||||
continue
|
||||
|
||||
default_value = Required
|
||||
if param.default != param.empty:
|
||||
default_value = param.default
|
||||
|
||||
if isinstance(default_value, Param):
|
||||
field_info = default_value
|
||||
default_value = field_info.default
|
||||
if is_coroutine_callable(self.call):
|
||||
return await self.call(**values)
|
||||
else:
|
||||
for allow_type in dependent.allow_types:
|
||||
if allow_type._check(param_name, param):
|
||||
field_info = allow_type(default_value)
|
||||
break
|
||||
return await run_sync(self.call)(**values)
|
||||
|
||||
def parse_param(self, name: str, param: inspect.Parameter) -> Param:
|
||||
for allow_type in self.allow_types:
|
||||
field_info = allow_type._check_param(self, name, param)
|
||||
if field_info:
|
||||
return field_info
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown parameter {name} for function {self.call} with type {param.annotation}"
|
||||
)
|
||||
|
||||
def parse_parameterless(self, value: Any) -> Param:
|
||||
for allow_type in self.allow_types:
|
||||
field_info = allow_type._check_parameterless(self, value)
|
||||
if field_info:
|
||||
return field_info
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown parameterless {value} for function {self.call} with type {type(value)}"
|
||||
)
|
||||
|
||||
def prepend_parameterless(self, value: Any) -> None:
|
||||
self.parameterless.insert(0, self.parse_parameterless(value))
|
||||
|
||||
def append_parameterless(self, value: Any) -> None:
|
||||
self.parameterless.append(self.parse_parameterless(value))
|
||||
|
||||
@classmethod
|
||||
def parse(
|
||||
cls: Type[T],
|
||||
*,
|
||||
call: Callable[..., Any],
|
||||
parameterless: Optional[List[Any]] = None,
|
||||
allow_types: Optional[List[Type[Param]]] = None,
|
||||
) -> T:
|
||||
signature = get_typed_signature(call)
|
||||
params = signature.parameters
|
||||
dependent = cls(
|
||||
call=call,
|
||||
allow_types=allow_types,
|
||||
)
|
||||
|
||||
parameterless_params = [
|
||||
dependent.parse_parameterless(param) for param in (parameterless or [])
|
||||
]
|
||||
dependent.parameterless.extend(parameterless_params)
|
||||
|
||||
for param_name, param in params.items():
|
||||
default_value = Required
|
||||
if param.default != param.empty:
|
||||
default_value = param.default
|
||||
|
||||
if isinstance(default_value, Param):
|
||||
field_info = default_value
|
||||
default_value = field_info.default
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown parameter {param_name} for function {call} with type {param.annotation}"
|
||||
field_info = dependent.parse_param(param_name, param)
|
||||
default_value = field_info.default
|
||||
|
||||
annotation: Any = Any
|
||||
required = default_value == Required
|
||||
if param.annotation != param.empty:
|
||||
annotation = param.annotation
|
||||
annotation = get_annotation_from_field_info(
|
||||
annotation, field_info, param_name
|
||||
)
|
||||
dependent.params.append(
|
||||
ModelField(
|
||||
name=param_name,
|
||||
type_=annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=None if required else default_value,
|
||||
required=required,
|
||||
field_info=field_info,
|
||||
)
|
||||
|
||||
annotation: Any = Any
|
||||
required = default_value == Required
|
||||
if param.annotation != param.empty:
|
||||
annotation = param.annotation
|
||||
annotation = get_annotation_from_field_info(annotation, field_info, param_name)
|
||||
dependent.params.append(
|
||||
ModelField(
|
||||
name=param_name,
|
||||
type_=annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=None if required else default_value,
|
||||
required=required,
|
||||
field_info=field_info,
|
||||
)
|
||||
)
|
||||
|
||||
return dependent
|
||||
return dependent
|
||||
|
||||
async def solve(
|
||||
self,
|
||||
**params: Any,
|
||||
) -> Dict[str, Any]:
|
||||
values: Dict[str, Any] = {}
|
||||
|
||||
async def solve_dependencies(
|
||||
*,
|
||||
_dependent: Dependent,
|
||||
_stack: Optional[AsyncExitStack] = None,
|
||||
_sub_dependents: Optional[List[Dependent]] = None,
|
||||
_dependency_cache: Optional[T_DependencyCache] = None,
|
||||
**params: Any,
|
||||
) -> Tuple[Dict[str, Any], T_DependencyCache]:
|
||||
values: Dict[str, Any] = {}
|
||||
dependency_cache = {} if _dependency_cache is None else _dependency_cache
|
||||
|
||||
# usual dependency
|
||||
for field in _dependent.params:
|
||||
field_info = field.field_info
|
||||
assert isinstance(field_info, Param), "Params must be subclasses of Param"
|
||||
value = field_info._solve(**params)
|
||||
if value == Undefined:
|
||||
value = field.get_default()
|
||||
_, errs_ = field.validate(value, values, loc=(str(field_info), field.alias))
|
||||
if errs_:
|
||||
logger.debug(
|
||||
f"{field_info} "
|
||||
f"type {type(value)} not match depends {_dependent.call} "
|
||||
f"annotation {field._type_display()}, ignored"
|
||||
)
|
||||
raise SkippedException(field, value)
|
||||
else:
|
||||
values[field.name] = value
|
||||
|
||||
# solve sub dependencies
|
||||
sub_dependent: Dependent
|
||||
for sub_dependent in chain(_sub_dependents or tuple(), _dependent.dependencies):
|
||||
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
|
||||
sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key)
|
||||
call = sub_dependent.call
|
||||
|
||||
# solve sub dependency with current cache
|
||||
solved_result = await solve_dependencies(
|
||||
_dependent=sub_dependent, _dependency_cache=dependency_cache, **params
|
||||
)
|
||||
sub_values, sub_dependency_cache = solved_result
|
||||
# update cache?
|
||||
# dependency_cache.update(sub_dependency_cache)
|
||||
|
||||
# run dependency function
|
||||
async with cache_lock:
|
||||
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependent.cache_key]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
assert isinstance(
|
||||
_stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
if is_gen_callable(call):
|
||||
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
|
||||
else:
|
||||
cm = asynccontextmanager(call)(**sub_values)
|
||||
solved = await _stack.enter_async_context(cm)
|
||||
elif is_coroutine_callable(call):
|
||||
solved = await call(**sub_values)
|
||||
for field in self.params:
|
||||
field_info = field.field_info
|
||||
assert isinstance(field_info, Param), "Params must be subclasses of Param"
|
||||
value = await field_info._solve(**params)
|
||||
if value == Undefined:
|
||||
value = field.get_default()
|
||||
_, errs_ = field.validate(value, values, loc=(str(field_info), field.alias))
|
||||
if errs_:
|
||||
logger.debug(
|
||||
f"{field_info} "
|
||||
f"type {type(value)} not match depends {self.call} "
|
||||
f"annotation {field._type_display()}, ignored"
|
||||
)
|
||||
raise SkippedException(field, value)
|
||||
else:
|
||||
solved = await run_sync(call)(**sub_values)
|
||||
values[field.name] = value
|
||||
|
||||
# parameter dependency
|
||||
if sub_dependent.name is not None:
|
||||
values[sub_dependent.name] = solved
|
||||
# save current dependency to cache
|
||||
if sub_dependent.cache_key not in dependency_cache:
|
||||
dependency_cache[sub_dependent.cache_key] = solved
|
||||
for param in self.parameterless:
|
||||
await param._solve(**params)
|
||||
|
||||
return values, dependency_cache
|
||||
|
||||
|
||||
def Depends(dependency: Optional[T_Handler] = None, *, use_cache: bool = True) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
参数依赖注入装饰器
|
||||
|
||||
:参数:
|
||||
|
||||
* ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数。默认为参数的类型注释。
|
||||
* ``use_cache: bool = True``: 是否使用缓存。默认为 ``True``。
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def depend_func() -> Any:
|
||||
return ...
|
||||
|
||||
def depend_gen_func():
|
||||
try:
|
||||
yield ...
|
||||
finally:
|
||||
...
|
||||
|
||||
async def handler(param_name: Any = Depends(depend_func), gen: Any = Depends(depend_gen_func)):
|
||||
...
|
||||
"""
|
||||
return DependsWrapper(dependency=dependency, use_cache=use_cache)
|
||||
return values
|
||||
|
Reference in New Issue
Block a user