♻️ reorganize internal tree

This commit is contained in:
yanyongyu
2022-02-06 17:08:11 +08:00
parent 65dc9a908b
commit 118519e15d
16 changed files with 52 additions and 25 deletions

View File

@ -0,0 +1,6 @@
from .bot import Bot as Bot
from .event import Event as Event
from .adapter import Adapter as Adapter
from .message import Message as Message
from .message import MessageSegment as MessageSegment
from .template import MessageTemplate as MessageTemplate

View File

@ -0,0 +1,106 @@
import abc
from contextlib import asynccontextmanager
from typing import Any, Dict, AsyncGenerator
from nonebot.config import Config
from nonebot.internal.driver import (
Driver,
Request,
Response,
WebSocket,
ForwardDriver,
ReverseDriver,
HTTPServerSetup,
WebSocketServerSetup,
)
from .bot import Bot
class Adapter(abc.ABC):
"""协议适配器基类。
通常,在 Adapter 中编写协议通信相关代码,如: 建立通信连接、处理接收与发送 data 等。
参数:
driver: {ref}`nonebot.drivers.Driver` 实例
kwargs: 其他由 {ref}`nonebot.drivers.Driver.register_adapter` 传入的额外参数
"""
def __init__(self, driver: Driver, **kwargs: Any):
self.driver: Driver = driver
"""{ref}`nonebot.drivers.Driver` 实例"""
self.bots: Dict[str, Bot] = {}
"""本协议适配器已建立连接的 {ref}`nonebot.adapters.Bot` 实例"""
@classmethod
@abc.abstractmethod
def get_name(cls) -> str:
"""当前协议适配器的名称"""
raise NotImplementedError
@property
def config(self) -> Config:
"""全局 NoneBot 配置"""
return self.driver.config
def bot_connect(self, bot: Bot) -> None:
"""告知 NoneBot 建立了一个新的 {ref}`nonebot.adapters.Bot` 连接。
当有新的 {ref}`nonebot.adapters.Bot` 实例连接建立成功时调用。
参数:
bot: {ref}`nonebot.adapters.Bot` 实例
"""
self.driver._bot_connect(bot)
self.bots[bot.self_id] = bot
def bot_disconnect(self, bot: Bot) -> None:
"""告知 NoneBot {ref}`nonebot.adapters.Bot` 连接已断开。
当有 {ref}`nonebot.adapters.Bot` 实例连接断开时调用。
参数:
bot: {ref}`nonebot.adapters.Bot` 实例
"""
self.driver._bot_disconnect(bot)
self.bots.pop(bot.self_id, None)
def setup_http_server(self, setup: HTTPServerSetup):
"""设置一个 HTTP 服务器路由配置"""
if not isinstance(self.driver, ReverseDriver):
raise TypeError("Current driver does not support http server")
self.driver.setup_http_server(setup)
def setup_websocket_server(self, setup: WebSocketServerSetup):
"""设置一个 WebSocket 服务器路由配置"""
if not isinstance(self.driver, ReverseDriver):
raise TypeError("Current driver does not support websocket server")
self.driver.setup_websocket_server(setup)
async def request(self, setup: Request) -> Response:
"""进行一个 HTTP 客户端请求"""
if not isinstance(self.driver, ForwardDriver):
raise TypeError("Current driver does not support http client")
return await self.driver.request(setup)
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
"""建立一个 WebSocket 客户端连接请求"""
if not isinstance(self.driver, ForwardDriver):
raise TypeError("Current driver does not support websocket client")
async with self.driver.websocket(setup) as ws:
yield ws
@abc.abstractmethod
async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any:
"""`Adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
参数:
api: API 名称
data: API 数据
"""
raise NotImplementedError
__autodoc__ = {"Adapter._call_api": True}

View File

