♻️ reorganize class and add bot hook di

This commit is contained in:
yanyongyu
2022-02-06 14:52:50 +08:00
parent b8456b12ad
commit fd11e2696b
24 changed files with 1747 additions and 1633 deletions

View File

106
nonebot/internal/adapter.py Normal file
View File

@ -0,0 +1,106 @@
import abc
from contextlib import asynccontextmanager
from typing import Any, Dict, AsyncGenerator
from nonebot.config import Config
from nonebot.drivers import (
Driver,
Request,
Response,
WebSocket,
ForwardDriver,
ReverseDriver,
HTTPServerSetup,
WebSocketServerSetup,
)
from .bot import Bot
class Adapter(abc.ABC):
"""协议适配器基类。
通常,在 Adapter 中编写协议通信相关代码,如: 建立通信连接、处理接收与发送 data 等。
参数:
driver: {ref}`nonebot.drivers.Driver` 实例
kwargs: 其他由 {ref}`nonebot.drivers.Driver.register_adapter` 传入的额外参数
"""
def __init__(self, driver: Driver, **kwargs: Any):
self.driver: Driver = driver
"""{ref}`nonebot.drivers.Driver` 实例"""
self.bots: Dict[str, Bot] = {}
"""本协议适配器已建立连接的 {ref}`nonebot.adapters.Bot` 实例"""
@classmethod
@abc.abstractmethod
def get_name(cls) -> str:
"""当前协议适配器的名称"""
raise NotImplementedError
@property
def config(self) -> Config:
"""全局 NoneBot 配置"""
return self.driver.config
def bot_connect(self, bot: Bot) -> None:
"""告知 NoneBot 建立了一个新的 {ref}`nonebot.adapters.Bot` 连接。
当有新的 {ref}`nonebot.adapters.Bot` 实例连接建立成功时调用。
参数:
bot: {ref}`nonebot.adapters.Bot` 实例
"""
self.driver._bot_connect(bot)
self.bots[bot.self_id] = bot
def bot_disconnect(self, bot: Bot) -> None:
"""告知 NoneBot {ref}`nonebot.adapters.Bot` 连接已断开。
当有 {ref}`nonebot.adapters.Bot` 实例连接断开时调用。
参数:
bot: {ref}`nonebot.adapters.Bot` 实例
"""
self.driver._bot_disconnect(bot)
self.bots.pop(bot.self_id, None)
def setup_http_server(self, setup: HTTPServerSetup):
"""设置一个 HTTP 服务器路由配置"""
if not isinstance(self.driver, ReverseDriver):
raise TypeError("Current driver does not support http server")
self.driver.setup_http_server(setup)
def setup_websocket_server(self, setup: WebSocketServerSetup):
"""设置一个 WebSocket 服务器路由配置"""
if not isinstance(self.driver, ReverseDriver):
raise TypeError("Current driver does not support websocket server")
self.driver.setup_websocket_server(setup)
async def request(self, setup: Request) -> Response:
"""进行一个 HTTP 客户端请求"""
if not isinstance(self.driver, ForwardDriver):
raise TypeError("Current driver does not support http client")
return await self.driver.request(setup)
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
"""建立一个 WebSocket 客户端连接请求"""
if not isinstance(self.driver, ForwardDriver):
raise TypeError("Current driver does not support websocket client")
async with self.driver.websocket(setup) as ws:
yield ws
@abc.abstractmethod
async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any:
"""`Adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
参数:
api: API 名称
data: API 数据
"""
raise NotImplementedError
__autodoc__ = {"Adapter._call_api": True}

162
nonebot/internal/bot.py Normal file
View File

