♻️ improve dependent structure (#1227)

This commit is contained in:
Ju4tCode
2022-09-07 09:59:05 +08:00
committed by GitHub
parent 595c64e760
commit a0b186aff3
3 changed files with 191 additions and 210 deletions

View File

@ -1,8 +1,7 @@
import asyncio
import inspect
import warnings
from typing import TYPE_CHECKING, Any, Literal, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast
from pydantic.fields import Required, Undefined, ModelField
@ -76,10 +75,7 @@ class DependParam(Param):
@classmethod
def _check_param(
cls,
dependent: Dependent,
name: str,
param: inspect.Parameter,
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["DependParam"]:
if isinstance(param.default, DependsInner):
dependency: T_Handler
@ -90,22 +86,20 @@ class DependParam(Param):
dependency = param.default.dependency
sub_dependent = Dependent[Any].parse(
call=dependency,
allow_types=dependent.allow_types,
allow_types=allow_types,
)
dependent.pre_checkers.extend(sub_dependent.pre_checkers)
sub_dependent.pre_checkers.clear()
return cls(
Required, use_cache=param.default.use_cache, dependent=sub_dependent
)
@classmethod
def _check_parameterless(
cls, dependent: "Dependent", value: Any
cls, value: Any, allow_types: Tuple[Type[Param], ...]
) -> Optional["Param"]:
if isinstance(value, DependsInner):
assert value.dependency, "Dependency cannot be empty"
dependent = Dependent[Any].parse(
call=value.dependency, allow_types=dependent.allow_types
call=value.dependency, allow_types=allow_types
)
return cls(Required, use_cache=value.use_cache, dependent=dependent)
@ -119,8 +113,7 @@ class DependParam(Param):
dependency_cache = {} if dependency_cache is None else dependency_cache
sub_dependent: Dependent = self.extra["dependent"]
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
call = sub_dependent.call
call = cast(Callable[..., Any], sub_dependent.call)
# solve sub dependency with current cache
sub_values = await sub_dependent.solve(
@ -132,7 +125,7 @@ class DependParam(Param):
# run dependency function
task: asyncio.Task[Any]
if use_cache and call in dependency_cache:
solved = await dependency_cache[call]
return await dependency_cache[call]
elif is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
stack, AsyncExitStack
@ -143,30 +136,20 @@ class DependParam(Param):
cm = asynccontextmanager(call)(**sub_values)
task = asyncio.create_task(stack.enter_async_context(cm))
dependency_cache[call] = task
solved = await task
return await task
elif is_coroutine_callable(call):
task = asyncio.create_task(call(**sub_values))
dependency_cache[call] = task
solved = await task
return await task
else:
task = asyncio.create_task(run_sync(call)(**sub_values))
dependency_cache[call] = task
solved = await task
return await task
return solved
class _BotChecker(Param):
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
try:
return check_field_type(field, bot)
except TypeMisMatch:
logger.debug(
f"Bot type {type(bot)} not match "
f"annotation {field._type_display()}, ignored"
)
raise
async def _check(self, **kwargs: Any) -> None:
# run sub dependent pre-checkers
sub_dependent: Dependent = self.extra["dependent"]
await sub_dependent.check(**kwargs)
class BotParam(Param):
@ -174,45 +157,32 @@ class BotParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["BotParam"]:
from nonebot.adapters import Bot
if param.default == param.empty:
if generic_check_issubclass(param.annotation, Bot):
checker: Optional[ModelField] = None
if param.annotation is not Bot:
dependent.pre_checkers.append(
_BotChecker(
Required,
field=ModelField(
name=name,
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
),
)
checker = ModelField(
name=param.name,
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
)
return cls(Required)
elif param.annotation == param.empty and name == "bot":
return cls(Required, checker=checker)
elif param.annotation == param.empty and param.name == "bot":
return cls(Required)
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
return bot
class _EventChecker(Param):
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
try:
return check_field_type(field, event)
except TypeMisMatch:
logger.debug(
f"Event type {type(event)} not match "
f"annotation {field._type_display()}, ignored"
)
raise
async def _check(self, bot: "Bot", **kwargs: Any) -> None:
if checker := self.extra.get("checker", None):
check_field_type(checker, bot)
class EventParam(Param):
@ -220,33 +190,33 @@ class EventParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["EventParam"]:
from nonebot.adapters import Event
if param.default == param.empty:
if generic_check_issubclass(param.annotation, Event):
checker: Optional[ModelField] = None
if param.annotation is not Event:
dependent.pre_checkers.append(
_EventChecker(
Required,
field=ModelField(
name=name,
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
),
)
checker = ModelField(
name=param.name,
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
)
return cls(Required)
elif param.annotation == param.empty and name == "event":
return cls(Required, checker=checker)
elif param.annotation == param.empty and param.name == "event":
return cls(Required)
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
return event
async def _check(self, event: "Event", **kwargs: Any) -> Any:
if checker := self.extra.get("checker", None):
check_field_type(checker, event)
class StateInner(T_State):
...
@ -257,14 +227,14 @@ class StateParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["StateParam"]:
if isinstance(param.default, StateInner):
return cls(Required)
elif param.default == param.empty:
if param.annotation is T_State:
return cls(Required)
elif param.annotation == param.empty and name == "state":
elif param.annotation == param.empty and param.name == "state":
return cls(Required)
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
@ -276,12 +246,12 @@ class MatcherParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["MatcherParam"]:
from nonebot.matcher import Matcher
if generic_check_issubclass(param.annotation, Matcher) or (
param.annotation == param.empty and name == "matcher"
param.annotation == param.empty and param.name == "matcher"
):
return cls(Required)
@ -317,10 +287,12 @@ class ArgParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["ArgParam"]:
if isinstance(param.default, ArgInner):
return cls(Required, key=param.default.key or name, type=param.default.type)
return cls(
Required, key=param.default.key or param.name, type=param.default.type
)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
message = matcher.get_arg(self.extra["key"])
@ -339,10 +311,10 @@ class ExceptionParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["ExceptionParam"]:
if generic_check_issubclass(param.annotation, Exception) or (
param.annotation == param.empty and name == "exception"
param.annotation == param.empty and param.name == "exception"
):
return cls(Required)
@ -355,7 +327,7 @@ class DefaultParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["DefaultParam"]:
if param.default != param.empty:
return cls(param.default)