🐛 fix matcher got receive type error

This commit is contained in:
yanyongyu
2021-12-31 22:43:29 +08:00
parent c48231454e
commit ec35f292bd
6 changed files with 130 additions and 43 deletions

View File

@ -49,11 +49,13 @@ class Dependent(Generic[R]):
self,
*,
call: Callable[..., Any],
pre_checkers: Optional[List[Param]] = None,
params: Optional[List[ModelField]] = None,
parameterless: Optional[List[Param]] = None,
allow_types: Optional[List[Type[Param]]] = None,
) -> None:
self.call = call
self.pre_checkers = pre_checkers or []
self.params = params or []
self.parameterless = parameterless or []
self.allow_types = allow_types or []
@ -116,11 +118,6 @@ class Dependent(Generic[R]):
allow_types=allow_types,
)
parameterless_params = [
dependent.parse_parameterless(param) for param in (parameterless or [])
]
dependent.parameterless.extend(parameterless_params)
for param_name, param in params.items():
default_value = Required
if param.default != param.empty:
@ -152,6 +149,11 @@ class Dependent(Generic[R]):
)
)
parameterless_params = [
dependent.parse_parameterless(param) for param in (parameterless or [])
]
dependent.parameterless.extend(parameterless_params)
logger.trace(
f"Parsed dependent with call={call}, "
f"params={[param.field_info for param in dependent.params]}, "
@ -166,6 +168,9 @@ class Dependent(Generic[R]):
) -> Dict[str, Any]:
values: Dict[str, Any] = {}
for checker in self.pre_checkers:
await checker._solve(**params)
for param in self.parameterless:
await param._solve(**params)

View File

@ -484,7 +484,6 @@ class Matcher(metaclass=MatcherMeta):
"""
async def _key_getter(event: Event, matcher: "Matcher"):
print(key, matcher.state)
matcher.set_target(ARG_KEY.format(key=key))
if matcher.get_target() == ARG_KEY.format(key=key):
matcher.set_arg(key, event.get_message())

View File

@ -4,10 +4,12 @@ from typing_extensions import Literal
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic.fields import Required, Undefined
from pydantic.fields import Required, Undefined, ModelField
from nonebot.log import logger
from nonebot.exception import TypeMisMatch
from nonebot.adapters import Bot, Event, Message
from nonebot.dependencies import Param, Dependent
from nonebot.dependencies import Param, Dependent, CustomConfig
from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.consts import (
CMD_KEY,
@ -94,11 +96,15 @@ class DependParam(Param):
dependency = param.annotation
else:
dependency = param.default.dependency
dependent = Dependent[Any].parse(
sub_dependent = Dependent[Any].parse(
call=dependency,
allow_types=dependent.allow_types,
)
return cls(Required, use_cache=param.default.use_cache, dependent=dependent)
dependent.pre_checkers.extend(sub_dependent.pre_checkers)
sub_dependent.pre_checkers.clear()
return cls(
Required, use_cache=param.default.use_cache, dependent=sub_dependent
)
@classmethod
def _check_parameterless(
@ -158,31 +164,81 @@ class DependParam(Param):
return solved
class _BotChecker(Param):
async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
_, errs_ = field.validate(bot, {}, loc=("bot",))
if errs_:
logger.debug(
f"Bot type {type(bot)} not match "
f"annotation {field._type_display()}, ignored"
)
raise TypeMisMatch(field, bot)
class BotParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["BotParam"]:
if param.default == param.empty and (
generic_check_issubclass(param.annotation, Bot)
or (param.annotation == param.empty and name == "bot")
):
return cls(Required)
if param.default == param.empty:
if generic_check_issubclass(param.annotation, Bot):
dependent.pre_checkers.append(
_BotChecker(
Required,
field=ModelField(
name="",
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
),
)
)
return cls(Required)
elif param.annotation == param.empty and name == "bot":
return cls(Required)
async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
return bot
class _EventChecker(Param):
async def _solve(self, event: Event, **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
_, errs_ = field.validate(event, {}, loc=("event",))
if errs_:
logger.debug(
f"Event type {type(event)} not match "
f"annotation {field._type_display()}, ignored"
)
raise TypeMisMatch(field, event)
class EventParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["EventParam"]:
if param.default == param.empty and (
generic_check_issubclass(param.annotation, Event)
or (param.annotation == param.empty and name == "event")
):
return cls(Required)
if param.default == param.empty:
if generic_check_issubclass(param.annotation, Event):
dependent.pre_checkers.append(
_EventChecker(
Required,
field=ModelField(
name="",
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
),
)
)
return cls(Required)
elif param.annotation == param.empty and name == "event":
return cls(Required)
async def _solve(self, event: Event, **kwargs: Any) -> Any:
return event