mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-31 15:06:42 +00:00 
			
		
		
		
	✨ Feature: 支持子依赖定义 Pydantic 类型校验 (#2310)
This commit is contained in:
		| @@ -45,6 +45,10 @@ class Param(abc.ABC, FieldInfo): | |||||||
|     继承自 `pydantic.fields.FieldInfo`,用于描述参数信息(不包括参数名)。 |     继承自 `pydantic.fields.FieldInfo`,用于描述参数信息(不包括参数名)。 | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, *args, validate: bool = False, **kwargs: Any) -> None: | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         self.validate = validate | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _check_param( |     def _check_param( | ||||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type["Param"], ...] |         cls, param: inspect.Parameter, allow_types: Tuple[Type["Param"], ...] | ||||||
| @@ -206,10 +210,12 @@ class Dependent(Generic[R]): | |||||||
|             raise |             raise | ||||||
|  |  | ||||||
|     async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any: |     async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any: | ||||||
|         value = await cast(Param, field.field_info)._solve(**params) |         param = cast(Param, field.field_info) | ||||||
|  |         value = await param._solve(**params) | ||||||
|         if value is Undefined: |         if value is Undefined: | ||||||
|             value = field.get_default() |             value = field.get_default() | ||||||
|         return check_field_type(field, value) |         v = check_field_type(field, value) | ||||||
|  |         return v if param.validate else value | ||||||
|  |  | ||||||
|     async def solve(self, **params: Any) -> Dict[str, Any]: |     async def solve(self, **params: Any) -> Dict[str, Any]: | ||||||
|         # solve parameterless |         # solve parameterless | ||||||
|   | |||||||
| @@ -5,7 +5,7 @@ FrontMatter: | |||||||
| """ | """ | ||||||
|  |  | ||||||
| import inspect | import inspect | ||||||
| from typing import Any, Dict, TypeVar, Callable, ForwardRef | from typing import Any, Dict, Callable, ForwardRef | ||||||
|  |  | ||||||
| from loguru import logger | from loguru import logger | ||||||
| from pydantic.fields import ModelField | from pydantic.fields import ModelField | ||||||
| @@ -13,8 +13,6 @@ from pydantic.typing import evaluate_forwardref | |||||||
|  |  | ||||||
| from nonebot.exception import TypeMisMatch | from nonebot.exception import TypeMisMatch | ||||||
|  |  | ||||||
| V = TypeVar("V") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: | def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: | ||||||
|     """获取可调用对象签名""" |     """获取可调用对象签名""" | ||||||
| @@ -49,10 +47,10 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> | |||||||
|     return annotation |     return annotation | ||||||
|  |  | ||||||
|  |  | ||||||
| def check_field_type(field: ModelField, value: V) -> V: | def check_field_type(field: ModelField, value: Any) -> Any: | ||||||
|     """检查字段类型是否匹配""" |     """检查字段类型是否匹配""" | ||||||
|  |  | ||||||
|     _, errs_ = field.validate(value, {}, loc=()) |     v, errs_ = field.validate(value, {}, loc=()) | ||||||
|     if errs_: |     if errs_: | ||||||
|         raise TypeMisMatch(field, value) |         raise TypeMisMatch(field, value) | ||||||
|     return value |     return v | ||||||
|   | |||||||
| @@ -1,11 +1,21 @@ | |||||||
| import asyncio | import asyncio | ||||||
| import inspect | import inspect | ||||||
| from typing_extensions import Annotated | from typing_extensions import Self, Annotated, override | ||||||
| from contextlib import AsyncExitStack, contextmanager, asynccontextmanager | from contextlib import AsyncExitStack, contextmanager, asynccontextmanager | ||||||
| from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast | from typing import ( | ||||||
|  |     TYPE_CHECKING, | ||||||
|  |     Any, | ||||||
|  |     Type, | ||||||
|  |     Tuple, | ||||||
|  |     Union, | ||||||
|  |     Literal, | ||||||
|  |     Callable, | ||||||
|  |     Optional, | ||||||
|  |     cast, | ||||||
|  | ) | ||||||
|  |  | ||||||
| from pydantic.typing import get_args, get_origin | from pydantic.typing import get_args, get_origin | ||||||
| from pydantic.fields import Required, Undefined, ModelField | from pydantic.fields import Required, FieldInfo, Undefined, ModelField | ||||||
|  |  | ||||||
| from nonebot.dependencies.utils import check_field_type | from nonebot.dependencies.utils import check_field_type | ||||||
| from nonebot.dependencies import Param, Dependent, CustomConfig | from nonebot.dependencies import Param, Dependent, CustomConfig | ||||||
| @@ -24,6 +34,23 @@ if TYPE_CHECKING: | |||||||
|     from nonebot.matcher import Matcher |     from nonebot.matcher import Matcher | ||||||
|     from nonebot.adapters import Bot, Event |     from nonebot.adapters import Bot, Event | ||||||
|  |  | ||||||
|  | EXTRA_FIELD_INFO = ( | ||||||
|  |     "gt", | ||||||
|  |     "lt", | ||||||
|  |     "ge", | ||||||
|  |     "le", | ||||||
|  |     "multiple_of", | ||||||
|  |     "allow_inf_nan", | ||||||
|  |     "max_digits", | ||||||
|  |     "decimal_places", | ||||||
|  |     "min_items", | ||||||
|  |     "max_items", | ||||||
|  |     "unique_items", | ||||||
|  |     "min_length", | ||||||
|  |     "max_length", | ||||||
|  |     "regex", | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class DependsInner: | class DependsInner: | ||||||
|     def __init__( |     def __init__( | ||||||
| @@ -31,26 +58,31 @@ class DependsInner: | |||||||
|         dependency: Optional[T_Handler] = None, |         dependency: Optional[T_Handler] = None, | ||||||
|         *, |         *, | ||||||
|         use_cache: bool = True, |         use_cache: bool = True, | ||||||
|  |         validate: Union[bool, FieldInfo] = False, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         self.dependency = dependency |         self.dependency = dependency | ||||||
|         self.use_cache = use_cache |         self.use_cache = use_cache | ||||||
|  |         self.validate = validate | ||||||
|  |  | ||||||
|     def __repr__(self) -> str: |     def __repr__(self) -> str: | ||||||
|         dep = get_name(self.dependency) |         dep = get_name(self.dependency) | ||||||
|         cache = "" if self.use_cache else ", use_cache=False" |         cache = "" if self.use_cache else ", use_cache=False" | ||||||
|         return f"DependsInner({dep}{cache})" |         validate = f", validate={self.validate}" if self.validate else "" | ||||||
|  |         return f"DependsInner({dep}{cache}{validate})" | ||||||
|  |  | ||||||
|  |  | ||||||
| def Depends( | def Depends( | ||||||
|     dependency: Optional[T_Handler] = None, |     dependency: Optional[T_Handler] = None, | ||||||
|     *, |     *, | ||||||
|     use_cache: bool = True, |     use_cache: bool = True, | ||||||
|  |     validate: Union[bool, FieldInfo] = False, | ||||||
| ) -> Any: | ) -> Any: | ||||||
|     """子依赖装饰器 |     """子依赖装饰器 | ||||||
|  |  | ||||||
|     参数: |     参数: | ||||||
|         dependency: 依赖函数。默认为参数的类型注释。 |         dependency: 依赖函数。默认为参数的类型注释。 | ||||||
|         use_cache: 是否使用缓存。默认为 `True`。 |         use_cache: 是否使用缓存。默认为 `True`。 | ||||||
|  |         validate: 是否使用 Pydantic 类型校验。默认为 `False`。 | ||||||
|  |  | ||||||
|     用法: |     用法: | ||||||
|         ```python |         ```python | ||||||
| @@ -70,7 +102,7 @@ def Depends( | |||||||
|             ... |             ... | ||||||
|         ``` |         ``` | ||||||
|     """ |     """ | ||||||
|     return DependsInner(dependency, use_cache=use_cache) |     return DependsInner(dependency, use_cache=use_cache, validate=validate) | ||||||
|  |  | ||||||
|  |  | ||||||
| class DependParam(Param): | class DependParam(Param): | ||||||
| @@ -85,23 +117,44 @@ class DependParam(Param): | |||||||
|         return f"Depends({self.extra['dependent']})" |         return f"Depends({self.extra['dependent']})" | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|  |     def _from_field( | ||||||
|  |         cls, sub_dependent: Dependent, use_cache: bool, validate: Union[bool, FieldInfo] | ||||||
|  |     ) -> Self: | ||||||
|  |         kwargs = {} | ||||||
|  |         if isinstance(validate, FieldInfo): | ||||||
|  |             kwargs.update((k, getattr(validate, k)) for k in EXTRA_FIELD_INFO) | ||||||
|  |  | ||||||
|  |         return cls( | ||||||
|  |             Required, | ||||||
|  |             validate=bool(validate), | ||||||
|  |             **kwargs, | ||||||
|  |             dependent=sub_dependent, | ||||||
|  |             use_cache=use_cache, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     @override | ||||||
|     def _check_param( |     def _check_param( | ||||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] |         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||||
|     ) -> Optional["DependParam"]: |     ) -> Optional[Self]: | ||||||
|         type_annotation, depends_inner = param.annotation, None |         type_annotation, depends_inner = param.annotation, None | ||||||
|  |         # extract type annotation and dependency from Annotated | ||||||
|         if get_origin(param.annotation) is Annotated: |         if get_origin(param.annotation) is Annotated: | ||||||
|             type_annotation, *extra_args = get_args(param.annotation) |             type_annotation, *extra_args = get_args(param.annotation) | ||||||
|             depends_inner = next( |             depends_inner = next( | ||||||
|                 (x for x in extra_args if isinstance(x, DependsInner)), None |                 (x for x in extra_args if isinstance(x, DependsInner)), None | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |         # param default value takes higher priority | ||||||
|         depends_inner = ( |         depends_inner = ( | ||||||
|             param.default if isinstance(param.default, DependsInner) else depends_inner |             param.default if isinstance(param.default, DependsInner) else depends_inner | ||||||
|         ) |         ) | ||||||
|  |         # not a dependent | ||||||
|         if depends_inner is None: |         if depends_inner is None: | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         dependency: T_Handler |         dependency: T_Handler | ||||||
|  |         # sub dependency is not specified, use type annotation | ||||||
|         if depends_inner.dependency is None: |         if depends_inner.dependency is None: | ||||||
|             assert ( |             assert ( | ||||||
|                 type_annotation is not inspect.Signature.empty |                 type_annotation is not inspect.Signature.empty | ||||||
| @@ -109,13 +162,18 @@ class DependParam(Param): | |||||||
|             dependency = type_annotation |             dependency = type_annotation | ||||||
|         else: |         else: | ||||||
|             dependency = depends_inner.dependency |             dependency = depends_inner.dependency | ||||||
|  |         # parse sub dependency | ||||||
|         sub_dependent = Dependent[Any].parse( |         sub_dependent = Dependent[Any].parse( | ||||||
|             call=dependency, |             call=dependency, | ||||||
|             allow_types=allow_types, |             allow_types=allow_types, | ||||||
|         ) |         ) | ||||||
|         return cls(Required, use_cache=depends_inner.use_cache, dependent=sub_dependent) |  | ||||||
|  |         return cls._from_field( | ||||||
|  |             sub_dependent, depends_inner.use_cache, depends_inner.validate | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|  |     @override | ||||||
|     def _check_parameterless( |     def _check_parameterless( | ||||||
|         cls, value: Any, allow_types: Tuple[Type[Param], ...] |         cls, value: Any, allow_types: Tuple[Type[Param], ...] | ||||||
|     ) -> Optional["Param"]: |     ) -> Optional["Param"]: | ||||||
| @@ -124,8 +182,9 @@ class DependParam(Param): | |||||||
|             dependent = Dependent[Any].parse( |             dependent = Dependent[Any].parse( | ||||||
|                 call=value.dependency, allow_types=allow_types |                 call=value.dependency, allow_types=allow_types | ||||||
|             ) |             ) | ||||||
|             return cls(Required, use_cache=value.use_cache, dependent=dependent) |             return cls._from_field(dependent, value.use_cache, value.validate) | ||||||
|  |  | ||||||
|  |     @override | ||||||
|     async def _solve( |     async def _solve( | ||||||
|         self, |         self, | ||||||
|         stack: Optional[AsyncExitStack] = None, |         stack: Optional[AsyncExitStack] = None, | ||||||
| @@ -169,6 +228,7 @@ class DependParam(Param): | |||||||
|             dependency_cache[call] = task |             dependency_cache[call] = task | ||||||
|             return await task |             return await task | ||||||
|  |  | ||||||
|  |     @override | ||||||
|     async def _check(self, **kwargs: Any) -> None: |     async def _check(self, **kwargs: Any) -> None: | ||||||
|         # run sub dependent pre-checkers |         # run sub dependent pre-checkers | ||||||
|         sub_dependent: Dependent = self.extra["dependent"] |         sub_dependent: Dependent = self.extra["dependent"] | ||||||
| @@ -195,9 +255,10 @@ class BotParam(Param): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|  |     @override | ||||||
|     def _check_param( |     def _check_param( | ||||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] |         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||||
|     ) -> Optional["BotParam"]: |     ) -> Optional[Self]: | ||||||
|         from nonebot.adapters import Bot |         from nonebot.adapters import Bot | ||||||
|  |  | ||||||
|         # param type is Bot(s) or subclass(es) of Bot or None |         # param type is Bot(s) or subclass(es) of Bot or None | ||||||
| @@ -217,9 +278,11 @@ class BotParam(Param): | |||||||
|         elif param.annotation == param.empty and param.name == "bot": |         elif param.annotation == param.empty and param.name == "bot": | ||||||
|             return cls(Required) |             return cls(Required) | ||||||
|  |  | ||||||
|  |     @override | ||||||
|     async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: |     async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: | ||||||
|         return bot |         return bot | ||||||
|  |  | ||||||
|  |     @override | ||||||
|     async def _check(self, bot: "Bot", **kwargs: Any) -> None: |     async def _check(self, bot: "Bot", **kwargs: Any) -> None: | ||||||
|         if checker := self.extra.get("checker"): |         if checker := self.extra.get("checker"): | ||||||
|             check_field_type(checker, bot) |             check_field_type(checker, bot) | ||||||
| @@ -245,9 +308,10 @@ class EventParam(Param): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|  |     @override | ||||||
|     def _check_param( |     def _check_param( | ||||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] |         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||||
|     ) -> Optional["EventParam"]: |     ) -> Optional[Self]: | ||||||
|         from nonebot.adapters import Event |         from nonebot.adapters import Event | ||||||
|  |  | ||||||
|         # param type is Event(s) or subclass(es) of Event or None |         # param type is Event(s) or subclass(es) of Event or None | ||||||
| @@ -267,9 +331,11 @@ class EventParam(Param): | |||||||
|         elif param.annotation == param.empty and param.name == "event": |         elif param.annotation == param.empty and param.name == "event": | ||||||
|             return cls(Required) |             return cls(Required) | ||||||
|  |  | ||||||
|  |     @override | ||||||
|     async def _solve(self, event: "Event", **kwargs: Any) -> Any: |     async def _solve(self, event: "Event", **kwargs: Any) -> Any: | ||||||
|         return event |         return event | ||||||
|  |  | ||||||
|  |     @override | ||||||
|     async def _check(self, event: "Event", **kwargs: Any) -> Any: |     async def _check(self, event: "Event", **kwargs: Any) -> Any: | ||||||
|         if checker := self.extra.get("checker", None): |         if checker := self.extra.get("checker", None): | ||||||
|             check_field_type(checker, event) |             check_field_type(checker, event) | ||||||
| @@ -287,9 +353,10 @@ class StateParam(Param): | |||||||
|         return "StateParam()" |         return "StateParam()" | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|  |     @override | ||||||
|     def _check_param( |     def _check_param( | ||||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] |         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||||
|     ) -> Optional["StateParam"]: |     ) -> Optional[Self]: | ||||||
|         # param type is T_State |         # param type is T_State | ||||||
|         if param.annotation is T_State: |         if param.annotation is T_State: | ||||||
|             return cls(Required) |             return cls(Required) | ||||||
| @@ -297,6 +364,7 @@ class StateParam(Param): | |||||||
|         elif param.annotation == param.empty and param.name == "state": |         elif param.annotation == param.empty and param.name == "state": | ||||||
|             return cls(Required) |             return cls(Required) | ||||||
|  |  | ||||||
|  |     @override | ||||||
|     async def _solve(self, state: T_State, **kwargs: Any) -> Any: |     async def _solve(self, state: T_State, **kwargs: Any) -> Any: | ||||||
|         return state |         return state | ||||||
|  |  | ||||||
| @@ -313,9 +381,10 @@ class MatcherParam(Param): | |||||||
|         return "MatcherParam()" |         return "MatcherParam()" | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|  |     @override | ||||||
|     def _check_param( |     def _check_param( | ||||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] |         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||||
|     ) -> Optional["MatcherParam"]: |     ) -> Optional[Self]: | ||||||
|         from nonebot.matcher import Matcher |         from nonebot.matcher import Matcher | ||||||
|  |  | ||||||
|         # param type is Matcher(s) or subclass(es) of Matcher or None |         # param type is Matcher(s) or subclass(es) of Matcher or None | ||||||
| @@ -335,9 +404,11 @@ class MatcherParam(Param): | |||||||
|         elif param.annotation == param.empty and param.name == "matcher": |         elif param.annotation == param.empty and param.name == "matcher": | ||||||
|             return cls(Required) |             return cls(Required) | ||||||
|  |  | ||||||
|  |     @override | ||||||
|     async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: |     async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: | ||||||
|         return matcher |         return matcher | ||||||
|  |  | ||||||
|  |     @override | ||||||
|     async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any: |     async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any: | ||||||
|         if checker := self.extra.get("checker", None): |         if checker := self.extra.get("checker", None): | ||||||
|             check_field_type(checker, matcher) |             check_field_type(checker, matcher) | ||||||
| @@ -382,9 +453,10 @@ class ArgParam(Param): | |||||||
|         return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})" |         return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})" | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|  |     @override | ||||||
|     def _check_param( |     def _check_param( | ||||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] |         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||||
|     ) -> Optional["ArgParam"]: |     ) -> Optional[Self]: | ||||||
|         if isinstance(param.default, ArgInner): |         if isinstance(param.default, ArgInner): | ||||||
|             return cls( |             return cls( | ||||||
|                 Required, key=param.default.key or param.name, type=param.default.type |                 Required, key=param.default.key or param.name, type=param.default.type | ||||||
| @@ -419,9 +491,10 @@ class ExceptionParam(Param): | |||||||
|         return "ExceptionParam()" |         return "ExceptionParam()" | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|  |     @override | ||||||
|     def _check_param( |     def _check_param( | ||||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] |         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||||
|     ) -> Optional["ExceptionParam"]: |     ) -> Optional[Self]: | ||||||
|         # param type is Exception(s) or subclass(es) of Exception or None |         # param type is Exception(s) or subclass(es) of Exception or None | ||||||
|         if generic_check_issubclass(param.annotation, Exception): |         if generic_check_issubclass(param.annotation, Exception): | ||||||
|             return cls(Required) |             return cls(Required) | ||||||
| @@ -429,6 +502,7 @@ class ExceptionParam(Param): | |||||||
|         elif param.annotation == param.empty and param.name == "exception": |         elif param.annotation == param.empty and param.name == "exception": | ||||||
|             return cls(Required) |             return cls(Required) | ||||||
|  |  | ||||||
|  |     @override | ||||||
|     async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any: |     async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any: | ||||||
|         return exception |         return exception | ||||||
|  |  | ||||||
| @@ -445,12 +519,14 @@ class DefaultParam(Param): | |||||||
|         return f"DefaultParam(default={self.default!r})" |         return f"DefaultParam(default={self.default!r})" | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|  |     @override | ||||||
|     def _check_param( |     def _check_param( | ||||||
|         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] |         cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] | ||||||
|     ) -> Optional["DefaultParam"]: |     ) -> Optional[Self]: | ||||||
|         if param.default != param.empty: |         if param.default != param.empty: | ||||||
|             return cls(param.default) |             return cls(param.default) | ||||||
|  |  | ||||||
|  |     @override | ||||||
|     async def _solve(self, **kwargs: Any) -> Any: |     async def _solve(self, **kwargs: Any) -> Any: | ||||||
|         return Undefined |         return Undefined | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,7 +1,10 @@ | |||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from typing_extensions import Annotated | from typing_extensions import Annotated | ||||||
|  |  | ||||||
|  | from pydantic import Field | ||||||
|  |  | ||||||
| from nonebot import on_message | from nonebot import on_message | ||||||
|  | from nonebot.adapters import Bot | ||||||
| from nonebot.params import Depends | from nonebot.params import Depends | ||||||
|  |  | ||||||
| test_depends = on_message() | test_depends = on_message() | ||||||
| @@ -33,6 +36,14 @@ class ClassDependency: | |||||||
|     y: int = Depends(gen_async) |     y: int = Depends(gen_async) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FooBot(Bot): | ||||||
|  |     ... | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def sub_bot(b: FooBot) -> FooBot: | ||||||
|  |     return b | ||||||
|  |  | ||||||
|  |  | ||||||
| # test parameterless | # test parameterless | ||||||
| @test_depends.handle(parameterless=[Depends(parameterless)]) | @test_depends.handle(parameterless=[Depends(parameterless)]) | ||||||
| async def depends(x: int = Depends(dependency)): | async def depends(x: int = Depends(dependency)): | ||||||
| @@ -46,19 +57,46 @@ async def depends_cache(y: int = Depends(dependency, use_cache=True)): | |||||||
|     return y |     return y | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # test class dependency | ||||||
| async def class_depend(c: ClassDependency = Depends()): | async def class_depend(c: ClassDependency = Depends()): | ||||||
|     return c |     return c | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # test annotated dependency | ||||||
| async def annotated_depend(x: Annotated[int, Depends(dependency)]): | async def annotated_depend(x: Annotated[int, Depends(dependency)]): | ||||||
|     return x |     return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # test annotated class dependency | ||||||
| async def annotated_class_depend(c: Annotated[ClassDependency, Depends()]): | async def annotated_class_depend(c: Annotated[ClassDependency, Depends()]): | ||||||
|     return c |     return c | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # test dependency priority | ||||||
| async def annotated_prior_depend( | async def annotated_prior_depend( | ||||||
|     x: Annotated[int, Depends(lambda: 2)] = Depends(dependency) |     x: Annotated[int, Depends(lambda: 2)] = Depends(dependency) | ||||||
| ): | ): | ||||||
|     return x |     return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # test sub dependency type mismatch | ||||||
|  | async def sub_type_mismatch(b: FooBot = Depends(sub_bot)): | ||||||
|  |     return b | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # test type validate | ||||||
|  | async def validate(x: int = Depends(lambda: "1", validate=True)): | ||||||
|  |     return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def validate_fail(x: int = Depends(lambda: "not_number", validate=True)): | ||||||
|  |     return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # test FieldInfo validate | ||||||
|  | async def validate_field(x: int = Depends(lambda: "1", validate=Field(gt=0))): | ||||||
|  |     return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def validate_field_fail(x: int = Depends(lambda: "0", validate=Field(gt=0))): | ||||||
|  |     return x | ||||||
|   | |||||||
| @@ -42,9 +42,14 @@ async def test_depend(app: App): | |||||||
|         ClassDependency, |         ClassDependency, | ||||||
|         runned, |         runned, | ||||||
|         depends, |         depends, | ||||||
|  |         validate, | ||||||
|         class_depend, |         class_depend, | ||||||
|         test_depends, |         test_depends, | ||||||
|  |         validate_fail, | ||||||
|  |         validate_field, | ||||||
|         annotated_depend, |         annotated_depend, | ||||||
|  |         sub_type_mismatch, | ||||||
|  |         validate_field_fail, | ||||||
|         annotated_class_depend, |         annotated_class_depend, | ||||||
|         annotated_prior_depend, |         annotated_prior_depend, | ||||||
|     ) |     ) | ||||||
| @@ -62,8 +67,7 @@ async def test_depend(app: App): | |||||||
|         event_next = make_fake_event()() |         event_next = make_fake_event()() | ||||||
|         ctx.receive_event(bot, event_next) |         ctx.receive_event(bot, event_next) | ||||||
|  |  | ||||||
|     assert len(runned) == 2 |     assert runned == [1, 1] | ||||||
|     assert runned[0] == runned[1] == 1 |  | ||||||
|  |  | ||||||
|     runned.clear() |     runned.clear() | ||||||
|  |  | ||||||
| @@ -84,6 +88,29 @@ async def test_depend(app: App): | |||||||
|     ) as ctx: |     ) as ctx: | ||||||
|         ctx.should_return(ClassDependency(x=1, y=2)) |         ctx.should_return(ClassDependency(x=1, y=2)) | ||||||
|  |  | ||||||
|  |     with pytest.raises(TypeMisMatch):  # noqa: PT012 | ||||||
|  |         async with app.test_dependent( | ||||||
|  |             sub_type_mismatch, allow_types=[DependParam, BotParam] | ||||||
|  |         ) as ctx: | ||||||
|  |             bot = ctx.create_bot() | ||||||
|  |             ctx.pass_params(bot=bot) | ||||||
|  |  | ||||||
|  |     async with app.test_dependent(validate, allow_types=[DependParam]) as ctx: | ||||||
|  |         ctx.should_return(1) | ||||||
|  |  | ||||||
|  |     with pytest.raises(TypeMisMatch): | ||||||
|  |         async with app.test_dependent(validate_fail, allow_types=[DependParam]) as ctx: | ||||||
|  |             ... | ||||||
|  |  | ||||||
|  |     async with app.test_dependent(validate_field, allow_types=[DependParam]) as ctx: | ||||||
|  |         ctx.should_return(1) | ||||||
|  |  | ||||||
|  |     with pytest.raises(TypeMisMatch): | ||||||
|  |         async with app.test_dependent( | ||||||
|  |             validate_field_fail, allow_types=[DependParam] | ||||||
|  |         ) as ctx: | ||||||
|  |             ... | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
| async def test_bot(app: App): | async def test_bot(app: App): | ||||||
|   | |||||||
| @@ -353,6 +353,80 @@ async def _(x: int = Depends(random_result, use_cache=False)): | |||||||
| 缓存的生命周期与当前接收到的事件相同。接收到事件后,子依赖在首次执行时缓存,在该事件处理完成后,缓存就会被清除。 | 缓存的生命周期与当前接收到的事件相同。接收到事件后,子依赖在首次执行时缓存,在该事件处理完成后,缓存就会被清除。 | ||||||
| ::: | ::: | ||||||
|  |  | ||||||
|  | ### 类型转换与校验 | ||||||
|  |  | ||||||
|  | 在依赖注入系统中,我们可以对子依赖的返回值进行自动类型转换与校验。这个功能由 Pydantic 支持,因此我们通过参数类型注解自动使用 Pydantic 支持的类型转换。例如: | ||||||
|  |  | ||||||
|  | <Tabs groupId="python"> | ||||||
|  |   <TabItem value="3.9" label="Python 3.9+" default> | ||||||
|  |  | ||||||
|  | ```python {6,9} | ||||||
|  | from typing import Annotated | ||||||
|  |  | ||||||
|  | from nonebot.params import Depends | ||||||
|  | from nonebot.adapters import Event | ||||||
|  |  | ||||||
|  | def get_user_id(event: Event) -> str: | ||||||
|  |     return event.get_user_id() | ||||||
|  |  | ||||||
|  | async def _(user_id: Annotated[int, Depends(get_user_id, validate=True)]): | ||||||
|  |     print(user_id) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  |   </TabItem> | ||||||
|  |   <TabItem value="3.8" label="Python 3.8+"> | ||||||
|  |  | ||||||
|  | ```python {4,7} | ||||||
|  | from nonebot.params import Depends | ||||||
|  | from nonebot.adapters import Event | ||||||
|  |  | ||||||
|  | def get_user_id(event: Event) -> str: | ||||||
|  |     return event.get_user_id() | ||||||
|  |  | ||||||
|  | async def _(user_id: int = Depends(get_user_id, validate=True)): | ||||||
|  |     print(user_id) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  |   </TabItem> | ||||||
|  | </Tabs> | ||||||
|  |  | ||||||
|  | 在进行类型自动转换的同时,Pydantic 还支持对数据进行更多的限制,如:大于、小于、长度等。使用方法如下: | ||||||
|  |  | ||||||
|  | <Tabs groupId="python"> | ||||||
|  |   <TabItem value="3.9" label="Python 3.9+" default> | ||||||
|  |  | ||||||
|  | ```python {7,10} | ||||||
|  | from typing import Annotated | ||||||
|  |  | ||||||
|  | from pydantic import Field | ||||||
|  | from nonebot.params import Depends | ||||||
|  | from nonebot.adapters import Event | ||||||
|  |  | ||||||
|  | def get_user_id(event: Event) -> str: | ||||||
|  |     return event.get_user_id() | ||||||
|  |  | ||||||
|  | async def _(user_id: Annotated[int, Depends(get_user_id, validate=Field(gt=100))]): | ||||||
|  |     print(user_id) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  |   </TabItem> | ||||||
|  |   <TabItem value="3.8" label="Python 3.8+"> | ||||||
|  |  | ||||||
|  | ```python {5,8} | ||||||
|  | from pydantic import Field | ||||||
|  | from nonebot.params import Depends | ||||||
|  | from nonebot.adapters import Event | ||||||
|  |  | ||||||
|  | def get_user_id(event: Event) -> str: | ||||||
|  |     return event.get_user_id() | ||||||
|  |  | ||||||
|  | async def _(user_id: int = Depends(get_user_id, validate=Field(gt=100))): | ||||||
|  |     print(user_id) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  |   </TabItem> | ||||||
|  | </Tabs> | ||||||
|  |  | ||||||
| ### 类作为依赖 | ### 类作为依赖 | ||||||
|  |  | ||||||
| 在前面的事例中,我们使用了函数作为子依赖。实际上,我们还可以使用类作为依赖。当我们在实例化一个类的时候,其实我们就在调用它,类本身也是一个可调用对象。例如: | 在前面的事例中,我们使用了函数作为子依赖。实际上,我们还可以使用类作为依赖。当我们在实例化一个类的时候,其实我们就在调用它,类本身也是一个可调用对象。例如: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user