mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-09-07 04:26:45 +00:00
✨ Feature: 兼容 Pydantic v2 (#2544)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -24,14 +24,11 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseConfig
|
||||
from pydantic.schema import get_annotation_from_field_info
|
||||
from pydantic.fields import Required, FieldInfo, Undefined, ModelField
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import _DependentCallable
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.utils import run_sync, is_coroutine_callable
|
||||
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
|
||||
|
||||
from .utils import check_field_type, get_typed_signature
|
||||
|
||||
@ -69,10 +66,6 @@ class Param(abc.ABC, FieldInfo):
|
||||
return
|
||||
|
||||
|
||||
class CustomConfig(BaseConfig):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Dependent(Generic[R]):
|
||||
"""依赖注入容器
|
||||
@ -125,12 +118,8 @@ class Dependent(Generic[R]):
|
||||
params = get_typed_signature(call).parameters.values()
|
||||
|
||||
for param in params:
|
||||
default_value = Required
|
||||
if param.default != param.empty:
|
||||
default_value = param.default
|
||||
|
||||
if isinstance(default_value, Param):
|
||||
field_info = default_value
|
||||
if isinstance(param.default, Param):
|
||||
field_info = param.default
|
||||
else:
|
||||
for allow_type in allow_types:
|
||||
if field_info := allow_type._check_param(param, allow_types):
|
||||
@ -141,25 +130,13 @@ class Dependent(Generic[R]):
|
||||
f"for function {call} with type {param.annotation}"
|
||||
)
|
||||
|
||||
default_value = field_info.default
|
||||
|
||||
annotation: Any = Any
|
||||
required = default_value == Required
|
||||
if param.annotation != param.empty:
|
||||
if param.annotation is not param.empty:
|
||||
annotation = param.annotation
|
||||
annotation = get_annotation_from_field_info(
|
||||
annotation, field_info, param.name
|
||||
)
|
||||
|
||||
fields.append(
|
||||
ModelField(
|
||||
name=param.name,
|
||||
type_=annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=None if required else default_value,
|
||||
required=required,
|
||||
field_info=field_info,
|
||||
ModelField.construct(
|
||||
name=param.name, annotation=annotation, field_info=field_info
|
||||
)
|
||||
)
|
||||
|
||||
@ -207,7 +184,7 @@ class Dependent(Generic[R]):
|
||||
async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any:
|
||||
param = cast(Param, field.field_info)
|
||||
value = await param._solve(**params)
|
||||
if value is Undefined:
|
||||
if value is PydanticUndefined:
|
||||
value = field.get_default()
|
||||
v = check_field_type(field, value)
|
||||
return v if param.validate else value
|
||||
|
@ -8,10 +8,10 @@ import inspect
|
||||
from typing import Any, Dict, Callable, ForwardRef
|
||||
|
||||
from loguru import logger
|
||||
from pydantic.fields import ModelField
|
||||
from pydantic.typing import evaluate_forwardref
|
||||
|
||||
from nonebot.exception import TypeMisMatch
|
||||
from nonebot.typing import evaluate_forwardref
|
||||
from nonebot.compat import DEFAULT_CONFIG, ModelField, model_field_validate
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
@ -50,7 +50,7 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) ->
|
||||
def check_field_type(field: ModelField, value: Any) -> Any:
|
||||
"""检查字段类型是否匹配"""
|
||||
|
||||
v, errs_ = field.validate(value, {}, loc=())
|
||||
if errs_:
|
||||
try:
|
||||
return model_field_validate(field, value, DEFAULT_CONFIG)
|
||||
except ValueError:
|
||||
raise TypeMisMatch(field, value)
|
||||
return v
|
||||
|
Reference in New Issue
Block a user