Feature: 兼容 Pydantic v2 (#2544)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ju4tCode
2024-01-26 11:12:57 +08:00
committed by GitHub
parent 82e4ccb227
commit bbd13c04cc
36 changed files with 6535 additions and 414 deletions

View File

@ -24,14 +24,11 @@ from typing import (
cast,
)
from pydantic import BaseConfig
from pydantic.schema import get_annotation_from_field_info
from pydantic.fields import Required, FieldInfo, Undefined, ModelField
from nonebot.log import logger
from nonebot.typing import _DependentCallable
from nonebot.exception import SkippedException
from nonebot.utils import run_sync, is_coroutine_callable
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
from .utils import check_field_type, get_typed_signature
@ -69,10 +66,6 @@ class Param(abc.ABC, FieldInfo):
return
class CustomConfig(BaseConfig):
arbitrary_types_allowed = True
@dataclass(frozen=True)
class Dependent(Generic[R]):
"""依赖注入容器
@ -125,12 +118,8 @@ class Dependent(Generic[R]):
params = get_typed_signature(call).parameters.values()
for param in params:
default_value = Required
if param.default != param.empty:
default_value = param.default
if isinstance(default_value, Param):
field_info = default_value
if isinstance(param.default, Param):
field_info = param.default
else:
for allow_type in allow_types:
if field_info := allow_type._check_param(param, allow_types):
@ -141,25 +130,13 @@ class Dependent(Generic[R]):
f"for function {call} with type {param.annotation}"
)
default_value = field_info.default
annotation: Any = Any
required = default_value == Required
if param.annotation != param.empty:
if param.annotation is not param.empty:
annotation = param.annotation
annotation = get_annotation_from_field_info(
annotation, field_info, param.name
)
fields.append(
ModelField(
name=param.name,
type_=annotation,
class_validators=None,
model_config=CustomConfig,
default=None if required else default_value,
required=required,
field_info=field_info,
ModelField.construct(
name=param.name, annotation=annotation, field_info=field_info
)
)
@ -207,7 +184,7 @@ class Dependent(Generic[R]):
async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any:
param = cast(Param, field.field_info)
value = await param._solve(**params)
if value is Undefined:
if value is PydanticUndefined:
value = field.get_default()
v = check_field_type(field, value)
return v if param.validate else value

View File

@ -8,10 +8,10 @@ import inspect
from typing import Any, Dict, Callable, ForwardRef
from loguru import logger
from pydantic.fields import ModelField
from pydantic.typing import evaluate_forwardref
from nonebot.exception import TypeMisMatch
from nonebot.typing import evaluate_forwardref
from nonebot.compat import DEFAULT_CONFIG, ModelField, model_field_validate
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
@ -50,7 +50,7 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) ->
def check_field_type(field: ModelField, value: Any) -> Any:
"""检查字段类型是否匹配"""
v, errs_ = field.validate(value, {}, loc=())
if errs_:
try:
return model_field_validate(field, value, DEFAULT_CONFIG)
except ValueError:
raise TypeMisMatch(field, value)
return v