mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-09-06 12:06:45 +00:00
♻️ reorganize internal tree
This commit is contained in:
6
nonebot/internal/adapter/__init__.py
Normal file
6
nonebot/internal/adapter/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .bot import Bot as Bot
|
||||
from .event import Event as Event
|
||||
from .adapter import Adapter as Adapter
|
||||
from .message import Message as Message
|
||||
from .message import MessageSegment as MessageSegment
|
||||
from .template import MessageTemplate as MessageTemplate
|
106
nonebot/internal/adapter/adapter.py
Normal file
106
nonebot/internal/adapter/adapter.py
Normal file
@ -0,0 +1,106 @@
|
||||
import abc
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, AsyncGenerator
|
||||
|
||||
from nonebot.config import Config
|
||||
from nonebot.internal.driver import (
|
||||
Driver,
|
||||
Request,
|
||||
Response,
|
||||
WebSocket,
|
||||
ForwardDriver,
|
||||
ReverseDriver,
|
||||
HTTPServerSetup,
|
||||
WebSocketServerSetup,
|
||||
)
|
||||
|
||||
from .bot import Bot
|
||||
|
||||
|
||||
class Adapter(abc.ABC):
|
||||
"""协议适配器基类。
|
||||
|
||||
通常,在 Adapter 中编写协议通信相关代码,如: 建立通信连接、处理接收与发送 data 等。
|
||||
|
||||
参数:
|
||||
driver: {ref}`nonebot.drivers.Driver` 实例
|
||||
kwargs: 其他由 {ref}`nonebot.drivers.Driver.register_adapter` 传入的额外参数
|
||||
"""
|
||||
|
||||
def __init__(self, driver: Driver, **kwargs: Any):
|
||||
self.driver: Driver = driver
|
||||
"""{ref}`nonebot.drivers.Driver` 实例"""
|
||||
self.bots: Dict[str, Bot] = {}
|
||||
"""本协议适配器已建立连接的 {ref}`nonebot.adapters.Bot` 实例"""
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_name(cls) -> str:
|
||||
"""当前协议适配器的名称"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
"""全局 NoneBot 配置"""
|
||||
return self.driver.config
|
||||
|
||||
def bot_connect(self, bot: Bot) -> None:
|
||||
"""告知 NoneBot 建立了一个新的 {ref}`nonebot.adapters.Bot` 连接。
|
||||
|
||||
当有新的 {ref}`nonebot.adapters.Bot` 实例连接建立成功时调用。
|
||||
|
||||
参数:
|
||||
bot: {ref}`nonebot.adapters.Bot` 实例
|
||||
"""
|
||||
self.driver._bot_connect(bot)
|
||||
self.bots[bot.self_id] = bot
|
||||
|
||||
def bot_disconnect(self, bot: Bot) -> None:
|
||||
"""告知 NoneBot {ref}`nonebot.adapters.Bot` 连接已断开。
|
||||
|
||||
当有 {ref}`nonebot.adapters.Bot` 实例连接断开时调用。
|
||||
|
||||
参数:
|
||||
bot: {ref}`nonebot.adapters.Bot` 实例
|
||||
"""
|
||||
self.driver._bot_disconnect(bot)
|
||||
self.bots.pop(bot.self_id, None)
|
||||
|
||||
def setup_http_server(self, setup: HTTPServerSetup):
|
||||
"""设置一个 HTTP 服务器路由配置"""
|
||||
if not isinstance(self.driver, ReverseDriver):
|
||||
raise TypeError("Current driver does not support http server")
|
||||
self.driver.setup_http_server(setup)
|
||||
|
||||
def setup_websocket_server(self, setup: WebSocketServerSetup):
|
||||
"""设置一个 WebSocket 服务器路由配置"""
|
||||
if not isinstance(self.driver, ReverseDriver):
|
||||
raise TypeError("Current driver does not support websocket server")
|
||||
self.driver.setup_websocket_server(setup)
|
||||
|
||||
async def request(self, setup: Request) -> Response:
|
||||
"""进行一个 HTTP 客户端请求"""
|
||||
if not isinstance(self.driver, ForwardDriver):
|
||||
raise TypeError("Current driver does not support http client")
|
||||
return await self.driver.request(setup)
|
||||
|
||||
@asynccontextmanager
|
||||
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
|
||||
"""建立一个 WebSocket 客户端连接请求"""
|
||||
if not isinstance(self.driver, ForwardDriver):
|
||||
raise TypeError("Current driver does not support websocket client")
|
||||
async with self.driver.websocket(setup) as ws:
|
||||
yield ws
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any:
|
||||
"""`Adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
|
||||
|
||||
参数:
|
||||
api: API 名称
|
||||
data: API 数据
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
__autodoc__ = {"Adapter._call_api": True}
|
162
nonebot/internal/adapter/bot.py
Normal file
162
nonebot/internal/adapter/bot.py
Normal file
@ -0,0 +1,162 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing_extensions import Protocol
|
||||
from typing import TYPE_CHECKING, Any, Set, Union, Optional
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Config
|
||||
from nonebot.exception import MockApiException
|
||||
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .event import Event
|
||||
from .adapter import Adapter
|
||||
from .message import Message, MessageSegment
|
||||
|
||||
|
||||
class _ApiCall(Protocol):
|
||||
async def __call__(self, **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
|
||||
class Bot(abc.ABC):
|
||||
"""Bot 基类。
|
||||
|
||||
用于处理上报消息,并提供 API 调用接口。
|
||||
|
||||
参数:
|
||||
adapter: 协议适配器实例
|
||||
self_id: 机器人 ID
|
||||
"""
|
||||
|
||||
_calling_api_hook: Set[T_CallingAPIHook] = set()
|
||||
"""call_api 时执行的函数"""
|
||||
_called_api_hook: Set[T_CalledAPIHook] = set()
|
||||
"""call_api 后执行的函数"""
|
||||
|
||||
def __init__(self, adapter: "Adapter", self_id: str):
|
||||
self.adapter: "Adapter" = adapter
|
||||
"""协议适配器实例"""
|
||||
self.self_id: str = self_id
|
||||
"""机器人 ID"""
|
||||
|
||||
def __getattr__(self, name: str) -> _ApiCall:
|
||||
return partial(self.call_api, name)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""协议适配器名称"""
|
||||
return self.adapter.get_name()
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
"""全局 NoneBot 配置"""
|
||||
return self.adapter.config
|
||||
|
||||
async def call_api(self, api: str, **data: Any) -> Any:
|
||||
"""调用机器人 API 接口,可以通过该函数或直接通过 bot 属性进行调用
|
||||
|
||||
参数:
|
||||
api: API 名称
|
||||
data: API 数据
|
||||
|
||||
用法:
|
||||
```python
|
||||
await bot.call_api("send_msg", message="hello world")
|
||||
await bot.send_msg(message="hello world")
|
||||
```
|
||||
"""
|
||||
|
||||
result: Any = None
|
||||
skip_calling_api: bool = False
|
||||
exception: Optional[Exception] = None
|
||||
|
||||
coros = list(map(lambda x: x(self, api, data), self._calling_api_hook))
|
||||
if coros:
|
||||
try:
|
||||
logger.debug("Running CallingAPI hooks...")
|
||||
await asyncio.gather(*coros)
|
||||
except MockApiException as e:
|
||||
skip_calling_api = True
|
||||
result = e.result
|
||||
logger.debug(
|
||||
f"Calling API {api} is cancelled. Return {result} instead."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
|
||||
"Running cancelled!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
if not skip_calling_api:
|
||||
try:
|
||||
result = await self.adapter._call_api(self, api, **data)
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
coros = list(
|
||||
map(lambda x: x(self, exception, api, data, result), self._called_api_hook)
|
||||
)
|
||||
if coros:
|
||||
try:
|
||||
logger.debug("Running CalledAPI hooks...")
|
||||
await asyncio.gather(*coros)
|
||||
except MockApiException as e:
|
||||
result = e.result
|
||||
logger.debug(
|
||||
f"Calling API {api} result is mocked. Return {result} instead."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
|
||||
"Running cancelled!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return result
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send(
|
||||
self,
|
||||
event: "Event",
|
||||
message: Union[str, "Message", "MessageSegment"],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""调用机器人基础发送消息接口
|
||||
|
||||
参数:
|
||||
event: 上报事件
|
||||
message: 要发送的消息
|
||||
kwargs: 任意额外参数
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
|
||||
"""调用 api 预处理。
|
||||
|
||||
钩子函数参数:
|
||||
|
||||
- bot: 当前 bot 对象
|
||||
- api: 调用的 api 名称
|
||||
- data: api 调用的参数字典
|
||||
"""
|
||||
cls._calling_api_hook.add(func)
|
||||
return func
|
||||
|
||||
@classmethod
|
||||
def on_called_api(cls, func: T_CalledAPIHook) -> T_CalledAPIHook:
|
||||
"""调用 api 后处理。
|
||||
|
||||
钩子函数参数:
|
||||
|
||||
- bot: 当前 bot 对象
|
||||
- exception: 调用 api 时发生的错误
|
||||
- api: 调用的 api 名称
|
||||
- data: api 调用的参数字典
|
||||
- result: api 调用的返回
|
||||
"""
|
||||
cls._called_api_hook.add(func)
|
||||
return func
|
70
nonebot/internal/adapter/event.py
Normal file
70
nonebot/internal/adapter/event.py
Normal file
@ -0,0 +1,70 @@
|
||||
import abc
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from nonebot.utils import DataclassEncoder
|
||||
|
||||
from .message import Message
|
||||
|
||||
|
||||
class Event(abc.ABC, BaseModel):
|
||||
"""Event 基类。提供获取关键信息的方法,其余信息可直接获取。"""
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
json_encoders = {Message: DataclassEncoder}
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_type(self) -> str:
|
||||
"""获取事件类型的方法,类型通常为 NoneBot 内置的四种类型。"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_event_name(self) -> str:
|
||||
"""获取事件名称的方法。"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_event_description(self) -> str:
|
||||
"""获取事件描述的方法,通常为事件具体内容。"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"[{self.get_event_name()}]: {self.get_event_description()}"
|
||||
|
||||
def get_log_string(self) -> str:
|
||||
"""获取事件日志信息的方法。
|
||||
|
||||
通常你不需要修改这个方法,只有当希望 NoneBot 隐藏该事件日志时,可以抛出 `NoLogException` 异常。
|
||||
|
||||
异常:
|
||||
NoLogException
|
||||
"""
|
||||
return f"[{self.get_event_name()}]: {self.get_event_description()}"
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_user_id(self) -> str:
|
||||
"""获取事件主体 id 的方法,通常是用户 id 。"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_session_id(self) -> str:
|
||||
"""获取会话 id 的方法,用于判断当前事件属于哪一个会话,通常是用户 id、群组 id 组合。"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_message(self) -> "Message":
|
||||
"""获取事件消息内容的方法。"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_plaintext(self) -> str:
|
||||
"""获取消息纯文本的方法。
|
||||
|
||||
通常不需要修改,默认通过 `get_message().extract_plain_text` 获取。
|
||||
"""
|
||||
return self.get_message().extract_plain_text()
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_tome(self) -> bool:
|
||||
"""获取事件是否与机器人有关的方法。"""
|
||||
raise NotImplementedError
|
341
nonebot/internal/adapter/message.py
Normal file
341
nonebot/internal/adapter/message.py
Normal file
@ -0,0 +1,341 @@
|
||||
import abc
|
||||
from copy import deepcopy
|
||||
from dataclasses import field, asdict, dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Type,
|
||||
Tuple,
|
||||
Union,
|
||||
Generic,
|
||||
TypeVar,
|
||||
Iterable,
|
||||
Optional,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
from .template import MessageTemplate
|
||||
|
||||
T = TypeVar("T")
|
||||
TMS = TypeVar("TMS", bound="MessageSegment")
|
||||
TM = TypeVar("TM", bound="Message")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSegment(abc.ABC, Generic[TM]):
|
||||
"""消息段基类"""
|
||||
|
||||
type: str
|
||||
"""消息段类型"""
|
||||
data: Dict[str, Any] = field(default_factory=dict)
|
||||
"""消息段数据"""
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_message_class(cls) -> Type[TM]:
|
||||
"""获取消息数组类型"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def __str__(self) -> str:
|
||||
"""该消息段所代表的 str,在命令匹配部分使用"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(str(self))
|
||||
|
||||
def __ne__(self: T, other: T) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __add__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
|
||||
return self.get_message_class()(self) + other
|
||||
|
||||
def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
|
||||
return self.get_message_class()(other) + self
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls._validate
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, value):
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"Expected dict for MessageSegment, got {type(value)}")
|
||||
return cls(**value)
|
||||
|
||||
def get(self, key: str, default: Any = None):
|
||||
return asdict(self).get(key, default)
|
||||
|
||||
def keys(self):
|
||||
return asdict(self).keys()
|
||||
|
||||
def values(self):
|
||||
return asdict(self).values()
|
||||
|
||||
def items(self):
|
||||
return asdict(self).items()
|
||||
|
||||
def copy(self: T) -> T:
|
||||
return deepcopy(self)
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_text(self) -> bool:
|
||||
"""当前消息段是否为纯文本"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Message(List[TMS], abc.ABC):
|
||||
"""消息数组
|
||||
|
||||
参数:
|
||||
message: 消息内容
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: Union[str, None, Iterable[TMS], TMS] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if message is None:
|
||||
return
|
||||
elif isinstance(message, str):
|
||||
self.extend(self._construct(message))
|
||||
elif isinstance(message, MessageSegment):
|
||||
self.append(message)
|
||||
elif isinstance(message, Iterable):
|
||||
self.extend(message)
|
||||
else:
|
||||
self.extend(self._construct(message)) # pragma: no cover
|
||||
|
||||
@classmethod
|
||||
def template(cls: Type[TM], format_string: Union[str, TM]) -> MessageTemplate[TM]:
|
||||
"""创建消息模板。
|
||||
|
||||
用法和 `str.format` 大致相同, 但是可以输出消息对象, 并且支持以 `Message` 对象作为消息模板
|
||||
|
||||
并且提供了拓展的格式化控制符, 可以用适用于该消息类型的 `MessageSegment` 的工厂方法创建消息
|
||||
|
||||
用法:
|
||||
```python
|
||||
>>> Message.template("{} {}").format("hello", "world") # 基础演示
|
||||
Message(MessageSegment(type='text', data={'text': 'hello world'}))
|
||||
>>> Message.template("{} {}").format(MessageSegment.image("file///..."), "world") # 支持消息段等对象
|
||||
Message(MessageSegment(type='image', data={'file': 'file///...'}), MessageSegment(type='text', data={'text': 'world'}))
|
||||
>>> Message.template( # 支持以Message对象作为消息模板
|
||||
... MessageSegment.text('test {event.user_id}') + MessageSegment.face(233) +
|
||||
... MessageSegment.text('test {event.message}')).format(event={'user_id':123456, 'message':'hello world'})
|
||||
Message(MessageSegment(type='text', data={'text': 'test 123456'}),
|
||||
MessageSegment(type='face', data={'face': 233}),
|
||||
MessageSegment(type='text', data={'text': 'test hello world'}))
|
||||
>>> Message.template("{link:image}").format(link='https://...') # 支持拓展格式化控制符
|
||||
Message(MessageSegment(type='image', data={'file': 'https://...'}))
|
||||
```
|
||||
|
||||
参数:
|
||||
format_string: 格式化字符串
|
||||
|
||||
返回:
|
||||
消息格式化器
|
||||
"""
|
||||
return MessageTemplate(format_string, cls)
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_segment_class(cls) -> Type[TMS]:
|
||||
"""获取消息段类型"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "".join(str(seg) for seg in self)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls._validate
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, value):
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
elif isinstance(value, Message):
|
||||
raise ValueError(f"Type {type(value)} can not be converted to {cls}")
|
||||
elif isinstance(value, str):
|
||||
pass
|
||||
elif isinstance(value, dict):
|
||||
value = parse_obj_as(cls.get_segment_class(), value)
|
||||
elif isinstance(value, Iterable):
|
||||
value = [parse_obj_as(cls.get_segment_class(), v) for v in value]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected str, dict or iterable for Message, got {type(value)}"
|
||||
)
|
||||
return cls(value)
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def _construct(msg: str) -> Iterable[TMS]:
|
||||
"""构造消息数组"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __add__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM:
|
||||
result = self.copy()
|
||||
result += other
|
||||
return result
|
||||
|
||||
def __radd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM:
|
||||
result = self.__class__(other)
|
||||
return result + self
|
||||
|
||||
def __iadd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM:
|
||||
if isinstance(other, str):
|
||||
self.extend(self._construct(other))
|
||||
elif isinstance(other, MessageSegment):
|
||||
self.append(other)
|
||||
elif isinstance(other, Iterable):
|
||||
self.extend(other)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type: {type(other)}") # pragma: no cover
|
||||
return self
|
||||
|
||||
@overload
|
||||
def __getitem__(self: TM, __args: str) -> TM:
|
||||
"""
|
||||
参数:
|
||||
__args: 消息段类型
|
||||
|
||||
返回:
|
||||
所有类型为 `__args` 的消息段
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __getitem__(self, __args: Tuple[str, int]) -> TMS:
|
||||
"""
|
||||
参数:
|
||||
__args: 消息段类型和索引
|
||||
|
||||
返回:
|
||||
类型为 `__args[0]` 的消息段第 `__args[1]` 个
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __getitem__(self: TM, __args: Tuple[str, slice]) -> TM:
|
||||
"""
|
||||
参数:
|
||||
__args: 消息段类型和切片
|
||||
|
||||
返回:
|
||||
类型为 `__args[0]` 的消息段切片 `__args[1]`
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __getitem__(self, __args: int) -> TMS:
|
||||
"""
|
||||
参数:
|
||||
__args: 索引
|
||||
|
||||
返回:
|
||||
第 `__args` 个消息段
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __getitem__(self: TM, __args: slice) -> TM:
|
||||
"""
|
||||
参数:
|
||||
__args: 切片
|
||||
|
||||
返回:
|
||||
消息切片 `__args`
|
||||
"""
|
||||
|
||||
def __getitem__(
|
||||
self: TM,
|
||||
args: Union[
|
||||
str,
|
||||
Tuple[str, int],
|
||||
Tuple[str, slice],
|
||||
int,
|
||||
slice,
|
||||
],
|
||||
) -> Union[TMS, TM]:
|
||||
arg1, arg2 = args if isinstance(args, tuple) else (args, None)
|
||||
if isinstance(arg1, int) and arg2 is None:
|
||||
return super().__getitem__(arg1)
|
||||
elif isinstance(arg1, slice) and arg2 is None:
|
||||
return self.__class__(super().__getitem__(arg1))
|
||||
elif isinstance(arg1, str) and arg2 is None:
|
||||
return self.__class__(seg for seg in self if seg.type == arg1)
|
||||
elif isinstance(arg1, str) and isinstance(arg2, int):
|
||||
return [seg for seg in self if seg.type == arg1][arg2]
|
||||
elif isinstance(arg1, str) and isinstance(arg2, slice):
|
||||
return self.__class__([seg for seg in self if seg.type == arg1][arg2])
|
||||
else:
|
||||
raise ValueError("Incorrect arguments to slice") # pragma: no cover
|
||||
|
||||
def index(self, value: Union[TMS, str], *args) -> int:
|
||||
if isinstance(value, str):
|
||||
first_segment = next((seg for seg in self if seg.type == value), None)
|
||||
if first_segment is None:
|
||||
raise ValueError(f"Segment with type {value} is not in message")
|
||||
return super().index(first_segment, *args)
|
||||
return super().index(value, *args)
|
||||
|
||||
def get(self: TM, type_: str, count: Optional[int] = None) -> TM:
|
||||
if count is None:
|
||||
return self[type_]
|
||||
|
||||
iterator, filtered = (
|
||||
seg for seg in self if seg.type == type_
|
||||
), self.__class__()
|
||||
for _ in range(count):
|
||||
seg = next(iterator, None)
|
||||
if seg is None:
|
||||
break
|
||||
filtered.append(seg)
|
||||
return filtered
|
||||
|
||||
def count(self, value: Union[TMS, str]) -> int:
|
||||
return len(self[value]) if isinstance(value, str) else super().count(value)
|
||||
|
||||
def append(self: TM, obj: Union[str, TMS]) -> TM:
|
||||
"""添加一个消息段到消息数组末尾。
|
||||
|
||||
参数:
|
||||
obj: 要添加的消息段
|
||||
"""
|
||||
if isinstance(obj, MessageSegment):
|
||||
super().append(obj)
|
||||
elif isinstance(obj, str):
|
||||
self.extend(self._construct(obj))
|
||||
else:
|
||||
raise ValueError(f"Unexpected type: {type(obj)} {obj}") # pragma: no cover
|
||||
return self
|
||||
|
||||
def extend(self: TM, obj: Union[TM, Iterable[TMS]]) -> TM:
|
||||
"""拼接一个消息数组或多个消息段到消息数组末尾。
|
||||
|
||||
参数:
|
||||
obj: 要添加的消息数组
|
||||
"""
|
||||
for segment in obj:
|
||||
self.append(segment)
|
||||
return self
|
||||
|
||||
def copy(self: TM) -> TM:
|
||||
return deepcopy(self)
|
||||
|
||||
def extract_plain_text(self) -> str:
|
||||
"""提取消息内纯文本消息"""
|
||||
|
||||
return "".join(str(seg) for seg in self if seg.is_text())
|
||||
|
||||
|
||||
__autodoc__ = {
|
||||
"MessageSegment.__str__": True,
|
||||
"MessageSegment.__add__": True,
|
||||
"Message.__getitem__": True,
|
||||
"Message._construct": True,
|
||||
}
|
181
nonebot/internal/adapter/template.py
Normal file
181
nonebot/internal/adapter/template.py
Normal file
@ -0,0 +1,181 @@
|
||||
import inspect
|
||||
import functools
|
||||
from string import Formatter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Set,
|
||||
Dict,
|
||||
List,
|
||||
Type,
|
||||
Tuple,
|
||||
Union,
|
||||
Generic,
|
||||
Mapping,
|
||||
TypeVar,
|
||||
Callable,
|
||||
Optional,
|
||||
Sequence,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .message import Message, MessageSegment
|
||||
|
||||
TM = TypeVar("TM", bound="Message")
|
||||
TF = TypeVar("TF", str, "Message")
|
||||
|
||||
FormatSpecFunc = Callable[[Any], str]
|
||||
FormatSpecFunc_T = TypeVar("FormatSpecFunc_T", bound=FormatSpecFunc)
|
||||
|
||||
|
||||
class MessageTemplate(Formatter, Generic[TF]):
|
||||
"""消息模板格式化实现类。
|
||||
|
||||
参数:
|
||||
template: 模板
|
||||
factory: 消息构造类型,默认为 `str`
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self: "MessageTemplate[str]", template: str, factory: Type[str] = str
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self: "MessageTemplate[TM]", template: Union[str, TM], factory: Type[TM]
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def __init__(self, template, factory=str) -> None:
|
||||
self.template: TF = template
|
||||
self.factory: Type[TF] = factory
|
||||
self.format_specs: Dict[str, FormatSpecFunc] = {}
|
||||
|
||||
def add_format_spec(
|
||||
self, spec: FormatSpecFunc_T, name: Optional[str] = None
|
||||
) -> FormatSpecFunc_T:
|
||||
name = name or spec.__name__
|
||||
if name in self.format_specs:
|
||||
raise ValueError(f"Format spec {name} already exists!")
|
||||
self.format_specs[name] = spec
|
||||
return spec
|
||||
|
||||
def format(self, *args: Any, **kwargs: Any) -> TF:
|
||||
"""根据模板和参数生成消息对象"""
|
||||
msg = self.factory()
|
||||
if isinstance(self.template, str):
|
||||
msg += self.vformat(self.template, args, kwargs)
|
||||
elif isinstance(self.template, self.factory):
|
||||
template = cast("Message[MessageSegment]", self.template)
|
||||
for seg in template:
|
||||
msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg
|
||||
else:
|
||||
raise TypeError("template must be a string or instance of Message!")
|
||||
|
||||
return msg # type:ignore
|
||||
|
||||
def vformat(
|
||||
self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]
|
||||
) -> TF:
|
||||
used_args = set()
|
||||
result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
|
||||
self.check_unused_args(list(used_args), args, kwargs)
|
||||
return result
|
||||
|
||||
def _vformat(
|
||||
self,
|
||||
format_string: str,
|
||||
args: Sequence[Any],
|
||||
kwargs: Mapping[str, Any],
|
||||
used_args: Set[Union[int, str]],
|
||||
recursion_depth: int,
|
||||
auto_arg_index: int = 0,
|
||||
) -> Tuple[TF, int]:
|
||||
if recursion_depth < 0:
|
||||
raise ValueError("Max string recursion exceeded")
|
||||
|
||||
results: List[Any] = []
|
||||
|
||||
for (literal_text, field_name, format_spec, conversion) in self.parse(
|
||||
format_string
|
||||
):
|
||||
|
||||
# output the literal text
|
||||
if literal_text:
|
||||
results.append(literal_text)
|
||||
|
||||
# if there's a field, output it
|
||||
if field_name is not None:
|
||||
# this is some markup, find the object and do
|
||||
# the formatting
|
||||
|
||||
# handle arg indexing when empty field_names are given.
|
||||
if field_name == "":
|
||||
if auto_arg_index is False:
|
||||
raise ValueError(
|
||||
"cannot switch from manual field specification to "
|
||||
"automatic field numbering"
|
||||
)
|
||||
field_name = str(auto_arg_index)
|
||||
auto_arg_index += 1
|
||||
elif field_name.isdigit():
|
||||
if auto_arg_index:
|
||||
raise ValueError(
|
||||
"cannot switch from manual field specification to "
|
||||
"automatic field numbering"
|
||||
)
|
||||
# disable auto arg incrementing, if it gets
|
||||
# used later on, then an exception will be raised
|
||||
auto_arg_index = False
|
||||
|
||||
# given the field_name, find the object it references
|
||||
# and the argument it came from
|
||||
obj, arg_used = self.get_field(field_name, args, kwargs)
|
||||
used_args.add(arg_used)
|
||||
|
||||
assert format_spec is not None
|
||||
|
||||
# do any conversion on the resulting object
|
||||
obj = self.convert_field(obj, conversion) if conversion else obj
|
||||
|
||||
# expand the format spec, if needed
|
||||
format_control, auto_arg_index = self._vformat(
|
||||
format_spec,
|
||||
args,
|
||||
kwargs,
|
||||
used_args,
|
||||
recursion_depth - 1,
|
||||
auto_arg_index,
|
||||
)
|
||||
|
||||
# format the object and append to the result
|
||||
formatted_text = self.format_field(obj, str(format_control))
|
||||
results.append(formatted_text)
|
||||
|
||||
return (
|
||||
self.factory(functools.reduce(self._add, results or [""])),
|
||||
auto_arg_index,
|
||||
)
|
||||
|
||||
def format_field(self, value: Any, format_spec: str) -> Any:
|
||||
formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec)
|
||||
if formatter is None and not issubclass(self.factory, str):
|
||||
segment_class: Type["MessageSegment"] = self.factory.get_segment_class()
|
||||
method = getattr(segment_class, format_spec, None)
|
||||
if inspect.ismethod(method):
|
||||
formatter = getattr(segment_class, format_spec)
|
||||
return (
|
||||
super().format_field(value, format_spec)
|
||||
if formatter is None
|
||||
else formatter(value)
|
||||
)
|
||||
|
||||
def _add(self, a: Any, b: Any) -> Any:
|
||||
try:
|
||||
return a + b
|
||||
except TypeError:
|
||||
return a + str(b)
|
Reference in New Issue
Block a user