mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-27 00:01:27 +00:00
✨ Feature: 支持子依赖定义 Pydantic 类型校验 (#2310)
This commit is contained in:
@ -1,11 +1,21 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing_extensions import Annotated
|
||||
from typing_extensions import Self, Annotated, override
|
||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Type,
|
||||
Tuple,
|
||||
Union,
|
||||
Literal,
|
||||
Callable,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic.typing import get_args, get_origin
|
||||
from pydantic.fields import Required, Undefined, ModelField
|
||||
from pydantic.fields import Required, FieldInfo, Undefined, ModelField
|
||||
|
||||
from nonebot.dependencies.utils import check_field_type
|
||||
from nonebot.dependencies import Param, Dependent, CustomConfig
|
||||
@ -24,6 +34,23 @@ if TYPE_CHECKING:
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.adapters import Bot, Event
|
||||
|
||||
EXTRA_FIELD_INFO = (
|
||||
"gt",
|
||||
"lt",
|
||||
"ge",
|
||||
"le",
|
||||
"multiple_of",
|
||||
"allow_inf_nan",
|
||||
"max_digits",
|
||||
"decimal_places",
|
||||
"min_items",
|
||||
"max_items",
|
||||
"unique_items",
|
||||
"min_length",
|
||||
"max_length",
|
||||
"regex",
|
||||
)
|
||||
|
||||
|
||||
class DependsInner:
|
||||
def __init__(
|
||||
@ -31,26 +58,31 @@ class DependsInner:
|
||||
dependency: Optional[T_Handler] = None,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
validate: Union[bool, FieldInfo] = False,
|
||||
) -> None:
|
||||
self.dependency = dependency
|
||||
self.use_cache = use_cache
|
||||
self.validate = validate
|
||||
|
||||
def __repr__(self) -> str:
|
||||
dep = get_name(self.dependency)
|
||||
cache = "" if self.use_cache else ", use_cache=False"
|
||||
return f"DependsInner({dep}{cache})"
|
||||
validate = f", validate={self.validate}" if self.validate else ""
|
||||
return f"DependsInner({dep}{cache}{validate})"
|
||||
|
||||
|
||||
def Depends(
|
||||
dependency: Optional[T_Handler] = None,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
validate: Union[bool, FieldInfo] = False,
|
||||
) -> Any:
|
||||
"""子依赖装饰器
|
||||
|
||||
参数:
|
||||
dependency: 依赖函数。默认为参数的类型注释。
|
||||
use_cache: 是否使用缓存。默认为 `True`。
|
||||
validate: 是否使用 Pydantic 类型校验。默认为 `False`。
|
||||
|
||||
用法:
|
||||
```python
|
||||
@ -70,7 +102,7 @@ def Depends(
|
||||
...
|
||||
```
|
||||
"""
|
||||
return DependsInner(dependency, use_cache=use_cache)
|
||||
return DependsInner(dependency, use_cache=use_cache, validate=validate)
|
||||
|
||||
|
||||
class DependParam(Param):
|
||||
@ -85,23 +117,44 @@ class DependParam(Param):
|
||||
return f"Depends({self.extra['dependent']})"
|
||||
|
||||
@classmethod
|
||||
def _from_field(
|
||||
cls, sub_dependent: Dependent, use_cache: bool, validate: Union[bool, FieldInfo]
|
||||
) -> Self:
|
||||
kwargs = {}
|
||||
if isinstance(validate, FieldInfo):
|
||||
kwargs.update((k, getattr(validate, k)) for k in EXTRA_FIELD_INFO)
|
||||
|
||||
return cls(
|
||||
Required,
|
||||
validate=bool(validate),
|
||||
**kwargs,
|
||||
dependent=sub_dependent,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _check_param(
|
||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||
) -> Optional["DependParam"]:
|
||||
) -> Optional[Self]:
|
||||
type_annotation, depends_inner = param.annotation, None
|
||||
# extract type annotation and dependency from Annotated
|
||||
if get_origin(param.annotation) is Annotated:
|
||||
type_annotation, *extra_args = get_args(param.annotation)
|
||||
depends_inner = next(
|
||||
(x for x in extra_args if isinstance(x, DependsInner)), None
|
||||
)
|
||||
|
||||
# param default value takes higher priority
|
||||
depends_inner = (
|
||||
param.default if isinstance(param.default, DependsInner) else depends_inner
|
||||
)
|
||||
# not a dependent
|
||||
if depends_inner is None:
|
||||
return
|
||||
|
||||
dependency: T_Handler
|
||||
# sub dependency is not specified, use type annotation
|
||||
if depends_inner.dependency is None:
|
||||
assert (
|
||||
type_annotation is not inspect.Signature.empty
|
||||
@ -109,13 +162,18 @@ class DependParam(Param):
|
||||
dependency = type_annotation
|
||||
else:
|
||||
dependency = depends_inner.dependency
|
||||
# parse sub dependency
|
||||
sub_dependent = Dependent[Any].parse(
|
||||
call=dependency,
|
||||
allow_types=allow_types,
|
||||
)
|
||||
return cls(Required, use_cache=depends_inner.use_cache, dependent=sub_dependent)
|
||||
|
||||
return cls._from_field(
|
||||
sub_dependent, depends_inner.use_cache, depends_inner.validate
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _check_parameterless(
|
||||
cls, value: Any, allow_types: Tuple[Type[Param], ...]
|
||||
) -> Optional["Param"]:
|
||||
@ -124,8 +182,9 @@ class DependParam(Param):
|
||||
dependent = Dependent[Any].parse(
|
||||
call=value.dependency, allow_types=allow_types
|
||||
)
|
||||
return cls(Required, use_cache=value.use_cache, dependent=dependent)
|
||||
return cls._from_field(dependent, value.use_cache, value.validate)
|
||||
|
||||
@override
|
||||
async def _solve(
|
||||
self,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
@ -169,6 +228,7 @@ class DependParam(Param):
|
||||
dependency_cache[call] = task
|
||||
return await task
|
||||
|
||||
@override
|
||||
async def _check(self, **kwargs: Any) -> None:
|
||||
# run sub dependent pre-checkers
|
||||
sub_dependent: Dependent = self.extra["dependent"]
|
||||
@ -195,9 +255,10 @@ class BotParam(Param):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _check_param(
|
||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||
) -> Optional["BotParam"]:
|
||||
) -> Optional[Self]:
|
||||
from nonebot.adapters import Bot
|
||||
|
||||
# param type is Bot(s) or subclass(es) of Bot or None
|
||||
@ -217,9 +278,11 @@ class BotParam(Param):
|
||||
elif param.annotation == param.empty and param.name == "bot":
|
||||
return cls(Required)
|
||||
|
||||
@override
|
||||
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
|
||||
return bot
|
||||
|
||||
@override
|
||||
async def _check(self, bot: "Bot", **kwargs: Any) -> None:
|
||||
if checker := self.extra.get("checker"):
|
||||
check_field_type(checker, bot)
|
||||
@ -245,9 +308,10 @@ class EventParam(Param):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _check_param(
|
||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||
) -> Optional["EventParam"]:
|
||||
) -> Optional[Self]:
|
||||
from nonebot.adapters import Event
|
||||
|
||||
# param type is Event(s) or subclass(es) of Event or None
|
||||
@ -267,9 +331,11 @@ class EventParam(Param):
|
||||
elif param.annotation == param.empty and param.name == "event":
|
||||
return cls(Required)
|
||||
|
||||
@override
|
||||
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
|
||||
return event
|
||||
|
||||
@override
|
||||
async def _check(self, event: "Event", **kwargs: Any) -> Any:
|
||||
if checker := self.extra.get("checker", None):
|
||||
check_field_type(checker, event)
|
||||
@ -287,9 +353,10 @@ class StateParam(Param):
|
||||
return "StateParam()"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _check_param(
|
||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||
) -> Optional["StateParam"]:
|
||||
) -> Optional[Self]:
|
||||
# param type is T_State
|
||||
if param.annotation is T_State:
|
||||
return cls(Required)
|
||||
@ -297,6 +364,7 @@ class StateParam(Param):
|
||||
elif param.annotation == param.empty and param.name == "state":
|
||||
return cls(Required)
|
||||
|
||||
@override
|
||||
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
|
||||
return state
|
||||
|
||||
@ -313,9 +381,10 @@ class MatcherParam(Param):
|
||||
return "MatcherParam()"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _check_param(
|
||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||
) -> Optional["MatcherParam"]:
|
||||
) -> Optional[Self]:
|
||||
from nonebot.matcher import Matcher
|
||||
|
||||
# param type is Matcher(s) or subclass(es) of Matcher or None
|
||||
@ -335,9 +404,11 @@ class MatcherParam(Param):
|
||||
elif param.annotation == param.empty and param.name == "matcher":
|
||||
return cls(Required)
|
||||
|
||||
@override
|
||||
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||
return matcher
|
||||
|
||||
@override
|
||||
async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||
if checker := self.extra.get("checker", None):
|
||||
check_field_type(checker, matcher)
|
||||
@ -382,9 +453,10 @@ class ArgParam(Param):
|
||||
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _check_param(
|
||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||
) -> Optional["ArgParam"]:
|
||||
) -> Optional[Self]:
|
||||
if isinstance(param.default, ArgInner):
|
||||
return cls(
|
||||
Required, key=param.default.key or param.name, type=param.default.type
|
||||
@ -419,9 +491,10 @@ class ExceptionParam(Param):
|
||||
return "ExceptionParam()"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _check_param(
|
||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||
) -> Optional["ExceptionParam"]:
|
||||
) -> Optional[Self]:
|
||||
# param type is Exception(s) or subclass(es) of Exception or None
|
||||
if generic_check_issubclass(param.annotation, Exception):
|
||||
return cls(Required)
|
||||
@ -429,6 +502,7 @@ class ExceptionParam(Param):
|
||||
elif param.annotation == param.empty and param.name == "exception":
|
||||
return cls(Required)
|
||||
|
||||
@override
|
||||
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
|
||||
return exception
|
||||
|
||||
@ -445,12 +519,14 @@ class DefaultParam(Param):
|
||||
return f"DefaultParam(default={self.default!r})"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _check_param(
|
||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||
) -> Optional["DefaultParam"]:
|
||||
) -> Optional[Self]:
|
||||
if param.default != param.empty:
|
||||
return cls(param.default)
|
||||
|
||||
@override
|
||||
async def _solve(self, **kwargs: Any) -> Any:
|
||||
return Undefined
|
||||
|
||||
|
Reference in New Issue
Block a user