mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-26 12:36:40 +00:00 
			
		
		
		
	✨ Feature: 支持子依赖定义 Pydantic 类型校验 (#2310)
This commit is contained in:
		| @@ -1,11 +1,21 @@ | ||||
| import asyncio | ||||
| import inspect | ||||
| from typing_extensions import Annotated | ||||
| from typing_extensions import Self, Annotated, override | ||||
| from contextlib import AsyncExitStack, contextmanager, asynccontextmanager | ||||
| from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast | ||||
| from typing import ( | ||||
|     TYPE_CHECKING, | ||||
|     Any, | ||||
|     Type, | ||||
|     Tuple, | ||||
|     Union, | ||||
|     Literal, | ||||
|     Callable, | ||||
|     Optional, | ||||
|     cast, | ||||
| ) | ||||
|  | ||||
| from pydantic.typing import get_args, get_origin | ||||
| from pydantic.fields import Required, Undefined, ModelField | ||||
| from pydantic.fields import Required, FieldInfo, Undefined, ModelField | ||||
|  | ||||
| from nonebot.dependencies.utils import check_field_type | ||||
| from nonebot.dependencies import Param, Dependent, CustomConfig | ||||
| @@ -24,6 +34,23 @@ if TYPE_CHECKING: | ||||
|     from nonebot.matcher import Matcher | ||||
|     from nonebot.adapters import Bot, Event | ||||
|  | ||||
| EXTRA_FIELD_INFO = ( | ||||
|     "gt", | ||||
|     "lt", | ||||
|     "ge", | ||||
|     "le", | ||||
|     "multiple_of", | ||||
|     "allow_inf_nan", | ||||
|     "max_digits", | ||||
|     "decimal_places", | ||||
|     "min_items", | ||||
|     "max_items", | ||||
|     "unique_items", | ||||
|     "min_length", | ||||
|     "max_length", | ||||
|     "regex", | ||||
| ) | ||||
|  | ||||
|  | ||||
| class DependsInner: | ||||
|     def __init__( | ||||
| @@ -31,26 +58,31 @@ class DependsInner: | ||||
|         dependency: Optional[T_Handler] = None, | ||||
|         *, | ||||
|         use_cache: bool = True, | ||||
|         validate: Union[bool, FieldInfo] = False, | ||||
|     ) -> None: | ||||
|         self.dependency = dependency | ||||
|         self.use_cache = use_cache | ||||
|         self.validate = validate | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         dep = get_name(self.dependency) | ||||
|         cache = "" if self.use_cache else ", use_cache=False" | ||||
|         return f"DependsInner({dep}{cache})" | ||||
|         validate = f", validate={self.validate}" if self.validate else "" | ||||
|         return f"DependsInner({dep}{cache}{validate})" | ||||
|  | ||||
|  | ||||
| def Depends( | ||||
|     dependency: Optional[T_Handler] = None, | ||||
|     *, | ||||
|     use_cache: bool = True, | ||||
|     validate: Union[bool, FieldInfo] = False, | ||||
| ) -> Any: | ||||
|     """子依赖装饰器 | ||||
|  | ||||
|     参数: | ||||
|         dependency: 依赖函数。默认为参数的类型注释。 | ||||
|         use_cache: 是否使用缓存。默认为 `True`。 | ||||
|         validate: 是否使用 Pydantic 类型校验。默认为 `False`。 | ||||
|  | ||||
|     用法: | ||||
|         ```python | ||||
| @@ -70,7 +102,7 @@ def Depends( | ||||
|             ... | ||||
|         ``` | ||||
|     """ | ||||
|     return DependsInner(dependency, use_cache=use_cache) | ||||
|     return DependsInner(dependency, use_cache=use_cache, validate=validate) | ||||
|  | ||||
|  | ||||
| class DependParam(Param): | ||||
| @@ -85,23 +117,44 @@ class DependParam(Param): | ||||
|         return f"Depends({self.extra['dependent']})" | ||||
|  | ||||
|     @classmethod | ||||
|     def _from_field( | ||||
|         cls, sub_dependent: Dependent, use_cache: bool, validate: Union[bool, FieldInfo] | ||||
|     ) -> Self: | ||||
|         kwargs = {} | ||||
|         if isinstance(validate, FieldInfo): | ||||
|             kwargs.update((k, getattr(validate, k)) for k in EXTRA_FIELD_INFO) | ||||
|  | ||||
|         return cls( | ||||
|             Required, | ||||
|             validate=bool(validate), | ||||
|             **kwargs, | ||||
|             dependent=sub_dependent, | ||||
|             use_cache=use_cache, | ||||
|         ) | ||||
|  | ||||
|     @classmethod | ||||
|     @override | ||||
|     def _check_param( | ||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||
|     ) -> Optional["DependParam"]: | ||||
|     ) -> Optional[Self]: | ||||
|         type_annotation, depends_inner = param.annotation, None | ||||
|         # extract type annotation and dependency from Annotated | ||||
|         if get_origin(param.annotation) is Annotated: | ||||
|             type_annotation, *extra_args = get_args(param.annotation) | ||||
|             depends_inner = next( | ||||
|                 (x for x in extra_args if isinstance(x, DependsInner)), None | ||||
|             ) | ||||
|  | ||||
|         # param default value takes higher priority | ||||
|         depends_inner = ( | ||||
|             param.default if isinstance(param.default, DependsInner) else depends_inner | ||||
|         ) | ||||
|         # not a dependent | ||||
|         if depends_inner is None: | ||||
|             return | ||||
|  | ||||
|         dependency: T_Handler | ||||
|         # sub dependency is not specified, use type annotation | ||||
|         if depends_inner.dependency is None: | ||||
|             assert ( | ||||
|                 type_annotation is not inspect.Signature.empty | ||||
| @@ -109,13 +162,18 @@ class DependParam(Param): | ||||
|             dependency = type_annotation | ||||
|         else: | ||||
|             dependency = depends_inner.dependency | ||||
|         # parse sub dependency | ||||
|         sub_dependent = Dependent[Any].parse( | ||||
|             call=dependency, | ||||
|             allow_types=allow_types, | ||||
|         ) | ||||
|         return cls(Required, use_cache=depends_inner.use_cache, dependent=sub_dependent) | ||||
|  | ||||
|         return cls._from_field( | ||||
|             sub_dependent, depends_inner.use_cache, depends_inner.validate | ||||
|         ) | ||||
|  | ||||
|     @classmethod | ||||
|     @override | ||||
|     def _check_parameterless( | ||||
|         cls, value: Any, allow_types: Tuple[Type[Param], ...] | ||||
|     ) -> Optional["Param"]: | ||||
| @@ -124,8 +182,9 @@ class DependParam(Param): | ||||
|             dependent = Dependent[Any].parse( | ||||
|                 call=value.dependency, allow_types=allow_types | ||||
|             ) | ||||
|             return cls(Required, use_cache=value.use_cache, dependent=dependent) | ||||
|             return cls._from_field(dependent, value.use_cache, value.validate) | ||||
|  | ||||
|     @override | ||||
|     async def _solve( | ||||
|         self, | ||||
|         stack: Optional[AsyncExitStack] = None, | ||||
| @@ -169,6 +228,7 @@ class DependParam(Param): | ||||
|             dependency_cache[call] = task | ||||
|             return await task | ||||
|  | ||||
|     @override | ||||
|     async def _check(self, **kwargs: Any) -> None: | ||||
|         # run sub dependent pre-checkers | ||||
|         sub_dependent: Dependent = self.extra["dependent"] | ||||
| @@ -195,9 +255,10 @@ class BotParam(Param): | ||||
|         ) | ||||
|  | ||||
|     @classmethod | ||||
|     @override | ||||
|     def _check_param( | ||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||
|     ) -> Optional["BotParam"]: | ||||
|     ) -> Optional[Self]: | ||||
|         from nonebot.adapters import Bot | ||||
|  | ||||
|         # param type is Bot(s) or subclass(es) of Bot or None | ||||
| @@ -217,9 +278,11 @@ class BotParam(Param): | ||||
|         elif param.annotation == param.empty and param.name == "bot": | ||||
|             return cls(Required) | ||||
|  | ||||
|     @override | ||||
|     async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: | ||||
|         return bot | ||||
|  | ||||
|     @override | ||||
|     async def _check(self, bot: "Bot", **kwargs: Any) -> None: | ||||
|         if checker := self.extra.get("checker"): | ||||
|             check_field_type(checker, bot) | ||||
| @@ -245,9 +308,10 @@ class EventParam(Param): | ||||
|         ) | ||||
|  | ||||
|     @classmethod | ||||
|     @override | ||||
|     def _check_param( | ||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||
|     ) -> Optional["EventParam"]: | ||||
|     ) -> Optional[Self]: | ||||
|         from nonebot.adapters import Event | ||||
|  | ||||
|         # param type is Event(s) or subclass(es) of Event or None | ||||
| @@ -267,9 +331,11 @@ class EventParam(Param): | ||||
|         elif param.annotation == param.empty and param.name == "event": | ||||
|             return cls(Required) | ||||
|  | ||||
|     @override | ||||
|     async def _solve(self, event: "Event", **kwargs: Any) -> Any: | ||||
|         return event | ||||
|  | ||||
|     @override | ||||
|     async def _check(self, event: "Event", **kwargs: Any) -> Any: | ||||
|         if checker := self.extra.get("checker", None): | ||||
|             check_field_type(checker, event) | ||||
| @@ -287,9 +353,10 @@ class StateParam(Param): | ||||
|         return "StateParam()" | ||||
|  | ||||
|     @classmethod | ||||
|     @override | ||||
|     def _check_param( | ||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||
|     ) -> Optional["StateParam"]: | ||||
|     ) -> Optional[Self]: | ||||
|         # param type is T_State | ||||
|         if param.annotation is T_State: | ||||
|             return cls(Required) | ||||
| @@ -297,6 +364,7 @@ class StateParam(Param): | ||||
|         elif param.annotation == param.empty and param.name == "state": | ||||
|             return cls(Required) | ||||
|  | ||||
|     @override | ||||
|     async def _solve(self, state: T_State, **kwargs: Any) -> Any: | ||||
|         return state | ||||
|  | ||||
| @@ -313,9 +381,10 @@ class MatcherParam(Param): | ||||
|         return "MatcherParam()" | ||||
|  | ||||
|     @classmethod | ||||
|     @override | ||||
|     def _check_param( | ||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||
|     ) -> Optional["MatcherParam"]: | ||||
|     ) -> Optional[Self]: | ||||
|         from nonebot.matcher import Matcher | ||||
|  | ||||
|         # param type is Matcher(s) or subclass(es) of Matcher or None | ||||
| @@ -335,9 +404,11 @@ class MatcherParam(Param): | ||||
|         elif param.annotation == param.empty and param.name == "matcher": | ||||
|             return cls(Required) | ||||
|  | ||||
|     @override | ||||
|     async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: | ||||
|         return matcher | ||||
|  | ||||
|     @override | ||||
|     async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any: | ||||
|         if checker := self.extra.get("checker", None): | ||||
|             check_field_type(checker, matcher) | ||||
| @@ -382,9 +453,10 @@ class ArgParam(Param): | ||||
|         return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})" | ||||
|  | ||||
|     @classmethod | ||||
|     @override | ||||
|     def _check_param( | ||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||
|     ) -> Optional["ArgParam"]: | ||||
|     ) -> Optional[Self]: | ||||
|         if isinstance(param.default, ArgInner): | ||||
|             return cls( | ||||
|                 Required, key=param.default.key or param.name, type=param.default.type | ||||
| @@ -419,9 +491,10 @@ class ExceptionParam(Param): | ||||
|         return "ExceptionParam()" | ||||
|  | ||||
|     @classmethod | ||||
|     @override | ||||
|     def _check_param( | ||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||
|     ) -> Optional["ExceptionParam"]: | ||||
|     ) -> Optional[Self]: | ||||
|         # param type is Exception(s) or subclass(es) of Exception or None | ||||
|         if generic_check_issubclass(param.annotation, Exception): | ||||
|             return cls(Required) | ||||
| @@ -429,6 +502,7 @@ class ExceptionParam(Param): | ||||
|         elif param.annotation == param.empty and param.name == "exception": | ||||
|             return cls(Required) | ||||
|  | ||||
|     @override | ||||
|     async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any: | ||||
|         return exception | ||||
|  | ||||
| @@ -445,12 +519,14 @@ class DefaultParam(Param): | ||||
|         return f"DefaultParam(default={self.default!r})" | ||||
|  | ||||
|     @classmethod | ||||
|     @override | ||||
|     def _check_param( | ||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||
|     ) -> Optional["DefaultParam"]: | ||||
|     ) -> Optional[Self]: | ||||
|         if param.default != param.empty: | ||||
|             return cls(param.default) | ||||
|  | ||||
|     @override | ||||
|     async def _solve(self, **kwargs: Any) -> Any: | ||||
|         return Undefined | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user