@ -0,0 +1,162 @@
import abc
import asyncio
from functools import partial
from typing_extensions import Protocol
from typing import TYPE_CHECKING, Any, Set, Union, Optional
from nonebot.log import logger
from nonebot.config import Config
from nonebot.exception import MockApiException
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
if TYPE_CHECKING:
from .event import Event
from .adapter import Adapter
from .message import Message, MessageSegment
class _ApiCall(Protocol):
async def __call__(self, **kwargs: Any) -> Any:
...
class Bot(abc.ABC):
"""Bot 基类。
用于处理上报消息,并提供 API 调用接口。
参数:
adapter: 协议适配器实例
self_id: 机器人 ID
"""
_calling_api_hook: Set[T_CallingAPIHook] = set()
"""call_api 时执行的函数"""
_called_api_hook: Set[T_CalledAPIHook] = set()
"""call_api 后执行的函数"""
def __init__(self, adapter: "Adapter", self_id: str):
self.adapter: "Adapter" = adapter
"""协议适配器实例"""
self.self_id: str = self_id
"""机器人 ID"""
def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name)
@property
def type(self) -> str:
"""协议适配器名称"""
return self.adapter.get_name()
@property
def config(self) -> Config:
"""全局 NoneBot 配置"""
return self.adapter.config
async def call_api(self, api: str, **data: Any) -> Any:
"""调用机器人 API 接口,可以通过该函数或直接通过 bot 属性进行调用
参数:
api: API 名称
data: API 数据
用法:
```python
await bot.call_api("send_msg", message="hello world")
await bot.send_msg(message="hello world")
```
"""
result: Any = None
skip_calling_api: bool = False
exception: Optional[Exception] = None
coros = list(map(lambda x: x(self, api, data), self._calling_api_hook))
if coros:
try:
logger.debug("Running CallingAPI hooks...")
await asyncio.gather(*coros)
except MockApiException as e:
skip_calling_api = True
result = e.result
logger.debug(
f"Calling API {api} is cancelled. Return {result} instead."
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
if not skip_calling_api:
try:
result = await self.adapter._call_api(self, api, **data)
except Exception as e:
exception = e
coros = list(
map(lambda x: x(self, exception, api, data, result), self._called_api_hook)
)
if coros:
try:
logger.debug("Running CalledAPI hooks...")
await asyncio.gather(*coros)
except MockApiException as e:
result = e.result
logger.debug(
f"Calling API {api} result is mocked. Return {result} instead."
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
if exception:
raise exception
return result
@abc.abstractmethod
async def send(
self,
event: "Event",
message: Union[str, "Message", "MessageSegment"],
**kwargs: Any,
) -> Any:
"""调用机器人基础发送消息接口
参数:
event: 上报事件
message: 要发送的消息
kwargs: 任意额外参数
"""
raise NotImplementedError
@classmethod
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
"""调用 api 预处理。
钩子函数参数:
- bot: 当前 bot 对象
- api: 调用的 api 名称
- data: api 调用的参数字典
"""
cls._calling_api_hook.add(func)
return func
@classmethod
def on_called_api(cls, func: T_CalledAPIHook) -> T_CalledAPIHook:
"""调用 api 后处理。
钩子函数参数:
- bot: 当前 bot 对象
- exception: 调用 api 时发生的错误
- api: 调用的 api 名称
- data: api 调用的参数字典
- result: api 调用的返回
"""
cls._called_api_hook.add(func)
return func

View File

@ -0,0 +1,70 @@
import abc
from pydantic import BaseModel
from nonebot.utils import DataclassEncoder
from .message import Message
class Event(abc.ABC, BaseModel):
"""Event 基类。提供获取关键信息的方法,其余信息可直接获取。"""
class Config:
extra = "allow"
json_encoders = {Message: DataclassEncoder}
@abc.abstractmethod
def get_type(self) -> str:
"""获取事件类型的方法,类型通常为 NoneBot 内置的四种类型。"""
raise NotImplementedError
@abc.abstractmethod
def get_event_name(self) -> str:
"""获取事件名称的方法。"""
raise NotImplementedError
@abc.abstractmethod
def get_event_description(self) -> str:
"""获取事件描述的方法,通常为事件具体内容。"""
raise NotImplementedError
def __str__(self) -> str:
return f"[{self.get_event_name()}]: {self.get_event_description()}"
def get_log_string(self) -> str:
"""获取事件日志信息的方法。
通常你不需要修改这个方法,只有当希望 NoneBot 隐藏该事件日志时,可以抛出 `NoLogException` 异常。
异常:
NoLogException
"""
return f"[{self.get_event_name()}]: {self.get_event_description()}"
@abc.abstractmethod
def get_user_id(self) -> str:
"""获取事件主体 id 的方法,通常是用户 id 。"""
raise NotImplementedError
@abc.abstractmethod
def get_session_id(self) -> str:
"""获取会话 id 的方法,用于判断当前事件属于哪一个会话,通常是用户 id、群组 id 组合。"""
raise NotImplementedError
@abc.abstractmethod
def get_message(self) -> "Message":
"""获取事件消息内容的方法。"""
raise NotImplementedError
def get_plaintext(self) -> str:
"""获取消息纯文本的方法。
通常不需要修改,默认通过 `get_message().extract_plain_text` 获取。
"""
return self.get_message().extract_plain_text()
@abc.abstractmethod
def is_tome(self) -> bool:
"""获取事件是否与机器人有关的方法。"""
raise NotImplementedError

View File

@ -0,0 +1,341 @@
import abc
from copy import deepcopy
from dataclasses import field, asdict, dataclass
from typing import (
Any,
Dict,
List,
Type,
Tuple,
Union,
Generic,
TypeVar,
Iterable,
Optional,
overload,
)
from pydantic import parse_obj_as
from .template import MessageTemplate
T = TypeVar("T")
TMS = TypeVar("TMS", bound="MessageSegment")
TM = TypeVar("TM", bound="Message")
@dataclass
class MessageSegment(abc.ABC, Generic[TM]):
"""消息段基类"""
type: str
"""消息段类型"""
data: Dict[str, Any] = field(default_factory=dict)
"""消息段数据"""
@classmethod
@abc.abstractmethod
def get_message_class(cls) -> Type[TM]:
"""获取消息数组类型"""
raise NotImplementedError
@abc.abstractmethod
def __str__(self) -> str:
"""该消息段所代表的 str在命令匹配部分使用"""
raise NotImplementedError
def __len__(self) -> int:
return len(str(self))
def __ne__(self: T, other: T) -> bool:
return not self == other
def __add__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
return self.get_message_class()(self) + other
def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
return self.get_message_class()(other) + self
@classmethod
def __get_validators__(cls):
yield cls._validate
@classmethod
def _validate(cls, value):
if isinstance(value, cls):
return value
if not isinstance(value, dict):
raise ValueError(f"Expected dict for MessageSegment, got {type(value)}")
return cls(**value)
def get(self, key: str, default: Any = None):
return asdict(self).get(key, default)
def keys(self):
return asdict(self).keys()
def values(self):
return asdict(self).values()
def items(self):
return asdict(self).items()
def copy(self: T) -> T:
return deepcopy(self)
@abc.abstractmethod
def is_text(self) -> bool:
"""当前消息段是否为纯文本"""
raise NotImplementedError
class Message(List[TMS], abc.ABC):
"""消息数组
参数:
message: 消息内容
"""
def __init__(
self,
message: Union[str, None, Iterable[TMS], TMS] = None,
):
super().__init__()
if message is None:
return
elif isinstance(message, str):
self.extend(self._construct(message))
elif isinstance(message, MessageSegment):
self.append(message)
elif isinstance(message, Iterable):
self.extend(message)
else:
self.extend(self._construct(message)) # pragma: no cover
@classmethod
def template(cls: Type[TM], format_string: Union[str, TM]) -> MessageTemplate[TM]:
"""创建消息模板。
用法和 `str.format` 大致相同, 但是可以输出消息对象, 并且支持以 `Message` 对象作为消息模板
并且提供了拓展的格式化控制符, 可以用适用于该消息类型的 `MessageSegment` 的工厂方法创建消息
用法:
```python
>>> Message.template("{} {}").format("hello", "world") # 基础演示
Message(MessageSegment(type='text', data={'text': 'hello world'}))
>>> Message.template("{} {}").format(MessageSegment.image("file///..."), "world") # 支持消息段等对象
Message(MessageSegment(type='image', data={'file': 'file///...'}), MessageSegment(type='text', data={'text': 'world'}))
>>> Message.template( # 支持以Message对象作为消息模板
... MessageSegment.text('test {event.user_id}') + MessageSegment.face(233) +
... MessageSegment.text('test {event.message}')).format(event={'user_id':123456, 'message':'hello world'})
Message(MessageSegment(type='text', data={'text': 'test 123456'}),
MessageSegment(type='face', data={'face': 233}),
MessageSegment(type='text', data={'text': 'test hello world'}))
>>> Message.template("{link:image}").format(link='https://...') # 支持拓展格式化控制符
Message(MessageSegment(type='image', data={'file': 'https://...'}))
```
参数:
format_string: 格式化字符串
返回:
消息格式化器
"""
return MessageTemplate(format_string, cls)
@classmethod
@abc.abstractmethod
def get_segment_class(cls) -> Type[TMS]:
"""获取消息段类型"""
raise NotImplementedError
def __str__(self) -> str:
return "".join(str(seg) for seg in self)
@classmethod
def __get_validators__(cls):
yield cls._validate
@classmethod
def _validate(cls, value):
if isinstance(value, cls):
return value
elif isinstance(value, Message):
raise ValueError(f"Type {type(value)} can not be converted to {cls}")
elif isinstance(value, str):
pass
elif isinstance(value, dict):
value = parse_obj_as(cls.get_segment_class(), value)
elif isinstance(value, Iterable):
value = [parse_obj_as(cls.get_segment_class(), v) for v in value]
else:
raise ValueError(
f"Expected str, dict or iterable for Message, got {type(value)}"
)
return cls(value)
@staticmethod
@abc.abstractmethod
def _construct(msg: str) -> Iterable[TMS]:
"""构造消息数组"""
raise NotImplementedError
def __add__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM:
result = self.copy()
result += other
return result
def __radd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM:
result = self.__class__(other)
return result + self
def __iadd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM:
if isinstance(other, str):
self.extend(self._construct(other))
elif isinstance(other, MessageSegment):
self.append(other)
elif isinstance(other, Iterable):
self.extend(other)
else:
raise ValueError(f"Unsupported type: {type(other)}") # pragma: no cover
return self
@overload
def __getitem__(self: TM, __args: str) -> TM:
"""
参数:
__args: 消息段类型
返回:
所有类型为 `__args` 的消息段
"""
@overload
def __getitem__(self, __args: Tuple[str, int]) -> TMS:
"""
参数:
__args: 消息段类型和索引
返回:
类型为 `__args[0]` 的消息段第 `__args[1]` 个
"""
@overload
def __getitem__(self: TM, __args: Tuple[str, slice]) -> TM:
"""
参数:
__args: 消息段类型和切片
返回:
类型为 `__args[0]` 的消息段切片 `__args[1]`
"""
@overload
def __getitem__(self, __args: int) -> TMS:
"""
参数:
__args: 索引
返回:
第 `__args` 个消息段
"""
@overload
def __getitem__(self: TM, __args: slice) -> TM:
"""
参数:
__args: 切片
返回:
消息切片 `__args`
"""
def __getitem__(
self: TM,
args: Union[
str,
Tuple[str, int],
Tuple[str, slice],
int,
slice,
],
) -> Union[TMS, TM]:
arg1, arg2 = args if isinstance(args, tuple) else (args, None)
if isinstance(arg1, int) and arg2 is None:
return super().__getitem__(arg1)
elif isinstance(arg1, slice) and arg2 is None:
return self.__class__(super().__getitem__(arg1))
elif isinstance(arg1, str) and arg2 is None:
return self.__class__(seg for seg in self if seg.type == arg1)
elif isinstance(arg1, str) and isinstance(arg2, int):
return [seg for seg in self if seg.type == arg1][arg2]
elif isinstance(arg1, str) and isinstance(arg2, slice):
return self.__class__([seg for seg in self if seg.type == arg1][arg2])
else:
raise ValueError("Incorrect arguments to slice") # pragma: no cover
def index(self, value: Union[TMS, str], *args) -> int:
if isinstance(value, str):
first_segment = next((seg for seg in self if seg.type == value), None)
if first_segment is None:
raise ValueError(f"Segment with type {value} is not in message")
return super().index(first_segment, *args)
return super().index(value, *args)
def get(self: TM, type_: str, count: Optional[int] = None) -> TM:
if count is None:
return self[type_]
iterator, filtered = (
seg for seg in self if seg.type == type_
), self.__class__()
for _ in range(count):
seg = next(iterator, None)
if seg is None:
break
filtered.append(seg)
return filtered
def count(self, value: Union[TMS, str]) -> int:
return len(self[value]) if isinstance(value, str) else super().count(value)
def append(self: TM, obj: Union[str, TMS]) -> TM:
"""添加一个消息段到消息数组末尾。
参数:
obj: 要添加的消息段
"""
if isinstance(obj, MessageSegment):
super().append(obj)
elif isinstance(obj, str):
self.extend(self._construct(obj))
else:
raise ValueError(f"Unexpected type: {type(obj)} {obj}") # pragma: no cover
return self
def extend(self: TM, obj: Union[TM, Iterable[TMS]]) -> TM:
"""拼接一个消息数组或多个消息段到消息数组末尾。
参数:
obj: 要添加的消息数组
"""
for segment in obj:
self.append(segment)
return self
def copy(self: TM) -> TM:
return deepcopy(self)
def extract_plain_text(self) -> str:
"""提取消息内纯文本消息"""
return "".join(str(seg) for seg in self if seg.is_text())
__autodoc__ = {
"MessageSegment.__str__": True,
"MessageSegment.__add__": True,
"Message.__getitem__": True,
"Message._construct": True,
}

View File

@ -0,0 +1,181 @@
import inspect
import functools
from string import Formatter
from typing import (
TYPE_CHECKING,
Any,
Set,
Dict,
List,
Type,
Tuple,
Union,
Generic,
Mapping,
TypeVar,
Callable,
Optional,
Sequence,
cast,
overload,
)
if TYPE_CHECKING:
from .message import Message, MessageSegment
TM = TypeVar("TM", bound="Message")
TF = TypeVar("TF", str, "Message")
FormatSpecFunc = Callable[[Any], str]
FormatSpecFunc_T = TypeVar("FormatSpecFunc_T", bound=FormatSpecFunc)
class MessageTemplate(Formatter, Generic[TF]):
"""消息模板格式化实现类。
参数:
template: 模板
factory: 消息构造类型,默认为 `str`
"""
@overload
def __init__(
self: "MessageTemplate[str]", template: str, factory: Type[str] = str
) -> None:
...
@overload
def __init__(
self: "MessageTemplate[TM]", template: Union[str, TM], factory: Type[TM]
) -> None:
...
def __init__(self, template, factory=str) -> None:
self.template: TF = template
self.factory: Type[TF] = factory
self.format_specs: Dict[str, FormatSpecFunc] = {}
def add_format_spec(
self, spec: FormatSpecFunc_T, name: Optional[str] = None
) -> FormatSpecFunc_T:
name = name or spec.__name__
if name in self.format_specs:
raise ValueError(f"Format spec {name} already exists!")
self.format_specs[name] = spec
return spec
def format(self, *args: Any, **kwargs: Any) -> TF:
"""根据模板和参数生成消息对象"""
msg = self.factory()
if isinstance(self.template, str):
msg += self.vformat(self.template, args, kwargs)
elif isinstance(self.template, self.factory):
template = cast("Message[MessageSegment]", self.template)
for seg in template:
msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg
else:
raise TypeError("template must be a string or instance of Message!")
return msg # type:ignore
def vformat(
self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]
) -> TF:
used_args = set()
result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
self.check_unused_args(list(used_args), args, kwargs)
return result
def _vformat(
self,
format_string: str,
args: Sequence[Any],
kwargs: Mapping[str, Any],
used_args: Set[Union[int, str]],
recursion_depth: int,
auto_arg_index: int = 0,
) -> Tuple[TF, int]:
if recursion_depth < 0:
raise ValueError("Max string recursion exceeded")
results: List[Any] = []
for (literal_text, field_name, format_spec, conversion) in self.parse(
format_string
):
# output the literal text
if literal_text:
results.append(literal_text)
# if there's a field, output it
if field_name is not None:
# this is some markup, find the object and do
# the formatting
# handle arg indexing when empty field_names are given.
if field_name == "":
if auto_arg_index is False:
raise ValueError(
"cannot switch from manual field specification to "
"automatic field numbering"
)
field_name = str(auto_arg_index)
auto_arg_index += 1
elif field_name.isdigit():
if auto_arg_index:
raise ValueError(
"cannot switch from manual field specification to "
"automatic field numbering"
)
# disable auto arg incrementing, if it gets
# used later on, then an exception will be raised
auto_arg_index = False
# given the field_name, find the object it references
# and the argument it came from
obj, arg_used = self.get_field(field_name, args, kwargs)
used_args.add(arg_used)
assert format_spec is not None
# do any conversion on the resulting object
obj = self.convert_field(obj, conversion) if conversion else obj
# expand the format spec, if needed
format_control, auto_arg_index = self._vformat(
format_spec,
args,
kwargs,
used_args,
recursion_depth - 1,
auto_arg_index,
)
# format the object and append to the result
formatted_text = self.format_field(obj, str(format_control))
results.append(formatted_text)
return (
self.factory(functools.reduce(self._add, results or [""])),
auto_arg_index,
)
def format_field(self, value: Any, format_spec: str) -> Any:
formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec)
if formatter is None and not issubclass(self.factory, str):
segment_class: Type["MessageSegment"] = self.factory.get_segment_class()
method = getattr(segment_class, format_spec, None)
if inspect.ismethod(method):
formatter = getattr(segment_class, format_spec)
return (
super().format_field(value, format_spec)
if formatter is None
else formatter(value)
)
def _add(self, a: Any, b: Any) -> Any:
try:
return a + b
except TypeError:
return a + str(b)