mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-09-07 04:26:45 +00:00
♻️ reorganize class and add bot hook di
This commit is contained in:
@ -5,21 +5,24 @@ FrontMatter:
|
||||
description: nonebot.params 模块
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import warnings
|
||||
from typing_extensions import Literal
|
||||
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
|
||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
|
||||
from pydantic.fields import Required, Undefined, ModelField
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.exception import TypeMisMatch
|
||||
from nonebot.adapters import Bot, Event, Message
|
||||
from nonebot.dependencies.utils import check_field_type
|
||||
from nonebot.dependencies import Param, Dependent, CustomConfig
|
||||
from nonebot.typing import T_State, T_Handler, T_DependencyCache
|
||||
from nonebot.typing import T_State
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.adapters import Event, Message
|
||||
from nonebot.internal.params import Arg as Arg
|
||||
from nonebot.internal.params import State as State
|
||||
from nonebot.internal.params import ArgStr as ArgStr
|
||||
from nonebot.internal.params import Depends as Depends
|
||||
from nonebot.internal.params import ArgParam as ArgParam
|
||||
from nonebot.internal.params import BotParam as BotParam
|
||||
from nonebot.internal.params import EventParam as EventParam
|
||||
from nonebot.internal.params import StateParam as StateParam
|
||||
from nonebot.internal.params import DependParam as DependParam
|
||||
from nonebot.internal.params import ArgPlainText as ArgPlainText
|
||||
from nonebot.internal.params import DefaultParam as DefaultParam
|
||||
from nonebot.internal.params import MatcherParam as MatcherParam
|
||||
from nonebot.internal.params import ExceptionParam as ExceptionParam
|
||||
from nonebot.consts import (
|
||||
CMD_KEY,
|
||||
PREFIX_KEY,
|
||||
@ -31,233 +34,6 @@ from nonebot.consts import (
|
||||
REGEX_GROUP,
|
||||
REGEX_MATCHED,
|
||||
)
|
||||
from nonebot.utils import (
|
||||
get_name,
|
||||
run_sync,
|
||||
is_gen_callable,
|
||||
run_sync_ctx_manager,
|
||||
is_async_gen_callable,
|
||||
is_coroutine_callable,
|
||||
generic_check_issubclass,
|
||||
)
|
||||
|
||||
|
||||
class DependsInner:
|
||||
def __init__(
|
||||
self,
|
||||
dependency: Optional[T_Handler] = None,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
) -> None:
|
||||
self.dependency = dependency
|
||||
self.use_cache = use_cache
|
||||
|
||||
def __repr__(self) -> str:
|
||||
dep = get_name(self.dependency)
|
||||
cache = "" if self.use_cache else ", use_cache=False"
|
||||
return f"{self.__class__.__name__}({dep}{cache})"
|
||||
|
||||
|
||||
def Depends(
|
||||
dependency: Optional[T_Handler] = None,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
) -> Any:
|
||||
"""子依赖装饰器
|
||||
|
||||
参数:
|
||||
dependency: 依赖函数。默认为参数的类型注释。
|
||||
use_cache: 是否使用缓存。默认为 `True`。
|
||||
|
||||
用法:
|
||||
```python
|
||||
def depend_func() -> Any:
|
||||
return ...
|
||||
|
||||
def depend_gen_func():
|
||||
try:
|
||||
yield ...
|
||||
finally:
|
||||
...
|
||||
|
||||
async def handler(param_name: Any = Depends(depend_func), gen: Any = Depends(depend_gen_func)):
|
||||
...
|
||||
```
|
||||
"""
|
||||
return DependsInner(dependency, use_cache=use_cache)
|
||||
|
||||
|
||||
class DependParam(Param):
|
||||
"""子依赖参数"""
|
||||
|
||||
@classmethod
|
||||
def _check_param(
|
||||
cls,
|
||||
dependent: Dependent,
|
||||
name: str,
|
||||
param: inspect.Parameter,
|
||||
) -> Optional["DependParam"]:
|
||||
if isinstance(param.default, DependsInner):
|
||||
dependency: T_Handler
|
||||
if param.default.dependency is None:
|
||||
assert param.annotation is not param.empty, "Dependency cannot be empty"
|
||||
dependency = param.annotation
|
||||
else:
|
||||
dependency = param.default.dependency
|
||||
sub_dependent = Dependent[Any].parse(
|
||||
call=dependency,
|
||||
allow_types=dependent.allow_types,
|
||||
)
|
||||
dependent.pre_checkers.extend(sub_dependent.pre_checkers)
|
||||
sub_dependent.pre_checkers.clear()
|
||||
return cls(
|
||||
Required, use_cache=param.default.use_cache, dependent=sub_dependent
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_parameterless(
|
||||
cls, dependent: "Dependent", value: Any
|
||||
) -> Optional["Param"]:
|
||||
if isinstance(value, DependsInner):
|
||||
assert value.dependency, "Dependency cannot be empty"
|
||||
dependent = Dependent[Any].parse(
|
||||
call=value.dependency, allow_types=dependent.allow_types
|
||||
)
|
||||
return cls(Required, use_cache=value.use_cache, dependent=dependent)
|
||||
|
||||
async def _solve(
|
||||
self,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
use_cache: bool = self.extra["use_cache"]
|
||||
dependency_cache = {} if dependency_cache is None else dependency_cache
|
||||
|
||||
sub_dependent: Dependent = self.extra["dependent"]
|
||||
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
|
||||
call = sub_dependent.call
|
||||
|
||||
# solve sub dependency with current cache
|
||||
sub_values = await sub_dependent.solve(
|
||||
stack=stack,
|
||||
dependency_cache=dependency_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# run dependency function
|
||||
task: asyncio.Task[Any]
|
||||
if use_cache and call in dependency_cache:
|
||||
solved = await dependency_cache[call]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
assert isinstance(
|
||||
stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
if is_gen_callable(call):
|
||||
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
|
||||
else:
|
||||
cm = asynccontextmanager(call)(**sub_values)
|
||||
task = asyncio.create_task(stack.enter_async_context(cm))
|
||||
dependency_cache[call] = task
|
||||
solved = await task
|
||||
elif is_coroutine_callable(call):
|
||||
task = asyncio.create_task(call(**sub_values))
|
||||
dependency_cache[call] = task
|
||||
solved = await task
|
||||
else:
|
||||
task = asyncio.create_task(run_sync(call)(**sub_values))
|
||||
dependency_cache[call] = task
|
||||
solved = await task
|
||||
|
||||
return solved
|
||||
|
||||
|
||||
class _BotChecker(Param):
|
||||
async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
|
||||
field: ModelField = self.extra["field"]
|
||||
try:
|
||||
return check_field_type(field, bot)
|
||||
except TypeMisMatch:
|
||||
logger.debug(
|
||||
f"Bot type {type(bot)} not match "
|
||||
f"annotation {field._type_display()}, ignored"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
class BotParam(Param):
|
||||
"""{ref}`nonebot.adapters._bot.Bot` 参数"""
|
||||
|
||||
@classmethod
|
||||
def _check_param(
|
||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
||||
) -> Optional["BotParam"]:
|
||||
if param.default == param.empty:
|
||||
if generic_check_issubclass(param.annotation, Bot):
|
||||
if param.annotation is not Bot:
|
||||
dependent.pre_checkers.append(
|
||||
_BotChecker(
|
||||
Required,
|
||||
field=ModelField(
|
||||
name=name,
|
||||
type_=param.annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=None,
|
||||
required=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
return cls(Required)
|
||||
elif param.annotation == param.empty and name == "bot":
|
||||
return cls(Required)
|
||||
|
||||
async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
|
||||
return bot
|
||||
|
||||
|
||||
class _EventChecker(Param):
|
||||
async def _solve(self, event: Event, **kwargs: Any) -> Any:
|
||||
field: ModelField = self.extra["field"]
|
||||
try:
|
||||
return check_field_type(field, event)
|
||||
except TypeMisMatch:
|
||||
logger.debug(
|
||||
f"Event type {type(event)} not match "
|
||||
f"annotation {field._type_display()}, ignored"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
class EventParam(Param):
|
||||
"""{ref}`nonebot.adapters._event.Event` 参数"""
|
||||
|
||||
@classmethod
|
||||
def _check_param(
|
||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
||||
) -> Optional["EventParam"]:
|
||||
if param.default == param.empty:
|
||||
if generic_check_issubclass(param.annotation, Event):
|
||||
if param.annotation is not Event:
|
||||
dependent.pre_checkers.append(
|
||||
_EventChecker(
|
||||
Required,
|
||||
field=ModelField(
|
||||
name=name,
|
||||
type_=param.annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=None,
|
||||
required=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
return cls(Required)
|
||||
elif param.annotation == param.empty and name == "event":
|
||||
return cls(Required)
|
||||
|
||||
async def _solve(self, event: Event, **kwargs: Any) -> Any:
|
||||
return event
|
||||
|
||||
|
||||
async def _event_type(event: Event) -> str:
|
||||
@ -265,7 +41,7 @@ async def _event_type(event: Event) -> str:
|
||||
|
||||
|
||||
def EventType() -> str:
|
||||
"""{ref}`nonebot.adapters._event.Event` 类型参数"""
|
||||
"""{ref}`nonebot.adapters.Event` 类型参数"""
|
||||
return Depends(_event_type)
|
||||
|
||||
|
||||
@ -274,7 +50,7 @@ async def _event_message(event: Event) -> Message:
|
||||
|
||||
|
||||
def EventMessage() -> Any:
|
||||
"""{ref}`nonebot.adapters._event.Event` 消息参数"""
|
||||
"""{ref}`nonebot.adapters.Event` 消息参数"""
|
||||
return Depends(_event_message)
|
||||
|
||||
|
||||
@ -283,7 +59,7 @@ async def _event_plain_text(event: Event) -> str:
|
||||
|
||||
|
||||
def EventPlainText() -> str:
|
||||
"""{ref}`nonebot.adapters._event.Event` 纯文本消息参数"""
|
||||
"""{ref}`nonebot.adapters.Event` 纯文本消息参数"""
|
||||
return Depends(_event_plain_text)
|
||||
|
||||
|
||||
@ -292,39 +68,10 @@ async def _event_to_me(event: Event) -> bool:
|
||||
|
||||
|
||||
def EventToMe() -> bool:
|
||||
"""{ref}`nonebot.adapters._event.Event` `to_me` 参数"""
|
||||
"""{ref}`nonebot.adapters.Event` `to_me` 参数"""
|
||||
return Depends(_event_to_me)
|
||||
|
||||
|
||||
class StateInner(T_State):
|
||||
...
|
||||
|
||||
|
||||
def State() -> T_State:
|
||||
"""**Deprecated**: 事件处理状态参数,请直接使用 {ref}`nonebot.typing.T_State`"""
|
||||
warnings.warn("State() is deprecated, use `T_State` instead", DeprecationWarning)
|
||||
return StateInner()
|
||||
|
||||
|
||||
class StateParam(Param):
|
||||
"""事件处理状态参数"""
|
||||
|
||||
@classmethod
|
||||
def _check_param(
|
||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
||||
) -> Optional["StateParam"]:
|
||||
if isinstance(param.default, StateInner):
|
||||
return cls(Required)
|
||||
elif param.default == param.empty:
|
||||
if param.annotation is T_State:
|
||||
return cls(Required)
|
||||
elif param.annotation == param.empty and name == "state":
|
||||
return cls(Required)
|
||||
|
||||
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
|
||||
return state
|
||||
|
||||
|
||||
def _command(state: T_State) -> Message:
|
||||
return state[PREFIX_KEY][CMD_KEY]
|
||||
|
||||
@ -397,22 +144,6 @@ def RegexDict() -> Dict[str, Any]:
|
||||
return Depends(_regex_dict, use_cache=False)
|
||||
|
||||
|
||||
class MatcherParam(Param):
|
||||
"""事件响应器实例参数"""
|
||||
|
||||
@classmethod
|
||||
def _check_param(
|
||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
||||
) -> Optional["MatcherParam"]:
|
||||
if generic_check_issubclass(param.annotation, Matcher) or (
|
||||
param.annotation == param.empty and name == "matcher"
|
||||
):
|
||||
return cls(Required)
|
||||
|
||||
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||
return matcher
|
||||
|
||||
|
||||
def Received(id: Optional[str] = None, default: Any = None) -> Any:
|
||||
"""`receive` 事件参数"""
|
||||
|
||||
@ -431,85 +162,18 @@ def LastReceived(default: Any = None) -> Any:
|
||||
return Depends(_last_received, use_cache=False)
|
||||
|
||||
|
||||
class ArgInner:
|
||||
def __init__(
|
||||
self, key: Optional[str], type: Literal["message", "str", "plaintext"]
|
||||
) -> None:
|
||||
self.key = key
|
||||
self.type = type
|
||||
|
||||
|
||||
def Arg(key: Optional[str] = None) -> Any:
|
||||
"""`got` 的 Arg 参数消息"""
|
||||
return ArgInner(key, "message")
|
||||
|
||||
|
||||
def ArgStr(key: Optional[str] = None) -> str:
|
||||
"""`got` 的 Arg 参数消息文本"""
|
||||
return ArgInner(key, "str") # type: ignore
|
||||
|
||||
|
||||
def ArgPlainText(key: Optional[str] = None) -> str:
|
||||
"""`got` 的 Arg 参数消息纯文本"""
|
||||
return ArgInner(key, "plaintext") # type: ignore
|
||||
|
||||
|
||||
class ArgParam(Param):
|
||||
"""`got` 的 Arg 参数"""
|
||||
|
||||
@classmethod
|
||||
def _check_param(
|
||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
||||
) -> Optional["ArgParam"]:
|
||||
if isinstance(param.default, ArgInner):
|
||||
return cls(Required, key=param.default.key or name, type=param.default.type)
|
||||
|
||||
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||
message = matcher.get_arg(self.extra["key"])
|
||||
if message is None:
|
||||
return message
|
||||
if self.extra["type"] == "message":
|
||||
return message
|
||||
elif self.extra["type"] == "str":
|
||||
return str(message)
|
||||
else:
|
||||
return message.extract_plain_text()
|
||||
|
||||
|
||||
class ExceptionParam(Param):
|
||||
"""`run_postprocessor` 的异常参数"""
|
||||
|
||||
@classmethod
|
||||
def _check_param(
|
||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
||||
) -> Optional["ExceptionParam"]:
|
||||
if generic_check_issubclass(param.annotation, Exception) or (
|
||||
param.annotation == param.empty and name == "exception"
|
||||
):
|
||||
return cls(Required)
|
||||
|
||||
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
|
||||
return exception
|
||||
|
||||
|
||||
class DefaultParam(Param):
|
||||
"""默认值参数"""
|
||||
|
||||
@classmethod
|
||||
def _check_param(
|
||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
||||
) -> Optional["DefaultParam"]:
|
||||
if param.default != param.empty:
|
||||
return cls(param.default)
|
||||
|
||||
async def _solve(self, **kwargs: Any) -> Any:
|
||||
return Undefined
|
||||
|
||||
|
||||
from nonebot.matcher import Matcher
|
||||
|
||||
__autodoc__ = {
|
||||
"DependsInner": False,
|
||||
"StateInner": False,
|
||||
"ArgInner": False,
|
||||
"Arg": True,
|
||||
"State": True,
|
||||
"ArgStr": True,
|
||||
"Depends": True,
|
||||
"ArgParam": True,
|
||||
"BotParam": True,
|
||||
"EventParam": True,
|
||||
"StateParam": True,
|
||||
"DependParam": True,
|
||||
"ArgPlainText": True,
|
||||
"DefaultParam": True,
|
||||
"MatcherParam": True,
|
||||
"ExceptionParam": True,
|
||||
}
|
||||
|
Reference in New Issue
Block a user