mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-30 22:46:40 +00:00 
			
		
		
		
	🚧 process handler dependency
This commit is contained in:
		
							
								
								
									
										2
									
								
								nonebot/dependencies/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								nonebot/dependencies/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| from .models import Depends as Depends | ||||
| from .utils import get_dependent as get_dependent | ||||
							
								
								
									
										89
									
								
								nonebot/dependencies/models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								nonebot/dependencies/models.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | ||||
| from enum import Enum | ||||
| from typing import Any, List, Callable, Optional | ||||
|  | ||||
| from pydantic.fields import Required, FieldInfo, ModelField | ||||
|  | ||||
| from nonebot.utils import get_name | ||||
|  | ||||
|  | ||||
| class Depends: | ||||
|  | ||||
|     def __init__(self, | ||||
|                  dependency: Optional[Callable[..., Any]] = None, | ||||
|                  *, | ||||
|                  use_cache: bool = True) -> None: | ||||
|         self.dependency = dependency | ||||
|         self.use_cache = use_cache | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         dep = get_name(self.dependency) | ||||
|         cache = "" if self.use_cache else ", use_cache=False" | ||||
|         return f"{self.__class__.__name__}({dep}{cache})" | ||||
|  | ||||
|  | ||||
| class Dependent: | ||||
|  | ||||
|     def __init__(self, | ||||
|                  *, | ||||
|                  func: Optional[Callable[..., Any]] = None, | ||||
|                  name: Optional[str] = None, | ||||
|                  bot_param: Optional[ModelField] = None, | ||||
|                  event_param: Optional[ModelField] = None, | ||||
|                  state_param: Optional[ModelField] = None, | ||||
|                  matcher_param: Optional[ModelField] = None, | ||||
|                  simple_params: Optional[List[ModelField]] = None, | ||||
|                  dependencies: Optional[List["Dependent"]] = None, | ||||
|                  use_cache: bool = True) -> None: | ||||
|         self.func = func | ||||
|         self.name = name | ||||
|         self.bot_param = bot_param | ||||
|         self.event_param = event_param | ||||
|         self.state_param = state_param | ||||
|         self.matcher_param = matcher_param | ||||
|         self.simple_params = simple_params or [] | ||||
|         self.dependencies = dependencies or [] | ||||
|         self.use_cache = use_cache | ||||
|         self.cache_key = (self.func,) | ||||
|  | ||||
|  | ||||
| class ParamTypes(Enum): | ||||
|     BOT = "bot" | ||||
|     EVENT = "event" | ||||
|     STATE = "state" | ||||
|     MATCHER = "matcher" | ||||
|     SIMPLE = "simple" | ||||
|  | ||||
|  | ||||
| class Param(FieldInfo): | ||||
|     in_: ParamTypes | ||||
|  | ||||
|     def __init__(self, default: Any): | ||||
|         super().__init__(default=default) | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         return f"{self.__class__.__name__}" | ||||
|  | ||||
|  | ||||
| class BotParam(Param): | ||||
|     in_ = ParamTypes.BOT | ||||
|  | ||||
|  | ||||
| class EventParam(Param): | ||||
|     in_ = ParamTypes.EVENT | ||||
|  | ||||
|  | ||||
| class StateParam(Param): | ||||
|     in_ = ParamTypes.STATE | ||||
|  | ||||
|  | ||||
| class MatcherParam(Param): | ||||
|     in_ = ParamTypes.MATCHER | ||||
|  | ||||
|  | ||||
| class SimpleParam(Param): | ||||
|     in_ = ParamTypes.SIMPLE | ||||
|  | ||||
|     def __init__(self, default: Any): | ||||
|         if default is Required: | ||||
|             raise ValueError("SimpleParam should be given a default value") | ||||
|         super().__init__(default) | ||||
							
								
								
									
										150
									
								
								nonebot/dependencies/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										150
									
								
								nonebot/dependencies/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,150 @@ | ||||
