♻️ improve dependent structure (#1227)

This commit is contained in:
Ju4tCode
2022-09-07 09:59:05 +08:00
committed by GitHub
parent 595c64e760
commit a0b186aff3
3 changed files with 191 additions and 210 deletions

View File

@ -6,15 +6,19 @@ FrontMatter:
"""
import abc
import asyncio
import inspect
from dataclasses import field, dataclass
from typing import (
Any,
Dict,
List,
Type,
Tuple,
Generic,
TypeVar,
Callable,
Iterable,
Optional,
Awaitable,
cast,
@ -25,7 +29,6 @@ from pydantic.schema import get_annotation_from_field_info
from pydantic.fields import Required, FieldInfo, Undefined, ModelField
from nonebot.log import logger
from nonebot.exception import TypeMisMatch
from nonebot.typing import _DependentCallable
from nonebot.utils import run_sync, is_coroutine_callable
@ -43,25 +46,29 @@ class Param(abc.ABC, FieldInfo):
@classmethod
def _check_param(
cls, dependent: "Dependent", name: str, param: inspect.Parameter
cls, param: inspect.Parameter, allow_types: Tuple[Type["Param"], ...]
) -> Optional["Param"]:
return None
return
@classmethod
def _check_parameterless(
cls, dependent: "Dependent", value: Any
cls, value: Any, allow_types: Tuple[Type["Param"], ...]
) -> Optional["Param"]:
return None
return
@abc.abstractmethod
async def _solve(self, **kwargs: Any) -> Any:
raise NotImplementedError
async def _check(self, **kwargs: Any) -> None:
return
class CustomConfig(BaseConfig):
arbitrary_types_allowed = True
@dataclass(frozen=True)
class Dependent(Generic[R]):
"""依赖注入容器
@ -73,76 +80,34 @@ class Dependent(Generic[R]):
allow_types: 允许的参数类型
"""
def __init__(
self,
*,
call: _DependentCallable[R],
pre_checkers: Optional[List[Param]] = None,
params: Optional[List[ModelField]] = None,
parameterless: Optional[List[Param]] = None,
allow_types: Optional[List[Type[Param]]] = None,
) -> None:
self.call = call
self.pre_checkers = pre_checkers or []
self.params = params or []
self.parameterless = parameterless or []
self.allow_types = allow_types or []
call: _DependentCallable[R]
params: Tuple[ModelField] = field(default_factory=tuple)
parameterless: Tuple[Param] = field(default_factory=tuple)
def __repr__(self) -> str:
return (
f"<Dependent call={self.call}, params={self.params},"
f" parameterless={self.parameterless}>"
)
def __str__(self) -> str:
return self.__repr__()
return f"<Dependent call={self.call}>"
async def __call__(self, **kwargs: Any) -> R:
# do pre-check
await self.check(**kwargs)
# solve param values
values = await self.solve(**kwargs)
# call function
if is_coroutine_callable(self.call):
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
else:
return await run_sync(cast(Callable[..., R], self.call))(**values)
def parse_param(self, name: str, param: inspect.Parameter) -> Param:
for allow_type in self.allow_types:
if field_info := allow_type._check_param(self, name, param):
return field_info
raise ValueError(
f"Unknown parameter {name} for function {self.call} with type {param.annotation}"
)
@staticmethod
def parse_params(
call: _DependentCallable[R], allow_types: Tuple[Type[Param], ...]
) -> Tuple[ModelField]:
fields: List[ModelField] = []
params = get_typed_signature(call).parameters.values()
def parse_parameterless(self, value: Any) -> Param:
for allow_type in self.allow_types:
if field_info := allow_type._check_parameterless(self, value):
return field_info
raise ValueError(
f"Unknown parameterless {value} for function {self.call} with type {type(value)}"
)
def prepend_parameterless(self, value: Any) -> None:
self.parameterless.insert(0, self.parse_parameterless(value))
def append_parameterless(self, value: Any) -> None:
self.parameterless.append(self.parse_parameterless(value))
@classmethod
def parse(
cls,
*,
call: _DependentCallable[R],
parameterless: Optional[List[Any]] = None,
allow_types: Optional[List[Type[Param]]] = None,
) -> "Dependent[R]":
signature = get_typed_signature(call)
params = signature.parameters
dependent = cls(
call=call,
allow_types=allow_types,
)
for param_name, param in params.items():
for param in params:
default_value = Required
if param.default != param.empty:
default_value = param.default
@ -150,7 +115,13 @@ class Dependent(Generic[R]):
if isinstance(default_value, Param):
field_info = default_value
else:
field_info = dependent.parse_param(param_name, param)
for allow_type in allow_types:
if field_info := allow_type._check_param(param, allow_types):
break
else:
raise ValueError(
f"Unknown parameter {param.name} for function {call} with type {param.annotation}"
)
default_value = field_info.default
@ -159,11 +130,12 @@ class Dependent(Generic[R]):
if param.annotation != param.empty:
annotation = param.annotation
annotation = get_annotation_from_field_info(
annotation, field_info, param_name
annotation, field_info, param.name
)
dependent.params.append(
fields.append(
ModelField(
name=param_name,
name=param.name,
type_=annotation,
class_validators=None,
model_config=CustomConfig,
@ -173,49 +145,69 @@ class Dependent(Generic[R]):
)
)
parameterless_params = [
dependent.parse_parameterless(param) for param in (parameterless or [])
]
dependent.parameterless.extend(parameterless_params)
return tuple(fields)
@staticmethod
def parse_parameterless(
parameterless: Tuple[Any, ...], allow_types: Tuple[Type[Param], ...]
) -> Tuple[Param, ...]:
parameterless_params: List[Param] = []
for value in parameterless:
for allow_type in allow_types:
if param := allow_type._check_parameterless(value, allow_types):
break
else:
raise ValueError(f"Unknown parameterless {value}")
parameterless_params.append(param)
return tuple(parameterless_params)
@classmethod
def parse(
cls,
*,
call: _DependentCallable[R],
parameterless: Optional[Iterable[Any]] = None,
allow_types: Iterable[Type[Param]],
) -> "Dependent[R]":
allow_types = tuple(allow_types)
params = cls.parse_params(call, allow_types)
parameterless_params = (
tuple()
if parameterless is None
else cls.parse_parameterless(tuple(parameterless), allow_types)
)
logger.trace(
f"Parsed dependent with call={call}, "
f"params={[param.field_info for param in dependent.params]}, "
f"parameterless={dependent.parameterless}"
f"params={params}, "
f"parameterless={parameterless_params}"
)
return dependent
return cls(call, params, parameterless_params)
async def solve(
self,
**params: Any,
) -> Dict[str, Any]:
values: Dict[str, Any] = {}
async def check(self, **params: Any) -> None:
await asyncio.gather(*(param._check(**params) for param in self.parameterless))
await asyncio.gather(
*(cast(Param, param.field_info)._check(**params) for param in self.params)
)
for checker in self.pre_checkers:
await checker._solve(**params)
async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any:
value = await cast(Param, field.field_info)._solve(**params)
if value is Undefined:
value = field.get_default()
return check_field_type(field, value)
async def solve(self, **params: Any) -> Dict[str, Any]:
# solve parameterless
for param in self.parameterless:
await param._solve(**params)
for field in self.params:
field_info = field.field_info
assert isinstance(field_info, Param), "Params must be subclasses of Param"
value = await field_info._solve(**params)
if value is Undefined:
value = field.get_default()
try:
values[field.name] = check_field_type(field, value)
except TypeMisMatch:
logger.debug(
f"{field_info} "
f"type {type(value)} not match depends {self.call} "
f"annotation {field._type_display()}, ignored"
)
raise
return values
# solve param values
values = await asyncio.gather(
*(self._solve_field(field, params) for field in self.params)
)
return {field.name: value for field, value in zip(self.params, values)}
__autodoc__ = {"CustomConfig": False}