@ -0,0 +1,162 @@
import abc
import asyncio
from functools import partial
from typing_extensions import Protocol
from typing import TYPE_CHECKING, Any, Set, Union, Optional
from nonebot.log import logger
from nonebot.config import Config
from nonebot.exception import MockApiException
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
if TYPE_CHECKING:
from .event import Event
from .adapter import Adapter
from .message import Message, MessageSegment
class _ApiCall(Protocol):
async def __call__(self, **kwargs: Any) -> Any:
...
class Bot(abc.ABC):
"""Bot 基类。
用于处理上报消息,并提供 API 调用接口。
参数:
adapter: 协议适配器实例
self_id: 机器人 ID
"""
_calling_api_hook: Set[T_CallingAPIHook] = set()
"""call_api 时执行的函数"""
_called_api_hook: Set[T_CalledAPIHook] = set()
"""call_api 后执行的函数"""
def __init__(self, adapter: "Adapter", self_id: str):
self.adapter: "Adapter" = adapter
"""协议适配器实例"""
self.self_id: str = self_id
"""机器人 ID"""
def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name)
@property
def type(self) -> str:
"""协议适配器名称"""
return self.adapter.get_name()
@property
def config(self) -> Config:
"""全局 NoneBot 配置"""
return self.adapter.config
async def call_api(self, api: str, **data: Any) -> Any:
"""调用机器人 API 接口,可以通过该函数或直接通过 bot 属性进行调用
参数:
api: API 名称
data: API 数据
用法:
```python
await bot.call_api("send_msg", message="hello world")
await bot.send_msg(message="hello world")
```
"""
result: Any = None
skip_calling_api: bool = False
exception: Optional[Exception] = None
coros = list(map(lambda x: x(self, api, data), self._calling_api_hook))
if coros:
try:
logger.debug("Running CallingAPI hooks...")
await asyncio.gather(*coros)
except MockApiException as e:
skip_calling_api = True
result = e.result
logger.debug(
f"Calling API {api} is cancelled. Return {result} instead."
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
if not skip_calling_api:
try:
result = await self.adapter._call_api(self, api, **data)
except Exception as e:
exception = e
coros = list(
map(lambda x: x(self, exception, api, data, result), self._called_api_hook)
)
if coros:
try:
logger.debug("Running CalledAPI hooks...")
await asyncio.gather(*coros)
except MockApiException as e:
result = e.result
logger.debug(
f"Calling API {api} result is mocked. Return {result} instead."
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
if exception:
raise exception
return result
@abc.abstractmethod
async def send(
self,
event: "Event",
message: Union[str, "Message", "MessageSegment"],
**kwargs: Any,
) -> Any:
"""调用机器人基础发送消息接口
参数:
event: 上报事件
message: 要发送的消息
kwargs: 任意额外参数
"""
raise NotImplementedError
@classmethod
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
"""调用 api 预处理。
钩子函数参数:
- bot: 当前 bot 对象
- api: 调用的 api 名称
- data: api 调用的参数字典
"""
cls._calling_api_hook.add(func)
return func
@classmethod
def on_called_api(cls, func: T_CalledAPIHook) -> T_CalledAPIHook:
"""调用 api 后处理。
钩子函数参数:
- bot: 当前 bot 对象
- exception: 调用 api 时发生的错误
- api: 调用的 api 名称
- data: api 调用的参数字典
- result: api 调用的返回
"""
cls._called_api_hook.add(func)
return func

233
nonebot/internal/driver.py Normal file
View File

@ -0,0 +1,233 @@
import abc
import asyncio
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
from nonebot.log import logger
from nonebot.utils import escape_tag
from nonebot.config import Env, Config
from nonebot.dependencies import Dependent
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
from .params import BotParam, DependParam, DefaultParam
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
if TYPE_CHECKING:
from .bot import Bot
from .adapter import Adapter
BOT_HOOK_PARAMS = [DependParam, BotParam, DefaultParam]
class Driver(abc.ABC):
"""Driver 基类。
参数:
env: 包含环境信息的 Env 对象
config: 包含配置信息的 Config 对象
"""
_adapters: Dict[str, "Adapter"] = {}
"""已注册的适配器列表"""
_bot_connection_hook: Set[Dependent[Any]] = set()
"""Bot 连接建立时执行的函数"""
_bot_disconnection_hook: Set[Dependent[Any]] = set()
"""Bot 连接断开时执行的函数"""
def __init__(self, env: Env, config: Config):
self.env: str = env.environment
"""环境名称"""
self.config: Config = config
"""全局配置对象"""
self._clients: Dict[str, "Bot"] = {}
@property
def bots(self) -> Dict[str, "Bot"]:
"""获取当前所有已连接的 Bot"""
return self._clients
def register_adapter(self, adapter: Type["Adapter"], **kwargs) -> None:
"""注册一个协议适配器
参数:
adapter: 适配器类
kwargs: 其他传递给适配器的参数
"""
name = adapter.get_name()
if name in self._adapters:
logger.opt(colors=True).debug(
f'Adapter "<y>{escape_tag(name)}</y>" already exists'
)
return
self._adapters[name] = adapter(self, **kwargs)
logger.opt(colors=True).debug(
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
)
@property
@abc.abstractmethod
def type(self) -> str:
"""驱动类型名称"""
raise NotImplementedError
@property
@abc.abstractmethod
def logger(self):
"""驱动专属 logger 日志记录器"""
raise NotImplementedError
@abc.abstractmethod
def run(self, *args, **kwargs):
"""
启动驱动框架
"""
logger.opt(colors=True).debug(
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
)
@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
"""注册一个在驱动器启动时执行的函数"""
raise NotImplementedError
@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
"""注册一个在驱动器停止时执行的函数"""
raise NotImplementedError
def on_bot_connect(self, func: T_BotConnectionHook) -> T_BotConnectionHook:
"""装饰一个函数使他在 bot 连接成功时执行。
钩子函数参数:
- bot: 当前连接上的 Bot 对象
"""
self._bot_connection_hook.add(
Dependent[Any].parse(call=func, allow_types=BOT_HOOK_PARAMS)
)
return func
def on_bot_disconnect(self, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook:
"""装饰一个函数使他在 bot 连接断开时执行。
钩子函数参数:
- bot: 当前连接上的 Bot 对象
"""
self._bot_disconnection_hook.add(
Dependent[Any].parse(call=func, allow_types=BOT_HOOK_PARAMS)
)
return func
def _bot_connect(self, bot: "Bot") -> None:
"""在连接成功后,调用该函数来注册 bot 对象"""
if bot.self_id in self._clients:
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
self._clients[bot.self_id] = bot
async def _run_hook(bot: "Bot") -> None:
coros = list(map(lambda x: x(bot=bot), self._bot_connection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketConnection hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot))
def _bot_disconnect(self, bot: "Bot") -> None:
"""在连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._clients:
del self._clients[bot.self_id]
async def _run_hook(bot: "Bot") -> None:
coros = list(map(lambda x: x(bot=bot), self._bot_disconnection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketDisConnection hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot))
class ForwardMixin(abc.ABC):
"""客户端混入基类。"""
@property
@abc.abstractmethod
def type(self) -> str:
"""客户端驱动类型名称"""
raise NotImplementedError
@abc.abstractmethod
async def request(self, setup: Request) -> Response:
"""发送一个 HTTP 请求"""
raise NotImplementedError
@abc.abstractmethod
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
"""发起一个 WebSocket 连接"""
raise NotImplementedError
yield # used for static type checking's generator detection
class ForwardDriver(Driver, ForwardMixin):
"""客户端基类。将客户端框架封装,以满足适配器使用。"""
class ReverseDriver(Driver):
"""服务端基类。将后端框架封装,以满足适配器使用。"""
@property
@abc.abstractmethod
def server_app(self) -> Any:
"""驱动 APP 对象"""
raise NotImplementedError
@property
@abc.abstractmethod
def asgi(self) -> Any:
"""驱动 ASGI 对象"""
raise NotImplementedError
@abc.abstractmethod
def setup_http_server(self, setup: "HTTPServerSetup") -> None:
"""设置一个 HTTP 服务器路由配置"""
raise NotImplementedError
@abc.abstractmethod
def setup_websocket_server(self, setup: "WebSocketServerSetup") -> None:
"""设置一个 WebSocket 服务器路由配置"""
raise NotImplementedError
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
"""将一个驱动器和多个混入类合并。"""
# check first
assert issubclass(driver, Driver), "`driver` must be subclass of Driver"
assert all(
map(lambda m: issubclass(m, ForwardMixin), mixins)
), "`mixins` must be subclass of ForwardMixin"
if not mixins:
return driver
class CombinedDriver(*mixins, driver, ForwardDriver): # type: ignore
@property
def type(self) -> str:
return (
driver.type.__get__(self)
+ "+"
+ "+".join(map(lambda x: x.type.__get__(self), mixins))
)
return CombinedDriver

70
nonebot/internal/event.py Normal file
View File

@ -0,0 +1,70 @@
import abc
from pydantic import BaseModel
from nonebot.utils import DataclassEncoder
from .message import Message
class Event(abc.ABC, BaseModel):
"""Event 基类。提供获取关键信息的方法,其余信息可直接获取。"""
class Config:
extra = "allow"
json_encoders = {Message: DataclassEncoder}
@abc.abstractmethod
def get_type(self) -> str:
"""获取事件类型的方法,类型通常为 NoneBot 内置的四种类型。"""
raise NotImplementedError
@abc.abstractmethod
def get_event_name(self) -> str:
"""获取事件名称的方法。"""
raise NotImplementedError
@abc.abstractmethod
def get_event_description(self) -> str:
"""获取事件描述的方法,通常为事件具体内容。"""
raise NotImplementedError
def __str__(self) -> str:
return f"[{self.get_event_name()}]: {self.get_event_description()}"
def get_log_string(self) -> str:
"""获取事件日志信息的方法。
通常你不需要修改这个方法,只有当希望 NoneBot 隐藏该事件日志时,可以抛出 `NoLogException` 异常。
异常:
NoLogException
"""
return f"[{self.get_event_name()}]: {self.get_event_description()}"
@abc.abstractmethod
def get_user_id(self) -> str:
"""获取事件主体 id 的方法,通常是用户 id 。"""
raise NotImplementedError
@abc.abstractmethod
def get_session_id(self) -> str:
"""获取会话 id 的方法,用于判断当前事件属于哪一个会话,通常是用户 id、群组 id 组合。"""
raise NotImplementedError
@abc.abstractmethod
def get_message(self) -> "Message":
"""获取事件消息内容的方法。"""
raise NotImplementedError
def get_plaintext(self) -> str:
"""获取消息纯文本的方法。
通常不需要修改,默认通过 `get_message().extract_plain_text` 获取。
"""
return self.get_message().extract_plain_text()
@abc.abstractmethod
def is_tome(self) -> bool:
"""获取事件是否与机器人有关的方法。"""
raise NotImplementedError

724
nonebot/internal/matcher.py Normal file
View File

@ -0,0 +1,724 @@
from types import ModuleType
from datetime import datetime
from contextvars import ContextVar
from collections import defaultdict
from contextlib import AsyncExitStack
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Type,
Union,
TypeVar,
Callable,
NoReturn,
Optional,
)
from nonebot.log import logger
from nonebot.dependencies import Dependent
from nonebot.consts import (
ARG_KEY,
RECEIVE_KEY,
REJECT_TARGET,
LAST_RECEIVE_KEY,
REJECT_CACHE_TARGET,
)
from nonebot.typing import (
Any,
T_State,
T_Handler,
T_TypeUpdater,
T_DependencyCache,
T_PermissionUpdater,
)
from nonebot.exception import (
TypeMisMatch,
PausedException,
StopPropagation,
SkippedException,
FinishedException,
RejectedException,
)
from .bot import Bot
from .rule import Rule
from .event import Event
from .template import MessageTemplate
from .permission import USER, Permission
from .message import Message, MessageSegment
from .params import (
Depends,
ArgParam,
BotParam,
EventParam,
StateParam,
DependParam,
DefaultParam,
MatcherParam,
)
if TYPE_CHECKING:
from nonebot.plugin import Plugin
T = TypeVar("T")
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
"""用于存储当前所有的事件响应器"""
current_bot: ContextVar[Bot] = ContextVar("current_bot")
current_event: ContextVar[Event] = ContextVar("current_event")
current_matcher: ContextVar["Matcher"] = ContextVar("current_matcher")
current_handler: ContextVar[Dependent] = ContextVar("current_handler")
class MatcherMeta(type):
if TYPE_CHECKING:
module: Optional[str]
plugin_name: Optional[str]
module_name: Optional[str]
module_prefix: Optional[str]
type: str
rule: Rule
permission: Permission
handlers: List[T_Handler]
priority: int
block: bool
temp: bool
expire_time: Optional[datetime]
def __repr__(self) -> str:
return (
f"<Matcher from {self.module_name or 'unknown'}, "
f"type={self.type}, priority={self.priority}, "
f"temp={self.temp}>"
)
def __str__(self) -> str:
return repr(self)
class Matcher(metaclass=MatcherMeta):
"""事件响应器类"""
plugin: Optional["Plugin"] = None
"""事件响应器所在插件"""
module: Optional[ModuleType] = None
"""事件响应器所在插件模块"""
plugin_name: Optional[str] = None
"""事件响应器所在插件名"""
module_name: Optional[str] = None
"""事件响应器所在点分割插件模块路径"""
type: str = ""
"""事件响应器类型"""
rule: Rule = Rule()
"""事件响应器匹配规则"""
permission: Permission = Permission()
"""事件响应器触发权限"""
handlers: List[Dependent[Any]] = []
"""事件响应器拥有的事件处理函数列表"""
priority: int = 1
"""事件响应器优先级"""
block: bool = False
"""事件响应器是否阻止事件传播"""
temp: bool = False
"""事件响应器是否为临时"""
expire_time: Optional[datetime] = None
"""事件响应器过期时间点"""
_default_state: T_State = {}
"""事件响应器默认状态"""
_default_type_updater: Optional[Dependent[str]] = None
"""事件响应器类型更新函数"""
_default_permission_updater: Optional[Dependent[Permission]] = None
"""事件响应器权限更新函数"""
HANDLER_PARAM_TYPES = [
DependParam,
BotParam,
EventParam,
StateParam,
ArgParam,
MatcherParam,
DefaultParam,
]
def __init__(self):
self.handlers = self.handlers.copy()
self.state = self._default_state.copy()
def __repr__(self) -> str:
return (
f"<Matcher from {self.module_name or 'unknown'}, type={self.type}, "
f"priority={self.priority}, temp={self.temp}>"
)
def __str__(self) -> str:
return repr(self)
@classmethod
def new(
cls,
type_: str = "",
rule: Optional[Rule] = None,
permission: Optional[Permission] = None,
handlers: Optional[List[Union[T_Handler, Dependent[Any]]]] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
*,
plugin: Optional["Plugin"] = None,
module: Optional[ModuleType] = None,
expire_time: Optional[datetime] = None,
default_state: Optional[T_State] = None,
default_type_updater: Optional[Union[T_TypeUpdater, Dependent[str]]] = None,
default_permission_updater: Optional[
Union[T_PermissionUpdater, Dependent[Permission]]
] = None,
) -> Type["Matcher"]:
"""
创建一个新的事件响应器,并存储至 `matchers <#matchers>`_
参数:
type_: 事件响应器类型,与 `event.get_type()` 一致时触发,空字符串表示任意
rule: 匹配规则
permission: 权限
handlers: 事件处理函数列表
temp: 是否为临时事件响应器,即触发一次后删除
priority: 响应优先级
block: 是否阻止事件向更低优先级的响应器传播
plugin: 事件响应器所在插件
module: 事件响应器所在模块
default_state: 默认状态 `state`
expire_time: 事件响应器最终有效时间点,过时即被删除
返回:
Type[Matcher]: 新的事件响应器类
"""
NewMatcher = type(
"Matcher",
(Matcher,),
{
"plugin": plugin,
"module": module,
"plugin_name": plugin and plugin.name,
"module_name": module and module.__name__,
"type": type_,
"rule": rule or Rule(),
"permission": permission or Permission(),
"handlers": [
handler
if isinstance(handler, Dependent)
else Dependent[Any].parse(
call=handler, allow_types=cls.HANDLER_PARAM_TYPES
)
for handler in handlers
]
if handlers
else [],
"temp": temp,
"expire_time": expire_time,
"priority": priority,
"block": block,
"_default_state": default_state or {},
"_default_type_updater": (
default_type_updater
if isinstance(default_type_updater, Dependent)
else default_type_updater
and Dependent[str].parse(
call=default_type_updater, allow_types=cls.HANDLER_PARAM_TYPES
)
),
"_default_permission_updater": (
default_permission_updater
if isinstance(default_permission_updater, Dependent)
else default_permission_updater
and Dependent[Permission].parse(
call=default_permission_updater,
allow_types=cls.HANDLER_PARAM_TYPES,
)
),
},
)
logger.trace(f"Define new matcher {NewMatcher}")
matchers[priority].append(NewMatcher)
return NewMatcher
@classmethod
async def check_perm(
cls,
bot: Bot,
event: Event,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
) -> bool:
"""检查是否满足触发权限
参数:
bot: Bot 对象
event: 上报事件
stack: 异步上下文栈
dependency_cache: 依赖缓存
返回:
是否满足权限
"""
event_type = event.get_type()
return event_type == (cls.type or event_type) and await cls.permission(
bot, event, stack, dependency_cache
)
@classmethod
async def check_rule(
cls,
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
) -> bool:
"""检查是否满足匹配规则
参数:
bot: Bot 对象
event: 上报事件
state: 当前状态
stack: 异步上下文栈
dependency_cache: 依赖缓存
返回:
是否满足匹配规则
"""
event_type = event.get_type()
return event_type == (cls.type or event_type) and await cls.rule(
bot, event, state, stack, dependency_cache
)
@classmethod
def type_updater(cls, func: T_TypeUpdater) -> T_TypeUpdater:
"""装饰一个函数来更改当前事件响应器的默认响应事件类型更新函数
参数:
func: 响应事件类型更新函数
"""
cls._default_type_updater = Dependent[str].parse(
call=func, allow_types=cls.HANDLER_PARAM_TYPES
)
return func
@classmethod
def permission_updater(cls, func: T_PermissionUpdater) -> T_PermissionUpdater:
"""装饰一个函数来更改当前事件响应器的默认会话权限更新函数
参数:
func: 会话权限更新函数
"""
cls._default_permission_updater = Dependent[Permission].parse(
call=func, allow_types=cls.HANDLER_PARAM_TYPES
)
return func
@classmethod
def append_handler(
cls, handler: T_Handler, parameterless: Optional[List[Any]] = None
) -> Dependent[Any]:
handler_ = Dependent[Any].parse(
call=handler,
parameterless=parameterless,
allow_types=cls.HANDLER_PARAM_TYPES,
)
cls.handlers.append(handler_)
return handler_
@classmethod
def handle(
cls, parameterless: Optional[List[Any]] = None
) -> Callable[[T_Handler], T_Handler]:
"""装饰一个函数来向事件响应器直接添加一个处理函数
参数:
parameterless: 非参数类型依赖列表
"""
def _decorator(func: T_Handler) -> T_Handler:
cls.append_handler(func, parameterless=parameterless)
return func
return _decorator
@classmethod
def receive(
cls, id: str = "", parameterless: Optional[List[Any]] = None
) -> Callable[[T_Handler], T_Handler]:
"""装饰一个函数来指示 NoneBot 在接收用户新的一条消息后继续运行该函数
参数:
id: 消息 ID
parameterless: 非参数类型依赖列表
"""
async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]:
matcher.set_target(RECEIVE_KEY.format(id=id))
if matcher.get_target() == RECEIVE_KEY.format(id=id):
matcher.set_receive(id, event)
return
if matcher.get_receive(id, ...) is not ...:
return
await matcher.reject()
_parameterless = [Depends(_receive), *(parameterless or [])]
def _decorator(func: T_Handler) -> T_Handler:
if cls.handlers and cls.handlers[-1].call is func:
func_handler = cls.handlers[-1]
for depend in reversed(_parameterless):
func_handler.prepend_parameterless(depend)
else:
cls.append_handler(func, parameterless=_parameterless)
return func
return _decorator
@classmethod
def got(
cls,
key: str,
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
parameterless: Optional[List[Any]] = None,
) -> Callable[[T_Handler], T_Handler]:
"""装饰一个函数来指示 NoneBot 获取一个参数 `key`
当要获取的 `key` 不存在时接收用户新的一条消息再运行该函数,如果 `key` 已存在则直接继续运行
参数:
key: 参数名
prompt: 在参数不存在时向用户发送的消息
parameterless: 非参数类型依赖列表
"""
async def _key_getter(event: Event, matcher: "Matcher"):
matcher.set_target(ARG_KEY.format(key=key))
if matcher.get_target() == ARG_KEY.format(key=key):
matcher.set_arg(key, event.get_message())
return
if matcher.get_arg(key, ...) is not ...:
return
await matcher.reject(prompt)
_parameterless = [
Depends(_key_getter),
*(parameterless or []),
]
def _decorator(func: T_Handler) -> T_Handler:
if cls.handlers and cls.handlers[-1].call is func:
func_handler = cls.handlers[-1]
for depend in reversed(_parameterless):
func_handler.prepend_parameterless(depend)
else:
cls.append_handler(func, parameterless=_parameterless)
return func
return _decorator
@classmethod
async def send(
cls,
message: Union[str, Message, MessageSegment, MessageTemplate],
**kwargs: Any,
) -> Any:
"""发送一条消息给当前交互用户
参数:
message: 消息内容
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,请参考对应 adapter 的 bot 对象 api
"""
bot = current_bot.get()
event = current_event.get()
state = current_matcher.get().state
if isinstance(message, MessageTemplate):
_message = message.format(**state)
else:
_message = message
return await bot.send(event=event, message=_message, **kwargs)
@classmethod
async def finish(
cls,
message: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
**kwargs,
) -> NoReturn:
"""发送一条消息给当前交互用户并结束当前事件响应器
参数:
message: 消息内容
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,请参考对应 adapter 的 bot 对象 api
"""
if message is not None:
await cls.send(message, **kwargs)
raise FinishedException
@classmethod
async def pause(
cls,
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
**kwargs,
) -> NoReturn:
"""发送一条消息给当前交互用户并暂停事件响应器,在接收用户新的一条消息后继续下一个处理函数
参数:
prompt: 消息内容
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,请参考对应 adapter 的 bot 对象 api
"""
if prompt is not None:
await cls.send(prompt, **kwargs)
raise PausedException
@classmethod
async def reject(
cls,
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
**kwargs,
) -> NoReturn:
"""最近使用 `got` / `receive` 接收的消息不符合预期,
发送一条消息给当前交互用户并将当前事件处理流程中断在当前位置,在接收用户新的一个事件后从头开始执行当前处理函数
参数:
prompt: 消息内容
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,请参考对应 adapter 的 bot 对象 api
"""
if prompt is not None:
await cls.send(prompt, **kwargs)
raise RejectedException
@classmethod
async def reject_arg(
cls,
key: str,
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
**kwargs,
) -> NoReturn:
"""最近使用 `got` 接收的消息不符合预期,
发送一条消息给当前交互用户并将当前事件处理流程中断在当前位置,在接收用户新的一条消息后从头开始执行当前处理函数
参数:
key: 参数名
prompt: 消息内容
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,请参考对应 adapter 的 bot 对象 api
"""
matcher = current_matcher.get()
matcher.set_target(ARG_KEY.format(key=key))
if prompt is not None:
await cls.send(prompt, **kwargs)
raise RejectedException
@classmethod
async def reject_receive(
cls,
id: str = "",
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
**kwargs,
) -> NoReturn:
"""最近使用 `receive` 接收的消息不符合预期,
发送一条消息给当前交互用户并将当前事件处理流程中断在当前位置,在接收用户新的一个事件后从头开始执行当前处理函数
参数:
id: 消息 id
prompt: 消息内容
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,请参考对应 adapter 的 bot 对象 api
"""
matcher = current_matcher.get()
matcher.set_target(RECEIVE_KEY.format(id=id))
if prompt is not None:
await cls.send(prompt, **kwargs)
raise RejectedException
@classmethod
def skip(cls) -> NoReturn:
"""跳过当前事件处理函数,继续下一个处理函数
通常在事件处理函数的依赖中使用。
"""
raise SkippedException
def get_receive(self, id: str, default: T = None) -> Union[Event, T]:
"""获取一个 `receive` 事件
如果没有找到对应的事件,返回 `default` 值
"""
return self.state.get(RECEIVE_KEY.format(id=id), default)
def set_receive(self, id: str, event: Event) -> None:
"""设置一个 `receive` 事件"""
self.state[RECEIVE_KEY.format(id=id)] = event
self.state[LAST_RECEIVE_KEY] = event
def get_last_receive(self, default: T = None) -> Union[Event, T]:
"""获取最近一次 `receive` 事件
如果没有事件,返回 `default` 值
"""
return self.state.get(LAST_RECEIVE_KEY, default)
def get_arg(self, key: str, default: T = None) -> Union[Message, T]:
"""获取一个 `got` 消息
如果没有找到对应的消息,返回 `default` 值
"""
return self.state.get(ARG_KEY.format(key=key), default)
def set_arg(self, key: str, message: Message) -> None:
"""设置一个 `got` 消息"""
self.state[ARG_KEY.format(key=key)] = message
def set_target(self, target: str, cache: bool = True) -> None:
if cache:
self.state[REJECT_CACHE_TARGET] = target
else:
self.state[REJECT_TARGET] = target
def get_target(self, default: T = None) -> Union[str, T]:
return self.state.get(REJECT_TARGET, default)
def stop_propagation(self):
"""阻止事件传播"""
self.block = True
async def update_type(self, bot: Bot, event: Event) -> str:
updater = self.__class__._default_type_updater
if not updater:
return "message"
return await updater(bot=bot, event=event, state=self.state, matcher=self)
async def update_permission(self, bot: Bot, event: Event) -> Permission:
updater = self.__class__._default_permission_updater
if not updater:
return USER(event.get_session_id(), perm=self.permission)
return await updater(bot=bot, event=event, state=self.state, matcher=self)
async def resolve_reject(self):
handler = current_handler.get()
self.handlers.insert(0, handler)
if REJECT_CACHE_TARGET in self.state:
self.state[REJECT_TARGET] = self.state[REJECT_CACHE_TARGET]
async def simple_run(
self,
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
):
logger.trace(
f"Matcher {self} run with incoming args: "
f"bot={bot}, event={event}, state={state}"
)
b_t = current_bot.set(bot)
e_t = current_event.set(event)
m_t = current_matcher.set(self)
try:
# Refresh preprocess state
self.state.update(state)
while self.handlers:
handler = self.handlers.pop(0)
current_handler.set(handler)
logger.debug(f"Running handler {handler}")
try:
await handler(
matcher=self,
bot=bot,
event=event,
state=self.state,
stack=stack,
dependency_cache=dependency_cache,
)
except TypeMisMatch as e:
logger.debug(
f"Handler {handler} param {e.param.name} value {e.value} "
f"mismatch type {e.param._type_display()}, skipped"
)
except SkippedException as e:
logger.debug(f"Handler {handler} skipped")
except StopPropagation:
self.block = True
finally:
logger.info(f"Matcher {self} running complete")
current_bot.reset(b_t)
current_event.reset(e_t)
current_matcher.reset(m_t)
# 运行handlers
async def run(
self,
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
):
try:
await self.simple_run(bot, event, state, stack, dependency_cache)
except RejectedException:
await self.resolve_reject()
type_ = await self.update_type(bot, event)
permission = await self.update_permission(bot, event)
Matcher.new(
type_,
Rule(),
permission,
self.handlers,
temp=True,
priority=0,
block=True,
plugin=self.plugin,
module=self.module,
expire_time=datetime.now() + bot.config.session_expire_timeout,
default_state=self.state,
default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__._default_permission_updater,
)
except PausedException:
type_ = await self.update_type(bot, event)
permission = await self.update_permission(bot, event)
Matcher.new(
type_,
Rule(),
permission,
self.handlers,
temp=True,
priority=0,
block=True,
plugin=self.plugin,
module=self.module,
expire_time=datetime.now() + bot.config.session_expire_timeout,
default_state=self.state,
default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__._default_permission_updater,
)
except FinishedException:
pass
__autodoc__ = {
"MatcherMeta": False,
"Matcher.get_target": False,
"Matcher.set_target": False,
"Matcher.update_type": False,
"Matcher.update_permission": False,
"Matcher.resolve_reject": False,
"Matcher.simple_run": False,
}

341
nonebot/internal/message.py Normal file
View File

@ -0,0 +1,341 @@
import abc
from copy import deepcopy
from dataclasses import field, asdict, dataclass
from typing import (
Any,
Dict,
List,
Type,
Tuple,
Union,
Generic,
TypeVar,
Iterable,
Optional,
overload,
)
from pydantic import parse_obj_as
from .template import MessageTemplate
T = TypeVar("T")
TMS = TypeVar("TMS", bound="MessageSegment")
TM = TypeVar("TM", bound="Message")
@dataclass
class MessageSegment(abc.ABC, Generic[TM]):
"""消息段基类"""
type: str
"""消息段类型"""
data: Dict[str, Any] = field(default_factory=dict)
"""消息段数据"""
@classmethod
@abc.abstractmethod
def get_message_class(cls) -> Type[TM]:
"""获取消息数组类型"""
raise NotImplementedError
@abc.abstractmethod
def __str__(self) -> str:
"""该消息段所代表的 str在命令匹配部分使用"""
raise NotImplementedError
def __len__(self) -> int:
return len(str(self))
def __ne__(self: T, other: T) -> bool:
return not self == other
def __add__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
return self.get_message_class()(self) + other
def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
return self.get_message_class()(other) + self
@classmethod
def __get_validators__(cls):
yield cls._validate
@classmethod
def _validate(cls, value):
if isinstance(value, cls):
return value
if not isinstance(value, dict):
raise ValueError(f"Expected dict for MessageSegment, got {type(value)}")
return cls(**value)
def get(self, key: str, default: Any = None):
return asdict(self).get(key, default)
def keys(self):
return asdict(self).keys()
def values(self):
return asdict(self).values()
def items(self):
return asdict(self).items()
def copy(self: T) -> T:
return deepcopy(self)
@abc.abstractmethod
def is_text(self) -> bool:
"""当前消息段是否为纯文本"""
raise NotImplementedError
class Message(List[TMS], abc.ABC):
"""消息数组
参数:
message: 消息内容
"""
def __init__(
self,
message: Union[str, None, Iterable[TMS], TMS] = None,
):
super().__init__()
if message is None:
return
elif isinstance(message, str):
self.extend(self._construct(message))
elif isinstance(message, MessageSegment):
self.append(message)
elif isinstance(message, Iterable):
self.extend(message)
else:
self.extend(self._construct(message)) # pragma: no cover
@classmethod
def template(cls: Type[TM], format_string: Union[str, TM]) -> MessageTemplate[TM]:
"""创建消息模板。
用法和 `str.format` 大致相同, 但是可以输出消息对象, 并且支持以 `Message` 对象作为消息模板
并且提供了拓展的格式化控制符, 可以用适用于该消息类型的 `MessageSegment` 的工厂方法创建消息
用法:
```python
>>> Message.template("{} {}").format("hello", "world") # 基础演示
Message(MessageSegment(type='text', data={'text': 'hello world'}))
>>> Message.template("{} {}").format(MessageSegment.image("file///..."), "world") # 支持消息段等对象
Message(MessageSegment(type='image', data={'file': 'file///...'}), MessageSegment(type='text', data={'text': 'world'}))
>>> Message.template( # 支持以Message对象作为消息模板
... MessageSegment.text('test {event.user_id}') + MessageSegment.face(233) +
... MessageSegment.text('test {event.message}')).format(event={'user_id':123456, 'message':'hello world'})
Message(MessageSegment(type='text', data={'text': 'test 123456'}),
MessageSegment(type='face', data={'face': 233}),
MessageSegment(type='text', data={'text': 'test hello world'}))
>>> Message.template("{link:image}").format(link='https://...') # 支持拓展格式化控制符
Message(MessageSegment(type='image', data={'file': 'https://...'}))
```
参数:
format_string: 格式化字符串
返回:
消息格式化器
"""
return MessageTemplate(format_string, cls)
@classmethod
@abc.abstractmethod
def get_segment_class(cls) -> Type[TMS]:
"""获取消息段类型"""
raise NotImplementedError
def __str__(self) -> str:
return "".join(str(seg) for seg in self)
@classmethod
def __get_validators__(cls):
yield cls._validate
@classmethod
def _validate(cls, value):
if isinstance(value, cls):
return value
elif isinstance(value, Message):
raise ValueError(f"Type {type(value)} can not be converted to {cls}")
elif isinstance(value, str):
pass
elif isinstance(value, dict):
value = parse_obj_as(cls.get_segment_class(), value)
elif isinstance(value, Iterable):
value = [parse_obj_as(cls.get_segment_class(), v) for v in value]
else:
raise ValueError(
f"Expected str, dict or iterable for Message, got {type(value)}"
)
return cls(value)
@staticmethod
@abc.abstractmethod
def _construct(msg: str) -> Iterable[TMS]:
"""构造消息数组"""
raise NotImplementedError
def __add__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM:
result = self.copy()
result += other
return result
def __radd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM:
result = self.__class__(other)
return result + self
def __iadd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM:
if isinstance(other, str):
self.extend(self._construct(other))
elif isinstance(other, MessageSegment):
self.append(other)
elif isinstance(other, Iterable):
self.extend(other)
else:
raise ValueError(f"Unsupported type: {type(other)}") # pragma: no cover
return self
@overload
def __getitem__(self: TM, __args: str) -> TM:
"""
参数:
__args: 消息段类型
返回:
所有类型为 `__args` 的消息段
"""
@overload
def __getitem__(self, __args: Tuple[str, int]) -> TMS:
"""
参数:
__args: 消息段类型和索引
返回:
类型为 `__args[0]` 的消息段第 `__args[1]` 个
"""
@overload
def __getitem__(self: TM, __args: Tuple[str, slice]) -> TM:
"""
参数:
__args: 消息段类型和切片
返回:
类型为 `__args[0]` 的消息段切片 `__args[1]`
"""
@overload
def __getitem__(self, __args: int) -> TMS:
"""
参数:
__args: 索引
返回:
第 `__args` 个消息段
"""
@overload
def __getitem__(self: TM, __args: slice) -> TM:
"""
参数:
__args: 切片
返回:
消息切片 `__args`
"""
def __getitem__(
self: TM,
args: Union[
str,
Tuple[str, int],
Tuple[str, slice],
int,
slice,
],
) -> Union[TMS, TM]:
arg1, arg2 = args if isinstance(args, tuple) else (args, None)
if isinstance(arg1, int) and arg2 is None:
return super().__getitem__(arg1)
elif isinstance(arg1, slice) and arg2 is None:
return self.__class__(super().__getitem__(arg1))
elif isinstance(arg1, str) and arg2 is None:
return self.__class__(seg for seg in self if seg.type == arg1)
elif isinstance(arg1, str) and isinstance(arg2, int):
return [seg for seg in self if seg.type == arg1][arg2]
elif isinstance(arg1, str) and isinstance(arg2, slice):
return self.__class__([seg for seg in self if seg.type == arg1][arg2])
else:
raise ValueError("Incorrect arguments to slice") # pragma: no cover
def index(self, value: Union[TMS, str], *args) -> int:
if isinstance(value, str):
first_segment = next((seg for seg in self if seg.type == value), None)
if first_segment is None:
raise ValueError(f"Segment with type {value} is not in message")
return super().index(first_segment, *args)
return super().index(value, *args)
def get(self: TM, type_: str, count: Optional[int] = None) -> TM:
if count is None:
return self[type_]
iterator, filtered = (
seg for seg in self if seg.type == type_
), self.__class__()
for _ in range(count):
seg = next(iterator, None)
if seg is None:
break
filtered.append(seg)
return filtered
def count(self, value: Union[TMS, str]) -> int:
return len(self[value]) if isinstance(value, str) else super().count(value)
def append(self: TM, obj: Union[str, TMS]) -> TM:
"""添加一个消息段到消息数组末尾。
参数:
obj: 要添加的消息段
"""
if isinstance(obj, MessageSegment):
super().append(obj)
elif isinstance(obj, str):
self.extend(self._construct(obj))
else:
raise ValueError(f"Unexpected type: {type(obj)} {obj}") # pragma: no cover
return self
def extend(self: TM, obj: Union[TM, Iterable[TMS]]) -> TM:
"""拼接一个消息数组或多个消息段到消息数组末尾。
参数:
obj: 要添加的消息数组
"""
for segment in obj:
self.append(segment)
return self
def copy(self: TM) -> TM:
return deepcopy(self)
def extract_plain_text(self) -> str:
"""提取消息内纯文本消息"""
return "".join(str(seg) for seg in self if seg.is_text())
__autodoc__ = {
"MessageSegment.__str__": True,
"MessageSegment.__add__": True,
"Message.__getitem__": True,
"Message._construct": True,
}

338
nonebot/internal/model.py Normal file
View File

@ -0,0 +1,338 @@
import abc
from enum import Enum
from dataclasses import dataclass
from http.cookiejar import Cookie, CookieJar
from typing import (
IO,
Any,
Dict,
List,
Tuple,
Union,
Mapping,
Callable,
Iterator,
Optional,
Awaitable,
MutableMapping,
)
from yarl import URL as URL
from multidict import CIMultiDict
RawURL = Tuple[bytes, bytes, Optional[int], bytes]
SimpleQuery = Union[str, int, float]
QueryVariable = Union[SimpleQuery, List[SimpleQuery]]
QueryTypes = Union[
None, str, Mapping[str, QueryVariable], List[Tuple[str, QueryVariable]]
]
HeaderTypes = Union[
None,
CIMultiDict[str],
Dict[str, str],
List[Tuple[str, str]],
]
CookieTypes = Union[None, "Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]]
ContentTypes = Union[str, bytes, None]
DataTypes = Union[dict, None]
FileContent = Union[IO[bytes], bytes]
FileType = Tuple[Optional[str], FileContent, Optional[str]]
FileTypes = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
Tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
FileType,
]
FilesTypes = Union[Dict[str, FileTypes], List[Tuple[str, FileTypes]], None]
class HTTPVersion(Enum):
H10 = "1.0"
H11 = "1.1"
H2 = "2"
class Request:
def __init__(
self,
method: Union[str, bytes],
url: Union["URL", str, RawURL],
*,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
content: ContentTypes = None,
data: DataTypes = None,
json: Any = None,
files: FilesTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
):
# method
self.method: str = (
method.decode("ascii").upper()
if isinstance(method, bytes)
else method.upper()
)
# http version
self.version: HTTPVersion = HTTPVersion(version)
# timeout
self.timeout: Optional[float] = timeout
# proxy
self.proxy: Optional[str] = proxy
# url
if isinstance(url, tuple):
scheme, host, port, path = url
url = URL.build(
scheme=scheme.decode("ascii"),
host=host.decode("ascii"),
port=port,
path=path.decode("ascii"),
)
else:
url = URL(url)
if params is not None:
url = url.update_query(params)
self.url: URL = url
# headers
self.headers: CIMultiDict[str]
if headers is not None:
self.headers = CIMultiDict(headers)
else:
self.headers = CIMultiDict()
# cookies
self.cookies = Cookies(cookies)
# body
self.content: ContentTypes = content
self.data: DataTypes = data
self.json: Any = json
self.files: Optional[List[Tuple[str, FileType]]] = None
if files:
self.files = []
files_ = files.items() if isinstance(files, dict) else files
for name, file_info in files_:
if not isinstance(file_info, tuple):
self.files.append((name, (None, file_info, None)))
elif len(file_info) == 2:
self.files.append((name, (file_info[0], file_info[1], None)))
else:
self.files.append((name, file_info)) # type: ignore
def __repr__(self) -> str:
class_name = self.__class__.__name__
url = str(self.url)
return f"<{class_name}({self.method!r}, {url!r})>"
class Response:
def __init__(
self,
status_code: int,
*,
headers: HeaderTypes = None,
content: ContentTypes = None,
request: Optional[Request] = None,
):
# status code
self.status_code: int = status_code
# headers
self.headers: CIMultiDict[str]
if headers is not None:
self.headers = CIMultiDict(headers)
else:
self.headers = CIMultiDict()
# body
self.content: ContentTypes = content
# request
self.request: Optional[Request] = request
class WebSocket(abc.ABC):
def __init__(self, *, request: Request):
# request
self.request: Request = request
@property
@abc.abstractmethod
def closed(self) -> bool:
"""
连接是否已经关闭
"""
raise NotImplementedError
@abc.abstractmethod
async def accept(self) -> None:
"""接受 WebSocket 连接请求"""
raise NotImplementedError
@abc.abstractmethod
async def close(self, code: int = 1000, reason: str = "") -> None:
"""关闭 WebSocket 连接请求"""
raise NotImplementedError
@abc.abstractmethod
async def receive(self) -> str:
"""接收一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def receive_bytes(self) -> bytes:
"""接收一条 WebSocket binary 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, data: str) -> None:
"""发送一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send_bytes(self, data: bytes) -> None:
"""发送一条 WebSocket binary 信息"""
raise NotImplementedError
class Cookies(MutableMapping):
def __init__(self, cookies: CookieTypes = None) -> None:
self.jar: CookieJar = cookies if isinstance(cookies, CookieJar) else CookieJar()
if cookies is not None and not isinstance(cookies, CookieJar):
if isinstance(cookies, dict):
for key, value in cookies.items():
self.set(key, value)
elif isinstance(cookies, list):
for key, value in cookies:
self.set(key, value)
elif isinstance(cookies, Cookies):
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
else:
raise TypeError(f"Cookies must be dict or list, not {type(cookies)}")
def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None:
cookie = Cookie(
version=0,
name=name,
value=value,
port=None,
port_specified=False,
domain=domain,
domain_specified=bool(domain),
domain_initial_dot=domain.startswith("."),
path=path,
path_specified=bool(path),
secure=False,
expires=None,
discard=True,
comment=None,
comment_url=None,
rest={},
rfc2109=False,
)
self.jar.set_cookie(cookie)
def get(
self,
name: str,
default: Optional[str] = None,
domain: str = None,
path: str = None,
) -> Optional[str]:
value: Optional[str] = None
for cookie in self.jar:
if (
cookie.name == name
and (domain is None or cookie.domain == domain)
and (path is None or cookie.path == path)
):
if value is not None:
message = f"Multiple cookies exist with name={name}"
raise ValueError(message)
value = cookie.value
return default if value is None else value
def delete(
self, name: str, domain: Optional[str] = None, path: Optional[str] = None
) -> None:
if domain is not None and path is not None:
return self.jar.clear(domain, path, name)
remove = [
cookie
for cookie in self.jar
if cookie.name == name
and (domain is None or cookie.domain == domain)
and (path is None or cookie.path == path)
]
for cookie in remove:
self.jar.clear(cookie.domain, cookie.path, cookie.name)
def clear(self, domain: Optional[str] = None, path: Optional[str] = None) -> None:
self.jar.clear(domain, path)
def update(self, cookies: CookieTypes = None) -> None:
cookies = Cookies(cookies)
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
def __setitem__(self, name: str, value: str) -> None:
return self.set(name, value)
def __getitem__(self, name: str) -> str:
value = self.get(name)
if value is None:
raise KeyError(name)
return value
def __delitem__(self, name: str) -> None:
return self.delete(name)
def __len__(self) -> int:
return len(self.jar)
def __iter__(self) -> Iterator[Cookie]:
return (cookie for cookie in self.jar)
def __repr__(self) -> str:
cookies_repr = ", ".join(
[
f"<Cookie {cookie.name}={cookie.value} for {cookie.domain} />"
for cookie in self.jar
]
)
return f"<Cookies [{cookies_repr}]>"
@dataclass
class HTTPServerSetup:
"""HTTP 服务器路由配置。"""
path: URL # path should not be absolute, check it by URL.is_absolute() == False
method: str
name: str
handle_func: Callable[[Request], Awaitable[Response]]
@dataclass
class WebSocketServerSetup:
"""WebSocket 服务器路由配置。"""
path: URL # path should not be absolute, check it by URL.is_absolute() == False
name: str
handle_func: Callable[[WebSocket], Awaitable[Any]]

378
nonebot/internal/params.py Normal file
View File

@ -0,0 +1,378 @@
import asyncio
import inspect
import warnings
from typing_extensions import Literal
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic.fields import Required, Undefined, ModelField
from nonebot.log import logger
from nonebot.exception import TypeMisMatch
from nonebot.dependencies.utils import check_field_type
from nonebot.dependencies import Param, Dependent, CustomConfig
from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.utils import (
get_name,
run_sync,
is_gen_callable,
run_sync_ctx_manager,
is_async_gen_callable,
is_coroutine_callable,
generic_check_issubclass,
)
if TYPE_CHECKING:
from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Event
class DependsInner:
def __init__(
self,
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True,
) -> None:
self.dependency = dependency
self.use_cache = use_cache
def __repr__(self) -> str:
dep = get_name(self.dependency)
cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({dep}{cache})"
def Depends(
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True,
) -> Any:
"""子依赖装饰器
参数:
dependency: 依赖函数。默认为参数的类型注释。
use_cache: 是否使用缓存。默认为 `True`。
用法:
```python
def depend_func() -> Any:
return ...
def depend_gen_func():
try:
yield ...
finally:
...
async def handler(param_name: Any = Depends(depend_func), gen: Any = Depends(depend_gen_func)):
...
```
"""
return DependsInner(dependency, use_cache=use_cache)
class DependParam(Param):
"""子依赖参数"""
@classmethod
def _check_param(
cls,
dependent: Dependent,
name: str,
param: inspect.Parameter,
) -> Optional["DependParam"]:
if isinstance(param.default, DependsInner):
dependency: T_Handler
if param.default.dependency is None:
assert param.annotation is not param.empty, "Dependency cannot be empty"
dependency = param.annotation
else:
dependency = param.default.dependency
sub_dependent = Dependent[Any].parse(
call=dependency,
allow_types=dependent.allow_types,
)
dependent.pre_checkers.extend(sub_dependent.pre_checkers)
sub_dependent.pre_checkers.clear()
return cls(
Required, use_cache=param.default.use_cache, dependent=sub_dependent
)
@classmethod
def _check_parameterless(
cls, dependent: "Dependent", value: Any
) -> Optional["Param"]:
if isinstance(value, DependsInner):
assert value.dependency, "Dependency cannot be empty"
dependent = Dependent[Any].parse(
call=value.dependency, allow_types=dependent.allow_types
)
return cls(Required, use_cache=value.use_cache, dependent=dependent)
async def _solve(
self,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
**kwargs: Any,
) -> Any:
use_cache: bool = self.extra["use_cache"]
dependency_cache = {} if dependency_cache is None else dependency_cache
sub_dependent: Dependent = self.extra["dependent"]
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
call = sub_dependent.call
# solve sub dependency with current cache
sub_values = await sub_dependent.solve(
stack=stack,
dependency_cache=dependency_cache,
**kwargs,
)
# run dependency function
task: asyncio.Task[Any]
if use_cache and call in dependency_cache:
solved = await dependency_cache[call]
elif is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(call):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
cm = asynccontextmanager(call)(**sub_values)
task = asyncio.create_task(stack.enter_async_context(cm))
dependency_cache[call] = task
solved = await task
elif is_coroutine_callable(call):
task = asyncio.create_task(call(**sub_values))
dependency_cache[call] = task
solved = await task
else:
task = asyncio.create_task(run_sync(call)(**sub_values))
dependency_cache[call] = task
solved = await task
return solved
class _BotChecker(Param):
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
try:
return check_field_type(field, bot)
except TypeMisMatch:
logger.debug(
f"Bot type {type(bot)} not match "
f"annotation {field._type_display()}, ignored"
)
raise
class BotParam(Param):
"""{ref}`nonebot.adapters.Bot` 参数"""
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["BotParam"]:
from nonebot.adapters import Bot
if param.default == param.empty:
if generic_check_issubclass(param.annotation, Bot):
if param.annotation is not Bot:
dependent.pre_checkers.append(
_BotChecker(
Required,
field=ModelField(
name=name,
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
),
)
)
return cls(Required)
elif param.annotation == param.empty and name == "bot":
return cls(Required)
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
return bot
class _EventChecker(Param):
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
try:
return check_field_type(field, event)
except TypeMisMatch:
logger.debug(
f"Event type {type(event)} not match "
f"annotation {field._type_display()}, ignored"
)
raise
class EventParam(Param):
"""{ref}`nonebot.adapters.Event` 参数"""
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["EventParam"]:
from nonebot.adapters import Event
if param.default == param.empty:
if generic_check_issubclass(param.annotation, Event):
if param.annotation is not Event:
dependent.pre_checkers.append(
_EventChecker(
Required,
field=ModelField(
name=name,
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
),
)
)
return cls(Required)
elif param.annotation == param.empty and name == "event":
return cls(Required)
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
return event
class StateInner(T_State):
...
def State() -> T_State:
"""**Deprecated**: 事件处理状态参数,请直接使用 {ref}`nonebot.typing.T_State`"""
warnings.warn("State() is deprecated, use `T_State` instead", DeprecationWarning)
return StateInner()
class StateParam(Param):
"""事件处理状态参数"""
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["StateParam"]:
if isinstance(param.default, StateInner):
return cls(Required)
elif param.default == param.empty:
if param.annotation is T_State:
return cls(Required)
elif param.annotation == param.empty and name == "state":
return cls(Required)
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
return state
class MatcherParam(Param):
"""事件响应器实例参数"""
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["MatcherParam"]:
from nonebot.matcher import Matcher
if generic_check_issubclass(param.annotation, Matcher) or (
param.annotation == param.empty and name == "matcher"
):
return cls(Required)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
return matcher
class ArgInner:
def __init__(
self, key: Optional[str], type: Literal["message", "str", "plaintext"]
) -> None:
self.key = key
self.type = type
def Arg(key: Optional[str] = None) -> Any:
"""`got` 的 Arg 参数消息"""
return ArgInner(key, "message")
def ArgStr(key: Optional[str] = None) -> str:
"""`got` 的 Arg 参数消息文本"""
return ArgInner(key, "str") # type: ignore
def ArgPlainText(key: Optional[str] = None) -> str:
"""`got` 的 Arg 参数消息纯文本"""
return ArgInner(key, "plaintext") # type: ignore
class ArgParam(Param):
"""`got` 的 Arg 参数"""
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["ArgParam"]:
if isinstance(param.default, ArgInner):
return cls(Required, key=param.default.key or name, type=param.default.type)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
message = matcher.get_arg(self.extra["key"])
if message is None:
return message
if self.extra["type"] == "message":
return message
elif self.extra["type"] == "str":
return str(message)
else:
return message.extract_plain_text()
class ExceptionParam(Param):
"""`run_postprocessor` 的异常参数"""
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["ExceptionParam"]:
if generic_check_issubclass(param.annotation, Exception) or (
param.annotation == param.empty and name == "exception"
):
return cls(Required)
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
return exception
class DefaultParam(Param):
"""默认值参数"""
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["DefaultParam"]:
if param.default != param.empty:
return cls(param.default)
async def _solve(self, **kwargs: Any) -> Any:
return Undefined
__autodoc__ = {
"DependsInner": False,
"StateInner": False,
"ArgInner": False,
}

View File

@ -0,0 +1,133 @@
import asyncio
from contextlib import AsyncExitStack
from typing import Any, Set, Tuple, Union, NoReturn, Optional, Coroutine
from nonebot.adapters import Bot, Event
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.typing import T_DependencyCache, T_PermissionChecker
from .params import BotParam, EventParam, DependParam, DefaultParam
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
try:
return await coro
except SkippedException:
return False
class Permission:
"""{ref}`nonebot.matcher.Matcher` 权限类。
当事件传递时,在 {ref}`nonebot.matcher.Matcher` 运行前进行检查。
参数:
checkers: PermissionChecker
用法:
```python
Permission(async_function) | sync_function
# 等价于
Permission(async_function, sync_function)
```
"""
__slots__ = ("checkers",)
HANDLER_PARAM_TYPES = [
DependParam,
BotParam,
EventParam,
DefaultParam,
]
def __init__(self, *checkers: Union[T_PermissionChecker, Dependent[bool]]) -> None:
self.checkers: Set[Dependent[bool]] = set(
checker
if isinstance(checker, Dependent)
else Dependent[bool].parse(
call=checker, allow_types=self.HANDLER_PARAM_TYPES
)
for checker in checkers
)
"""存储 `PermissionChecker`"""
async def __call__(
self,
bot: Bot,
event: Event,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
) -> bool:
"""检查是否满足某个权限
参数:
bot: Bot 对象
event: Event 对象
stack: 异步上下文栈
dependency_cache: 依赖缓存
"""
if not self.checkers:
return True
results = await asyncio.gather(
*(
_run_coro_with_catch(
checker(
bot=bot,
event=event,
stack=stack,
dependency_cache=dependency_cache,
)
)
for checker in self.checkers
),
)
return any(results)
def __and__(self, other) -> NoReturn:
raise RuntimeError("And operation between Permissions is not allowed.")
def __or__(
self, other: Optional[Union["Permission", T_PermissionChecker]]
) -> "Permission":
if other is None:
return self
elif isinstance(other, Permission):
return Permission(*self.checkers, *other.checkers)
else:
return Permission(*self.checkers, other)
class User:
"""检查当前事件是否属于指定会话
参数:
users: 会话 ID 元组
perm: 需同时满足的权限
"""
__slots__ = ("users", "perm")
def __init__(
self, users: Tuple[str, ...], perm: Optional[Permission] = None
) -> None:
self.users = users
self.perm = perm
async def __call__(self, bot: Bot, event: Event) -> bool:
return bool(
event.get_session_id() in self.users
and (self.perm is None or await self.perm(bot, event))
)
def USER(*users: str, perm: Optional[Permission] = None):
"""匹配当前事件属于指定会话
参数:
user: 会话白名单
perm: 需要同时满足的权限
"""
return Permission(User(users, perm))

95
nonebot/internal/rule.py Normal file
View File

@ -0,0 +1,95 @@
import asyncio
from contextlib import AsyncExitStack
from typing import Set, Union, NoReturn, Optional
from nonebot.adapters import Bot, Event
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache
from .params import BotParam, EventParam, StateParam, DependParam, DefaultParam
class Rule:
"""{ref}`nonebot.matcher.Matcher` 规则类。
当事件传递时,在 {ref}`nonebot.matcher.Matcher` 运行前进行检查。
参数:
*checkers: RuleChecker
用法:
```python
Rule(async_function) & sync_function
# 等价于
Rule(async_function, sync_function)
```
"""
__slots__ = ("checkers",)
HANDLER_PARAM_TYPES = [
DependParam,
BotParam,
EventParam,
StateParam,
DefaultParam,
]
def __init__(self, *checkers: Union[T_RuleChecker, Dependent[bool]]) -> None:
self.checkers: Set[Dependent[bool]] = set(
checker
if isinstance(checker, Dependent)
else Dependent[bool].parse(
call=checker, allow_types=self.HANDLER_PARAM_TYPES
)
for checker in checkers
)
"""存储 `RuleChecker`"""
async def __call__(
self,
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
) -> bool:
"""检查是否符合所有规则
参数:
bot: Bot 对象
event: Event 对象
state: 当前 State
stack: 异步上下文栈
dependency_cache: 依赖缓存
"""
if not self.checkers:
return True
try:
results = await asyncio.gather(
*(
checker(
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
)
for checker in self.checkers
)
)
except SkippedException:
return False
return all(results)
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
if other is None:
return self
elif isinstance(other, Rule):
return Rule(*self.checkers, *other.checkers)
else:
return Rule(*self.checkers, other)
def __or__(self, other) -> NoReturn:
raise RuntimeError("Or operation between rules is not allowed.")

View File

@ -0,0 +1,181 @@
import inspect
import functools
from string import Formatter
from typing import (
TYPE_CHECKING,
Any,
Set,
Dict,
List,
Type,
Tuple,
Union,
Generic,
Mapping,
TypeVar,
Callable,
Optional,
Sequence,
cast,
overload,
)
if TYPE_CHECKING:
from .message import Message, MessageSegment
TM = TypeVar("TM", bound="Message")
TF = TypeVar("TF", str, "Message")
FormatSpecFunc = Callable[[Any], str]
FormatSpecFunc_T = TypeVar("FormatSpecFunc_T", bound=FormatSpecFunc)
class MessageTemplate(Formatter, Generic[TF]):
"""消息模板格式化实现类。
参数:
template: 模板
factory: 消息构造类型,默认为 `str`
"""
@overload
def __init__(
self: "MessageTemplate[str]", template: str, factory: Type[str] = str
) -> None:
...
@overload
def __init__(
self: "MessageTemplate[TM]", template: Union[str, TM], factory: Type[TM]
) -> None:
...
def __init__(self, template, factory=str) -> None:
self.template: TF = template
self.factory: Type[TF] = factory
self.format_specs: Dict[str, FormatSpecFunc] = {}
def add_format_spec(
self, spec: FormatSpecFunc_T, name: Optional[str] = None
) -> FormatSpecFunc_T:
name = name or spec.__name__
if name in self.format_specs:
raise ValueError(f"Format spec {name} already exists!")
self.format_specs[name] = spec
return spec
def format(self, *args: Any, **kwargs: Any) -> TF:
"""根据模板和参数生成消息对象"""
msg = self.factory()
if isinstance(self.template, str):
msg += self.vformat(self.template, args, kwargs)
elif isinstance(self.template, self.factory):
template = cast("Message[MessageSegment]", self.template)
for seg in template:
msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg
else:
raise TypeError("template must be a string or instance of Message!")
return msg # type:ignore
def vformat(
self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]
) -> TF:
used_args = set()
result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
self.check_unused_args(list(used_args), args, kwargs)
return result
def _vformat(
self,
format_string: str,
args: Sequence[Any],
kwargs: Mapping[str, Any],
used_args: Set[Union[int, str]],
recursion_depth: int,
auto_arg_index: int = 0,
) -> Tuple[TF, int]:
if recursion_depth < 0:
raise ValueError("Max string recursion exceeded")
results: List[Any] = []
for (literal_text, field_name, format_spec, conversion) in self.parse(
format_string
):
# output the literal text
if literal_text:
results.append(literal_text)
# if there's a field, output it
if field_name is not None:
# this is some markup, find the object and do
# the formatting
# handle arg indexing when empty field_names are given.
if field_name == "":
if auto_arg_index is False:
raise ValueError(
"cannot switch from manual field specification to "
"automatic field numbering"
)
field_name = str(auto_arg_index)
auto_arg_index += 1
elif field_name.isdigit():
if auto_arg_index:
raise ValueError(
"cannot switch from manual field specification to "
"automatic field numbering"
)
# disable auto arg incrementing, if it gets
# used later on, then an exception will be raised
auto_arg_index = False
# given the field_name, find the object it references
# and the argument it came from
obj, arg_used = self.get_field(field_name, args, kwargs)
used_args.add(arg_used)
assert format_spec is not None
# do any conversion on the resulting object
obj = self.convert_field(obj, conversion) if conversion else obj
# expand the format spec, if needed
format_control, auto_arg_index = self._vformat(
format_spec,
args,
kwargs,
used_args,
recursion_depth - 1,
auto_arg_index,
)
# format the object and append to the result
formatted_text = self.format_field(obj, str(format_control))
results.append(formatted_text)
return (
self.factory(functools.reduce(self._add, results or [""])),
auto_arg_index,
)
def format_field(self, value: Any, format_spec: str) -> Any:
formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec)
if formatter is None and not issubclass(self.factory, str):
segment_class: Type["MessageSegment"] = self.factory.get_segment_class()
method = getattr(segment_class, format_spec, None)
if inspect.ismethod(method):
formatter = getattr(segment_class, format_spec)
return (
super().format_field(value, format_spec)
if formatter is None
else formatter(value)
)
def _add(self, a: Any, b: Any) -> Any:
try:
return a + b
except TypeError:
return a + str(b)