mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-06-09 13:56:09 +00:00
🚧 process handler dependency
This commit is contained in:
parent
57e826a835
commit
c454cf0874
2
nonebot/dependencies/__init__.py
Normal file
2
nonebot/dependencies/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .models import Depends as Depends
|
||||||
|
from .utils import get_dependent as get_dependent
|
89
nonebot/dependencies/models.py
Normal file
89
nonebot/dependencies/models.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Any, List, Callable, Optional
|
||||||
|
|
||||||
|
from pydantic.fields import Required, FieldInfo, ModelField
|
||||||
|
|
||||||
|
from nonebot.utils import get_name
|
||||||
|
|
||||||
|
|
||||||
|
class Depends:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dependency: Optional[Callable[..., Any]] = None,
|
||||||
|
*,
|
||||||
|
use_cache: bool = True) -> None:
|
||||||
|
self.dependency = dependency
|
||||||
|
self.use_cache = use_cache
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
dep = get_name(self.dependency)
|
||||||
|
cache = "" if self.use_cache else ", use_cache=False"
|
||||||
|
return f"{self.__class__.__name__}({dep}{cache})"
|
||||||
|
|
||||||
|
|
||||||
|
class Dependent:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
func: Optional[Callable[..., Any]] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
bot_param: Optional[ModelField] = None,
|
||||||
|
event_param: Optional[ModelField] = None,
|
||||||
|
state_param: Optional[ModelField] = None,
|
||||||
|
matcher_param: Optional[ModelField] = None,
|
||||||
|
simple_params: Optional[List[ModelField]] = None,
|
||||||
|
dependencies: Optional[List["Dependent"]] = None,
|
||||||
|
use_cache: bool = True) -> None:
|
||||||
|
self.func = func
|
||||||
|
self.name = name
|
||||||
|
self.bot_param = bot_param
|
||||||
|
self.event_param = event_param
|
||||||
|
self.state_param = state_param
|
||||||
|
self.matcher_param = matcher_param
|
||||||
|
self.simple_params = simple_params or []
|
||||||
|
self.dependencies = dependencies or []
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.cache_key = (self.func,)
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTypes(Enum):
|
||||||
|
BOT = "bot"
|
||||||
|
EVENT = "event"
|
||||||
|
STATE = "state"
|
||||||
|
MATCHER = "matcher"
|
||||||
|
SIMPLE = "simple"
|
||||||
|
|
||||||
|
|
||||||
|
class Param(FieldInfo):
|
||||||
|
in_: ParamTypes
|
||||||
|
|
||||||
|
def __init__(self, default: Any):
|
||||||
|
super().__init__(default=default)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}"
|
||||||
|
|
||||||
|
|
||||||
|
class BotParam(Param):
|
||||||
|
in_ = ParamTypes.BOT
|
||||||
|
|
||||||
|
|
||||||
|
class EventParam(Param):
|
||||||
|
in_ = ParamTypes.EVENT
|
||||||
|
|
||||||
|
|
||||||
|
class StateParam(Param):
|
||||||
|
in_ = ParamTypes.STATE
|
||||||
|
|
||||||
|
|
||||||
|
class MatcherParam(Param):
|
||||||
|
in_ = ParamTypes.MATCHER
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleParam(Param):
|
||||||
|
in_ = ParamTypes.SIMPLE
|
||||||
|
|
||||||
|
def __init__(self, default: Any):
|
||||||
|
if default is Required:
|
||||||
|
raise ValueError("SimpleParam should be given a default value")
|
||||||
|
super().__init__(default)
|
150
nonebot/dependencies/utils.py
Normal file
150
nonebot/dependencies/utils.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import inspect
|
||||||
|
from typing import Any, Dict, Type, Union, Callable, Optional, ForwardRef
|
||||||
|
|
||||||
|
from pydantic import BaseConfig
|
||||||
|
from pydantic.class_validators import Validator
|
||||||
|
from pydantic.typing import evaluate_forwardref
|
||||||
|
from pydantic.schema import get_annotation_from_field_info
|
||||||
|
from pydantic.fields import Required, FieldInfo, ModelField, UndefinedType
|
||||||
|
|
||||||
|
from .models import Param, Depends, Dependent, ParamTypes, SimpleParam
|
||||||
|
|
||||||
|
|
||||||
|
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||||
|
signature = inspect.signature(call)
|
||||||
|
globalns = getattr(call, "__globals__", {})
|
||||||
|
typed_params = [
|
||||||
|
inspect.Parameter(
|
||||||
|
name=param.name,
|
||||||
|
kind=param.kind,
|
||||||
|
default=param.default,
|
||||||
|
annotation=get_typed_annotation(param, globalns),
|
||||||
|
) for param in signature.parameters.values()
|
||||||
|
]
|
||||||
|
typed_signature = inspect.Signature(typed_params)
|
||||||
|
return typed_signature
|
||||||
|
|
||||||
|
|
||||||
|
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str,
|
||||||
|
Any]) -> Any:
|
||||||
|
annotation = param.annotation
|
||||||
|
if isinstance(annotation, str):
|
||||||
|
annotation = ForwardRef(annotation)
|
||||||
|
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||||
|
return annotation
|
||||||
|
|
||||||
|
|
||||||
|
def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent:
|
||||||
|
depends: Depends = param.default
|
||||||
|
if depends.dependency:
|
||||||
|
dependency = depends.dependency
|
||||||
|
else:
|
||||||
|
dependency = param.annotation
|
||||||
|
return get_sub_dependant(
|
||||||
|
depends=depends,
|
||||||
|
dependency=dependency,
|
||||||
|
name=param.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameterless_sub_dependant(*, depends: Depends) -> Dependent:
|
||||||
|
assert callable(
|
||||||
|
depends.dependency
|
||||||
|
), "A parameter-less dependency must have a callable dependency"
|
||||||
|
return get_sub_dependant(depends=depends, dependency=depends.dependency)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sub_dependant(
|
||||||
|
*,
|
||||||
|
depends: Depends,
|
||||||
|
dependency: Callable[..., Any],
|
||||||
|
name: Optional[str] = None,
|
||||||
|
) -> Dependent:
|
||||||
|
sub_dependant = get_dependent(
|
||||||
|
func=dependency,
|
||||||
|
name=name,
|
||||||
|
use_cache=depends.use_cache,
|
||||||
|
)
|
||||||
|
return sub_dependant
|
||||||
|
|
||||||
|
|
||||||
|
def get_dependent(*,
|
||||||
|
func: Callable[..., Any],
|
||||||
|
name: Optional[str] = None,
|
||||||
|
use_cache: bool = True) -> Dependent:
|
||||||
|
signature = get_typed_signature(func)
|
||||||
|
params = signature.parameters
|
||||||
|
dependent = Dependent(func=func, name=name, use_cache=use_cache)
|
||||||
|
for param_name, param in params.items():
|
||||||
|
if isinstance(param.default, Depends):
|
||||||
|
sub_dependent = get_param_sub_dependent(param=param)
|
||||||
|
dependent.dependencies.append(sub_dependent)
|
||||||
|
continue
|
||||||
|
param_field = get_param_field(param=param,
|
||||||
|
param_name=param_name,
|
||||||
|
default_field_info=SimpleParam)
|
||||||
|
|
||||||
|
return dependent
|
||||||
|
|
||||||
|
|
||||||
|
def get_param_field(*,
|
||||||
|
param: inspect.Parameter,
|
||||||
|
param_name: str,
|
||||||
|
default_field_info: Type[Param] = Param,
|
||||||
|
force_type: Optional[ParamTypes] = None,
|
||||||
|
ignore_default: bool = False) -> ModelField:
|
||||||
|
default_value = Required
|
||||||
|
if param.default != param.empty and not ignore_default:
|
||||||
|
default_value = param.default
|
||||||
|
if isinstance(default_value, FieldInfo):
|
||||||
|
field_info = default_value
|
||||||
|
default_value = field_info.default
|
||||||
|
if (isinstance(field_info, Param) and
|
||||||
|
getattr(field_info, "in_", None) is None):
|
||||||
|
field_info.in_ = default_field_info.in_
|
||||||
|
if force_type:
|
||||||
|
field_info.in_ = force_type # type: ignore
|
||||||
|
else:
|
||||||
|
field_info = default_field_info(default_value)
|
||||||
|
required: bool = default_value == Required
|
||||||
|
annotation: Any = Any
|
||||||
|
if param.annotation != param.empty:
|
||||||
|
annotation = param.annotation
|
||||||
|
annotation = get_annotation_from_field_info(annotation, field_info,
|
||||||
|
param_name)
|
||||||
|
if not field_info.alias and getattr(field_info, "convert_underscores",
|
||||||
|
None):
|
||||||
|
alias = param.name.replace("_", "-")
|
||||||
|
else:
|
||||||
|
alias = field_info.alias or param.name
|
||||||
|
field = create_field(
|
||||||
|
name=param.name,
|
||||||
|
type_=annotation,
|
||||||
|
default=None if required else default_value,
|
||||||
|
alias=alias,
|
||||||
|
required=required,
|
||||||
|
field_info=field_info,
|
||||||
|
)
|
||||||
|
# field.required = required
|
||||||
|
|
||||||
|
return field
|
||||||
|
|
||||||
|
|
||||||
|
def create_field(name: str,
|
||||||
|
type_: Type[Any],
|
||||||
|
class_validators: Optional[Dict[str, Validator]] = None,
|
||||||
|
default: Optional[Any] = None,
|
||||||
|
required: Union[bool, UndefinedType] = False,
|
||||||
|
model_config: Type[BaseConfig] = BaseConfig,
|
||||||
|
field_info: Optional[FieldInfo] = None,
|
||||||
|
alias: Optional[str] = None) -> ModelField:
|
||||||
|
class_validators = class_validators or {}
|
||||||
|
field_info = field_info or FieldInfo(None)
|
||||||
|
return ModelField(name=name,
|
||||||
|
type_=type_,
|
||||||
|
class_validators=class_validators,
|
||||||
|
model_config=model_config,
|
||||||
|
default=default,
|
||||||
|
required=required,
|
||||||
|
alias=alias,
|
||||||
|
field_info=field_info)
|
@ -6,171 +6,22 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import _eval_type # type: ignore
|
from typing import Optional
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Optional,
|
|
||||||
ForwardRef)
|
|
||||||
|
|
||||||
from nonebot.log import logger
|
from pydantic.typing import evaluate_forwardref
|
||||||
from nonebot.typing import T_State, T_Handler
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
from nonebot.utils import get_name
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.typing import T_Handler
|
||||||
from nonebot.adapters import Bot, Event
|
|
||||||
|
|
||||||
|
|
||||||
class Handler:
|
class Handler:
|
||||||
"""事件处理函数类"""
|
"""事件处理函数类"""
|
||||||
|
|
||||||
def __init__(self, func: T_Handler):
|
def __init__(self, func: T_Handler, *, name: Optional[str] = None):
|
||||||
"""装饰事件处理函数以便根据动态参数运行"""
|
"""装饰事件处理函数以便根据动态参数运行"""
|
||||||
self.func: T_Handler = func
|
self.func: T_Handler = func
|
||||||
"""
|
"""
|
||||||
:类型: ``T_Handler``
|
:类型: ``T_Handler``
|
||||||
:说明: 事件处理函数
|
:说明: 事件处理函数
|
||||||
"""
|
"""
|
||||||
self.signature: inspect.Signature = self.get_signature()
|
self.name = get_name(func) if name is None else name
|
||||||
"""
|
|
||||||
:类型: ``inspect.Signature``
|
|
||||||
:说明: 事件处理函数签名
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (f"<Handler {self.func.__name__}(bot: {self.bot_type}, "
|
|
||||||
f"event: {self.event_type}, state: {self.state_type}, "
|
|
||||||
f"matcher: {self.matcher_type})>")
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return repr(self)
|
|
||||||
|
|
||||||
async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event",
|
|
||||||
state: T_State):
|
|
||||||
BotType = ((self.bot_type is not inspect.Parameter.empty) and
|
|
||||||
inspect.isclass(self.bot_type) and self.bot_type)
|
|
||||||
if BotType and not isinstance(bot, BotType):
|
|
||||||
logger.debug(
|
|
||||||
f"Matcher {matcher} bot type {type(bot)} not match annotation {BotType}, ignored"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
EventType = ((self.event_type is not inspect.Parameter.empty) and
|
|
||||||
inspect.isclass(self.event_type) and self.event_type)
|
|
||||||
if EventType and not isinstance(event, EventType):
|
|
||||||
logger.debug(
|
|
||||||
f"Matcher {matcher} event type {type(event)} not match annotation {EventType}, ignored"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
args = {"bot": bot, "event": event, "state": state, "matcher": matcher}
|
|
||||||
await self.func(
|
|
||||||
**{
|
|
||||||
k: v
|
|
||||||
for k, v in args.items()
|
|
||||||
if self.signature.parameters.get(k, None) is not None
|
|
||||||
})
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bot_type(self) -> Union[Type["Bot"], inspect.Parameter.empty]:
|
|
||||||
"""
|
|
||||||
:类型: ``Union[Type["Bot"], inspect.Parameter.empty]``
|
|
||||||
:说明: 事件处理函数接受的 Bot 对象类型"""
|
|
||||||
return self.signature.parameters["bot"].annotation
|
|
||||||
|
|
||||||
@property
|
|
||||||
def event_type(
|
|
||||||
self) -> Optional[Union[Type["Event"], inspect.Parameter.empty]]:
|
|
||||||
"""
|
|
||||||
:类型: ``Optional[Union[Type[Event], inspect.Parameter.empty]]``
|
|
||||||
:说明: 事件处理函数接受的 event 类型 / 不需要 event 参数
|
|
||||||
"""
|
|
||||||
if "event" not in self.signature.parameters:
|
|
||||||
return None
|
|
||||||
return self.signature.parameters["event"].annotation
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state_type(self) -> Optional[Union[T_State, inspect.Parameter.empty]]:
|
|
||||||
"""
|
|
||||||
:类型: ``Optional[Union[T_State, inspect.Parameter.empty]]``
|
|
||||||
:说明: 事件处理函数是否接受 state 参数
|
|
||||||
"""
|
|
||||||
if "state" not in self.signature.parameters:
|
|
||||||
return None
|
|
||||||
return self.signature.parameters["state"].annotation
|
|
||||||
|
|
||||||
@property
|
|
||||||
def matcher_type(
|
|
||||||
self) -> Optional[Union[Type["Matcher"], inspect.Parameter.empty]]:
|
|
||||||
"""
|
|
||||||
:类型: ``Optional[Union[Type["Matcher"], inspect.Parameter.empty]]``
|
|
||||||
:说明: 事件处理函数是否接受 matcher 参数
|
|
||||||
"""
|
|
||||||
if "matcher" not in self.signature.parameters:
|
|
||||||
return None
|
|
||||||
return self.signature.parameters["matcher"].annotation
|
|
||||||
|
|
||||||
def get_signature(self) -> inspect.Signature:
|
|
||||||
wrapped_signature = self._get_typed_signature()
|
|
||||||
signature = self._get_typed_signature(False)
|
|
||||||
self._check_params(signature)
|
|
||||||
self._check_bot_param(signature)
|
|
||||||
self._check_bot_param(wrapped_signature)
|
|
||||||
signature.parameters["bot"].replace(
|
|
||||||
annotation=wrapped_signature.parameters["bot"].annotation)
|
|
||||||
if "event" in wrapped_signature.parameters and "event" in signature.parameters:
|
|
||||||
signature.parameters["event"].replace(
|
|
||||||
annotation=wrapped_signature.parameters["event"].annotation)
|
|
||||||
return signature
|
|
||||||
|
|
||||||
def update_signature(
|
|
||||||
self, **kwargs: Union[None, Type["Bot"], Type["Event"], Type["Matcher"],
|
|
||||||
T_State, inspect.Parameter.empty]
|
|
||||||
) -> None:
|
|
||||||
params: List[inspect.Parameter] = []
|
|
||||||
for param in ["bot", "event", "state", "matcher"]:
|
|
||||||
sig = self.signature.parameters.get(param, None)
|
|
||||||
if param in kwargs:
|
|
||||||
sig = inspect.Parameter(param,
|
|
||||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
||||||
annotation=kwargs[param])
|
|
||||||
if sig:
|
|
||||||
params.append(sig)
|
|
||||||
|
|
||||||
self.signature = inspect.Signature(params)
|
|
||||||
|
|
||||||
def _get_typed_signature(self,
|
|
||||||
follow_wrapped: bool = True) -> inspect.Signature:
|
|
||||||
signature = inspect.signature(self.func, follow_wrapped=follow_wrapped)
|
|
||||||
globalns = getattr(self.func, "__globals__", {})
|
|
||||||
typed_params = [
|
|
||||||
inspect.Parameter(
|
|
||||||
name=param.name,
|
|
||||||
kind=param.kind,
|
|
||||||
default=param.default,
|
|
||||||
annotation=param.annotation if follow_wrapped else
|
|
||||||
self._get_typed_annotation(param, globalns),
|
|
||||||
) for param in signature.parameters.values()
|
|
||||||
]
|
|
||||||
typed_signature = inspect.Signature(typed_params)
|
|
||||||
return typed_signature
|
|
||||||
|
|
||||||
def _get_typed_annotation(self, param: inspect.Parameter,
|
|
||||||
globalns: Dict[str, Any]) -> Any:
|
|
||||||
try:
|
|
||||||
if isinstance(param.annotation, str):
|
|
||||||
return _eval_type(ForwardRef(param.annotation), globalns,
|
|
||||||
globalns)
|
|
||||||
else:
|
|
||||||
return param.annotation
|
|
||||||
except Exception:
|
|
||||||
return param.annotation
|
|
||||||
|
|
||||||
def _check_params(self, signature: inspect.Signature):
|
|
||||||
if not set(signature.parameters.keys()) <= {
|
|
||||||
"bot", "event", "state", "matcher"
|
|
||||||
}:
|
|
||||||
raise ValueError(
|
|
||||||
"Handler param names must in `bot`/`event`/`state`/`matcher`")
|
|
||||||
|
|
||||||
def _check_bot_param(self, signature: inspect.Signature):
|
|
||||||
if not any(
|
|
||||||
param.name == "bot" for param in signature.parameters.values()):
|
|
||||||
raise ValueError("Handler missing parameter 'bot'")
|
|
||||||
|
@ -143,20 +143,11 @@ T_PermissionChecker = Callable[["Bot", "Event"], Union[bool, Awaitable[bool]]]
|
|||||||
RuleChecker 即判断是否响应消息的处理函数。
|
RuleChecker 即判断是否响应消息的处理函数。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
T_Handler = Union[Callable[[Any, Any, Any, Any], Union[Awaitable[None],
|
T_Handler = Callable[..., Union[Awaitable[None], Awaitable[NoReturn]]]
|
||||||
Awaitable[NoReturn]]],
|
|
||||||
Callable[[Any, Any, Any], Union[Awaitable[None],
|
|
||||||
Awaitable[NoReturn]]],
|
|
||||||
Callable[[Any, Any], Union[Awaitable[None],
|
|
||||||
Awaitable[NoReturn]]],
|
|
||||||
Callable[[Any], Union[Awaitable[None], Awaitable[NoReturn]]]]
|
|
||||||
"""
|
"""
|
||||||
:类型:
|
:类型:
|
||||||
|
|
||||||
* ``Callable[[Bot, Event, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]``
|
* ``Callable[..., Union[Awaitable[None], Awaitable[NoReturn]]]``
|
||||||
* ``Callable[[Bot, Event], Union[Awaitable[None], Awaitable[NoReturn]]]``
|
|
||||||
* ``Callable[[Bot, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]``
|
|
||||||
* ``Callable[[Bot], Union[Awaitable[None], Awaitable[NoReturn]]]``
|
|
||||||
|
|
||||||
:说明:
|
:说明:
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from functools import wraps, partial
|
from functools import wraps, partial
|
||||||
from typing import Any, Callable, Optional, Awaitable
|
from typing import Any, Callable, Optional, Awaitable
|
||||||
@ -51,6 +52,12 @@ def run_sync(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
|
|||||||
return _wrapper
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def get_name(obj: Any) -> str:
|
||||||
|
if inspect.isfunction(obj) or inspect.isclass(obj):
|
||||||
|
return obj.__name__
|
||||||
|
return obj.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
class DataclassEncoder(json.JSONEncoder):
|
class DataclassEncoder(json.JSONEncoder):
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user