♻️ rewrite dependency injection system

This commit is contained in:
yanyongyu
2021-12-12 18:19:08 +08:00
parent 6b5a5e53eb
commit 66ba25494a
17 changed files with 728 additions and 733 deletions

View File

@ -5,227 +5,170 @@
该模块实现了依赖注入的定义与处理。
"""
import abc
import inspect
from itertools import chain
from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from typing import Any, Dict, List, Type, Generic, TypeVar, Callable, Optional
from pydantic import BaseConfig
from pydantic.schema import get_annotation_from_field_info
from pydantic.fields import Required, Undefined, ModelField
from pydantic.fields import Required, FieldInfo, Undefined, ModelField
from nonebot.log import logger
from .models import Param as Param
from .utils import get_typed_signature
from .models import Dependent as Dependent
from nonebot.exception import SkippedException
from .models import DependsWrapper as DependsWrapper
from nonebot.typing import T_Handler, T_DependencyCache
from nonebot.utils import (
CacheLock,
run_sync,
is_gen_callable,
run_sync_ctx_manager,
is_async_gen_callable,
is_coroutine_callable,
)
from nonebot.utils import run_sync, is_coroutine_callable
cache_lock = CacheLock()
T = TypeVar("T", bound="Dependent")
R = TypeVar("R")
class Param(abc.ABC, FieldInfo):
@classmethod
def _check_param(
cls, dependent: "Dependent", name: str, param: inspect.Parameter
) -> Optional["Param"]:
return None
@classmethod
def _check_parameterless(
cls, dependent: "Dependent", value: Any
) -> Optional["Param"]:
return None
@abc.abstractmethod
async def _solve(self, **kwargs: Any) -> Any:
raise NotImplementedError
class CustomConfig(BaseConfig):
arbitrary_types_allowed = True
def get_param_sub_dependent(
*, param: inspect.Parameter, allow_types: Optional[List[Type[Param]]] = None
) -> Dependent:
depends: DependsWrapper = param.default
if depends.dependency:
dependency = depends.dependency
else:
dependency = param.annotation
return get_sub_dependant(
depends=depends, dependency=dependency, name=param.name, allow_types=allow_types
)
class Dependent(Generic[R]):
def __init__(
self,
*,
call: Callable[..., Any],
params: Optional[List[ModelField]] = None,
parameterless: Optional[List[Param]] = None,
allow_types: Optional[List[Type[Param]]] = None,
) -> None:
self.call = call
self.params = params or []
self.parameterless = parameterless or []
self.allow_types = allow_types or []
async def __call__(self, **kwargs: Any) -> R:
values = await self.solve(**kwargs)
def get_parameterless_sub_dependant(
*, depends: DependsWrapper, allow_types: Optional[List[Type[Param]]] = None
) -> Dependent:
assert callable(
depends.dependency
), "A parameter-less dependency must have a callable dependency"
return get_sub_dependant(
depends=depends, dependency=depends.dependency, allow_types=allow_types
)
def get_sub_dependant(
*,
depends: DependsWrapper,
dependency: T_Handler,
name: Optional[str] = None,
allow_types: Optional[List[Type[Param]]] = None,
) -> Dependent:
sub_dependant = get_dependent(
call=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types
)
return sub_dependant
def get_dependent(
*,
call: T_Handler,
name: Optional[str] = None,
use_cache: bool = True,
allow_types: Optional[List[Type[Param]]] = None,
) -> Dependent:
signature = get_typed_signature(call)
params = signature.parameters
dependent = Dependent(
call=call, name=name, allow_types=allow_types, use_cache=use_cache
)
for param_name, param in params.items():
if isinstance(param.default, DependsWrapper):
sub_dependent = get_param_sub_dependent(
param=param, allow_types=allow_types
)
dependent.dependencies.append(sub_dependent)
continue
default_value = Required
if param.default != param.empty:
default_value = param.default
if isinstance(default_value, Param):
field_info = default_value
default_value = field_info.default
if is_coroutine_callable(self.call):
return await self.call(**values)
else:
for allow_type in dependent.allow_types:
if allow_type._check(param_name, param):
field_info = allow_type(default_value)
break
return await run_sync(self.call)(**values)
def parse_param(self, name: str, param: inspect.Parameter) -> Param:
for allow_type in self.allow_types:
field_info = allow_type._check_param(self, name, param)
if field_info:
return field_info
else:
raise ValueError(
f"Unknown parameter {name} for function {self.call} with type {param.annotation}"
)
def parse_parameterless(self, value: Any) -> Param:
for allow_type in self.allow_types:
field_info = allow_type._check_parameterless(self, value)
if field_info:
return field_info
else:
raise ValueError(
f"Unknown parameterless {value} for function {self.call} with type {type(value)}"
)
def prepend_parameterless(self, value: Any) -> None:
self.parameterless.insert(0, self.parse_parameterless(value))
def append_parameterless(self, value: Any) -> None:
self.parameterless.append(self.parse_parameterless(value))
@classmethod
def parse(
cls: Type[T],
*,
call: Callable[..., Any],
parameterless: Optional[List[Any]] = None,
allow_types: Optional[List[Type[Param]]] = None,
) -> T:
signature = get_typed_signature(call)
params = signature.parameters
dependent = cls(
call=call,
allow_types=allow_types,
)
parameterless_params = [
dependent.parse_parameterless(param) for param in (parameterless or [])
]
dependent.parameterless.extend(parameterless_params)
for param_name, param in params.items():
default_value = Required
if param.default != param.empty:
default_value = param.default
if isinstance(default_value, Param):
field_info = default_value
default_value = field_info.default
else:
raise ValueError(
f"Unknown parameter {param_name} for function {call} with type {param.annotation}"
field_info = dependent.parse_param(param_name, param)
default_value = field_info.default
annotation: Any = Any
required = default_value == Required
if param.annotation != param.empty:
annotation = param.annotation
annotation = get_annotation_from_field_info(
annotation, field_info, param_name
)
dependent.params.append(
ModelField(
name=param_name,
type_=annotation,
class_validators=None,
model_config=CustomConfig,
default=None if required else default_value,
required=required,
field_info=field_info,
)
annotation: Any = Any
required = default_value == Required
if param.annotation != param.empty:
annotation = param.annotation
annotation = get_annotation_from_field_info(annotation, field_info, param_name)
dependent.params.append(
ModelField(
name=param_name,
type_=annotation,
class_validators=None,
model_config=CustomConfig,
default=None if required else default_value,
required=required,
field_info=field_info,
)
)
return dependent
return dependent
async def solve(
self,
**params: Any,
) -> Dict[str, Any]:
values: Dict[str, Any] = {}
async def solve_dependencies(
*,
_dependent: Dependent,
_stack: Optional[AsyncExitStack] = None,
_sub_dependents: Optional[List[Dependent]] = None,
_dependency_cache: Optional[T_DependencyCache] = None,
**params: Any,
) -> Tuple[Dict[str, Any], T_DependencyCache]:
values: Dict[str, Any] = {}
dependency_cache = {} if _dependency_cache is None else _dependency_cache
# usual dependency
for field in _dependent.params:
field_info = field.field_info
assert isinstance(field_info, Param), "Params must be subclasses of Param"
value = field_info._solve(**params)
if value == Undefined:
value = field.get_default()
_, errs_ = field.validate(value, values, loc=(str(field_info), field.alias))
if errs_:
logger.debug(
f"{field_info} "
f"type {type(value)} not match depends {_dependent.call} "
f"annotation {field._type_display()}, ignored"
)
raise SkippedException(field, value)
else:
values[field.name] = value
# solve sub dependencies
sub_dependent: Dependent
for sub_dependent in chain(_sub_dependents or tuple(), _dependent.dependencies):
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key)
call = sub_dependent.call
# solve sub dependency with current cache
solved_result = await solve_dependencies(
_dependent=sub_dependent, _dependency_cache=dependency_cache, **params
)
sub_values, sub_dependency_cache = solved_result
# update cache?
# dependency_cache.update(sub_dependency_cache)
# run dependency function
async with cache_lock:
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
solved = dependency_cache[sub_dependent.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
_stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(call):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
cm = asynccontextmanager(call)(**sub_values)
solved = await _stack.enter_async_context(cm)
elif is_coroutine_callable(call):
solved = await call(**sub_values)
for field in self.params:
field_info = field.field_info
assert isinstance(field_info, Param), "Params must be subclasses of Param"
value = await field_info._solve(**params)
if value == Undefined:
value = field.get_default()
_, errs_ = field.validate(value, values, loc=(str(field_info), field.alias))
if errs_:
logger.debug(
f"{field_info} "
f"type {type(value)} not match depends {self.call} "
f"annotation {field._type_display()}, ignored"
)
raise SkippedException(field, value)
else:
solved = await run_sync(call)(**sub_values)
values[field.name] = value
# parameter dependency
if sub_dependent.name is not None:
values[sub_dependent.name] = solved
# save current dependency to cache
if sub_dependent.cache_key not in dependency_cache:
dependency_cache[sub_dependent.cache_key] = solved
for param in self.parameterless:
await param._solve(**params)
return values, dependency_cache
def Depends(dependency: Optional[T_Handler] = None, *, use_cache: bool = True) -> Any:
"""
:说明:
参数依赖注入装饰器
:参数:
* ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数。默认为参数的类型注释。
* ``use_cache: bool = True``: 是否使用缓存。默认为 ``True``。
.. code-block:: python
def depend_func() -> Any:
return ...
def depend_gen_func():
try:
yield ...
finally:
...
async def handler(param_name: Any = Depends(depend_func), gen: Any = Depends(depend_gen_func)):
...
"""
return DependsWrapper(dependency=dependency, use_cache=use_cache)
return values