🚧 process handler dependency

This commit is contained in:
yanyongyu 2021-11-12 18:10:40 +08:00
parent 57e826a835
commit c454cf0874
6 changed files with 256 additions and 166 deletions

View File

@ -0,0 +1,2 @@
from .models import Depends as Depends
from .utils import get_dependent as get_dependent

View File

@ -0,0 +1,89 @@
from enum import Enum
from typing import Any, List, Callable, Optional
from pydantic.fields import Required, FieldInfo, ModelField
from nonebot.utils import get_name
class Depends:
def __init__(self,
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True) -> None:
self.dependency = dependency
self.use_cache = use_cache
def __repr__(self) -> str:
dep = get_name(self.dependency)
cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({dep}{cache})"
class Dependent:
def __init__(self,
*,
func: Optional[Callable[..., Any]] = None,
name: Optional[str] = None,
bot_param: Optional[ModelField] = None,
event_param: Optional[ModelField] = None,
state_param: Optional[ModelField] = None,
matcher_param: Optional[ModelField] = None,
simple_params: Optional[List[ModelField]] = None,
dependencies: Optional[List["Dependent"]] = None,
use_cache: bool = True) -> None:
self.func = func
self.name = name
self.bot_param = bot_param
self.event_param = event_param
self.state_param = state_param
self.matcher_param = matcher_param
self.simple_params = simple_params or []
self.dependencies = dependencies or []
self.use_cache = use_cache
self.cache_key = (self.func,)
class ParamTypes(Enum):
BOT = "bot"
EVENT = "event"
STATE = "state"
MATCHER = "matcher"
SIMPLE = "simple"
class Param(FieldInfo):
in_: ParamTypes
def __init__(self, default: Any):
super().__init__(default=default)
def __repr__(self) -> str:
return f"{self.__class__.__name__}"
class BotParam(Param):
in_ = ParamTypes.BOT
class EventParam(Param):
in_ = ParamTypes.EVENT
class StateParam(Param):
in_ = ParamTypes.STATE
class MatcherParam(Param):
in_ = ParamTypes.MATCHER
class SimpleParam(Param):
in_ = ParamTypes.SIMPLE
def __init__(self, default: Any):
if default is Required:
raise ValueError("SimpleParam should be given a default value")
super().__init__(default)

View File

@ -0,0 +1,150 @@
import inspect
from typing import Any, Dict, Type, Union, Callable, Optional, ForwardRef
from pydantic import BaseConfig
from pydantic.class_validators import Validator
from pydantic.typing import evaluate_forwardref
from pydantic.schema import get_annotation_from_field_info
from pydantic.fields import Required, FieldInfo, ModelField, UndefinedType
from .models import Param, Depends, Dependent, ParamTypes, SimpleParam
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param, globalns),
) for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return typed_signature
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str,
Any]) -> Any:
annotation = param.annotation
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
return annotation
def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent:
depends: Depends = param.default
if depends.dependency:
dependency = depends.dependency
else:
dependency = param.annotation
return get_sub_dependant(
depends=depends,
dependency=dependency,
name=param.name,
)
def get_parameterless_sub_dependant(*, depends: Depends) -> Dependent:
assert callable(
depends.dependency
), "A parameter-less dependency must have a callable dependency"
return get_sub_dependant(depends=depends, dependency=depends.dependency)
def get_sub_dependant(
*,
depends: Depends,
dependency: Callable[..., Any],
name: Optional[str] = None,
) -> Dependent:
sub_dependant = get_dependent(
func=dependency,
name=name,
use_cache=depends.use_cache,
)
return sub_dependant
def get_dependent(*,
func: Callable[..., Any],
name: Optional[str] = None,
use_cache: bool = True) -> Dependent:
signature = get_typed_signature(func)
params = signature.parameters
dependent = Dependent(func=func, name=name, use_cache=use_cache)
for param_name, param in params.items():
if isinstance(param.default, Depends):
sub_dependent = get_param_sub_dependent(param=param)
dependent.dependencies.append(sub_dependent)
continue
param_field = get_param_field(param=param,
param_name=param_name,
default_field_info=SimpleParam)
return dependent
def get_param_field(*,
param: inspect.Parameter,
param_name: str,
default_field_info: Type[Param] = Param,
force_type: Optional[ParamTypes] = None,
ignore_default: bool = False) -> ModelField:
default_value = Required
if param.default != param.empty and not ignore_default:
default_value = param.default
if isinstance(default_value, FieldInfo):
field_info = default_value
default_value = field_info.default
if (isinstance(field_info, Param) and
getattr(field_info, "in_", None) is None):
field_info.in_ = default_field_info.in_
if force_type:
field_info.in_ = force_type # type: ignore
else:
field_info = default_field_info(default_value)
required: bool = default_value == Required
annotation: Any = Any
if param.annotation != param.empty:
annotation = param.annotation
annotation = get_annotation_from_field_info(annotation, field_info,
param_name)
if not field_info.alias and getattr(field_info, "convert_underscores",
None):
alias = param.name.replace("_", "-")
else:
alias = field_info.alias or param.name
field = create_field(
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=alias,
required=required,
field_info=field_info,
)
# field.required = required
return field
def create_field(name: str,
type_: Type[Any],
class_validators: Optional[Dict[str, Validator]] = None,
default: Optional[Any] = None,
required: Union[bool, UndefinedType] = False,
model_config: Type[BaseConfig] = BaseConfig,
field_info: Optional[FieldInfo] = None,
alias: Optional[str] = None) -> ModelField:
class_validators = class_validators or {}
field_info = field_info or FieldInfo(None)
return ModelField(name=name,
type_=type_,
class_validators=class_validators,
model_config=model_config,
default=default,
required=required,
alias=alias,
field_info=field_info)

