allow extra param with default value

This commit is contained in:
yanyongyu
2021-11-22 11:38:42 +08:00
parent 23c237cb2a
commit 3120abacb3
7 changed files with 45 additions and 26 deletions

View File

@ -11,8 +11,8 @@ from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic import BaseConfig
from pydantic.fields import Required, ModelField
from pydantic.schema import get_annotation_from_field_info
from pydantic.fields import Required, Undefined, ModelField
from nonebot.log import logger
from .models import Param as Param
@ -90,16 +90,25 @@ def get_dependent(*,
dependent.dependencies.append(sub_dependent)
continue
for allow_type in dependent.allow_types:
if allow_type._check(param_name, param):
field_info = allow_type(param.default)
break
default_value = Required
if param.default != param.empty:
default_value = param.default
if isinstance(default_value, Param):
field_info = default_value
default_value = field_info.default
else:
raise ValueError(
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
)
for allow_type in dependent.allow_types:
if allow_type._check(param_name, param):
field_info = allow_type(default_value)
break
else:
raise ValueError(
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
)
annotation: Any = Any
required = default_value == Required
if param.annotation != param.empty:
annotation = param.annotation
annotation = get_annotation_from_field_info(annotation, field_info,
@ -109,8 +118,8 @@ def get_dependent(*,
type_=annotation,
class_validators=None,
model_config=CustomConfig,
default=Required,
required=True,
default=None if required else default_value,
required=required,
field_info=field_info))
return dependent
@ -176,6 +185,8 @@ async def solve_dependencies(
assert isinstance(field_info,
Param), "Params must be subclasses of Param"
value = field_info._solve(**params)
if value == Undefined:
value = field.get_default()
_, errs_ = field.validate(value,
values,
loc=(str(field_info), field.alias))