mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-27 16:21:28 +00:00
✨ Feature: 兼容 Pydantic v2 (#2544)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing_extensions import Self, Annotated, override
|
||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||
from typing_extensions import Self, Annotated, get_args, override, get_origin
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -14,12 +14,12 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic.typing import get_args, get_origin
|
||||
from pydantic.fields import Required, FieldInfo, Undefined, ModelField
|
||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
||||
|
||||
from nonebot.dependencies import Param, Dependent
|
||||
from nonebot.dependencies.utils import check_field_type
|
||||
from nonebot.dependencies import Param, Dependent, CustomConfig
|
||||
from nonebot.typing import T_State, T_Handler, T_DependencyCache
|
||||
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
|
||||
from nonebot.utils import (
|
||||
get_name,
|
||||
run_sync,
|
||||
@ -34,23 +34,6 @@ if TYPE_CHECKING:
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.adapters import Bot, Event
|
||||
|
||||
EXTRA_FIELD_INFO = (
|
||||
"gt",
|
||||
"lt",
|
||||
"ge",
|
||||
"le",
|
||||
"multiple_of",
|
||||
"allow_inf_nan",
|
||||
"max_digits",
|
||||
"decimal_places",
|
||||
"min_items",
|
||||
"max_items",
|
||||
"unique_items",
|
||||
"min_length",
|
||||
"max_length",
|
||||
"regex",
|
||||
)
|
||||
|
||||
|
||||
class DependsInner:
|
||||
def __init__(
|
||||
@ -58,7 +41,7 @@ class DependsInner:
|
||||
dependency: Optional[T_Handler] = None,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
validate: Union[bool, FieldInfo] = False,
|
||||
validate: Union[bool, PydanticFieldInfo] = False,
|
||||
) -> None:
|
||||
self.dependency = dependency
|
||||
self.use_cache = use_cache
|
||||
@ -75,7 +58,7 @@ def Depends(
|
||||
dependency: Optional[T_Handler] = None,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
validate: Union[bool, FieldInfo] = False,
|
||||
validate: Union[bool, PydanticFieldInfo] = False,
|
||||
) -> Any:
|
||||
"""子依赖装饰器
|
||||
|
||||
@ -113,24 +96,32 @@ class DependParam(Param):
|
||||
本注入应该具有最高优先级,因此应该在其他参数之前检查。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *args, dependent: Dependent, use_cache: bool, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dependent = dependent
|
||||
self.use_cache = use_cache
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Depends({self.extra['dependent']})"
|
||||
return f"Depends({self.dependent}, use_cache={self.use_cache})"
|
||||
|
||||
@classmethod
|
||||
def _from_field(
|
||||
cls, sub_dependent: Dependent, use_cache: bool, validate: Union[bool, FieldInfo]
|
||||
cls,
|
||||
sub_dependent: Dependent,
|
||||
use_cache: bool,
|
||||
validate: Union[bool, PydanticFieldInfo],
|
||||
) -> Self:
|
||||
kwargs = {}
|
||||
if isinstance(validate, FieldInfo):
|
||||
kwargs.update((k, getattr(validate, k)) for k in EXTRA_FIELD_INFO)
|
||||
if isinstance(validate, PydanticFieldInfo):
|
||||
kwargs.update(extract_field_info(validate))
|
||||
|
||||
return cls(
|
||||
Required,
|
||||
validate=bool(validate),
|
||||
**kwargs,
|
||||
dependent=sub_dependent,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
kwargs["validate"] = bool(validate)
|
||||
kwargs["dependent"] = sub_dependent
|
||||
kwargs["use_cache"] = use_cache
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
@ -191,10 +182,10 @@ class DependParam(Param):
|
||||
dependency_cache: Optional[T_DependencyCache] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
use_cache: bool = self.extra["use_cache"]
|
||||
use_cache: bool = self.use_cache
|
||||
dependency_cache = {} if dependency_cache is None else dependency_cache
|
||||
|
||||
sub_dependent: Dependent = self.extra["dependent"]
|
||||
sub_dependent: Dependent = self.dependent
|
||||
call = cast(Callable[..., Any], sub_dependent.call)
|
||||
|
||||
# solve sub dependency with current cache
|
||||
@ -231,8 +222,7 @@ class DependParam(Param):
|
||||
@override
|
||||
async def _check(self, **kwargs: Any) -> None:
|
||||
# run sub dependent pre-checkers
|
||||
sub_dependent: Dependent = self.extra["dependent"]
|
||||
await sub_dependent.check(**kwargs)
|
||||
await self.dependent.check(**kwargs)
|
||||
|
||||
|
||||
class BotParam(Param):
|
||||
@ -243,14 +233,16 @@ class BotParam(Param):
|
||||
为保证兼容性,本注入还会解析名为 `bot` 且没有类型注解的参数。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *args, checker: Optional[ModelField] = None, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.checker = checker
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
"BotParam("
|
||||
+ (
|
||||
repr(cast(ModelField, checker).type_)
|
||||
if (checker := self.extra.get("checker"))
|
||||
else ""
|
||||
)
|
||||
+ (repr(self.checker.annotation) if self.checker is not None else "")
|
||||
+ ")"
|
||||
)
|
||||
|
||||
@ -265,18 +257,13 @@ class BotParam(Param):
|
||||
if generic_check_issubclass(param.annotation, Bot):
|
||||
checker: Optional[ModelField] = None
|
||||
if param.annotation is not Bot:
|
||||
checker = ModelField(
|
||||
name=param.name,
|
||||
type_=param.annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=None,
|
||||
required=True,
|
||||
checker = ModelField.construct(
|
||||
name=param.name, annotation=param.annotation, field_info=FieldInfo()
|
||||
)
|
||||
return cls(Required, checker=checker)
|
||||
return cls(checker=checker)
|
||||
# legacy: param is named "bot" and has no type annotation
|
||||
elif param.annotation == param.empty and param.name == "bot":
|
||||
return cls(Required)
|
||||
return cls()
|
||||
|
||||
@override
|
||||
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
|
||||
@ -284,8 +271,8 @@ class BotParam(Param):
|
||||
|
||||
@override
|
||||
async def _check(self, bot: "Bot", **kwargs: Any) -> None:
|
||||
if checker := self.extra.get("checker"):
|
||||
check_field_type(checker, bot)
|
||||
if self.checker is not None:
|
||||
check_field_type(self.checker, bot)
|
||||
|
||||
|
||||
class EventParam(Param):
|
||||
@ -296,14 +283,16 @@ class EventParam(Param):
|
||||
为保证兼容性,本注入还会解析名为 `event` 且没有类型注解的参数。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *args, checker: Optional[ModelField] = None, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.checker = checker
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
"EventParam("
|
||||
+ (
|
||||
repr(cast(ModelField, checker).type_)
|
||||
if (checker := self.extra.get("checker"))
|
||||
else ""
|
||||
)
|
||||
+ (repr(self.checker.annotation) if self.checker is not None else "")
|
||||
+ ")"
|
||||
)
|
||||
|
||||
@ -318,18 +307,13 @@ class EventParam(Param):
|
||||
if generic_check_issubclass(param.annotation, Event):
|
||||
checker: Optional[ModelField] = None
|
||||
if param.annotation is not Event:
|
||||
checker = ModelField(
|
||||
name=param.name,
|
||||
type_=param.annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=None,
|
||||
required=True,
|
||||
checker = ModelField.construct(
|
||||
name=param.name, annotation=param.annotation, field_info=FieldInfo()
|
||||
)
|
||||
return cls(Required, checker=checker)
|
||||
return cls(checker=checker)
|
||||
# legacy: param is named "event" and has no type annotation
|
||||
elif param.annotation == param.empty and param.name == "event":
|
||||
return cls(Required)
|
||||
return cls()
|
||||
|
||||
@override
|
||||
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
|
||||
@ -337,8 +321,8 @@ class EventParam(Param):
|
||||
|
||||
@override
|
||||
async def _check(self, event: "Event", **kwargs: Any) -> Any:
|
||||
if checker := self.extra.get("checker", None):
|
||||
check_field_type(checker, event)
|
||||
if self.checker is not None:
|
||||
check_field_type(self.checker, event)
|
||||
|
||||
|
||||
class StateParam(Param):
|
||||
@ -359,10 +343,10 @@ class StateParam(Param):
|
||||
) -> Optional[Self]:
|
||||
# param type is T_State
|
||||
if param.annotation is T_State:
|
||||
return cls(Required)
|
||||
return cls()
|
||||
# legacy: param is named "state" and has no type annotation
|
||||
elif param.annotation == param.empty and param.name == "state":
|
||||
return cls(Required)
|
||||
return cls()
|
||||
|
||||
@override
|
||||
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
|
||||
@ -377,8 +361,18 @@ class MatcherParam(Param):
|
||||
为保证兼容性,本注入还会解析名为 `matcher` 且没有类型注解的参数。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *args, checker: Optional[ModelField] = None, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.checker = checker
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "MatcherParam()"
|
||||
return (
|
||||
"MatcherParam("
|
||||
+ (repr(self.checker.annotation) if self.checker is not None else "")
|
||||
+ ")"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
@ -391,18 +385,13 @@ class MatcherParam(Param):
|
||||
if generic_check_issubclass(param.annotation, Matcher):
|
||||
checker: Optional[ModelField] = None
|
||||
if param.annotation is not Matcher:
|
||||
checker = ModelField(
|
||||
name=param.name,
|
||||
type_=param.annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=None,
|
||||
required=True,
|
||||
checker = ModelField.construct(
|
||||
name=param.name, annotation=param.annotation, field_info=FieldInfo()
|
||||
)
|
||||
return cls(Required, checker=checker)
|
||||
return cls(checker=checker)
|
||||
# legacy: param is named "matcher" and has no type annotation
|
||||
elif param.annotation == param.empty and param.name == "matcher":
|
||||
return cls(Required)
|
||||
return cls()
|
||||
|
||||
@override
|
||||
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||
@ -410,16 +399,16 @@ class MatcherParam(Param):
|
||||
|
||||
@override
|
||||
async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||
if checker := self.extra.get("checker", None):
|
||||
check_field_type(checker, matcher)
|
||||
if self.checker is not None:
|
||||
check_field_type(self.checker, matcher)
|
||||
|
||||
|
||||
class ArgInner:
|
||||
def __init__(
|
||||
self, key: Optional[str], type: Literal["message", "str", "plaintext"]
|
||||
) -> None:
|
||||
self.key = key
|
||||
self.type = type
|
||||
self.key: Optional[str] = key
|
||||
self.type: Literal["message", "str", "plaintext"] = type
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ArgInner(key={self.key!r}, type={self.type!r})"
|
||||
@ -449,8 +438,19 @@ class ArgParam(Param):
|
||||
留空则会根据参数名称获取。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
key: str,
|
||||
type: Literal["message", "str", "plaintext"],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.key = key
|
||||
self.type = type
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"
|
||||
return f"ArgParam(key={self.key!r}, type={self.type!r})"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
@ -458,22 +458,19 @@ class ArgParam(Param):
|
||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||
) -> Optional[Self]:
|
||||
if isinstance(param.default, ArgInner):
|
||||
return cls(
|
||||
Required, key=param.default.key or param.name, type=param.default.type
|
||||
)
|
||||
return cls(key=param.default.key or param.name, type=param.default.type)
|
||||
elif get_origin(param.annotation) is Annotated:
|
||||
for arg in get_args(param.annotation)[:0:-1]:
|
||||
if isinstance(arg, ArgInner):
|
||||
return cls(Required, key=arg.key or param.name, type=arg.type)
|
||||
return cls(key=arg.key or param.name, type=arg.type)
|
||||
|
||||
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||
key: str = self.extra["key"]
|
||||
message = matcher.get_arg(key)
|
||||
message = matcher.get_arg(self.key)
|
||||
if message is None:
|
||||
return message
|
||||
if self.extra["type"] == "message":
|
||||
if self.type == "message":
|
||||
return message
|
||||
elif self.extra["type"] == "str":
|
||||
elif self.type == "str":
|
||||
return str(message)
|
||||
else:
|
||||
return message.extract_plain_text()
|
||||
@ -497,10 +494,10 @@ class ExceptionParam(Param):
|
||||
) -> Optional[Self]:
|
||||
# param type is Exception(s) or subclass(es) of Exception or None
|
||||
if generic_check_issubclass(param.annotation, Exception):
|
||||
return cls(Required)
|
||||
return cls()
|
||||
# legacy: param is named "exception" and has no type annotation
|
||||
elif param.annotation == param.empty and param.name == "exception":
|
||||
return cls(Required)
|
||||
return cls()
|
||||
|
||||
@override
|
||||
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
|
||||
@ -524,11 +521,11 @@ class DefaultParam(Param):
|
||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||
) -> Optional[Self]:
|
||||
if param.default != param.empty:
|
||||
return cls(param.default)
|
||||
return cls(default=param.default)
|
||||
|
||||
@override
|
||||
async def _solve(self, **kwargs: Any) -> Any:
|
||||
return Undefined
|
||||
return PydanticUndefined
|
||||
|
||||
|
||||
__autodoc__ = {
|
||||
|
Reference in New Issue
Block a user