View File

@ -6,171 +6,22 @@
""" """
import inspect import inspect
from typing import _eval_type # type: ignore from typing import Optional
from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Optional,
ForwardRef)
from nonebot.log import logger from pydantic.typing import evaluate_forwardref
from nonebot.typing import T_State, T_Handler
if TYPE_CHECKING: from nonebot.utils import get_name
from nonebot.matcher import Matcher from nonebot.typing import T_Handler
from nonebot.adapters import Bot, Event
class Handler: class Handler:
"""事件处理函数类""" """事件处理函数类"""
def __init__(self, func: T_Handler): def __init__(self, func: T_Handler, *, name: Optional[str] = None):
"""装饰事件处理函数以便根据动态参数运行""" """装饰事件处理函数以便根据动态参数运行"""
self.func: T_Handler = func self.func: T_Handler = func
""" """
:类型: ``T_Handler`` :类型: ``T_Handler``
:说明: 事件处理函数 :说明: 事件处理函数
""" """
self.signature: inspect.Signature = self.get_signature() self.name = get_name(func) if name is None else name
"""
:类型: ``inspect.Signature``
:说明: 事件处理函数签名
"""
def __repr__(self) -> str:
return (f"<Handler {self.func.__name__}(bot: {self.bot_type}, "
f"event: {self.event_type}, state: {self.state_type}, "
f"matcher: {self.matcher_type})>")
def __str__(self) -> str:
return repr(self)
async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event",
state: T_State):
BotType = ((self.bot_type is not inspect.Parameter.empty) and
inspect.isclass(self.bot_type) and self.bot_type)
if BotType and not isinstance(bot, BotType):
logger.debug(
f"Matcher {matcher} bot type {type(bot)} not match annotation {BotType}, ignored"
)
return
EventType = ((self.event_type is not inspect.Parameter.empty) and
inspect.isclass(self.event_type) and self.event_type)
if EventType and not isinstance(event, EventType):
logger.debug(
f"Matcher {matcher} event type {type(event)} not match annotation {EventType}, ignored"
)
return
args = {"bot": bot, "event": event, "state": state, "matcher": matcher}
await self.func(
**{
k: v
for k, v in args.items()
if self.signature.parameters.get(k, None) is not None
})
@property
def bot_type(self) -> Union[Type["Bot"], inspect.Parameter.empty]:
"""
:类型: ``Union[Type["Bot"], inspect.Parameter.empty]``
:说明: 事件处理函数接受的 Bot 对象类型"""
return self.signature.parameters["bot"].annotation
@property
def event_type(
self) -> Optional[Union[Type["Event"], inspect.Parameter.empty]]:
"""
:类型: ``Optional[Union[Type[Event], inspect.Parameter.empty]]``
:说明: 事件处理函数接受的 event 类型 / 不需要 event 参数
"""
if "event" not in self.signature.parameters:
return None
return self.signature.parameters["event"].annotation
@property
def state_type(self) -> Optional[Union[T_State, inspect.Parameter.empty]]:
"""
:类型: ``Optional[Union[T_State, inspect.Parameter.empty]]``
:说明: 事件处理函数是否接受 state 参数
"""
if "state" not in self.signature.parameters:
return None
return self.signature.parameters["state"].annotation
@property
def matcher_type(
self) -> Optional[Union[Type["Matcher"], inspect.Parameter.empty]]:
"""
:类型: ``Optional[Union[Type["Matcher"], inspect.Parameter.empty]]``
:说明: 事件处理函数是否接受 matcher 参数
"""
if "matcher" not in self.signature.parameters:
return None
return self.signature.parameters["matcher"].annotation
def get_signature(self) -> inspect.Signature:
wrapped_signature = self._get_typed_signature()
signature = self._get_typed_signature(False)
self._check_params(signature)
self._check_bot_param(signature)
self._check_bot_param(wrapped_signature)
signature.parameters["bot"].replace(
annotation=wrapped_signature.parameters["bot"].annotation)
if "event" in wrapped_signature.parameters and "event" in signature.parameters:
signature.parameters["event"].replace(
annotation=wrapped_signature.parameters["event"].annotation)
return signature
def update_signature(
self, **kwargs: Union[None, Type["Bot"], Type["Event"], Type["Matcher"],
T_State, inspect.Parameter.empty]
) -> None:
params: List[inspect.Parameter] = []
for param in ["bot", "event", "state", "matcher"]:
sig = self.signature.parameters.get(param, None)
if param in kwargs:
sig = inspect.Parameter(param,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=kwargs[param])
if sig:
params.append(sig)
self.signature = inspect.Signature(params)
def _get_typed_signature(self,
follow_wrapped: bool = True) -> inspect.Signature:
signature = inspect.signature(self.func, follow_wrapped=follow_wrapped)
globalns = getattr(self.func, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=param.annotation if follow_wrapped else
self._get_typed_annotation(param, globalns),
) for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return typed_signature
def _get_typed_annotation(self, param: inspect.Parameter,
globalns: Dict[str, Any]) -> Any:
try:
if isinstance(param.annotation, str):
return _eval_type(ForwardRef(param.annotation), globalns,
globalns)
else:
return param.annotation
except Exception:
return param.annotation
def _check_params(self, signature: inspect.Signature):
if not set(signature.parameters.keys()) <= {
"bot", "event", "state", "matcher"
}:
raise ValueError(
"Handler param names must in `bot`/`event`/`state`/`matcher`")
def _check_bot_param(self, signature: inspect.Signature):
if not any(
param.name == "bot" for param in signature.parameters.values()):
raise ValueError("Handler missing parameter 'bot'")

