🎨 improve typing

This commit is contained in:
yanyongyu
2021-06-14 19:52:35 +08:00
parent e9bc98e74d
commit ddd96271b0
4 changed files with 121 additions and 117 deletions

View File

@ -11,14 +11,14 @@ from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessa
from .utils import log, escape, unescape, _b2s
class MessageSegment(BaseMessageSegment):
class MessageSegment(BaseMessageSegment["Message"]):
"""
CQHTTP 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。
"""
@overrides(BaseMessageSegment)
def __init__(self, type: str, data: Dict[str, Any]) -> None:
super().__init__(type=type, data=data)
@classmethod
def get_message_class(cls):
return Message
@overrides(BaseMessageSegment)
def __str__(self) -> str:
@ -37,7 +37,8 @@ class MessageSegment(BaseMessageSegment):
@overrides(BaseMessageSegment)
def __add__(self, other) -> "Message":
return Message(self) + other
return Message(self) + (MessageSegment.text(other) if isinstance(
other, str) else other)
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message":
@ -234,10 +235,25 @@ class Message(BaseMessage[MessageSegment]):
CQHTTP 协议 Message 适配。
"""
def __radd__(self, other: Union[str, MessageSegment,
"Message"]) -> "Message":
result = MessageSegment.text(other) if isinstance(other, str) else other
return super(Message, self).__radd__(result)
@classmethod
def get_segment_class(cls):
return MessageSegment
@overrides(BaseMessage)
def __add__(
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
"Message"]
) -> "Message":
return super(Message, self).__add__(
MessageSegment.text(other) if isinstance(other, str) else other)
@overrides(BaseMessage)
def __radd__(
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
"Message"]
) -> "Message":
return super(Message, self).__radd__(
MessageSegment.text(other) if isinstance(other, str) else other)
@staticmethod
@overrides(BaseMessage)
@ -280,10 +296,6 @@ class Message(BaseMessage[MessageSegment]):
}
yield MessageSegment(type_, data)
@overrides(BaseMessage)
def extract_plain_text(self) -> str:
def _concat(x: str, y: MessageSegment) -> str:
return f"{x} {y.data['text']}" if y.is_text() else x
plain_text = reduce(_concat, self, "")
return plain_text[1:] if plain_text else plain_text
return "".join(seg.data["text"] for seg in self if seg.is_text())