| import inspect | ||||
| from typing import Any, Dict, Type, Union, Callable, Optional, ForwardRef | ||||
|  | ||||
| from pydantic import BaseConfig | ||||
| from pydantic.class_validators import Validator | ||||
| from pydantic.typing import evaluate_forwardref | ||||
| from pydantic.schema import get_annotation_from_field_info | ||||
| from pydantic.fields import Required, FieldInfo, ModelField, UndefinedType | ||||
|  | ||||
| from .models import Param, Depends, Dependent, ParamTypes, SimpleParam | ||||
|  | ||||
|  | ||||
| def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: | ||||
|     signature = inspect.signature(call) | ||||
|     globalns = getattr(call, "__globals__", {}) | ||||
|     typed_params = [ | ||||
|         inspect.Parameter( | ||||
|             name=param.name, | ||||
|             kind=param.kind, | ||||
|             default=param.default, | ||||
|             annotation=get_typed_annotation(param, globalns), | ||||
|         ) for param in signature.parameters.values() | ||||
|     ] | ||||
|     typed_signature = inspect.Signature(typed_params) | ||||
|     return typed_signature | ||||
|  | ||||
|  | ||||
| def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, | ||||
|                                                                   Any]) -> Any: | ||||
|     annotation = param.annotation | ||||
|     if isinstance(annotation, str): | ||||
|         annotation = ForwardRef(annotation) | ||||
|         annotation = evaluate_forwardref(annotation, globalns, globalns) | ||||
|     return annotation | ||||
|  | ||||
|  | ||||
| def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent: | ||||
|     depends: Depends = param.default | ||||
|     if depends.dependency: | ||||
|         dependency = depends.dependency | ||||
|     else: | ||||
|         dependency = param.annotation | ||||
|     return get_sub_dependant( | ||||
|         depends=depends, | ||||
|         dependency=dependency, | ||||
|         name=param.name, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def get_parameterless_sub_dependant(*, depends: Depends) -> Dependent: | ||||
|     assert callable( | ||||
|         depends.dependency | ||||
|     ), "A parameter-less dependency must have a callable dependency" | ||||
|     return get_sub_dependant(depends=depends, dependency=depends.dependency) | ||||
|  | ||||
|  | ||||
| def get_sub_dependant( | ||||
|     *, | ||||
|     depends: Depends, | ||||
|     dependency: Callable[..., Any], | ||||
|     name: Optional[str] = None, | ||||
| ) -> Dependent: | ||||
|     sub_dependant = get_dependent( | ||||
|         func=dependency, | ||||
|         name=name, | ||||
|         use_cache=depends.use_cache, | ||||
|     ) | ||||
|     return sub_dependant | ||||
|  | ||||
|  | ||||
| def get_dependent(*, | ||||
|                   func: Callable[..., Any], | ||||
|                   name: Optional[str] = None, | ||||
|                   use_cache: bool = True) -> Dependent: | ||||
|     signature = get_typed_signature(func) | ||||
|     params = signature.parameters | ||||
|     dependent = Dependent(func=func, name=name, use_cache=use_cache) | ||||
|     for param_name, param in params.items(): | ||||
|         if isinstance(param.default, Depends): | ||||
|             sub_dependent = get_param_sub_dependent(param=param) | ||||
|             dependent.dependencies.append(sub_dependent) | ||||
|             continue | ||||
|         param_field = get_param_field(param=param, | ||||
|                                       param_name=param_name, | ||||
|                                       default_field_info=SimpleParam) | ||||
|  | ||||
|     return dependent | ||||
|  | ||||
|  | ||||
| def get_param_field(*, | ||||
|                     param: inspect.Parameter, | ||||
|                     param_name: str, | ||||
|                     default_field_info: Type[Param] = Param, | ||||
|                     force_type: Optional[ParamTypes] = None, | ||||
|                     ignore_default: bool = False) -> ModelField: | ||||
|     default_value = Required | ||||
|     if param.default != param.empty and not ignore_default: | ||||
|         default_value = param.default | ||||
|     if isinstance(default_value, FieldInfo): | ||||
|         field_info = default_value | ||||
|         default_value = field_info.default | ||||
|         if (isinstance(field_info, Param) and | ||||
|                 getattr(field_info, "in_", None) is None): | ||||
|             field_info.in_ = default_field_info.in_ | ||||
|         if force_type: | ||||
|             field_info.in_ = force_type  # type: ignore | ||||
|     else: | ||||
|         field_info = default_field_info(default_value) | ||||
|     required: bool = default_value == Required | ||||
|     annotation: Any = Any | ||||
|     if param.annotation != param.empty: | ||||
|         annotation = param.annotation | ||||
|     annotation = get_annotation_from_field_info(annotation, field_info, | ||||
|                                                 param_name) | ||||
|     if not field_info.alias and getattr(field_info, "convert_underscores", | ||||
|                                         None): | ||||
|         alias = param.name.replace("_", "-") | ||||
|     else: | ||||
|         alias = field_info.alias or param.name | ||||
|     field = create_field( | ||||
|         name=param.name, | ||||
|         type_=annotation, | ||||
|         default=None if required else default_value, | ||||
|         alias=alias, | ||||
|         required=required, | ||||
|         field_info=field_info, | ||||
|     ) | ||||
|     # field.required = required | ||||
|  | ||||
|     return field | ||||
|  | ||||
|  | ||||
| def create_field(name: str, | ||||
|                  type_: Type[Any], | ||||
|                  class_validators: Optional[Dict[str, Validator]] = None, | ||||
|                  default: Optional[Any] = None, | ||||
|                  required: Union[bool, UndefinedType] = False, | ||||
|                  model_config: Type[BaseConfig] = BaseConfig, | ||||
|                  field_info: Optional[FieldInfo] = None, | ||||
|                  alias: Optional[str] = None) -> ModelField: | ||||
|     class_validators = class_validators or {} | ||||
|     field_info = field_info or FieldInfo(None) | ||||
|     return ModelField(name=name, | ||||
|                       type_=type_, | ||||
|                       class_validators=class_validators, | ||||
|                       model_config=model_config, | ||||
|                       default=default, | ||||
|                       required=required, | ||||
|                       alias=alias, | ||||
|                       field_info=field_info) | ||||
| @@ -6,171 +6,22 @@ | ||||
| """ | ||||
|  | ||||
| import inspect | ||||
| from typing import _eval_type  # type: ignore | ||||
| from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Optional, | ||||
|                     ForwardRef) | ||||
| from typing import Optional | ||||
|  | ||||
| from nonebot.log import logger | ||||
| from nonebot.typing import T_State, T_Handler | ||||
| from pydantic.typing import evaluate_forwardref | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from nonebot.matcher import Matcher | ||||
|     from nonebot.adapters import Bot, Event | ||||
| from nonebot.utils import get_name | ||||
| from nonebot.typing import T_Handler | ||||
|  | ||||
|  | ||||
| class Handler: | ||||
|     """事件处理函数类""" | ||||
|  | ||||
|     def __init__(self, func: T_Handler): | ||||
|     def __init__(self, func: T_Handler, *, name: Optional[str] = None): | ||||
|         """装饰事件处理函数以便根据动态参数运行""" | ||||
|         self.func: T_Handler = func | ||||
|         """ | ||||
|         :类型: ``T_Handler`` | ||||
|         :说明: 事件处理函数 | ||||
|         """ | ||||
|         self.signature: inspect.Signature = self.get_signature() | ||||
|         """ | ||||
|         :类型: ``inspect.Signature`` | ||||
|         :说明: 事件处理函数签名 | ||||
|         """ | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         return (f"<Handler {self.func.__name__}(bot: {self.bot_type}, " | ||||
|                 f"event: {self.event_type}, state: {self.state_type}, " | ||||
|                 f"matcher: {self.matcher_type})>") | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return repr(self) | ||||
|  | ||||
|     async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event", | ||||
|                        state: T_State): | ||||
|         BotType = ((self.bot_type is not inspect.Parameter.empty) and | ||||
|                    inspect.isclass(self.bot_type) and self.bot_type) | ||||
|         if BotType and not isinstance(bot, BotType): | ||||
|             logger.debug( | ||||
|                 f"Matcher {matcher} bot type {type(bot)} not match annotation {BotType}, ignored" | ||||
|             ) | ||||
|             return | ||||
|  | ||||
|         EventType = ((self.event_type is not inspect.Parameter.empty) and | ||||
|                      inspect.isclass(self.event_type) and self.event_type) | ||||
|         if EventType and not isinstance(event, EventType): | ||||
|             logger.debug( | ||||
|                 f"Matcher {matcher} event type {type(event)} not match annotation {EventType}, ignored" | ||||
|             ) | ||||
|             return | ||||
|  | ||||
|         args = {"bot": bot, "event": event, "state": state, "matcher": matcher} | ||||
|         await self.func( | ||||
|             **{ | ||||
|                 k: v | ||||
|                 for k, v in args.items() | ||||
|                 if self.signature.parameters.get(k, None) is not None | ||||
|             }) | ||||
|  | ||||
|     @property | ||||
|     def bot_type(self) -> Union[Type["Bot"], inspect.Parameter.empty]: | ||||
|         """ | ||||
|         :类型: ``Union[Type["Bot"], inspect.Parameter.empty]`` | ||||
|         :说明: 事件处理函数接受的 Bot 对象类型""" | ||||
|         return self.signature.parameters["bot"].annotation | ||||
|  | ||||
|     @property | ||||
|     def event_type( | ||||
|             self) -> Optional[Union[Type["Event"], inspect.Parameter.empty]]: | ||||
|         """ | ||||
|         :类型: ``Optional[Union[Type[Event], inspect.Parameter.empty]]`` | ||||
|         :说明: 事件处理函数接受的 event 类型 / 不需要 event 参数 | ||||
|         """ | ||||
|         if "event" not in self.signature.parameters: | ||||
|             return None | ||||
|         return self.signature.parameters["event"].annotation | ||||
|  | ||||
|     @property | ||||
|     def state_type(self) -> Optional[Union[T_State, inspect.Parameter.empty]]: | ||||
|         """ | ||||
|         :类型: ``Optional[Union[T_State, inspect.Parameter.empty]]`` | ||||
|         :说明: 事件处理函数是否接受 state 参数 | ||||
|         """ | ||||
|         if "state" not in self.signature.parameters: | ||||
|             return None | ||||
|         return self.signature.parameters["state"].annotation | ||||
|  | ||||
|     @property | ||||
|     def matcher_type( | ||||
|             self) -> Optional[Union[Type["Matcher"], inspect.Parameter.empty]]: | ||||
|         """ | ||||
|         :类型: ``Optional[Union[Type["Matcher"], inspect.Parameter.empty]]`` | ||||
|         :说明: 事件处理函数是否接受 matcher 参数 | ||||
|         """ | ||||
|         if "matcher" not in self.signature.parameters: | ||||
|             return None | ||||
|         return self.signature.parameters["matcher"].annotation | ||||
|  | ||||
|     def get_signature(self) -> inspect.Signature: | ||||
|         wrapped_signature = self._get_typed_signature() | ||||
|         signature = self._get_typed_signature(False) | ||||
|         self._check_params(signature) | ||||
|         self._check_bot_param(signature) | ||||
|         self._check_bot_param(wrapped_signature) | ||||
|         signature.parameters["bot"].replace( | ||||
|             annotation=wrapped_signature.parameters["bot"].annotation) | ||||
|         if "event" in wrapped_signature.parameters and "event" in signature.parameters: | ||||
|             signature.parameters["event"].replace( | ||||
|                 annotation=wrapped_signature.parameters["event"].annotation) | ||||
|         return signature | ||||
|  | ||||
|     def update_signature( | ||||
|         self, **kwargs: Union[None, Type["Bot"], Type["Event"], Type["Matcher"], | ||||
|                               T_State, inspect.Parameter.empty] | ||||
|     ) -> None: | ||||
|         params: List[inspect.Parameter] = [] | ||||
|         for param in ["bot", "event", "state", "matcher"]: | ||||
|             sig = self.signature.parameters.get(param, None) | ||||
|             if param in kwargs: | ||||
|                 sig = inspect.Parameter(param, | ||||
|                                         inspect.Parameter.POSITIONAL_OR_KEYWORD, | ||||
|                                         annotation=kwargs[param]) | ||||
|             if sig: | ||||
|                 params.append(sig) | ||||
|  | ||||
|         self.signature = inspect.Signature(params) | ||||
|  | ||||
|     def _get_typed_signature(self, | ||||
|                              follow_wrapped: bool = True) -> inspect.Signature: | ||||
|         signature = inspect.signature(self.func, follow_wrapped=follow_wrapped) | ||||
|         globalns = getattr(self.func, "__globals__", {}) | ||||
|         typed_params = [ | ||||
|             inspect.Parameter( | ||||
|                 name=param.name, | ||||
|                 kind=param.kind, | ||||
|                 default=param.default, | ||||
|                 annotation=param.annotation if follow_wrapped else | ||||
|                 self._get_typed_annotation(param, globalns), | ||||
|             ) for param in signature.parameters.values() | ||||
|         ] | ||||
|         typed_signature = inspect.Signature(typed_params) | ||||
|         return typed_signature | ||||
|  | ||||
|     def _get_typed_annotation(self, param: inspect.Parameter, | ||||
|                               globalns: Dict[str, Any]) -> Any: | ||||
|         try: | ||||
|             if isinstance(param.annotation, str): | ||||
|                 return _eval_type(ForwardRef(param.annotation), globalns, | ||||
|                                   globalns) | ||||
|             else: | ||||
|                 return param.annotation | ||||
|         except Exception: | ||||
|             return param.annotation | ||||
|  | ||||
|     def _check_params(self, signature: inspect.Signature): | ||||
|         if not set(signature.parameters.keys()) <= { | ||||
|                 "bot", "event", "state", "matcher" | ||||
|         }: | ||||
|             raise ValueError( | ||||
|                 "Handler param names must in `bot`/`event`/`state`/`matcher`") | ||||
|  | ||||
|     def _check_bot_param(self, signature: inspect.Signature): | ||||
|         if not any( | ||||
|                 param.name == "bot" for param in signature.parameters.values()): | ||||
|             raise ValueError("Handler missing parameter 'bot'") | ||||
|         self.name = get_name(func) if name is None else name | ||||
|   | ||||
| @@ -143,20 +143,11 @@ T_PermissionChecker = Callable[["Bot", "Event"], Union[bool, Awaitable[bool]]] | ||||
|   RuleChecker 即判断是否响应消息的处理函数。 | ||||
| """ | ||||
|  | ||||
| T_Handler = Union[Callable[[Any, Any, Any, Any], Union[Awaitable[None], | ||||
|                                                        Awaitable[NoReturn]]], | ||||
|                   Callable[[Any, Any, Any], Union[Awaitable[None], | ||||
|                                                   Awaitable[NoReturn]]], | ||||
|                   Callable[[Any, Any], Union[Awaitable[None], | ||||
|                                              Awaitable[NoReturn]]], | ||||
|                   Callable[[Any], Union[Awaitable[None], Awaitable[NoReturn]]]] | ||||
| T_Handler = Callable[..., Union[Awaitable[None], Awaitable[NoReturn]]] | ||||
| """ | ||||
| :类型: | ||||
|  | ||||
|   * ``Callable[[Bot, Event, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]`` | ||||
|   * ``Callable[[Bot, Event], Union[Awaitable[None], Awaitable[NoReturn]]]`` | ||||
|   * ``Callable[[Bot, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]`` | ||||
|   * ``Callable[[Bot], Union[Awaitable[None], Awaitable[NoReturn]]]`` | ||||
|   * ``Callable[..., Union[Awaitable[None], Awaitable[NoReturn]]]`` | ||||
|  | ||||
| :说明: | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| import re | ||||
| import json | ||||
| import asyncio | ||||
| import inspect | ||||
| import dataclasses | ||||
| from functools import wraps, partial | ||||
| from typing import Any, Callable, Optional, Awaitable | ||||
| @@ -51,6 +52,12 @@ def run_sync(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: | ||||
|     return _wrapper | ||||
|  | ||||
|  | ||||
| def get_name(obj: Any) -> str: | ||||
|     if inspect.isfunction(obj) or inspect.isclass(obj): | ||||
|         return obj.__name__ | ||||
|     return obj.__class__.__name__ | ||||
|  | ||||
|  | ||||
| class DataclassEncoder(json.JSONEncoder): | ||||
|     """ | ||||
|     :说明: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user