View File

@ -143,20 +143,11 @@ T_PermissionChecker = Callable[["Bot", "Event"], Union[bool, Awaitable[bool]]]
RuleChecker 即判断是否响应消息的处理函数 RuleChecker 即判断是否响应消息的处理函数
""" """
T_Handler = Union[Callable[[Any, Any, Any, Any], Union[Awaitable[None], T_Handler = Callable[..., Union[Awaitable[None], Awaitable[NoReturn]]]
Awaitable[NoReturn]]],
Callable[[Any, Any, Any], Union[Awaitable[None],
Awaitable[NoReturn]]],
Callable[[Any, Any], Union[Awaitable[None],
Awaitable[NoReturn]]],
Callable[[Any], Union[Awaitable[None], Awaitable[NoReturn]]]]
""" """
:类型: :类型:
* ``Callable[[Bot, Event, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]`` * ``Callable[..., Union[Awaitable[None], Awaitable[NoReturn]]]``
* ``Callable[[Bot, Event], Union[Awaitable[None], Awaitable[NoReturn]]]``
* ``Callable[[Bot, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]``
* ``Callable[[Bot], Union[Awaitable[None], Awaitable[NoReturn]]]``
:说明: :说明:

View File

@ -1,6 +1,7 @@
import re import re
import json import json
import asyncio import asyncio
import inspect
import dataclasses import dataclasses
from functools import wraps, partial from functools import wraps, partial
from typing import Any, Callable, Optional, Awaitable from typing import Any, Callable, Optional, Awaitable
@ -51,6 +52,12 @@ def run_sync(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
return _wrapper return _wrapper
def get_name(obj: Any) -> str:
if inspect.isfunction(obj) or inspect.isclass(obj):
return obj.__name__
return obj.__class__.__name__
class DataclassEncoder(json.JSONEncoder): class DataclassEncoder(json.JSONEncoder):
""" """
:说明: :说明: