🎨 improve typing

This commit is contained in:
yanyongyu
2021-06-14 19:52:35 +08:00
parent e9bc98e74d
commit ddd96271b0
4 changed files with 121 additions and 117 deletions

View File

@@ -6,7 +6,9 @@ try:
del pkg_resources
except ImportError:
import pkgutil
__path__: Iterable[str] = pkgutil.extend_path(__path__, __name__)
__path__: Iterable[str] = pkgutil.extend_path(
__path__, # type: ignore
__name__)
del pkgutil
except Exception:
pass

View File

@@ -7,28 +7,25 @@
import abc
import asyncio
from copy import copy
from functools import reduce, partial
from copy import deepcopy
from functools import partial
from typing_extensions import Protocol
from dataclasses import dataclass, field
from typing import (Any, Set, List, Dict, Tuple, Union, TypeVar, Mapping,
Optional, Iterable, Awaitable, TYPE_CHECKING)
from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping,
Generic, Optional, Iterable)
from pydantic import BaseModel
from nonebot.log import logger
from nonebot.utils import DataclassEncoder
from nonebot.drivers import HTTPConnection, HTTPResponse
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
if TYPE_CHECKING:
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
from nonebot.utils import DataclassEncoder
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
class _ApiCall(Protocol):
def __call__(self, **kwargs: Any) -> Awaitable[Any]:
async def __call__(self, **kwargs: Any) -> Any:
...
@@ -37,9 +34,9 @@ class Bot(abc.ABC):
Bot 基类。用于处理上报消息,并提供 API 调用接口。
"""
driver: "Driver"
driver: Driver
"""Driver 对象"""
config: "Config"
config: Config
"""Config 配置对象"""
_calling_api_hook: Set[T_CallingAPIHook] = set()
"""
@@ -56,9 +53,8 @@ class Bot(abc.ABC):
"""
:参数:
* ``connection_type: str``: http 或者 websocket
* ``self_id: str``: 机器人 ID
* ``websocket: Optional[WebSocket]``: Websocket 连接对象
* ``request: HTTPConnection``: request 连接对象
"""
self.self_id: str = self_id
"""机器人 ID"""
@@ -75,7 +71,7 @@ class Bot(abc.ABC):
raise NotImplementedError
@classmethod
def register(cls, driver: "Driver", config: "Config"):
def register(cls, driver: Driver, config: Config):
"""
:说明:
@@ -87,7 +83,7 @@ class Bot(abc.ABC):
@classmethod
@abc.abstractmethod
async def check_permission(
cls, driver: "Driver", request: HTTPConnection
cls, driver: Driver, request: HTTPConnection
) -> Tuple[Optional[str], Optional[HTTPResponse]]:
"""
:说明:
@@ -97,18 +93,12 @@ class Bot(abc.ABC):
:参数:
* ``driver: Driver``: Driver 对象
* ``connection_type: str``: 连接类型
* ``headers: dict``: 请求头
* ``body: Optional[bytes]``: 请求数据WebSocket 连接该部分为 None
* ``request: HTTPConnection``: request 请求详情
:返回:
- ``str``: 连接唯一标识符,``None`` 代表连接不合法
- ``HTTPResponse``: HTTP 上报响应
:异常:
- ``RequestDenied``: 请求非法
- ``Optional[str]``: 连接唯一标识符,``None`` 代表连接不合法
- ``Optional[HTTPResponse]``: HTTP 上报响应
"""
raise NotImplementedError
@@ -210,21 +200,45 @@ class Bot(abc.ABC):
@classmethod
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
"""
:说明:
调用 api 预处理。
:参数:
* ``bot: Bot``: 当前 bot 对象
* ``api: str``: 调用的 api 名称
* ``data: Dict[str, Any]``: api 调用的参数字典
"""
cls._calling_api_hook.add(func)
return func
@classmethod
def on_called_api(cls, func: T_CalledAPIHook) -> T_CalledAPIHook:
"""
:说明:
调用 api 后处理。
:参数:
* ``bot: Bot``: 当前 bot 对象
* ``exception: Optional[Exception]``: 调用 api 时发生的错误
* ``api: str``: 调用的 api 名称
* ``data: Dict[str, Any]``: api 调用的参数字典
* ``result: Any``: api 调用的返回
"""
cls._called_api_hook.add(func)
return func
T_Message = TypeVar("T_Message", bound="Message")
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment")
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment[Message]")
@dataclass
class MessageSegment(abc.ABC, Mapping):
class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
"""消息段基类"""
type: str
"""
@@ -237,6 +251,11 @@ class MessageSegment(abc.ABC, Mapping):
- 说明: 消息段数据
"""
@abc.abstractmethod
@classmethod
def get_message_class(cls) -> Type[T_Message]:
raise NotImplementedError
@abc.abstractmethod
def __str__(self) -> str:
"""该消息段所代表的 str在命令匹配部分使用"""
@@ -248,46 +267,27 @@ class MessageSegment(abc.ABC, Mapping):
def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool:
return not self == other
@abc.abstractmethod
def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment,
T_Message]) -> T_Message:
"""你需要在这里实现不同消息段的合并:
比如:
if isinstance(other, str):
...
elif isinstance(other, MessageSegment):
...
注意:需要返回一个新生成的对象
"""
raise NotImplementedError
def __add__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> T_Message:
return self.get_message_class()(self) + other
@abc.abstractmethod
def __radd__(
self: T_MessageSegment, other: Union[str, dict, list, T_MessageSegment,
T_Message]) -> "T_Message":
"""你需要在这里实现不同消息段的合并:
比如:
if isinstance(other, str):
...
elif isinstance(other, MessageSegment):
...
注意:需要返回一个新生成的对象
"""
raise NotImplementedError
def __radd__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> T_Message:
return self.get_message_class()(other) + self
def __getitem__(self, key):
return getattr(self, key)
def __getitem__(self, key: str):
return self.data[key]
def __setitem__(self, key, value):
return setattr(self, key, value)
def __setitem__(self, key: str, value: Any):
self.data[key] = value
def __iter__(self):
yield from self.data.__iter__()
def __contains__(self, key: object) -> bool:
def __contains__(self, key: Any) -> bool:
return key in self.data
def get(self, key: str, default=None):
def get(self, key: str, default: Any = None):
return getattr(self, key, default)
def keys(self):
@@ -300,7 +300,7 @@ class MessageSegment(abc.ABC, Mapping):
return self.data.items()
def copy(self: T_MessageSegment) -> T_MessageSegment:
return copy(self)
return deepcopy(self)
@abc.abstractmethod
def is_text(self) -> bool:
@@ -310,7 +310,7 @@ class MessageSegment(abc.ABC, Mapping):
class Message(List[T_MessageSegment], abc.ABC):
"""消息数组"""
def __init__(self,
def __init__(self: T_Message,
message: Union[str, None, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message, Any] = None,
*args,
@@ -330,8 +330,13 @@ class Message(List[T_MessageSegment], abc.ABC):
else:
self.extend(self._construct(message))
@abc.abstractmethod
@classmethod
def get_segment_class(cls) -> Type[T_MessageSegment]:
raise NotImplementedError
def __str__(self):
return ''.join((str(seg) for seg in self))
return "".join(str(seg) for seg in self)
@classmethod
def __get_validators__(cls):
@@ -348,30 +353,31 @@ class Message(List[T_MessageSegment], abc.ABC):
) -> Iterable[T_MessageSegment]:
raise NotImplementedError
def __add__(self: T_Message, other: Union[str, T_MessageSegment,
T_Message]) -> T_Message:
result = self.__class__(self)
if isinstance(other, str):
result.extend(self._construct(other))
elif isinstance(other, MessageSegment):
result.append(other)
elif isinstance(other, Message):
result.extend(other)
def __add__(
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message]
) -> T_Message:
result = self.copy()
result += other
return result
def __radd__(self: T_Message, other: Union[str, T_MessageSegment,
T_Message]) -> T_Message:
def __radd__(
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message]
) -> T_Message:
result = self.__class__(other)
return result.__add__(self)
return result + self
def __iadd__(self: T_Message, other: Union[str, T_MessageSegment,
T_Message]) -> T_Message:
if isinstance(other, str):
self.extend(self._construct(other))
elif isinstance(other, MessageSegment):
def __iadd__(
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message]
) -> T_Message:
if isinstance(other, MessageSegment):
self.append(other)
elif isinstance(other, Message):
self.extend(other)
else:
self.extend(self._construct(other))
return self
def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message:
@@ -385,7 +391,7 @@ class Message(List[T_MessageSegment], abc.ABC):
* ``obj: Union[str, MessageSegment]``: 要添加的消息段
"""
if isinstance(obj, MessageSegment):
super().append(obj)
super(Message, self).append(obj)
elif isinstance(obj, str):
self.extend(self._construct(obj))
else:
@@ -407,33 +413,17 @@ class Message(List[T_MessageSegment], abc.ABC):
self.append(segment)
return self
def reduce(self: T_Message) -> None:
"""
:说明:
def copy(self: T_Message) -> T_Message:
return deepcopy(self)
缩减消息数组,即按 MessageSegment 的实现拼接相邻消息段
"""
index = 0
while index < len(self):
if index > 0 and self[index -
1].is_text() and self[index].is_text():
self[index - 1] += self[index]
del self[index]
else:
index += 1
def extract_plain_text(self: T_Message) -> str:
def extract_plain_text(self) -> str:
"""
:说明:
提取消息内纯文本消息
"""
def _concat(x: str, y: T_MessageSegment) -> str:
return f"{x} {y}" if y.is_text() else x
plain_text = reduce(_concat, self, "")
return plain_text[1:] if plain_text else plain_text
return "".join(str(seg) for seg in self if seg.is_text())
class Event(abc.ABC, BaseModel):

