🎨 fix message typing error

This commit is contained in:
yanyongyu 2021-06-17 01:07:19 +08:00
parent 6749afe75e
commit b2f21ab974
2 changed files with 30 additions and 46 deletions

View File

@ -233,12 +233,14 @@ class Bot(abc.ABC):
return func return func
T_Message = TypeVar("T_Message", bound="Message") T = TypeVar("T")
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment[Message]") TMS = TypeVar("TMS")
TM = TypeVar("TM", bound="Message")
# TM = TypeVar("TM_co", bound="Message")
@dataclass @dataclass
class MessageSegment(Mapping, abc.ABC, Generic[T_Message]): class MessageSegment(Mapping, abc.ABC, Generic[TM]):
"""消息段基类""" """消息段基类"""
type: str type: str
""" """
@ -253,7 +255,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def get_message_class(cls) -> Type[T_Message]: def get_message_class(cls) -> Type[TM]:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
@ -264,15 +266,13 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
def __len__(self) -> int: def __len__(self) -> int:
return len(str(self)) return len(str(self))
def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool: def __ne__(self: T, other: T) -> bool:
return not self == other return not self == other
def __add__(self, other: Union[str, Mapping, def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
Iterable[Mapping]]) -> T_Message:
return self.get_message_class()(self) + other return self.get_message_class()(self) + other
def __radd__(self, other: Union[str, Mapping, def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
Iterable[Mapping]]) -> T_Message:
return self.get_message_class()(other) + self return self.get_message_class()(other) + self
def __getitem__(self, key: str): def __getitem__(self, key: str):
@ -299,7 +299,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
def items(self): def items(self):
return self.data.items() return self.data.items()
def copy(self: T_MessageSegment) -> T_MessageSegment: def copy(self: T) -> T:
return deepcopy(self) return deepcopy(self)
@abc.abstractmethod @abc.abstractmethod
@ -307,12 +307,12 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
raise NotImplementedError raise NotImplementedError
class Message(List[T_MessageSegment], abc.ABC): class Message(List[TMS], abc.ABC):
"""消息数组""" """消息数组"""
def __init__(self: T_Message, def __init__(self: TM,
message: Union[str, None, Mapping, Iterable[Mapping], message: Union[str, None, Mapping, Iterable[Mapping], TMS, TM,
T_MessageSegment, T_Message, Any] = None, Any] = None,
*args, *args,
**kwargs): **kwargs):
""" """
@ -332,7 +332,7 @@ class Message(List[T_MessageSegment], abc.ABC):
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def get_segment_class(cls) -> Type[T_MessageSegment]: def get_segment_class(cls) -> Type[TMS]:
raise NotImplementedError raise NotImplementedError
def __str__(self): def __str__(self):
@ -349,29 +349,19 @@ class Message(List[T_MessageSegment], abc.ABC):
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def _construct( def _construct(
msg: Union[str, Mapping, Iterable[Mapping], Any] msg: Union[str, Mapping, Iterable[Mapping], Any]) -> Iterable[TMS]:
) -> Iterable[T_MessageSegment]:
raise NotImplementedError raise NotImplementedError
def __add__( def __add__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message]
) -> T_Message:
result = self.copy() result = self.copy()
result += other result += other
return result return result
def __radd__( def __radd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message]
) -> T_Message:
result = self.__class__(other) result = self.__class__(other)
return result + self return result + self
def __iadd__( def __iadd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message]
) -> T_Message:
if isinstance(other, MessageSegment): if isinstance(other, MessageSegment):
self.append(other) self.append(other)
elif isinstance(other, Message): elif isinstance(other, Message):
@ -380,7 +370,7 @@ class Message(List[T_MessageSegment], abc.ABC):
self.extend(self._construct(other)) self.extend(self._construct(other))
return self return self
def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message: def append(self: TM, obj: Union[str, TMS]) -> TM:
""" """
:说明: :说明:
@ -398,8 +388,7 @@ class Message(List[T_MessageSegment], abc.ABC):
raise ValueError(f"Unexpected type: {type(obj)} {obj}") raise ValueError(f"Unexpected type: {type(obj)} {obj}")
return self return self
def extend(self: T_Message, def extend(self: TM, obj: Union[TM, Iterable[TMS]]) -> TM:
obj: Union[T_Message, Iterable[T_MessageSegment]]) -> T_Message:
""" """
:说明: :说明:
@ -413,10 +402,10 @@ class Message(List[T_MessageSegment], abc.ABC):
self.append(segment) self.append(segment)
return self return self
def copy(self: T_Message) -> T_Message: def copy(self: TM) -> TM:
return deepcopy(self) return deepcopy(self)
def extract_plain_text(self) -> str: def extract_plain_text(self: "Message[MessageSegment]") -> str:
""" """
:说明: :说明:

View File

@ -2,8 +2,7 @@ import re
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from base64 import b64encode from base64 import b64encode
from functools import reduce from typing import Type, Union, Tuple, Mapping, Iterable, Optional
from typing import Any, List, Dict, Union, Tuple, Mapping, Iterable, Optional
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
@ -17,7 +16,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
""" """
@classmethod @classmethod
def get_message_class(cls): def get_message_class(cls) -> Type["Message"]:
return Message return Message
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
@ -236,22 +235,18 @@ class Message(BaseMessage[MessageSegment]):
""" """
@classmethod @classmethod
def get_segment_class(cls): def get_segment_class(cls) -> Type[MessageSegment]:
return MessageSegment return MessageSegment
@overrides(BaseMessage) @overrides(BaseMessage)
def __add__( def __add__(self, other: Union[str, Mapping,
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment, Iterable[Mapping]]) -> "Message":
"Message"]
) -> "Message":
return super(Message, self).__add__( return super(Message, self).__add__(
MessageSegment.text(other) if isinstance(other, str) else other) MessageSegment.text(other) if isinstance(other, str) else other)
@overrides(BaseMessage) @overrides(BaseMessage)
def __radd__( def __radd__(self, other: Union[str, Mapping,
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment, Iterable[Mapping]]) -> "Message":
"Message"]
) -> "Message":
return super(Message, self).__radd__( return super(Message, self).__radd__(
MessageSegment.text(other) if isinstance(other, str) else other) MessageSegment.text(other) if isinstance(other, str) else other)