mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-28 00:31:14 +00:00
♿ allow extra param with default value
This commit is contained in:
@ -11,8 +11,8 @@ from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast
|
||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||
|
||||
from pydantic import BaseConfig
|
||||
from pydantic.fields import Required, ModelField
|
||||
from pydantic.schema import get_annotation_from_field_info
|
||||
from pydantic.fields import Required, Undefined, ModelField
|
||||
|
||||
from nonebot.log import logger
|
||||
from .models import Param as Param
|
||||
@ -90,16 +90,25 @@ def get_dependent(*,
|
||||
dependent.dependencies.append(sub_dependent)
|
||||
continue
|
||||
|
||||
for allow_type in dependent.allow_types:
|
||||
if allow_type._check(param_name, param):
|
||||
field_info = allow_type(param.default)
|
||||
break
|
||||
default_value = Required
|
||||
if param.default != param.empty:
|
||||
default_value = param.default
|
||||
|
||||
if isinstance(default_value, Param):
|
||||
field_info = default_value
|
||||
default_value = field_info.default
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
|
||||
)
|
||||
for allow_type in dependent.allow_types:
|
||||
if allow_type._check(param_name, param):
|
||||
field_info = allow_type(default_value)
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
|
||||
)
|
||||
|
||||
annotation: Any = Any
|
||||
required = default_value == Required
|
||||
if param.annotation != param.empty:
|
||||
annotation = param.annotation
|
||||
annotation = get_annotation_from_field_info(annotation, field_info,
|
||||
@ -109,8 +118,8 @@ def get_dependent(*,
|
||||
type_=annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=Required,
|
||||
required=True,
|
||||
default=None if required else default_value,
|
||||
required=required,
|
||||
field_info=field_info))
|
||||
|
||||
return dependent
|
||||
@ -176,6 +185,8 @@ async def solve_dependencies(
|
||||
assert isinstance(field_info,
|
||||
Param), "Params must be subclasses of Param"
|
||||
value = field_info._solve(**params)
|
||||
if value == Undefined:
|
||||
value = field.get_default()
|
||||
_, errs_ = field.validate(value,
|
||||
values,
|
||||
loc=(str(field_info), field.alias))
|
||||
|
Reference in New Issue
Block a user