View File

@@ -50,8 +50,8 @@ class Filter:
def __call__(self, record):
module = sys.modules.get(record["name"])
if module:
plugin_name = getattr(module, "__plugin_name__", record["name"])
record["name"] = plugin_name
module_name = getattr(module, "__module_name__", record["name"])
record["name"] = module_name
record["name"] = record["name"].split(".")[0]
levelno = logger.level(self.level).no if isinstance(self.level,
str) else self.level

View File

@@ -11,14 +11,14 @@ from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessa
from .utils import log, escape, unescape, _b2s
class MessageSegment(BaseMessageSegment):
class MessageSegment(BaseMessageSegment["Message"]):
"""
CQHTTP 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。
"""
@overrides(BaseMessageSegment)
def __init__(self, type: str, data: Dict[str, Any]) -> None:
super().__init__(type=type, data=data)
@classmethod
def get_message_class(cls):
return Message
@overrides(BaseMessageSegment)
def __str__(self) -> str:
@@ -37,7 +37,8 @@ class MessageSegment(BaseMessageSegment):
@overrides(BaseMessageSegment)
def __add__(self, other) -> "Message":
return Message(self) + other
return Message(self) + (MessageSegment.text(other) if isinstance(
other, str) else other)
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message":
@@ -234,10 +235,25 @@ class Message(BaseMessage[MessageSegment]):
CQHTTP 协议 Message 适配。
"""
def __radd__(self, other: Union[str, MessageSegment,
"Message"]) -> "Message":
result = MessageSegment.text(other) if isinstance(other, str) else other
return super(Message, self).__radd__(result)
@classmethod
def get_segment_class(cls):
return MessageSegment
@overrides(BaseMessage)
def __add__(
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
"Message"]
) -> "Message":
return super(Message, self).__add__(
MessageSegment.text(other) if isinstance(other, str) else other)
@overrides(BaseMessage)
def __radd__(
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
"Message"]
) -> "Message":
return super(Message, self).__radd__(
MessageSegment.text(other) if isinstance(other, str) else other)
@staticmethod
@overrides(BaseMessage)
@@ -280,10 +296,6 @@ class Message(BaseMessage[MessageSegment]):
}
yield MessageSegment(type_, data)
@overrides(BaseMessage)
def extract_plain_text(self) -> str:
def _concat(x: str, y: MessageSegment) -> str:
return f"{x} {y.data['text']}" if y.is_text() else x
plain_text = reduce(_concat, self, "")
return plain_text[1:] if plain_text else plain_text
return "".join(seg.data["text"] for seg in self if seg.is_text())