Feature: 兼容 Pydantic v2 (#2544)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ju4tCode
2024-01-26 11:12:57 +08:00
committed by GitHub
parent 82e4ccb227
commit bbd13c04cc
36 changed files with 6535 additions and 414 deletions

View File

@ -1,7 +1,7 @@
import asyncio
import inspect
from typing_extensions import Self, Annotated, override
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from typing_extensions import Self, Annotated, get_args, override, get_origin
from typing import (
TYPE_CHECKING,
Any,
@ -14,12 +14,12 @@ from typing import (
cast,
)
from pydantic.typing import get_args, get_origin
from pydantic.fields import Required, FieldInfo, Undefined, ModelField
from pydantic.fields import FieldInfo as PydanticFieldInfo
from nonebot.dependencies import Param, Dependent
from nonebot.dependencies.utils import check_field_type
from nonebot.dependencies import Param, Dependent, CustomConfig
from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
from nonebot.utils import (
get_name,
run_sync,
@ -34,23 +34,6 @@ if TYPE_CHECKING:
from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Event
EXTRA_FIELD_INFO = (
"gt",
"lt",
"ge",
"le",
"multiple_of",
"allow_inf_nan",
"max_digits",
"decimal_places",
"min_items",
"max_items",
"unique_items",
"min_length",
"max_length",
"regex",
)
class DependsInner:
def __init__(
@ -58,7 +41,7 @@ class DependsInner:
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True,
validate: Union[bool, FieldInfo] = False,
validate: Union[bool, PydanticFieldInfo] = False,
) -> None:
self.dependency = dependency
self.use_cache = use_cache
@ -75,7 +58,7 @@ def Depends(
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True,
validate: Union[bool, FieldInfo] = False,
validate: Union[bool, PydanticFieldInfo] = False,
) -> Any:
"""子依赖装饰器
@ -113,24 +96,32 @@ class DependParam(Param):
本注入应该具有最高优先级,因此应该在其他参数之前检查。
"""
def __init__(
self, *args, dependent: Dependent, use_cache: bool, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self.dependent = dependent
self.use_cache = use_cache
def __repr__(self) -> str:
return f"Depends({self.extra['dependent']})"
return f"Depends({self.dependent}, use_cache={self.use_cache})"
@classmethod
def _from_field(
cls, sub_dependent: Dependent, use_cache: bool, validate: Union[bool, FieldInfo]
cls,
sub_dependent: Dependent,
use_cache: bool,
validate: Union[bool, PydanticFieldInfo],
) -> Self:
kwargs = {}
if isinstance(validate, FieldInfo):
kwargs.update((k, getattr(validate, k)) for k in EXTRA_FIELD_INFO)
if isinstance(validate, PydanticFieldInfo):
kwargs.update(extract_field_info(validate))
return cls(
Required,
validate=bool(validate),
**kwargs,
dependent=sub_dependent,
use_cache=use_cache,
)
kwargs["validate"] = bool(validate)
kwargs["dependent"] = sub_dependent
kwargs["use_cache"] = use_cache
return cls(**kwargs)
@classmethod
@override
@ -191,10 +182,10 @@ class DependParam(Param):
dependency_cache: Optional[T_DependencyCache] = None,
**kwargs: Any,
) -> Any:
use_cache: bool = self.extra["use_cache"]
use_cache: bool = self.use_cache
dependency_cache = {} if dependency_cache is None else dependency_cache
sub_dependent: Dependent = self.extra["dependent"]
sub_dependent: Dependent = self.dependent
call = cast(Callable[..., Any], sub_dependent.call)
# solve sub dependency with current cache
@ -231,8 +222,7 @@ class DependParam(Param):
@override
async def _check(self, **kwargs: Any) -> None:
# run sub dependent pre-checkers
sub_dependent: Dependent = self.extra["dependent"]
await sub_dependent.check(**kwargs)
await self.dependent.check(**kwargs)
class BotParam(Param):
@ -243,14 +233,16 @@ class BotParam(Param):
为保证兼容性,本注入还会解析名为 `bot` 且没有类型注解的参数。
"""
def __init__(
self, *args, checker: Optional[ModelField] = None, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self.checker = checker
def __repr__(self) -> str:
return (
"BotParam("
+ (
repr(cast(ModelField, checker).type_)
if (checker := self.extra.get("checker"))
else ""
)
+ (repr(self.checker.annotation) if self.checker is not None else "")
+ ")"
)
@ -265,18 +257,13 @@ class BotParam(Param):
if generic_check_issubclass(param.annotation, Bot):
checker: Optional[ModelField] = None
if param.annotation is not Bot:
checker = ModelField(
name=param.name,
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
checker = ModelField.construct(
name=param.name, annotation=param.annotation, field_info=FieldInfo()
)
return cls(Required, checker=checker)
return cls(checker=checker)
# legacy: param is named "bot" and has no type annotation
elif param.annotation == param.empty and param.name == "bot":
return cls(Required)
return cls()
@override
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
@ -284,8 +271,8 @@ class BotParam(Param):
@override
async def _check(self, bot: "Bot", **kwargs: Any) -> None:
if checker := self.extra.get("checker"):
check_field_type(checker, bot)
if self.checker is not None:
check_field_type(self.checker, bot)
class EventParam(Param):
@ -296,14 +283,16 @@ class EventParam(Param):
为保证兼容性,本注入还会解析名为 `event` 且没有类型注解的参数。
"""
def __init__(
self, *args, checker: Optional[ModelField] = None, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self.checker = checker
def __repr__(self) -> str:
return (
"EventParam("
+ (
repr(cast(ModelField, checker).type_)
if (checker := self.extra.get("checker"))
else ""
)
+ (repr(self.checker.annotation) if self.checker is not None else "")
+ ")"
)
@ -318,18 +307,13 @@ class EventParam(Param):
if generic_check_issubclass(param.annotation, Event):
checker: Optional[ModelField] = None
if param.annotation is not Event:
checker = ModelField(
name=param.name,
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
checker = ModelField.construct(
name=param.name, annotation=param.annotation, field_info=FieldInfo()
)
return cls(Required, checker=checker)
return cls(checker=checker)
# legacy: param is named "event" and has no type annotation
elif param.annotation == param.empty and param.name == "event":
return cls(Required)
return cls()
@override
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
@ -337,8 +321,8 @@ class EventParam(Param):
@override
async def _check(self, event: "Event", **kwargs: Any) -> Any:
if checker := self.extra.get("checker", None):
check_field_type(checker, event)
if self.checker is not None:
check_field_type(self.checker, event)
class StateParam(Param):
@ -359,10 +343,10 @@ class StateParam(Param):
) -> Optional[Self]:
# param type is T_State
if param.annotation is T_State:
return cls(Required)
return cls()
# legacy: param is named "state" and has no type annotation
elif param.annotation == param.empty and param.name == "state":
return cls(Required)
return cls()
@override
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
@ -377,8 +361,18 @@ class MatcherParam(Param):
为保证兼容性,本注入还会解析名为 `matcher` 且没有类型注解的参数。
"""
def __init__(
self, *args, checker: Optional[ModelField] = None, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self.checker = checker
def __repr__(self) -> str:
return "MatcherParam()"
return (
"MatcherParam("
+ (repr(self.checker.annotation) if self.checker is not None else "")
+ ")"
)
@classmethod
@override
@ -391,18 +385,13 @@ class MatcherParam(Param):
if generic_check_issubclass(param.annotation, Matcher):
checker: Optional[ModelField] = None
if param.annotation is not Matcher:
checker = ModelField(
name=param.name,
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
checker = ModelField.construct(
name=param.name, annotation=param.annotation, field_info=FieldInfo()
)
return cls(Required, checker=checker)
return cls(checker=checker)
# legacy: param is named "matcher" and has no type annotation
elif param.annotation == param.empty and param.name == "matcher":
return cls(Required)
return cls()
@override
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
@ -410,16 +399,16 @@ class MatcherParam(Param):
@override
async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any:
if checker := self.extra.get("checker", None):
check_field_type(checker, matcher)
if self.checker is not None:
check_field_type(self.checker, matcher)
class ArgInner:
def __init__(
self, key: Optional[str], type: Literal["message", "str", "plaintext"]
) -> None:
self.key = key
self.type = type
self.key: Optional[str] = key
self.type: Literal["message", "str", "plaintext"] = type
def __repr__(self) -> str:
return f"ArgInner(key={self.key!r}, type={self.type!r})"
@ -449,8 +438,19 @@ class ArgParam(Param):
留空则会根据参数名称获取。
"""
def __init__(
self,
*args,
key: str,
type: Literal["message", "str", "plaintext"],
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.key = key
self.type = type
def __repr__(self) -> str:
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"
return f"ArgParam(key={self.key!r}, type={self.type!r})"
@classmethod
@override
@ -458,22 +458,19 @@ class ArgParam(Param):
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional[Self]:
if isinstance(param.default, ArgInner):
return cls(
Required, key=param.default.key or param.name, type=param.default.type
)
return cls(key=param.default.key or param.name, type=param.default.type)
elif get_origin(param.annotation) is Annotated:
for arg in get_args(param.annotation)[:0:-1]:
if isinstance(arg, ArgInner):
return cls(Required, key=arg.key or param.name, type=arg.type)
return cls(key=arg.key or param.name, type=arg.type)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
key: str = self.extra["key"]
message = matcher.get_arg(key)
message = matcher.get_arg(self.key)
if message is None:
return message
if self.extra["type"] == "message":
if self.type == "message":
return message
elif self.extra["type"] == "str":
elif self.type == "str":
return str(message)
else:
return message.extract_plain_text()
@ -497,10 +494,10 @@ class ExceptionParam(Param):
) -> Optional[Self]:
# param type is Exception(s) or subclass(es) of Exception or None
if generic_check_issubclass(param.annotation, Exception):
return cls(Required)
return cls()
# legacy: param is named "exception" and has no type annotation
elif param.annotation == param.empty and param.name == "exception":
return cls(Required)
return cls()
@override
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
@ -524,11 +521,11 @@ class DefaultParam(Param):
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional[Self]:
if param.default != param.empty:
return cls(param.default)
return cls(default=param.default)
@override
async def _solve(self, **kwargs: Any) -> Any:
return Undefined
return PydanticUndefined
__autodoc__ = {