mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-09-12 23:16:39 +00:00
@ -278,6 +278,7 @@ def run(host: Optional[str] = None,
|
||||
get_driver().run(host, port, *args, **kwargs)
|
||||
|
||||
|
||||
import nonebot.params as params
|
||||
from nonebot.plugin import export as export
|
||||
from nonebot.plugin import require as require
|
||||
from nonebot.plugin import on_regex as on_regex
|
||||
|
@ -67,8 +67,7 @@ class CustomEnvSettings(EnvSettingsSource):
|
||||
env_val = settings.__config__.json_loads(env_val)
|
||||
except ValueError as e:
|
||||
raise SettingsError(
|
||||
f'error parsing JSON for "{env_name}"' # type: ignore
|
||||
) from e
|
||||
f'error parsing JSON for "{env_name}"') from e
|
||||
d[field.alias] = env_val
|
||||
|
||||
if env_file_vars:
|
||||
|
232
nonebot/dependencies/__init__.py
Normal file
232
nonebot/dependencies/__init__.py
Normal file
@ -0,0 +1,232 @@
|
||||
"""
|
||||
依赖注入处理模块
|
||||
===============
|
||||
|
||||
该模块实现了依赖注入的定义与处理。
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast
|
||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||
|
||||
from pydantic import BaseConfig
|
||||
from pydantic.schema import get_annotation_from_field_info
|
||||
from pydantic.fields import Required, Undefined, ModelField
|
||||
|
||||
from nonebot.log import logger
|
||||
from .models import Param as Param
|
||||
from .utils import get_typed_signature
|
||||
from .models import Dependent as Dependent
|
||||
from nonebot.exception import SkippedException
|
||||
from .models import DependsWrapper as DependsWrapper
|
||||
from nonebot.typing import T_Handler, T_DependencyCache
|
||||
from nonebot.utils import (CacheLock, run_sync, is_gen_callable,
|
||||
run_sync_ctx_manager, is_async_gen_callable,
|
||||
is_coroutine_callable)
|
||||
|
||||
cache_lock = CacheLock()
|
||||
|
||||
|
||||
class CustomConfig(BaseConfig):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
def get_param_sub_dependent(
|
||||
*,
|
||||
param: inspect.Parameter,
|
||||
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
|
||||
depends: DependsWrapper = param.default
|
||||
if depends.dependency:
|
||||
dependency = depends.dependency
|
||||
else:
|
||||
dependency = param.annotation
|
||||
return get_sub_dependant(depends=depends,
|
||||
dependency=dependency,
|
||||
name=param.name,
|
||||
allow_types=allow_types)
|
||||
|
||||
|
||||
def get_parameterless_sub_dependant(
|
||||
*,
|
||||
depends: DependsWrapper,
|
||||
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
|
||||
assert callable(
|
||||
depends.dependency
|
||||
), "A parameter-less dependency must have a callable dependency"
|
||||
return get_sub_dependant(depends=depends,
|
||||
dependency=depends.dependency,
|
||||
allow_types=allow_types)
|
||||
|
||||
|
||||
def get_sub_dependant(
|
||||
*,
|
||||
depends: DependsWrapper,
|
||||
dependency: T_Handler,
|
||||
name: Optional[str] = None,
|
||||
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
|
||||
sub_dependant = get_dependent(func=dependency,
|
||||
name=name,
|
||||
use_cache=depends.use_cache,
|
||||
allow_types=allow_types)
|
||||
return sub_dependant
|
||||
|
||||
|
||||
def get_dependent(*,
|
||||
func: T_Handler,
|
||||
name: Optional[str] = None,
|
||||
use_cache: bool = True,
|
||||
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
|
||||
signature = get_typed_signature(func)
|
||||
params = signature.parameters
|
||||
dependent = Dependent(func=func,
|
||||
name=name,
|
||||
allow_types=allow_types,
|
||||
use_cache=use_cache)
|
||||
for param_name, param in params.items():
|
||||
if isinstance(param.default, DependsWrapper):
|
||||
sub_dependent = get_param_sub_dependent(param=param,
|
||||
allow_types=allow_types)
|
||||
dependent.dependencies.append(sub_dependent)
|
||||
continue
|
||||
|
||||
default_value = Required
|
||||
if param.default != param.empty:
|
||||
default_value = param.default
|
||||
|
||||
if isinstance(default_value, Param):
|
||||
field_info = default_value
|
||||
default_value = field_info.default
|
||||
else:
|
||||
for allow_type in dependent.allow_types:
|
||||
if allow_type._check(param_name, param):
|
||||
field_info = allow_type(default_value)
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
|
||||
)
|
||||
|
||||
annotation: Any = Any
|
||||
required = default_value == Required
|
||||
if param.annotation != param.empty:
|
||||
annotation = param.annotation
|
||||
annotation = get_annotation_from_field_info(annotation, field_info,
|
||||
param_name)
|
||||
dependent.params.append(
|
||||
ModelField(name=param_name,
|
||||
type_=annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=None if required else default_value,
|
||||
required=required,
|
||||
field_info=field_info))
|
||||
|
||||
return dependent
|
||||
|
||||
|
||||
async def solve_dependencies(
|
||||
*,
|
||||
_dependent: Dependent,
|
||||
_stack: Optional[AsyncExitStack] = None,
|
||||
_sub_dependents: Optional[List[Dependent]] = None,
|
||||
_dependency_cache: Optional[T_DependencyCache] = None,
|
||||
**params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]:
|
||||
values: Dict[str, Any] = {}
|
||||
dependency_cache = {} if _dependency_cache is None else _dependency_cache
|
||||
|
||||
# solve sub dependencies
|
||||
sub_dependent: Dependent
|
||||
for sub_dependent in chain(_sub_dependents or tuple(),
|
||||
_dependent.dependencies):
|
||||
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
|
||||
sub_dependent.cache_key = cast(Callable[..., Any],
|
||||
sub_dependent.cache_key)
|
||||
func = sub_dependent.func
|
||||
|
||||
# solve sub dependency with current cache
|
||||
solved_result = await solve_dependencies(
|
||||
_dependent=sub_dependent,
|
||||
_dependency_cache=dependency_cache,
|
||||
**params)
|
||||
sub_values, sub_dependency_cache = solved_result
|
||||
# update cache?
|
||||
# dependency_cache.update(sub_dependency_cache)
|
||||
|
||||
# run dependency function
|
||||
async with cache_lock:
|
||||
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependent.cache_key]
|
||||
elif is_gen_callable(func) or is_async_gen_callable(func):
|
||||
assert isinstance(
|
||||
_stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
if is_gen_callable(func):
|
||||
cm = run_sync_ctx_manager(
|
||||
contextmanager(func)(**sub_values))
|
||||
else:
|
||||
cm = asynccontextmanager(func)(**sub_values)
|
||||
solved = await _stack.enter_async_context(cm)
|
||||
elif is_coroutine_callable(func):
|
||||
solved = await func(**sub_values)
|
||||
else:
|
||||
solved = await run_sync(func)(**sub_values)
|
||||
|
||||
# parameter dependency
|
||||
if sub_dependent.name is not None:
|
||||
values[sub_dependent.name] = solved
|
||||
# save current dependency to cache
|
||||
if sub_dependent.cache_key not in dependency_cache:
|
||||
dependency_cache[sub_dependent.cache_key] = solved
|
||||
|
||||
# usual dependency
|
||||
for field in _dependent.params:
|
||||
field_info = field.field_info
|
||||
assert isinstance(field_info,
|
||||
Param), "Params must be subclasses of Param"
|
||||
value = field_info._solve(**params)
|
||||
if value == Undefined:
|
||||
value = field.get_default()
|
||||
_, errs_ = field.validate(value,
|
||||
values,
|
||||
loc=(str(field_info), field.alias))
|
||||
if errs_:
|
||||
logger.debug(
|
||||
f"{field_info} "
|
||||
f"type {type(value)} not match depends {_dependent.func} "
|
||||
f"annotation {field._type_display()}, ignored")
|
||||
raise SkippedException
|
||||
else:
|
||||
values[field.name] = value
|
||||
|
||||
return values, dependency_cache
|
||||
|
||||
|
||||
def Depends(dependency: Optional[T_Handler] = None,
|
||||
*,
|
||||
use_cache: bool = True) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
参数依赖注入装饰器
|
||||
|
||||
:参数:
|
||||
|
||||
* ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数。默认为参数的类型注释。
|
||||
* ``use_cache: bool = True``: 是否使用缓存。默认为 ``True``。
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def depend_func() -> Any:
|
||||
return ...
|
||||
|
||||
def depend_gen_func():
|
||||
try:
|
||||
yield ...
|
||||
finally:
|
||||
...
|
||||
|
||||
async def handler(param_name: Any = Depends(depend_func), gen: Any = Depends(depend_gen_func)):
|
||||
...
|
||||
"""
|
||||
return DependsWrapper(dependency=dependency, use_cache=use_cache)
|
54
nonebot/dependencies/models.py
Normal file
54
nonebot/dependencies/models.py
Normal file
@ -0,0 +1,54 @@
|
||||
import abc
|
||||
import inspect
|
||||
from typing import Any, List, Type, Optional
|
||||
|
||||
from pydantic.fields import FieldInfo, ModelField
|
||||
|
||||
from nonebot.utils import get_name
|
||||
from nonebot.typing import T_Handler
|
||||
|
||||
|
||||
class Param(abc.ABC, FieldInfo):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def _check(cls, name: str, param: inspect.Parameter) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _solve(self, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DependsWrapper:
|
||||
|
||||
def __init__(self,
|
||||
dependency: Optional[T_Handler] = None,
|
||||
*,
|
||||
use_cache: bool = True) -> None:
|
||||
self.dependency = dependency
|
||||
self.use_cache = use_cache
|
||||
|
||||
def __repr__(self) -> str:
|
||||
dep = get_name(self.dependency)
|
||||
cache = "" if self.use_cache else ", use_cache=False"
|
||||
return f"{self.__class__.__name__}({dep}{cache})"
|
||||
|
||||
|
||||
class Dependent:
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
func: Optional[T_Handler] = None,
|
||||
name: Optional[str] = None,
|
||||
params: Optional[List[ModelField]] = None,
|
||||
allow_types: Optional[List[Type[Param]]] = None,
|
||||
dependencies: Optional[List["Dependent"]] = None,
|
||||
use_cache: bool = True) -> None:
|
||||
self.func = func
|
||||
self.name = name
|
||||
self.params = params or []
|
||||
self.allow_types = allow_types or []
|
||||
self.dependencies = dependencies or []
|
||||
self.use_cache = use_cache
|
||||
self.cache_key = self.func
|
37
nonebot/dependencies/utils.py
Normal file
37
nonebot/dependencies/utils.py
Normal file
@ -0,0 +1,37 @@
|
||||
import inspect
|
||||
from typing import Any, Dict
|
||||
|
||||
from loguru import logger
|
||||
from pydantic.typing import ForwardRef, evaluate_forwardref
|
||||
|
||||
from nonebot.typing import T_Handler
|
||||
|
||||
|
||||
def get_typed_signature(func: T_Handler) -> inspect.Signature:
|
||||
signature = inspect.signature(func)
|
||||
globalns = getattr(func, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=get_typed_annotation(param, globalns),
|
||||
) for param in signature.parameters.values()
|
||||
]
|
||||
typed_signature = inspect.Signature(typed_params)
|
||||
return typed_signature
|
||||
|
||||
|
||||
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str,
|
||||
Any]) -> Any:
|
||||
annotation = param.annotation
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
try:
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).warning(
|
||||
f"Unknown ForwardRef[\"{param.annotation}\"] for parameter {param.name}"
|
||||
)
|
||||
return inspect.Parameter.empty
|
||||
return annotation
|
@ -248,6 +248,8 @@ class Driver(ForwardDriver):
|
||||
await asyncio.sleep(3)
|
||||
continue
|
||||
|
||||
setup_ = cast(HTTPPollingSetup, setup_)
|
||||
|
||||
if not bot:
|
||||
request = await _build_request(setup_)
|
||||
if not request:
|
||||
@ -264,7 +266,6 @@ class Driver(ForwardDriver):
|
||||
bot.request = request
|
||||
|
||||
request = cast(HTTPRequest, request)
|
||||
setup_ = cast(HTTPPollingSetup, setup_)
|
||||
|
||||
headers = request.headers
|
||||
timeout = aiohttp.ClientTimeout(30)
|
||||
|
@ -409,6 +409,8 @@ class Driver(ReverseDriver, ForwardDriver):
|
||||
await asyncio.sleep(3)
|
||||
continue
|
||||
|
||||
setup_ = cast(HTTPPollingSetup, setup_)
|
||||
|
||||
if not bot:
|
||||
request = await _build_request(setup_)
|
||||
if not request:
|
||||
@ -423,7 +425,6 @@ class Driver(ReverseDriver, ForwardDriver):
|
||||
continue
|
||||
bot.request = request
|
||||
|
||||
setup_ = cast(HTTPPollingSetup, setup_)
|
||||
request = cast(HTTPRequest, request)
|
||||
headers = request.headers
|
||||
|
||||
|
@ -6,6 +6,8 @@
|
||||
这些异常并非所有需要用户处理,在 NoneBot 内部运行时被捕获,并进行对应操作。
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class NoneBotException(Exception):
|
||||
"""
|
||||
@ -13,9 +15,33 @@ class NoneBotException(Exception):
|
||||
|
||||
所有 NoneBot 发生的异常基类。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# Rule Exception
|
||||
class ParserExit(NoneBotException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
``shell command`` 处理消息失败时返回的异常
|
||||
|
||||
:参数:
|
||||
|
||||
* ``status``
|
||||
* ``message``
|
||||
"""
|
||||
|
||||
def __init__(self, status: int = 0, message: Optional[str] = None):
|
||||
self.status = status
|
||||
self.message = message
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ParserExit status={self.status} message={self.message}>"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
# Processor Exception
|
||||
class IgnoredException(NoneBotException):
|
||||
"""
|
||||
:说明:
|
||||
@ -37,71 +63,6 @@ class IgnoredException(NoneBotException):
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
class ParserExit(NoneBotException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
``shell command`` 处理消息失败时返回的异常
|
||||
|
||||
:参数:
|
||||
|
||||
* ``status``
|
||||
* ``message``
|
||||
"""
|
||||
|
||||
def __init__(self, status=0, message=None):
|
||||
self.status = status
|
||||
self.message = message
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ParserExit status={self.status} message={self.message}>"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
class PausedException(NoneBotException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后继续下一个 ``Handler``。
|
||||
可用于用户输入新信息。
|
||||
|
||||
:用法:
|
||||
|
||||
可以在 ``Handler`` 中通过 ``Matcher.pause()`` 抛出。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RejectedException(NoneBotException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后重新运行当前 ``Handler``。
|
||||
可用于用户重新输入。
|
||||
|
||||
:用法:
|
||||
|
||||
可以在 ``Handler`` 中通过 ``Matcher.reject()`` 抛出。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class FinishedException(NoneBotException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
指示 NoneBot 结束当前 ``Handler`` 且后续 ``Handler`` 不再被运行。
|
||||
可用于结束用户会话。
|
||||
|
||||
:用法:
|
||||
|
||||
可以在 ``Handler`` 中通过 ``Matcher.finish()`` 抛出。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StopPropagation(NoneBotException):
|
||||
"""
|
||||
:说明:
|
||||
@ -112,9 +73,69 @@ class StopPropagation(NoneBotException):
|
||||
|
||||
在 ``Matcher.block == True`` 时抛出。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# Matcher Exceptions
|
||||
class MatcherException(NoneBotException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
所有 Matcher 发生的异常基类。
|
||||
"""
|
||||
|
||||
|
||||
class SkippedException(MatcherException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
指示 NoneBot 立即结束当前 ``Handler`` 的处理,继续处理下一个 ``Handler``。
|
||||
|
||||
:用法:
|
||||
|
||||
可以在 ``Handler`` 中通过 ``Matcher.skip()`` 抛出。
|
||||
"""
|
||||
|
||||
|
||||
class PausedException(MatcherException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后继续下一个 ``Handler``。
|
||||
可用于用户输入新信息。
|
||||
|
||||
:用法:
|
||||
|
||||
可以在 ``Handler`` 中通过 ``Matcher.pause()`` 抛出。
|
||||
"""
|
||||
|
||||
|
||||
class RejectedException(MatcherException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后重新运行当前 ``Handler``。
|
||||
可用于用户重新输入。
|
||||
|
||||
:用法:
|
||||
|
||||
可以在 ``Handler`` 中通过 ``Matcher.reject()`` 抛出。
|
||||
"""
|
||||
|
||||
|
||||
class FinishedException(MatcherException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
指示 NoneBot 结束当前 ``Handler`` 且后续 ``Handler`` 不再被运行。
|
||||
可用于结束用户会话。
|
||||
|
||||
:用法:
|
||||
|
||||
可以在 ``Handler`` 中通过 ``Matcher.finish()`` 抛出。
|
||||
"""
|
||||
|
||||
|
||||
# Adapter Exceptions
|
||||
class AdapterException(NoneBotException):
|
||||
"""
|
||||
:说明:
|
||||
@ -130,7 +151,7 @@ class AdapterException(NoneBotException):
|
||||
self.adapter_name = adapter_name
|
||||
|
||||
|
||||
class NoLogException(Exception):
|
||||
class NoLogException(AdapterException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
|
@ -5,172 +5,114 @@
|
||||
该模块实现事件处理函数的封装,以实现动态参数等功能。
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from typing import _eval_type # type: ignore
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Optional,
|
||||
ForwardRef)
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any, Dict, List, Type, Callable, Optional
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import T_State, T_Handler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.utils import get_name, run_sync
|
||||
from nonebot.dependencies import (Param, Dependent, DependsWrapper,
|
||||
get_dependent, solve_dependencies,
|
||||
get_parameterless_sub_dependant)
|
||||
|
||||
|
||||
class Handler:
|
||||
"""事件处理函数类"""
|
||||
"""事件处理器类。支持依赖注入。"""
|
||||
|
||||
def __init__(self, func: T_Handler):
|
||||
"""装饰事件处理函数以便根据动态参数运行"""
|
||||
self.func: T_Handler = func
|
||||
def __init__(self,
|
||||
func: Callable[..., Any],
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
dependencies: Optional[List[DependsWrapper]] = None,
|
||||
allow_types: Optional[List[Type[Param]]] = None):
|
||||
"""
|
||||
:类型: ``T_Handler``
|
||||
:说明:
|
||||
|
||||
装饰一个函数为事件处理器。
|
||||
|
||||
:参数:
|
||||
|
||||
* ``func: Callable[..., Any]``: 事件处理函数。
|
||||
* ``name: Optional[str]``: 事件处理器名称。默认为函数名。
|
||||
* ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入。
|
||||
* ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型。
|
||||
"""
|
||||
self.func = func
|
||||
"""
|
||||
:类型: ``Callable[..., Any]``
|
||||
:说明: 事件处理函数
|
||||
"""
|
||||
self.signature: inspect.Signature = self.get_signature()
|
||||
self.name = get_name(func) if name is None else name
|
||||
"""
|
||||
:类型: ``inspect.Signature``
|
||||
:说明: 事件处理函数签名
|
||||
:类型: ``str``
|
||||
:说明: 事件处理函数名
|
||||
"""
|
||||
self.allow_types = allow_types or []
|
||||
"""
|
||||
:类型: ``List[Type[Param]]``
|
||||
:说明: 事件处理器允许的参数类型
|
||||
"""
|
||||
|
||||
self.dependencies = dependencies or []
|
||||
"""
|
||||
:类型: ``List[DependsWrapper]``
|
||||
:说明: 事件处理器的额外依赖
|
||||
"""
|
||||
self.sub_dependents: Dict[Callable[..., Any], Dependent] = {}
|
||||
if dependencies:
|
||||
for depends in dependencies:
|
||||
self.cache_dependent(depends)
|
||||
self.dependent = get_dependent(func=func, allow_types=self.allow_types)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"<Handler {self.func.__name__}(bot: {self.bot_type}, "
|
||||
f"event: {self.event_type}, state: {self.state_type}, "
|
||||
f"matcher: {self.matcher_type})>")
|
||||
return (
|
||||
f"<Handler {self.name}({', '.join(map(str, self.dependent.params))})>"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return repr(self)
|
||||
|
||||
async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event",
|
||||
state: T_State):
|
||||
BotType = ((self.bot_type is not inspect.Parameter.empty) and
|
||||
inspect.isclass(self.bot_type) and self.bot_type)
|
||||
if BotType and not isinstance(bot, BotType):
|
||||
logger.debug(
|
||||
f"Matcher {matcher} bot type {type(bot)} not match annotation {BotType}, ignored"
|
||||
)
|
||||
return
|
||||
async def __call__(self,
|
||||
*,
|
||||
_stack: Optional[AsyncExitStack] = None,
|
||||
_dependency_cache: Optional[Dict[Callable[..., Any],
|
||||
Any]] = None,
|
||||
**params) -> Any:
|
||||
values, _ = await solve_dependencies(
|
||||
_dependent=self.dependent,
|
||||
_stack=_stack,
|
||||
_sub_dependents=[
|
||||
self.sub_dependents[dependency.dependency] # type: ignore
|
||||
for dependency in self.dependencies
|
||||
],
|
||||
_dependency_cache=_dependency_cache,
|
||||
**params)
|
||||
|
||||
EventType = ((self.event_type is not inspect.Parameter.empty) and
|
||||
inspect.isclass(self.event_type) and self.event_type)
|
||||
if EventType and not isinstance(event, EventType):
|
||||
logger.debug(
|
||||
f"Matcher {matcher} event type {type(event)} not match annotation {EventType}, ignored"
|
||||
)
|
||||
return
|
||||
if asyncio.iscoroutinefunction(self.func):
|
||||
return await self.func(**values)
|
||||
else:
|
||||
return await run_sync(self.func)(**values)
|
||||
|
||||
args = {"bot": bot, "event": event, "state": state, "matcher": matcher}
|
||||
await self.func(
|
||||
**{
|
||||
k: v
|
||||
for k, v in args.items()
|
||||
if self.signature.parameters.get(k, None) is not None
|
||||
})
|
||||
def cache_dependent(self, dependency: DependsWrapper):
|
||||
if not dependency.dependency:
|
||||
raise ValueError(f"{dependency} has no dependency")
|
||||
if dependency.dependency in self.sub_dependents:
|
||||
raise ValueError(f"{dependency} is already in dependencies")
|
||||
sub_dependant = get_parameterless_sub_dependant(
|
||||
depends=dependency, allow_types=self.allow_types)
|
||||
self.sub_dependents[dependency.dependency] = sub_dependant
|
||||
|
||||
@property
|
||||
def bot_type(self) -> Union[Type["Bot"], inspect.Parameter.empty]:
|
||||
"""
|
||||
:类型: ``Union[Type["Bot"], inspect.Parameter.empty]``
|
||||
:说明: 事件处理函数接受的 Bot 对象类型"""
|
||||
return self.signature.parameters["bot"].annotation
|
||||
def prepend_dependency(self, dependency: DependsWrapper):
|
||||
self.cache_dependent(dependency)
|
||||
self.dependencies.insert(0, dependency)
|
||||
|
||||
@property
|
||||
def event_type(
|
||||
self) -> Optional[Union[Type["Event"], inspect.Parameter.empty]]:
|
||||
"""
|
||||
:类型: ``Optional[Union[Type[Event], inspect.Parameter.empty]]``
|
||||
:说明: 事件处理函数接受的 event 类型 / 不需要 event 参数
|
||||
"""
|
||||
if "event" not in self.signature.parameters:
|
||||
return None
|
||||
return self.signature.parameters["event"].annotation
|
||||
def append_dependency(self, dependency: DependsWrapper):
|
||||
self.cache_dependent(dependency)
|
||||
self.dependencies.append(dependency)
|
||||
|
||||
@property
|
||||
def state_type(self) -> Optional[Union[T_State, inspect.Parameter.empty]]:
|
||||
"""
|
||||
:类型: ``Optional[Union[T_State, inspect.Parameter.empty]]``
|
||||
:说明: 事件处理函数是否接受 state 参数
|
||||
"""
|
||||
if "state" not in self.signature.parameters:
|
||||
return None
|
||||
return self.signature.parameters["state"].annotation
|
||||
|
||||
@property
|
||||
def matcher_type(
|
||||
self) -> Optional[Union[Type["Matcher"], inspect.Parameter.empty]]:
|
||||
"""
|
||||
:类型: ``Optional[Union[Type["Matcher"], inspect.Parameter.empty]]``
|
||||
:说明: 事件处理函数是否接受 matcher 参数
|
||||
"""
|
||||
if "matcher" not in self.signature.parameters:
|
||||
return None
|
||||
return self.signature.parameters["matcher"].annotation
|
||||
|
||||
def get_signature(self) -> inspect.Signature:
|
||||
wrapped_signature = self._get_typed_signature()
|
||||
signature = self._get_typed_signature(False)
|
||||
self._check_params(signature)
|
||||
self._check_bot_param(signature)
|
||||
self._check_bot_param(wrapped_signature)
|
||||
signature.parameters["bot"].replace(
|
||||
annotation=wrapped_signature.parameters["bot"].annotation)
|
||||
if "event" in wrapped_signature.parameters and "event" in signature.parameters:
|
||||
signature.parameters["event"].replace(
|
||||
annotation=wrapped_signature.parameters["event"].annotation)
|
||||
return signature
|
||||
|
||||
def update_signature(
|
||||
self, **kwargs: Union[None, Type["Bot"], Type["Event"], Type["Matcher"],
|
||||
T_State, inspect.Parameter.empty]
|
||||
) -> None:
|
||||
params: List[inspect.Parameter] = []
|
||||
for param in ["bot", "event", "state", "matcher"]:
|
||||
sig = self.signature.parameters.get(param, None)
|
||||
if param in kwargs:
|
||||
sig = inspect.Parameter(param,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=kwargs[param])
|
||||
if sig:
|
||||
params.append(sig)
|
||||
|
||||
self.signature = inspect.Signature(params)
|
||||
|
||||
def _get_typed_signature(self,
|
||||
follow_wrapped: bool = True) -> inspect.Signature:
|
||||
signature = inspect.signature(self.func, follow_wrapped=follow_wrapped)
|
||||
globalns = getattr(self.func, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=param.annotation if follow_wrapped else
|
||||
self._get_typed_annotation(param, globalns),
|
||||
) for param in signature.parameters.values()
|
||||
]
|
||||
typed_signature = inspect.Signature(typed_params)
|
||||
return typed_signature
|
||||
|
||||
def _get_typed_annotation(self, param: inspect.Parameter,
|
||||
globalns: Dict[str, Any]) -> Any:
|
||||
try:
|
||||
if isinstance(param.annotation, str):
|
||||
return _eval_type(ForwardRef(param.annotation), globalns,
|
||||
globalns)
|
||||
else:
|
||||
return param.annotation
|
||||
except Exception:
|
||||
return param.annotation
|
||||
|
||||
def _check_params(self, signature: inspect.Signature):
|
||||
if not set(signature.parameters.keys()) <= {
|
||||
"bot", "event", "state", "matcher"
|
||||
}:
|
||||
raise ValueError(
|
||||
"Handler param names must in `bot`/`event`/`state`/`matcher`")
|
||||
|
||||
def _check_bot_param(self, signature: inspect.Signature):
|
||||
if not any(
|
||||
param.name == "bot" for param in signature.parameters.values()):
|
||||
raise ValueError("Handler missing parameter 'bot'")
|
||||
def remove_dependency(self, dependency: DependsWrapper):
|
||||
if not dependency.dependency:
|
||||
raise ValueError(f"{dependency} has no dependency")
|
||||
if dependency.dependency in self.sub_dependents:
|
||||
del self.sub_dependents[dependency.dependency]
|
||||
if dependency in self.dependencies:
|
||||
self.dependencies.remove(dependency)
|
||||
|
@ -5,35 +5,39 @@
|
||||
该模块实现事件响应器的创建与运行,并提供一些快捷方法来帮助用户更好的与机器人进行对话 。
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from types import ModuleType
|
||||
from datetime import datetime
|
||||
from contextvars import ContextVar
|
||||
from collections import defaultdict
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable,
|
||||
NoReturn, Optional)
|
||||
|
||||
from nonebot import params
|
||||
from nonebot.rule import Rule
|
||||
from nonebot.log import logger
|
||||
from nonebot.handler import Handler
|
||||
from nonebot.adapters import MessageTemplate
|
||||
from nonebot.dependencies import DependsWrapper
|
||||
from nonebot.permission import USER, Permission
|
||||
from nonebot.adapters import (Bot, Event, Message, MessageSegment,
|
||||
MessageTemplate)
|
||||
from nonebot.exception import (PausedException, StopPropagation,
|
||||
FinishedException, RejectedException)
|
||||
SkippedException, FinishedException,
|
||||
RejectedException)
|
||||
from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater,
|
||||
T_StateFactory, T_PermissionUpdater)
|
||||
T_StateFactory, T_DependencyCache,
|
||||
T_PermissionUpdater)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.plugin import Plugin
|
||||
from nonebot.adapters import Bot, Event, Message, MessageSegment
|
||||
|
||||
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
|
||||
"""
|
||||
:类型: ``Dict[int, List[Type[Matcher]]]``
|
||||
:说明: 用于存储当前所有的事件响应器
|
||||
"""
|
||||
current_bot: ContextVar["Bot"] = ContextVar("current_bot")
|
||||
current_event: ContextVar["Event"] = ContextVar("current_event")
|
||||
current_bot: ContextVar[Bot] = ContextVar("current_bot")
|
||||
current_event: ContextVar[Event] = ContextVar("current_event")
|
||||
current_state: ContextVar[T_State] = ContextVar("current_state")
|
||||
|
||||
|
||||
@ -152,6 +156,11 @@ class Matcher(metaclass=MatcherMeta):
|
||||
:说明: 事件响应器权限更新函数
|
||||
"""
|
||||
|
||||
HANDLER_PARAM_TYPES = [
|
||||
params.BotParam, params.EventParam, params.StateParam,
|
||||
params.MatcherParam, params.DefaultParam
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""实例化 Matcher 以便运行"""
|
||||
self.handlers = self.handlers.copy()
|
||||
@ -228,8 +237,8 @@ class Matcher(metaclass=MatcherMeta):
|
||||
"permission":
|
||||
permission or Permission(),
|
||||
"handlers": [
|
||||
handler
|
||||
if isinstance(handler, Handler) else Handler(handler)
|
||||
handler if isinstance(handler, Handler) else Handler(
|
||||
handler, allow_types=cls.HANDLER_PARAM_TYPES)
|
||||
for handler in handlers
|
||||
] if handlers else [],
|
||||
"temp":
|
||||
@ -258,7 +267,13 @@ class Matcher(metaclass=MatcherMeta):
|
||||
return NewMatcher
|
||||
|
||||
@classmethod
|
||||
async def check_perm(cls, bot: "Bot", event: "Event") -> bool:
|
||||
async def check_perm(
|
||||
cls,
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[Dict[Callable[..., Any],
|
||||
Any]] = None) -> bool:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -275,11 +290,17 @@ class Matcher(metaclass=MatcherMeta):
|
||||
"""
|
||||
event_type = event.get_type()
|
||||
return (event_type == (cls.type or event_type) and
|
||||
await cls.permission(bot, event))
|
||||
await cls.permission(bot, event, stack, dependency_cache))
|
||||
|
||||
@classmethod
|
||||
async def check_rule(cls, bot: "Bot", event: "Event",
|
||||
state: T_State) -> bool:
|
||||
async def check_rule(
|
||||
cls,
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
state: T_State,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[Dict[Callable[..., Any],
|
||||
Any]] = None) -> bool:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -297,7 +318,7 @@ class Matcher(metaclass=MatcherMeta):
|
||||
"""
|
||||
event_type = event.get_type()
|
||||
return (event_type == (cls.type or event_type) and
|
||||
await cls.rule(bot, event, state))
|
||||
await cls.rule(bot, event, state, stack, dependency_cache))
|
||||
|
||||
@classmethod
|
||||
def args_parser(cls, func: T_ArgsParser) -> T_ArgsParser:
|
||||
@ -343,8 +364,13 @@ class Matcher(metaclass=MatcherMeta):
|
||||
return func
|
||||
|
||||
@classmethod
|
||||
def append_handler(cls, handler: T_Handler) -> Handler:
|
||||
handler_ = Handler(handler)
|
||||
def append_handler(
|
||||
cls,
|
||||
handler: T_Handler,
|
||||
dependencies: Optional[List[DependsWrapper]] = None) -> Handler:
|
||||
handler_ = Handler(handler,
|
||||
dependencies=dependencies,
|
||||
allow_types=cls.HANDLER_PARAM_TYPES)
|
||||
cls.handlers.append(handler_)
|
||||
return handler_
|
||||
|
||||
@ -378,22 +404,22 @@ class Matcher(metaclass=MatcherMeta):
|
||||
* 无
|
||||
"""
|
||||
|
||||
async def _receive(bot: "Bot", event: "Event") -> NoReturn:
|
||||
raise PausedException
|
||||
|
||||
if cls.handlers:
|
||||
# 已有前置handlers则接受一条新的消息,否则视为接收初始消息
|
||||
receive_handler = cls.append_handler(_receive)
|
||||
else:
|
||||
receive_handler = None
|
||||
async def _receive(state: T_State) -> Union[None, NoReturn]:
|
||||
if state.get(_receive):
|
||||
return
|
||||
state[_receive] = True
|
||||
raise RejectedException
|
||||
|
||||
def _decorator(func: T_Handler) -> T_Handler:
|
||||
if not cls.handlers or cls.handlers[-1] is not func:
|
||||
func_handler = cls.append_handler(func)
|
||||
if receive_handler:
|
||||
receive_handler.update_signature(
|
||||
bot=func_handler.bot_type,
|
||||
event=func_handler.event_type)
|
||||
|
||||
depend = DependsWrapper(_receive)
|
||||
|
||||
if cls.handlers and cls.handlers[-1].func is func:
|
||||
func_handler = cls.handlers[-1]
|
||||
func_handler.prepend_dependency(depend)
|
||||
else:
|
||||
cls.append_handler(
|
||||
func, dependencies=[depend] if cls.handlers else [])
|
||||
|
||||
return func
|
||||
|
||||
@ -403,7 +429,7 @@ class Matcher(metaclass=MatcherMeta):
|
||||
def got(
|
||||
cls,
|
||||
key: str,
|
||||
prompt: Optional[Union[str, "Message", "MessageSegment",
|
||||
prompt: Optional[Union[str, Message, MessageSegment,
|
||||
MessageTemplate]] = None,
|
||||
args_parser: Optional[T_ArgsParser] = None
|
||||
) -> Callable[[T_Handler], T_Handler]:
|
||||
@ -419,8 +445,12 @@ class Matcher(metaclass=MatcherMeta):
|
||||
* ``args_parser: Optional[T_ArgsParser]``: 可选参数解析函数,空则使用默认解析函数
|
||||
"""
|
||||
|
||||
async def _key_getter(bot: "Bot", event: "Event", state: T_State):
|
||||
async def _key_getter(bot: Bot, event: Event, state: T_State):
|
||||
if state.get(f"_{key}_prompted"):
|
||||
return
|
||||
|
||||
state["_current_key"] = key
|
||||
state[f"_{key}_prompted"] = True
|
||||
if key not in state:
|
||||
if prompt is not None:
|
||||
if isinstance(prompt, MessageTemplate):
|
||||
@ -428,52 +458,40 @@ class Matcher(metaclass=MatcherMeta):
|
||||
else:
|
||||
_prompt = prompt
|
||||
await bot.send(event=event, message=_prompt)
|
||||
raise PausedException
|
||||
raise RejectedException
|
||||
else:
|
||||
state["_skip_key"] = True
|
||||
state[f"_{key}_parsed"] = True
|
||||
|
||||
async def _key_parser(bot: "Bot", event: "Event", state: T_State):
|
||||
if key in state and state.get("_skip_key"):
|
||||
del state["_skip_key"]
|
||||
async def _key_parser(bot: Bot, event: Event, state: T_State):
|
||||
if key in state and state.get(f"_{key}_parsed"):
|
||||
return
|
||||
|
||||
parser = args_parser or cls._default_parser
|
||||
if parser:
|
||||
# parser = cast(T_ArgsParser["Bot", "Event"], parser)
|
||||
await parser(bot, event, state)
|
||||
else:
|
||||
state[state["_current_key"]] = str(event.get_message())
|
||||
|
||||
getter_handler = cls.append_handler(_key_getter)
|
||||
parser_handler = cls.append_handler(_key_parser)
|
||||
state[key] = str(event.get_message())
|
||||
state[f"_{key}_parsed"] = True
|
||||
|
||||
def _decorator(func: T_Handler) -> T_Handler:
|
||||
if not hasattr(cls.handlers[-1].func, "__wrapped__"):
|
||||
parser = cls.handlers.pop()
|
||||
func_handler = Handler(func)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(bot: "Bot", event: "Event", state: T_State,
|
||||
matcher: Matcher):
|
||||
await parser(matcher, bot, event, state)
|
||||
await func_handler(matcher, bot, event, state)
|
||||
if "_current_key" in state:
|
||||
del state["_current_key"]
|
||||
get_depend = DependsWrapper(_key_getter)
|
||||
parser_depend = DependsWrapper(_key_parser)
|
||||
|
||||
wrapper_handler = cls.append_handler(wrapper)
|
||||
|
||||
getter_handler.update_signature(
|
||||
bot=wrapper_handler.bot_type,
|
||||
event=wrapper_handler.event_type)
|
||||
parser_handler.update_signature(
|
||||
bot=wrapper_handler.bot_type,
|
||||
event=wrapper_handler.event_type)
|
||||
if cls.handlers and cls.handlers[-1].func is func:
|
||||
func_handler = cls.handlers[-1]
|
||||
func_handler.prepend_dependency(parser_depend)
|
||||
func_handler.prepend_dependency(get_depend)
|
||||
else:
|
||||
cls.append_handler(func,
|
||||
dependencies=[get_depend, parser_depend])
|
||||
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
||||
@classmethod
|
||||
async def send(cls, message: Union[str, "Message", "MessageSegment",
|
||||
async def send(cls, message: Union[str, Message, MessageSegment,
|
||||
MessageTemplate], **kwargs) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
@ -496,7 +514,7 @@ class Matcher(metaclass=MatcherMeta):
|
||||
|
||||
@classmethod
|
||||
async def finish(cls,
|
||||
message: Optional[Union[str, "Message", "MessageSegment",
|
||||
message: Optional[Union[str, Message, MessageSegment,
|
||||
MessageTemplate]] = None,
|
||||
**kwargs) -> NoReturn:
|
||||
"""
|
||||
@ -522,7 +540,7 @@ class Matcher(metaclass=MatcherMeta):
|
||||
|
||||
@classmethod
|
||||
async def pause(cls,
|
||||
prompt: Optional[Union[str, "Message", "MessageSegment",
|
||||
prompt: Optional[Union[str, Message, MessageSegment,
|
||||
MessageTemplate]] = None,
|
||||
**kwargs) -> NoReturn:
|
||||
"""
|
||||
@ -548,8 +566,8 @@ class Matcher(metaclass=MatcherMeta):
|
||||
|
||||
@classmethod
|
||||
async def reject(cls,
|
||||
prompt: Optional[Union[str, "Message",
|
||||
"MessageSegment"]] = None,
|
||||
prompt: Optional[Union[str, Message,
|
||||
MessageSegment]] = None,
|
||||
**kwargs) -> NoReturn:
|
||||
"""
|
||||
:说明:
|
||||
@ -564,6 +582,8 @@ class Matcher(metaclass=MatcherMeta):
|
||||
bot = current_bot.get()
|
||||
event = current_event.get()
|
||||
state = current_state.get()
|
||||
if "_current_key" in state and f"_{state['_current_key']}_parsed" in state:
|
||||
del state[f"_{state['_current_key']}_parsed"]
|
||||
if isinstance(prompt, MessageTemplate):
|
||||
_prompt = prompt.format(**state)
|
||||
else:
|
||||
@ -581,7 +601,12 @@ class Matcher(metaclass=MatcherMeta):
|
||||
self.block = True
|
||||
|
||||
# 运行handlers
|
||||
async def run(self, bot: "Bot", event: "Event", state: T_State):
|
||||
async def run(self,
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
state: T_State,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None):
|
||||
b_t = current_bot.set(bot)
|
||||
e_t = current_event.set(event)
|
||||
s_t = current_state.set(self.state)
|
||||
@ -594,7 +619,15 @@ class Matcher(metaclass=MatcherMeta):
|
||||
while self.handlers:
|
||||
handler = self.handlers.pop(0)
|
||||
logger.debug(f"Running handler {handler}")
|
||||
await handler(self, bot, event, self.state)
|
||||
try:
|
||||
await handler(matcher=self,
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=self.state,
|
||||
_stack=stack,
|
||||
_dependency_cache=dependency_cache)
|
||||
except SkippedException:
|
||||
pass
|
||||
|
||||
except RejectedException:
|
||||
self.handlers.insert(0, handler) # type: ignore
|
||||
@ -610,11 +643,8 @@ class Matcher(metaclass=MatcherMeta):
|
||||
|
||||
updater = self.__class__._default_permission_updater
|
||||
if updater:
|
||||
permission = await updater(
|
||||
bot,
|
||||
event,
|
||||
self.state, # type: ignore
|
||||
self.permission)
|
||||
permission = await updater(bot, event, self.state,
|
||||
self.permission)
|
||||
else:
|
||||
permission = USER(event.get_session_id(), perm=self.permission)
|
||||
|
||||
@ -647,11 +677,8 @@ class Matcher(metaclass=MatcherMeta):
|
||||
|
||||
updater = self.__class__._default_permission_updater
|
||||
if updater:
|
||||
permission = await updater(
|
||||
bot,
|
||||
event,
|
||||
self.state, # type: ignore
|
||||
self.permission)
|
||||
permission = await updater(bot, event, self.state,
|
||||
self.permission)
|
||||
else:
|
||||
permission = USER(event.get_session_id(), perm=self.permission)
|
||||
|
||||
|
@ -7,23 +7,39 @@ NoneBot 内部处理并按优先级分发事件给所有事件响应器,提供
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Set, Type, Optional
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Optional
|
||||
|
||||
from nonebot import params
|
||||
from nonebot.log import logger
|
||||
from nonebot.rule import TrieRule
|
||||
from nonebot.handler import Handler
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.matcher import Matcher, matchers
|
||||
from nonebot.exception import NoLogException, StopPropagation, IgnoredException
|
||||
from nonebot.typing import (T_State, T_RunPreProcessor, T_RunPostProcessor,
|
||||
T_EventPreProcessor, T_EventPostProcessor)
|
||||
from nonebot.typing import (T_State, T_DependencyCache, T_RunPreProcessor,
|
||||
T_RunPostProcessor, T_EventPreProcessor,
|
||||
T_EventPostProcessor)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.adapters import Bot, Event
|
||||
|
||||
_event_preprocessors: Set[T_EventPreProcessor] = set()
|
||||
_event_postprocessors: Set[T_EventPostProcessor] = set()
|
||||
_run_preprocessors: Set[T_RunPreProcessor] = set()
|
||||
_run_postprocessors: Set[T_RunPostProcessor] = set()
|
||||
_event_preprocessors: Set[Handler] = set()
|
||||
_event_postprocessors: Set[Handler] = set()
|
||||
_run_preprocessors: Set[Handler] = set()
|
||||
_run_postprocessors: Set[Handler] = set()
|
||||
|
||||
EVENT_PCS_PARAMS = [
|
||||
params.BotParam, params.EventParam, params.StateParam, params.DefaultParam
|
||||
]
|
||||
RUN_PREPCS_PARAMS = [
|
||||
params.MatcherParam, params.BotParam, params.EventParam, params.StateParam,
|
||||
params.DefaultParam
|
||||
]
|
||||
RUN_POSTPCS_PARAMS = [
|
||||
params.MatcherParam, params.ExceptionParam, params.BotParam,
|
||||
params.EventParam, params.StateParam, params.DefaultParam
|
||||
]
|
||||
|
||||
|
||||
def event_preprocessor(func: T_EventPreProcessor) -> T_EventPreProcessor:
|
||||
@ -31,16 +47,8 @@ def event_preprocessor(func: T_EventPreProcessor) -> T_EventPreProcessor:
|
||||
:说明:
|
||||
|
||||
事件预处理。装饰一个函数,使它在每次接收到事件并分发给各响应器之前执行。
|
||||
|
||||
:参数:
|
||||
|
||||
事件预处理函数接收三个参数。
|
||||
|
||||
* ``bot: Bot``: Bot 对象
|
||||
* ``event: Event``: Event 对象
|
||||
* ``state: T_State``: 当前 State
|
||||
"""
|
||||
_event_preprocessors.add(func)
|
||||
_event_preprocessors.add(Handler(func, allow_types=EVENT_PCS_PARAMS))
|
||||
return func
|
||||
|
||||
|
||||
@ -49,16 +57,8 @@ def event_postprocessor(func: T_EventPostProcessor) -> T_EventPostProcessor:
|
||||
:说明:
|
||||
|
||||
事件后处理。装饰一个函数,使它在每次接收到事件并分发给各响应器之后执行。
|
||||
|
||||
:参数:
|
||||
|
||||
事件后处理函数接收三个参数。
|
||||
|
||||
* ``bot: Bot``: Bot 对象
|
||||
* ``event: Event``: Event 对象
|
||||
* ``state: T_State``: 当前事件运行前 State
|
||||
"""
|
||||
_event_postprocessors.add(func)
|
||||
_event_postprocessors.add(Handler(func, allow_types=EVENT_PCS_PARAMS))
|
||||
return func
|
||||
|
||||
|
||||
@ -67,17 +67,8 @@ def run_preprocessor(func: T_RunPreProcessor) -> T_RunPreProcessor:
|
||||
:说明:
|
||||
|
||||
运行预处理。装饰一个函数,使它在每次事件响应器运行前执行。
|
||||
|
||||
:参数:
|
||||
|
||||
运行预处理函数接收四个参数。
|
||||
|
||||
* ``matcher: Matcher``: 当前要运行的事件响应器
|
||||
* ``bot: Bot``: Bot 对象
|
||||
* ``event: Event``: Event 对象
|
||||
* ``state: T_State``: 当前 State
|
||||
"""
|
||||
_run_preprocessors.add(func)
|
||||
_run_preprocessors.add(Handler(func, allow_types=RUN_PREPCS_PARAMS))
|
||||
return func
|
||||
|
||||
|
||||
@ -86,23 +77,19 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
|
||||
:说明:
|
||||
|
||||
运行后处理。装饰一个函数,使它在每次事件响应器运行后执行。
|
||||
|
||||
:参数:
|
||||
|
||||
运行后处理函数接收五个参数。
|
||||
|
||||
* ``matcher: Matcher``: 运行完毕的事件响应器
|
||||
* ``exception: Optional[Exception]``: 事件响应器运行错误(如果存在)
|
||||
* ``bot: Bot``: Bot 对象
|
||||
* ``event: Event``: Event 对象
|
||||
* ``state: T_State``: 当前 State
|
||||
"""
|
||||
_run_postprocessors.add(func)
|
||||
_run_postprocessors.add(Handler(func, allow_types=RUN_POSTPCS_PARAMS))
|
||||
return func
|
||||
|
||||
|
||||
async def _check_matcher(priority: int, Matcher: Type[Matcher], bot: "Bot",
|
||||
event: "Event", state: T_State) -> None:
|
||||
async def _check_matcher(
|
||||
priority: int,
|
||||
Matcher: Type[Matcher],
|
||||
bot: "Bot",
|
||||
event: "Event",
|
||||
state: T_State,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None) -> None:
|
||||
if Matcher.expire_time and datetime.now() > Matcher.expire_time:
|
||||
try:
|
||||
matchers[priority].remove(Matcher)
|
||||
@ -112,7 +99,9 @@ async def _check_matcher(priority: int, Matcher: Type[Matcher], bot: "Bot",
|
||||
|
||||
try:
|
||||
if not await Matcher.check_perm(
|
||||
bot, event) or not await Matcher.check_rule(bot, event, state):
|
||||
bot, event, stack,
|
||||
dependency_cache) or not await Matcher.check_rule(
|
||||
bot, event, state, stack, dependency_cache):
|
||||
return
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
@ -125,17 +114,29 @@ async def _check_matcher(priority: int, Matcher: Type[Matcher], bot: "Bot",
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await _run_matcher(Matcher, bot, event, state)
|
||||
await _run_matcher(Matcher, bot, event, state, stack, dependency_cache)
|
||||
|
||||
|
||||
async def _run_matcher(Matcher: Type[Matcher], bot: "Bot", event: "Event",
|
||||
state: T_State) -> None:
|
||||
async def _run_matcher(
|
||||
Matcher: Type[Matcher],
|
||||
bot: "Bot",
|
||||
event: "Event",
|
||||
state: T_State,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None) -> None:
|
||||
logger.info(f"Event will be handled by {Matcher}")
|
||||
|
||||
matcher = Matcher()
|
||||
|
||||
coros = list(
|
||||
map(lambda x: x(matcher, bot, event, state), _run_preprocessors))
|
||||
map(
|
||||
lambda x: x(matcher=matcher,
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
_stack=stack,
|
||||
_dependency_cache=dependency_cache),
|
||||
_run_preprocessors))
|
||||
if coros:
|
||||
try:
|
||||
await asyncio.gather(*coros)
|
||||
@ -153,7 +154,7 @@ async def _run_matcher(Matcher: Type[Matcher], bot: "Bot", event: "Event",
|
||||
|
||||
try:
|
||||
logger.debug(f"Running matcher {matcher}")
|
||||
await matcher.run(bot, event, state)
|
||||
await matcher.run(bot, event, state, stack, dependency_cache)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f"<r><bg #f8bbd0>Running matcher {matcher} failed.</bg #f8bbd0></r>"
|
||||
@ -161,7 +162,14 @@ async def _run_matcher(Matcher: Type[Matcher], bot: "Bot", event: "Event",
|
||||
exception = e
|
||||
|
||||
coros = list(
|
||||
map(lambda x: x(matcher, exception, bot, event, state),
|
||||
map(
|
||||
lambda x: x(matcher=matcher,
|
||||
exception=exception,
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
_stack=stack,
|
||||
_dependency_cache=dependency_cache),
|
||||
_run_postprocessors))
|
||||
if coros:
|
||||
try:
|
||||
@ -203,59 +211,79 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
|
||||
if show_log:
|
||||
logger.opt(colors=True).success(log_msg)
|
||||
|
||||
state = {}
|
||||
coros = list(map(lambda x: x(bot, event, state), _event_preprocessors))
|
||||
if coros:
|
||||
try:
|
||||
if show_log:
|
||||
logger.debug("Running PreProcessors...")
|
||||
await asyncio.gather(*coros)
|
||||
except IgnoredException as e:
|
||||
logger.opt(colors=True).info(
|
||||
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
|
||||
"Event ignored!</bg #f8bbd0></r>")
|
||||
return
|
||||
state: Dict[Any, Any] = {}
|
||||
dependency_cache: T_DependencyCache = {}
|
||||
|
||||
# Trie Match
|
||||
_, _ = TrieRule.get_value(bot, event, state)
|
||||
|
||||
break_flag = False
|
||||
for priority in sorted(matchers.keys()):
|
||||
if break_flag:
|
||||
break
|
||||
|
||||
if show_log:
|
||||
logger.debug(f"Checking for matchers in priority {priority}...")
|
||||
|
||||
pending_tasks = [
|
||||
_check_matcher(priority, matcher, bot, event, state.copy())
|
||||
for matcher in matchers[priority]
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*pending_tasks, return_exceptions=True)
|
||||
|
||||
for result in results:
|
||||
if not isinstance(result, Exception):
|
||||
continue
|
||||
if isinstance(result, StopPropagation):
|
||||
break_flag = True
|
||||
logger.debug("Stop event propagation")
|
||||
else:
|
||||
logger.opt(colors=True, exception=result).error(
|
||||
"<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
|
||||
async with AsyncExitStack() as stack:
|
||||
coros = list(
|
||||
map(
|
||||
lambda x: x(bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
_stack=stack,
|
||||
_dependency_cache=dependency_cache),
|
||||
_event_preprocessors))
|
||||
if coros:
|
||||
try:
|
||||
if show_log:
|
||||
logger.debug("Running PreProcessors...")
|
||||
await asyncio.gather(*coros)
|
||||
except IgnoredException as e:
|
||||
logger.opt(colors=True).info(
|
||||
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
|
||||
"Event ignored!</bg #f8bbd0></r>")
|
||||
return
|
||||
|
||||
# Trie Match
|
||||
_, _ = TrieRule.get_value(bot, event, state)
|
||||
|
||||
break_flag = False
|
||||
for priority in sorted(matchers.keys()):
|
||||
if break_flag:
|
||||
break
|
||||
|
||||
coros = list(map(lambda x: x(bot, event, state), _event_postprocessors))
|
||||
if coros:
|
||||
try:
|
||||
if show_log:
|
||||
logger.debug("Running PostProcessors...")
|
||||
await asyncio.gather(*coros)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
|
||||
)
|
||||
logger.debug(f"Checking for matchers in priority {priority}...")
|
||||
|
||||
pending_tasks = [
|
||||
_check_matcher(priority, matcher, bot, event, state.copy(),
|
||||
stack, dependency_cache)
|
||||
for matcher in matchers[priority]
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*pending_tasks,
|
||||
return_exceptions=True)
|
||||
|
||||
for result in results:
|
||||
if not isinstance(result, Exception):
|
||||
continue
|
||||
if isinstance(result, StopPropagation):
|
||||
break_flag = True
|
||||
logger.debug("Stop event propagation")
|
||||
else:
|
||||
logger.opt(colors=True, exception=result).error(
|
||||
"<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
coros = list(
|
||||
map(
|
||||
lambda x: x(bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
_stack=stack,
|
||||
_dependency_cache=dependency_cache),
|
||||
_event_postprocessors))
|
||||
if coros:
|
||||
try:
|
||||
if show_log:
|
||||
logger.debug("Running PostProcessors...")
|
||||
await asyncio.gather(*coros)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
|
||||
)
|
||||
|
84
nonebot/params.py
Normal file
84
nonebot/params.py
Normal file
@ -0,0 +1,84 @@
|
||||
import inspect
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic.fields import Undefined
|
||||
|
||||
from nonebot.typing import T_State
|
||||
from nonebot.dependencies import Param
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.utils import generic_check_issubclass
|
||||
|
||||
|
||||
class BotParam(Param):
|
||||
|
||||
@classmethod
|
||||
def _check(cls, name: str, param: inspect.Parameter) -> bool:
|
||||
return generic_check_issubclass(
|
||||
param.annotation, Bot) or (param.annotation == param.empty and
|
||||
name == "bot")
|
||||
|
||||
def _solve(self, bot: Bot, **kwargs: Any) -> Any:
|
||||
return bot
|
||||
|
||||
|
||||
class EventParam(Param):
|
||||
|
||||
@classmethod
|
||||
def _check(cls, name: str, param: inspect.Parameter) -> bool:
|
||||
return generic_check_issubclass(
|
||||
param.annotation, Event) or (param.annotation == param.empty and
|
||||
name == "event")
|
||||
|
||||
def _solve(self, event: Event, **kwargs: Any) -> Any:
|
||||
return event
|
||||
|
||||
|
||||
class StateParam(Param):
|
||||
|
||||
@classmethod
|
||||
def _check(cls, name: str, param: inspect.Parameter) -> bool:
|
||||
return generic_check_issubclass(
|
||||
param.annotation, Dict) or (param.annotation == param.empty and
|
||||
name == "state")
|
||||
|
||||
def _solve(self, state: T_State, **kwargs: Any) -> Any:
|
||||
return state
|
||||
|
||||
|
||||
class MatcherParam(Param):
|
||||
|
||||
@classmethod
|
||||
def _check(cls, name: str, param: inspect.Parameter) -> bool:
|
||||
return generic_check_issubclass(
|
||||
param.annotation, Matcher) or (param.annotation == param.empty and
|
||||
name == "matcher")
|
||||
|
||||
def _solve(self, matcher: Optional["Matcher"] = None, **kwargs: Any) -> Any:
|
||||
return matcher
|
||||
|
||||
|
||||
class ExceptionParam(Param):
|
||||
|
||||
@classmethod
|
||||
def _check(cls, name: str, param: inspect.Parameter) -> bool:
|
||||
return generic_check_issubclass(
|
||||
param.annotation, Exception) or (param.annotation == param.empty and
|
||||
name == "exception")
|
||||
|
||||
def _solve(self,
|
||||
exception: Optional[Exception] = None,
|
||||
**kwargs: Any) -> Any:
|
||||
return exception
|
||||
|
||||
|
||||
class DefaultParam(Param):
|
||||
|
||||
@classmethod
|
||||
def _check(cls, name: str, param: inspect.Parameter) -> bool:
|
||||
return param.default != param.empty
|
||||
|
||||
def _solve(self, **kwargs: Any) -> Any:
|
||||
return Undefined
|
||||
|
||||
|
||||
from nonebot.matcher import Matcher
|
@ -2,7 +2,7 @@ r"""
|
||||
权限
|
||||
====
|
||||
|
||||
每个 ``Matcher`` 拥有一个 ``Permission`` ,其中是 **异步** ``PermissionChecker`` 的集合,只要有一个 ``PermissionChecker`` 检查结果为 ``True`` 时就会继续运行。
|
||||
每个 ``Matcher`` 拥有一个 ``Permission`` ,其中是 ``PermissionChecker`` 的集合,只要有一个 ``PermissionChecker`` 检查结果为 ``True`` 时就会继续运行。
|
||||
|
||||
\:\:\:tip 提示
|
||||
``PermissionChecker`` 既可以是 async function 也可以是 sync function
|
||||
@ -10,14 +10,14 @@ r"""
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Union, Callable, NoReturn, Optional, Awaitable
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any, Dict, List, Type, Union, Callable, NoReturn, Optional
|
||||
|
||||
from nonebot.utils import run_sync
|
||||
from nonebot import params
|
||||
from nonebot.handler import Handler
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.typing import T_PermissionChecker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.adapters import Bot, Event
|
||||
|
||||
|
||||
class Permission:
|
||||
"""
|
||||
@ -36,15 +36,21 @@ class Permission:
|
||||
"""
|
||||
__slots__ = ("checkers",)
|
||||
|
||||
def __init__(
|
||||
self, *checkers: Callable[["Bot", "Event"],
|
||||
Awaitable[bool]]) -> None:
|
||||
HANDLER_PARAM_TYPES = [
|
||||
params.BotParam, params.EventParam, params.DefaultParam
|
||||
]
|
||||
|
||||
def __init__(self, *checkers: Union[T_PermissionChecker, Handler]) -> None:
|
||||
"""
|
||||
:参数:
|
||||
|
||||
* ``*checkers: Callable[[Bot, Event], Awaitable[bool]]``: **异步** PermissionChecker
|
||||
* ``*checkers: Union[T_PermissionChecker, Handler]``: PermissionChecker
|
||||
"""
|
||||
self.checkers = set(checkers)
|
||||
|
||||
self.checkers = set(
|
||||
checker if isinstance(checker, Handler) else Handler(
|
||||
checker, allow_types=self.HANDLER_PARAM_TYPES)
|
||||
for checker in checkers)
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -52,10 +58,16 @@ class Permission:
|
||||
|
||||
:类型:
|
||||
|
||||
* ``Set[Callable[[Bot, Event], Awaitable[bool]]]``
|
||||
* ``Set[Handler]``
|
||||
"""
|
||||
|
||||
async def __call__(self, bot: "Bot", event: "Event") -> bool:
|
||||
async def __call__(
|
||||
self,
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[Dict[Callable[..., Any],
|
||||
Any]] = None) -> bool:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -65,6 +77,8 @@ class Permission:
|
||||
|
||||
* ``bot: Bot``: Bot 对象
|
||||
* ``event: Event``: Event 对象
|
||||
* ``stack: Optional[AsyncExitStack]``: 异步上下文栈
|
||||
* ``dependency_cache: Optional[Dict[Callable[..., Any], Any]]``: 依赖缓存
|
||||
|
||||
:返回:
|
||||
|
||||
@ -73,7 +87,11 @@ class Permission:
|
||||
if not self.checkers:
|
||||
return True
|
||||
results = await asyncio.gather(
|
||||
*map(lambda c: c(bot, event), self.checkers))
|
||||
*(checker(bot=bot,
|
||||
event=event,
|
||||
_stack=stack,
|
||||
_dependency_cache=dependency_cache)
|
||||
for checker in self.checkers))
|
||||
return any(results)
|
||||
|
||||
def __and__(self, other) -> NoReturn:
|
||||
@ -82,31 +100,27 @@ class Permission:
|
||||
def __or__(
|
||||
self, other: Optional[Union["Permission",
|
||||
T_PermissionChecker]]) -> "Permission":
|
||||
checkers = self.checkers.copy()
|
||||
if other is None:
|
||||
return self
|
||||
elif isinstance(other, Permission):
|
||||
checkers |= other.checkers
|
||||
elif asyncio.iscoroutinefunction(other):
|
||||
checkers.add(other) # type: ignore
|
||||
return Permission(*self.checkers, *other.checkers)
|
||||
else:
|
||||
checkers.add(run_sync(other))
|
||||
return Permission(*checkers)
|
||||
return Permission(*self.checkers, other)
|
||||
|
||||
|
||||
async def _message(bot: "Bot", event: "Event") -> bool:
|
||||
async def _message(event: Event) -> bool:
|
||||
return event.get_type() == "message"
|
||||
|
||||
|
||||
async def _notice(bot: "Bot", event: "Event") -> bool:
|
||||
async def _notice(event: Event) -> bool:
|
||||
return event.get_type() == "notice"
|
||||
|
||||
|
||||
async def _request(bot: "Bot", event: "Event") -> bool:
|
||||
async def _request(event: Event) -> bool:
|
||||
return event.get_type() == "request"
|
||||
|
||||
|
||||
async def _metaevent(bot: "Bot", event: "Event") -> bool:
|
||||
async def _metaevent(event: Event) -> bool:
|
||||
return event.get_type() == "meta_event"
|
||||
|
||||
|
||||
@ -140,14 +154,14 @@ def USER(*user: str, perm: Optional[Permission] = None):
|
||||
* ``perm: Optional[Permission]``: 需要同时满足的权限
|
||||
"""
|
||||
|
||||
async def _user(bot: "Bot", event: "Event") -> bool:
|
||||
async def _user(bot: Bot, event: Event) -> bool:
|
||||
return bool(event.get_session_id() in user and
|
||||
(perm is None or await perm(bot, event)))
|
||||
|
||||
return Permission(_user)
|
||||
|
||||
|
||||
async def _superuser(bot: "Bot", event: "Event") -> bool:
|
||||
async def _superuser(bot: Bot, event: Event) -> bool:
|
||||
return (event.get_type() == "message" and
|
||||
event.get_user_id() in bot.config.superusers)
|
||||
|
||||
|
@ -2,19 +2,16 @@ import re
|
||||
import sys
|
||||
import inspect
|
||||
from types import ModuleType
|
||||
from typing import (TYPE_CHECKING, Any, Set, Dict, List, Type, Tuple, Union,
|
||||
Optional)
|
||||
from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.handler import Handler
|
||||
from nonebot.matcher import Matcher
|
||||
from .manager import _current_plugin
|
||||
from nonebot.permission import Permission
|
||||
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory
|
||||
from nonebot.rule import (Rule, ArgumentParser, regex, command, keyword,
|
||||
endswith, startswith, shell_command)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.rule import (PREFIX_KEY, RAW_CMD_KEY, Rule, ArgumentParser, regex,
|
||||
command, keyword, endswith, startswith, shell_command)
|
||||
|
||||
|
||||
def _store_matcher(matcher: Type[Matcher]) -> None:
|
||||
@ -376,16 +373,16 @@ def on_command(cmd: Union[str, Tuple[str, ...]],
|
||||
- ``Type[Matcher]``
|
||||
"""
|
||||
|
||||
async def _strip_cmd(bot: "Bot", event: "Event", state: T_State):
|
||||
async def _strip_cmd(event: Event, state: T_State):
|
||||
message = event.get_message()
|
||||
if len(message) < 1:
|
||||
return
|
||||
segment = message.pop(0)
|
||||
segment_text = str(segment).lstrip()
|
||||
if not segment_text.startswith(state["_prefix"]["raw_command"]):
|
||||
if not segment_text.startswith(state[PREFIX_KEY][RAW_CMD_KEY]):
|
||||
return
|
||||
new_message = message.__class__(
|
||||
segment_text[len(state["_prefix"]["raw_command"]):].lstrip())
|
||||
segment_text[len(state[PREFIX_KEY][RAW_CMD_KEY]):].lstrip())
|
||||
for new_segment in reversed(new_message):
|
||||
message.insert(0, new_segment)
|
||||
|
||||
@ -433,12 +430,11 @@ def on_shell_command(cmd: Union[str, Tuple[str, ...]],
|
||||
- ``Type[Matcher]``
|
||||
"""
|
||||
|
||||
async def _strip_cmd(bot: "Bot", event: "Event", state: T_State):
|
||||
async def _strip_cmd(event: Event, state: T_State):
|
||||
message = event.get_message()
|
||||
segment = message.pop(0)
|
||||
new_message = message.__class__(
|
||||
str(segment)
|
||||
[len(state["_prefix"]["raw_command"]):].strip()) # type: ignore
|
||||
str(segment)[len(state[PREFIX_KEY][RAW_CMD_KEY]):].strip())
|
||||
for new_segment in reversed(new_message):
|
||||
message.insert(0, new_segment)
|
||||
|
||||
|
@ -3,14 +3,14 @@ from functools import reduce
|
||||
from nonebot.rule import to_me
|
||||
from nonebot.plugin import on_command
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot.adapters.cqhttp import (Bot, Message, MessageEvent, MessageSegment,
|
||||
from nonebot.adapters.cqhttp import (Message, MessageEvent, MessageSegment,
|
||||
unescape)
|
||||
|
||||
say = on_command("say", to_me(), permission=SUPERUSER)
|
||||
|
||||
|
||||
@say.handle()
|
||||
async def say_unescape(bot: Bot, event: MessageEvent):
|
||||
async def say_unescape(event: MessageEvent):
|
||||
|
||||
def _unescape(message: Message, segment: MessageSegment):
|
||||
if segment.is_text():
|
||||
@ -18,12 +18,12 @@ async def say_unescape(bot: Bot, event: MessageEvent):
|
||||
return message.append(segment)
|
||||
|
||||
message = reduce(_unescape, event.get_message(), Message()) # type: ignore
|
||||
await bot.send(message=message, event=event)
|
||||
await say.send(message=message)
|
||||
|
||||
|
||||
echo = on_command("echo", to_me())
|
||||
|
||||
|
||||
@echo.handle()
|
||||
async def echo_escape(bot: Bot, event: MessageEvent):
|
||||
await bot.send(message=event.get_message(), event=event)
|
||||
async def echo_escape(event: MessageEvent):
|
||||
await echo.send(message=event.get_message())
|
||||
|
@ -1,8 +1,6 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
from nonebot.typing import T_State
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.message import (IgnoredException, run_preprocessor,
|
||||
run_postprocessor)
|
||||
|
||||
@ -10,7 +8,7 @@ _running_matcher: Dict[str, int] = {}
|
||||
|
||||
|
||||
@run_preprocessor
|
||||
async def preprocess(matcher: Matcher, bot: Bot, event: Event, state: T_State):
|
||||
async def preprocess(event: Event):
|
||||
try:
|
||||
session_id = event.get_session_id()
|
||||
except Exception:
|
||||
@ -24,8 +22,7 @@ async def preprocess(matcher: Matcher, bot: Bot, event: Event, state: T_State):
|
||||
|
||||
|
||||
@run_postprocessor
|
||||
async def postprocess(matcher: Matcher, exception: Optional[Exception],
|
||||
bot: Bot, event: Event, state: T_State):
|
||||
async def postprocess(event: Event):
|
||||
try:
|
||||
session_id = event.get_session_id()
|
||||
except Exception:
|
||||
|
@ -1 +0,0 @@
|
||||
|
||||
|
183
nonebot/rule.py
183
nonebot/rule.py
@ -2,10 +2,10 @@ r"""
|
||||
规则
|
||||
====
|
||||
|
||||
每个事件响应器 ``Matcher`` 拥有一个匹配规则 ``Rule`` ,其中是 **异步** ``RuleChecker`` 的集合,只有当所有 ``RuleChecker`` 检查结果为 ``True`` 时继续运行。
|
||||
每个事件响应器 ``Matcher`` 拥有一个匹配规则 ``Rule`` ,其中是 ``RuleChecker`` 的集合,只有当所有 ``RuleChecker`` 检查结果为 ``True`` 时继续运行。
|
||||
|
||||
\:\:\:tip 提示
|
||||
``RuleChecker`` 既可以是 async function 也可以是 sync function,但在最终会被 ``nonebot.utils.run_sync`` 转换为 async function
|
||||
``RuleChecker`` 既可以是 async function 也可以是 sync function
|
||||
\:\:\:
|
||||
"""
|
||||
|
||||
@ -14,20 +14,36 @@ import shlex
|
||||
import asyncio
|
||||
from itertools import product
|
||||
from argparse import Namespace
|
||||
from contextlib import AsyncExitStack
|
||||
from typing_extensions import TypedDict
|
||||
from argparse import ArgumentParser as ArgParser
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Tuple, Union, Callable, NoReturn,
|
||||
Optional, Sequence, Awaitable)
|
||||
from typing import (Any, Dict, List, Type, Tuple, Union, Callable, NoReturn,
|
||||
Optional, Sequence)
|
||||
|
||||
from pygtrie import CharTrie
|
||||
|
||||
from nonebot import get_driver
|
||||
from nonebot.log import logger
|
||||
from nonebot.utils import run_sync
|
||||
from nonebot.handler import Handler
|
||||
from nonebot import params, get_driver
|
||||
from nonebot.exception import ParserExit
|
||||
from nonebot.typing import T_State, T_RuleChecker
|
||||
from nonebot.adapters import Bot, Event, MessageSegment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.adapters import Bot, Event
|
||||
PREFIX_KEY = "_prefix"
|
||||
SUFFIX_KEY = "_suffix"
|
||||
CMD_KEY = "command"
|
||||
RAW_CMD_KEY = "raw_command"
|
||||
CMD_RESULT = TypedDict("CMD_RESULT", {
|
||||
"command": Optional[Tuple[str, ...]],
|
||||
"raw_command": Optional[str]
|
||||
})
|
||||
|
||||
SHELL_ARGS = "_args"
|
||||
SHELL_ARGV = "_argv"
|
||||
|
||||
REGEX_MATCHED = "_matched"
|
||||
REGEX_GROUP = "_matched_groups"
|
||||
REGEX_DICT = "_matched_dict"
|
||||
|
||||
|
||||
class Rule:
|
||||
@ -47,16 +63,22 @@ class Rule:
|
||||
"""
|
||||
__slots__ = ("checkers",)
|
||||
|
||||
def __init__(
|
||||
self, *checkers: Callable[["Bot", "Event", T_State],
|
||||
Awaitable[bool]]) -> None:
|
||||
HANDLER_PARAM_TYPES = [
|
||||
params.BotParam, params.EventParam, params.StateParam,
|
||||
params.DefaultParam
|
||||
]
|
||||
|
||||
def __init__(self, *checkers: Union[T_RuleChecker, Handler]) -> None:
|
||||
"""
|
||||
:参数:
|
||||
|
||||
* ``*checkers: Callable[[Bot, Event, T_State], Awaitable[bool]]``: **异步** RuleChecker
|
||||
* ``*checkers: Union[T_RuleChecker, Handler]``: RuleChecker
|
||||
|
||||
"""
|
||||
self.checkers = set(checkers)
|
||||
self.checkers = set(
|
||||
checker if isinstance(checker, Handler) else Handler(
|
||||
checker, allow_types=self.HANDLER_PARAM_TYPES)
|
||||
for checker in checkers)
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -64,11 +86,17 @@ class Rule:
|
||||
|
||||
:类型:
|
||||
|
||||
* ``Set[Callable[[Bot, Event, T_State], Awaitable[bool]]]``
|
||||
* ``Set[Handler]``
|
||||
"""
|
||||
|
||||
async def __call__(self, bot: "Bot", event: "Event",
|
||||
state: T_State) -> bool:
|
||||
async def __call__(
|
||||
self,
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
state: T_State,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[Dict[Callable[..., Any],
|
||||
Any]] = None) -> bool:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -79,26 +107,31 @@ class Rule:
|
||||
* ``bot: Bot``: Bot 对象
|
||||
* ``event: Event``: Event 对象
|
||||
* ``state: T_State``: 当前 State
|
||||
* ``stack: Optional[AsyncExitStack]``: 异步上下文栈
|
||||
* ``dependency_cache: Optional[Dict[Callable[..., Any], Any]]``: 依赖缓存
|
||||
|
||||
:返回:
|
||||
|
||||
- ``bool``
|
||||
"""
|
||||
if not self.checkers:
|
||||
return True
|
||||
results = await asyncio.gather(
|
||||
*map(lambda c: c(bot, event, state), self.checkers))
|
||||
*(checker(bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
_stack=stack,
|
||||
_dependency_cache=dependency_cache)
|
||||
for checker in self.checkers))
|
||||
return all(results)
|
||||
|
||||
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
|
||||
checkers = self.checkers.copy()
|
||||
if other is None:
|
||||
return self
|
||||
elif isinstance(other, Rule):
|
||||
checkers |= other.checkers
|
||||
elif asyncio.iscoroutinefunction(other):
|
||||
checkers.add(other) # type: ignore
|
||||
return Rule(*self.checkers, *other.checkers)
|
||||
else:
|
||||
checkers.add(run_sync(other))
|
||||
return Rule(*checkers)
|
||||
return Rule(*self.checkers, other)
|
||||
|
||||
def __or__(self, other) -> NoReturn:
|
||||
raise RuntimeError("Or operation between rules is not allowed.")
|
||||
@ -123,58 +156,28 @@ class TrieRule:
|
||||
cls.suffix[suffix[::-1]] = value
|
||||
|
||||
@classmethod
|
||||
def get_value(cls, bot: "Bot", event: "Event",
|
||||
state: T_State) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
def get_value(cls, bot: Bot, event: Event,
|
||||
state: T_State) -> Tuple[CMD_RESULT, CMD_RESULT]:
|
||||
prefix = CMD_RESULT(command=None, raw_command=None)
|
||||
suffix = CMD_RESULT(command=None, raw_command=None)
|
||||
state[PREFIX_KEY] = prefix
|
||||
state[SUFFIX_KEY] = suffix
|
||||
if event.get_type() != "message":
|
||||
state["_prefix"] = {"raw_command": None, "command": None}
|
||||
state["_suffix"] = {"raw_command": None, "command": None}
|
||||
return {
|
||||
"raw_command": None,
|
||||
"command": None
|
||||
}, {
|
||||
"raw_command": None,
|
||||
"command": None
|
||||
}
|
||||
return prefix, suffix
|
||||
|
||||
prefix = None
|
||||
suffix = None
|
||||
message = event.get_message()
|
||||
message_seg = message[0]
|
||||
message_seg: MessageSegment = message[0]
|
||||
if message_seg.is_text():
|
||||
prefix = cls.prefix.longest_prefix(str(message_seg).lstrip())
|
||||
message_seg_r = message[-1]
|
||||
pf = cls.prefix.longest_prefix(str(message_seg).lstrip())
|
||||
prefix[RAW_CMD_KEY] = pf.key
|
||||
prefix[CMD_KEY] = pf.value
|
||||
message_seg_r: MessageSegment = message[-1]
|
||||
if message_seg_r.is_text():
|
||||
suffix = cls.suffix.longest_prefix(
|
||||
str(message_seg_r).rstrip()[::-1])
|
||||
sf = cls.suffix.longest_prefix(str(message_seg_r).rstrip()[::-1])
|
||||
suffix[RAW_CMD_KEY] = sf.key
|
||||
suffix[CMD_KEY] = sf.value
|
||||
|
||||
state["_prefix"] = {
|
||||
"raw_command": prefix.key,
|
||||
"command": prefix.value
|
||||
} if prefix else {
|
||||
"raw_command": None,
|
||||
"command": None
|
||||
}
|
||||
state["_suffix"] = {
|
||||
"raw_command": suffix.key,
|
||||
"command": suffix.value
|
||||
} if suffix else {
|
||||
"raw_command": None,
|
||||
"command": None
|
||||
}
|
||||
|
||||
return ({
|
||||
"raw_command": prefix.key,
|
||||
"command": prefix.value
|
||||
} if prefix else {
|
||||
"raw_command": None,
|
||||
"command": None
|
||||
}, {
|
||||
"raw_command": suffix.key,
|
||||
"command": suffix.value
|
||||
} if suffix else {
|
||||
"raw_command": None,
|
||||
"command": None
|
||||
})
|
||||
return prefix, suffix
|
||||
|
||||
|
||||
def startswith(msg: Union[str, Tuple[str, ...]],
|
||||
@ -195,7 +198,7 @@ def startswith(msg: Union[str, Tuple[str, ...]],
|
||||
f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})",
|
||||
re.IGNORECASE if ignorecase else 0)
|
||||
|
||||
async def _startswith(bot: "Bot", event: "Event", state: T_State) -> bool:
|
||||
async def _startswith(bot: Bot, event: Event, state: T_State) -> bool:
|
||||
if event.get_type() != "message":
|
||||
return False
|
||||
text = event.get_plaintext()
|
||||
@ -222,7 +225,7 @@ def endswith(msg: Union[str, Tuple[str, ...]],
|
||||
f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$",
|
||||
re.IGNORECASE if ignorecase else 0)
|
||||
|
||||
async def _endswith(bot: "Bot", event: "Event", state: T_State) -> bool:
|
||||
async def _endswith(bot: Bot, event: Event, state: T_State) -> bool:
|
||||
if event.get_type() != "message":
|
||||
return False
|
||||
text = event.get_plaintext()
|
||||
@ -242,7 +245,7 @@ def keyword(*keywords: str) -> Rule:
|
||||
* ``*keywords: str``: 关键词
|
||||
"""
|
||||
|
||||
async def _keyword(bot: "Bot", event: "Event", state: T_State) -> bool:
|
||||
async def _keyword(event: Event) -> bool:
|
||||
if event.get_type() != "message":
|
||||
return False
|
||||
text = event.get_plaintext()
|
||||
@ -290,8 +293,8 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
|
||||
for start, sep in product(command_start, command_sep):
|
||||
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
||||
|
||||
async def _command(bot: "Bot", event: "Event", state: T_State) -> bool:
|
||||
return state["_prefix"]["command"] in commands
|
||||
async def _command(state: T_State) -> bool:
|
||||
return state[PREFIX_KEY][CMD_KEY] in commands
|
||||
|
||||
return Rule(_command)
|
||||
|
||||
@ -310,7 +313,7 @@ class ArgumentParser(ArgParser):
|
||||
old_message += message
|
||||
setattr(self, "message", old_message)
|
||||
|
||||
def exit(self, status=0, message=None):
|
||||
def exit(self, status: int = 0, message: Optional[str] = None):
|
||||
raise ParserExit(status=status,
|
||||
message=message or getattr(self, "message", None))
|
||||
|
||||
@ -376,19 +379,18 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]],
|
||||
for start, sep in product(command_start, command_sep):
|
||||
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
||||
|
||||
async def _shell_command(bot: "Bot", event: "Event",
|
||||
state: T_State) -> bool:
|
||||
if state["_prefix"]["command"] in commands:
|
||||
async def _shell_command(event: Event, state: T_State) -> bool:
|
||||
if state[PREFIX_KEY][CMD_KEY] in commands:
|
||||
message = str(event.get_message())
|
||||
strip_message = message[len(state["_prefix"]["raw_command"]
|
||||
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]
|
||||
):].lstrip()
|
||||
state["argv"] = shlex.split(strip_message)
|
||||
state[SHELL_ARGV] = shlex.split(strip_message)
|
||||
if parser:
|
||||
try:
|
||||
args = parser.parse_args(state["argv"])
|
||||
state["args"] = args
|
||||
args = parser.parse_args(state[SHELL_ARGV])
|
||||
state[SHELL_ARGS] = args
|
||||
except ParserExit as e:
|
||||
state["args"] = e
|
||||
state[SHELL_ARGS] = e
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@ -417,14 +419,14 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
|
||||
|
||||
pattern = re.compile(regex, flags)
|
||||
|
||||
async def _regex(bot: "Bot", event: "Event", state: T_State) -> bool:
|
||||
async def _regex(event: Event, state: T_State) -> bool:
|
||||
if event.get_type() != "message":
|
||||
return False
|
||||
matched = pattern.search(str(event.get_message()))
|
||||
if matched:
|
||||
state["_matched"] = matched.group()
|
||||
state["_matched_groups"] = matched.groups()
|
||||
state["_matched_dict"] = matched.groupdict()
|
||||
state[REGEX_MATCHED] = matched.group()
|
||||
state[REGEX_GROUP] = matched.groups()
|
||||
state[REGEX_DICT] = matched.groupdict()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@ -432,6 +434,10 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
|
||||
return Rule(_regex)
|
||||
|
||||
|
||||
async def _to_me(event: Event) -> bool:
|
||||
return event.is_tome()
|
||||
|
||||
|
||||
def to_me() -> Rule:
|
||||
"""
|
||||
:说明:
|
||||
@ -443,7 +449,4 @@ def to_me() -> Rule:
|
||||
* 无
|
||||
"""
|
||||
|
||||
async def _to_me(bot: "Bot", event: "Event", state: T_State) -> bool:
|
||||
return event.is_tome()
|
||||
|
||||
return Rule(_to_me)
|
||||
|
@ -17,16 +17,15 @@
|
||||
.. _typing:
|
||||
https://docs.python.org/3/library/typing.html
|
||||
"""
|
||||
from collections.abc import Callable as BaseCallable
|
||||
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Union, TypeVar, Callable,
|
||||
NoReturn, Optional, Awaitable)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.permission import Permission
|
||||
|
||||
T_Wrapped = TypeVar("T_Wrapped", bound=BaseCallable)
|
||||
T_Wrapped = TypeVar("T_Wrapped", bound=Callable)
|
||||
|
||||
|
||||
def overrides(InterfaceClass: object):
|
||||
@ -90,77 +89,109 @@ T_CalledAPIHook = Callable[
|
||||
``bot.call_api`` 后执行的函数,参数分别为 bot, exception, api, data, result
|
||||
"""
|
||||
|
||||
T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]]
|
||||
T_EventPreProcessor = Callable[..., Union[None, Awaitable[None]]]
|
||||
"""
|
||||
:类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]``
|
||||
:类型: ``Callable[..., Union[None, Awaitable[None]]]``
|
||||
|
||||
:依赖参数:
|
||||
|
||||
* ``BotParam``: Bot 对象
|
||||
* ``EventParam``: Event 对象
|
||||
* ``StateParam``: State 对象
|
||||
|
||||
:说明:
|
||||
|
||||
事件预处理函数 EventPreProcessor 类型
|
||||
"""
|
||||
T_EventPostProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]]
|
||||
T_EventPostProcessor = Callable[..., Union[None, Awaitable[None]]]
|
||||
"""
|
||||
:类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]``
|
||||
:类型: ``Callable[..., Union[None, Awaitable[None]]]``
|
||||
|
||||
:依赖参数:
|
||||
|
||||
* ``BotParam``: Bot 对象
|
||||
* ``EventParam``: Event 对象
|
||||
* ``StateParam``: State 对象
|
||||
|
||||
:说明:
|
||||
|
||||
事件预处理函数 EventPostProcessor 类型
|
||||
"""
|
||||
T_RunPreProcessor = Callable[["Matcher", "Bot", "Event", T_State],
|
||||
Awaitable[None]]
|
||||
T_RunPreProcessor = Callable[..., Union[None, Awaitable[None]]]
|
||||
"""
|
||||
:类型: ``Callable[[Matcher, Bot, Event, T_State], Awaitable[None]]``
|
||||
:类型: ``Callable[..., Union[None, Awaitable[None]]]``
|
||||
|
||||
:依赖参数:
|
||||
|
||||
* ``BotParam``: Bot 对象
|
||||
* ``EventParam``: Event 对象
|
||||
* ``StateParam``: State 对象
|
||||
* ``MatcherParam``: Matcher 对象
|
||||
|
||||
:说明:
|
||||
|
||||
事件响应器运行前预处理函数 RunPreProcessor 类型
|
||||
"""
|
||||
T_RunPostProcessor = Callable[
|
||||
["Matcher", Optional[Exception], "Bot", "Event", T_State], Awaitable[None]]
|
||||
T_RunPostProcessor = Callable[..., Union[None, Awaitable[None]]]
|
||||
"""
|
||||
:类型: ``Callable[[Matcher, Optional[Exception], Bot, Event, T_State], Awaitable[None]]``
|
||||
:类型: ``Callable[..., Union[None, Awaitable[None]]]``
|
||||
|
||||
:依赖参数:
|
||||
|
||||
* ``BotParam``: Bot 对象
|
||||
* ``EventParam``: Event 对象
|
||||
* ``StateParam``: State 对象
|
||||
* ``MatcherParam``: Matcher 对象
|
||||
* ``ExceptionParam``: 异常对象(可能为 None)
|
||||
|
||||
:说明:
|
||||
|
||||
事件响应器运行前预处理函数 RunPostProcessor 类型,第二个参数为运行时产生的错误(如果存在)
|
||||
"""
|
||||
|
||||
T_RuleChecker = Callable[["Bot", "Event", T_State], Union[bool,
|
||||
Awaitable[bool]]]
|
||||
T_RuleChecker = Callable[..., Union[bool, Awaitable[bool]]]
|
||||
"""
|
||||
:类型: ``Callable[[Bot, Event, T_State], Union[bool, Awaitable[bool]]]``
|
||||
:类型: ``Callable[..., Union[bool, Awaitable[bool]]]``
|
||||
|
||||
:依赖参数:
|
||||
|
||||
* ``BotParam``: Bot 对象
|
||||
* ``EventParam``: Event 对象
|
||||
* ``StateParam``: State 对象
|
||||
|
||||
:说明:
|
||||
|
||||
RuleChecker 即判断是否响应事件的处理函数。
|
||||
"""
|
||||
T_PermissionChecker = Callable[["Bot", "Event"], Union[bool, Awaitable[bool]]]
|
||||
T_PermissionChecker = Callable[..., Union[bool, Awaitable[bool]]]
|
||||
"""
|
||||
:类型: ``Callable[[Bot, Event], Union[bool, Awaitable[bool]]]``
|
||||
:类型: ``Callable[..., Union[bool, Awaitable[bool]]]``
|
||||
|
||||
:依赖参数:
|
||||
|
||||
* ``BotParam``: Bot 对象
|
||||
* ``EventParam``: Event 对象
|
||||
|
||||
:说明:
|
||||
|
||||
RuleChecker 即判断是否响应消息的处理函数。
|
||||
"""
|
||||
|
||||
T_Handler = Union[Callable[[Any, Any, Any, Any], Union[Awaitable[None],
|
||||
Awaitable[NoReturn]]],
|
||||
Callable[[Any, Any, Any], Union[Awaitable[None],
|
||||
Awaitable[NoReturn]]],
|
||||
Callable[[Any, Any], Union[Awaitable[None],
|
||||
Awaitable[NoReturn]]],
|
||||
Callable[[Any], Union[Awaitable[None], Awaitable[NoReturn]]]]
|
||||
T_Handler = Callable[..., Any]
|
||||
"""
|
||||
:类型:
|
||||
|
||||
* ``Callable[[Bot, Event, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]``
|
||||
* ``Callable[[Bot, Event], Union[Awaitable[None], Awaitable[NoReturn]]]``
|
||||
* ``Callable[[Bot, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]``
|
||||
* ``Callable[[Bot], Union[Awaitable[None], Awaitable[NoReturn]]]``
|
||||
:类型: ``Callable[..., Any]``
|
||||
|
||||
:说明:
|
||||
|
||||
Handler 即事件的处理函数。
|
||||
Handler 处理函数。
|
||||
"""
|
||||
T_DependencyCache = Dict[T_Handler, Any]
|
||||
"""
|
||||
:类型: ``Dict[T_Handler, Any]``
|
||||
|
||||
:说明:
|
||||
|
||||
依赖缓存, 用于存储依赖函数的返回值
|
||||
"""
|
||||
T_ArgsParser = Callable[["Bot", "Event", T_State], Union[Awaitable[None],
|
||||
Awaitable[NoReturn]]]
|
||||
|
153
nonebot/utils.py
153
nonebot/utils.py
@ -1,13 +1,23 @@
|
||||
import re
|
||||
import json
|
||||
import asyncio
|
||||
import inspect
|
||||
import collections
|
||||
import dataclasses
|
||||
from functools import wraps, partial
|
||||
from typing import Any, Callable, Optional, Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing_extensions import GenericAlias # type: ignore
|
||||
from typing_extensions import ParamSpec, get_args, get_origin
|
||||
from typing import (Any, Type, Deque, Tuple, Union, TypeVar, Callable, Optional,
|
||||
Awaitable, AsyncGenerator, ContextManager)
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import overrides
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def escape_tag(s: str) -> str:
|
||||
"""
|
||||
@ -26,7 +36,48 @@ def escape_tag(s: str) -> str:
|
||||
return re.sub(r"</?((?:[fb]g\s)?[^<>\s]*)>", r"\\\g<0>", s)
|
||||
|
||||
|
||||
def run_sync(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
|
||||
def generic_check_issubclass(
|
||||
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any],
|
||||
...]]) -> bool:
|
||||
try:
|
||||
return issubclass(cls, class_or_tuple)
|
||||
except TypeError:
|
||||
if get_origin(cls) is Union:
|
||||
for type_ in get_args(cls):
|
||||
if type_ is not type(None) and not generic_check_issubclass(
|
||||
type_, class_or_tuple):
|
||||
return False
|
||||
return True
|
||||
elif isinstance(cls, GenericAlias):
|
||||
origin = get_origin(cls)
|
||||
return bool(origin and issubclass(origin, class_or_tuple))
|
||||
raise
|
||||
|
||||
|
||||
def is_coroutine_callable(func: Callable[..., Any]) -> bool:
|
||||
if inspect.isroutine(func):
|
||||
return inspect.iscoroutinefunction(func)
|
||||
if inspect.isclass(func):
|
||||
return False
|
||||
func_ = getattr(func, "__call__", None)
|
||||
return inspect.iscoroutinefunction(func_)
|
||||
|
||||
|
||||
def is_gen_callable(func: Callable[..., Any]) -> bool:
|
||||
if inspect.isgeneratorfunction(func):
|
||||
return True
|
||||
func_ = getattr(func, "__call__", None)
|
||||
return inspect.isgeneratorfunction(func_)
|
||||
|
||||
|
||||
def is_async_gen_callable(func: Callable[..., Any]) -> bool:
|
||||
if inspect.isasyncgenfunction(func):
|
||||
return True
|
||||
func_ = getattr(func, "__call__", None)
|
||||
return inspect.isasyncgenfunction(func_)
|
||||
|
||||
|
||||
def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -34,15 +85,15 @@ def run_sync(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
|
||||
|
||||
:参数:
|
||||
|
||||
* ``func: Callable[..., Any]``: 被装饰的同步函数
|
||||
* ``func: Callable[P, R]``: 被装饰的同步函数
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Callable[..., Awaitable[Any]]``
|
||||
- ``Callable[P, Awaitable[R]]``
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def _wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
loop = asyncio.get_running_loop()
|
||||
pfunc = partial(func, *args, **kwargs)
|
||||
result = await loop.run_in_executor(None, pfunc)
|
||||
@ -51,6 +102,98 @@ def run_sync(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
|
||||
return _wrapper
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def run_sync_ctx_manager(
|
||||
cm: ContextManager[T],) -> AsyncGenerator[T, None]:
|
||||
try:
|
||||
yield await run_sync(cm.__enter__)()
|
||||
except Exception as e:
|
||||
ok = await run_sync(cm.__exit__)(type(e), e, None)
|
||||
if not ok:
|
||||
raise e
|
||||
else:
|
||||
await run_sync(cm.__exit__)(None, None, None)
|
||||
|
||||
|
||||
def get_name(obj: Any) -> str:
|
||||
if inspect.isfunction(obj) or inspect.isclass(obj):
|
||||
return obj.__name__
|
||||
return obj.__class__.__name__
|
||||
|
||||
|
||||
class CacheLock:
|
||||
|
||||
def __init__(self):
|
||||
self._waiters: Optional[Deque[asyncio.Future]] = None
|
||||
self._locked = False
|
||||
|
||||
def __repr__(self):
|
||||
extra = "locked" if self._locked else "unlocked"
|
||||
if self._waiters:
|
||||
extra = f"{extra}, waiters: {len(self._waiters)}"
|
||||
return f"<{self.__class__.__name__} [{extra}]>"
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.acquire()
|
||||
return None
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
self.release()
|
||||
|
||||
def locked(self):
|
||||
return self._locked
|
||||
|
||||
async def acquire(self):
|
||||
if (not self._locked and (self._waiters is None or
|
||||
all(w.cancelled() for w in self._waiters))):
|
||||
self._locked = True
|
||||
return True
|
||||
|
||||
if self._waiters is None:
|
||||
self._waiters = collections.deque()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
self._waiters.append(future)
|
||||
|
||||
# Finally block should be called before the CancelledError
|
||||
# handling as we don't want CancelledError to call
|
||||
# _wake_up_first() and attempt to wake up itself.
|
||||
try:
|
||||
try:
|
||||
await future
|
||||
finally:
|
||||
self._waiters.remove(future)
|
||||
except asyncio.CancelledError:
|
||||
if not self._locked:
|
||||
self._wake_up_first()
|
||||
raise
|
||||
|
||||
self._locked = True
|
||||
return True
|
||||
|
||||
def release(self):
|
||||
if self._locked:
|
||||
self._locked = False
|
||||
self._wake_up_first()
|
||||
else:
|
||||
raise RuntimeError("Lock is not acquired.")
|
||||
|
||||
def _wake_up_first(self):
|
||||
if not self._waiters:
|
||||
return
|
||||
try:
|
||||
future = next(iter(self._waiters))
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
# .done() necessarily means that a waiter will wake up later on and
|
||||
# either take the lock, or, if it was cancelled and lock wasn't
|
||||
# taken already, will hit this again and wake up a new waiter.
|
||||
if not future.done():
|
||||
future.set_result(True)
|
||||
|
||||
|
||||
class DataclassEncoder(json.JSONEncoder):
|
||||
"""
|
||||
:说明:
|
||||
|
Reference in New Issue
Block a user