🎨 fix message typing error

This commit is contained in:
yanyongyu
2021-06-17 01:07:19 +08:00
parent 6749afe75e
commit b2f21ab974
2 changed files with 30 additions and 46 deletions

View File

@ -233,12 +233,14 @@ class Bot(abc.ABC):
return func
T_Message = TypeVar("T_Message", bound="Message")
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment[Message]")
T = TypeVar("T")
TMS = TypeVar("TMS")
TM = TypeVar("TM", bound="Message")
# TM = TypeVar("TM_co", bound="Message")
@dataclass
class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
class MessageSegment(Mapping, abc.ABC, Generic[TM]):
"""消息段基类"""
type: str
"""
@ -253,7 +255,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
@classmethod
@abc.abstractmethod
def get_message_class(cls) -> Type[T_Message]:
def get_message_class(cls) -> Type[TM]:
raise NotImplementedError
@abc.abstractmethod
@ -264,15 +266,13 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
def __len__(self) -> int:
return len(str(self))
def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool:
def __ne__(self: T, other: T) -> bool:
return not self == other
def __add__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> T_Message:
def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
return self.get_message_class()(self) + other
def __radd__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> T_Message:
def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
return self.get_message_class()(other) + self
def __getitem__(self, key: str):
@ -299,7 +299,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
def items(self):
return self.data.items()
def copy(self: T_MessageSegment) -> T_MessageSegment:
def copy(self: T) -> T:
return deepcopy(self)
@abc.abstractmethod
@ -307,12 +307,12 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
raise NotImplementedError
class Message(List[T_MessageSegment], abc.ABC):
class Message(List[TMS], abc.ABC):
"""消息数组"""
def __init__(self: T_Message,
message: Union[str, None, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message, Any] = None,
def __init__(self: TM,
message: Union[str, None, Mapping, Iterable[Mapping], TMS, TM,
Any] = None,
*args,
**kwargs):
"""
@ -332,7 +332,7 @@ class Message(List[T_MessageSegment], abc.ABC):
@classmethod
@abc.abstractmethod
def get_segment_class(cls) -> Type[T_MessageSegment]:
def get_segment_class(cls) -> Type[TMS]:
raise NotImplementedError
def __str__(self):
@ -349,29 +349,19 @@ class Message(List[T_MessageSegment], abc.ABC):
@staticmethod
@abc.abstractmethod
def _construct(
msg: Union[str, Mapping, Iterable[Mapping], Any]
) -> Iterable[T_MessageSegment]:
msg: Union[str, Mapping, Iterable[Mapping], Any]) -> Iterable[TMS]:
raise NotImplementedError
def __add__(
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message]
) -> T_Message:
def __add__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
result = self.copy()
result += other
return result
def __radd__(
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message]
) -> T_Message:
def __radd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
result = self.__class__(other)
return result + self
def __iadd__(
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message]
) -> T_Message:
def __iadd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
if isinstance(other, MessageSegment):
self.append(other)
elif isinstance(other, Message):
@ -380,7 +370,7 @@ class Message(List[T_MessageSegment], abc.ABC):
self.extend(self._construct(other))
return self
def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message:
def append(self: TM, obj: Union[str, TMS]) -> TM:
"""
:说明:
@ -398,8 +388,7 @@ class Message(List[T_MessageSegment], abc.ABC):
raise ValueError(f"Unexpected type: {type(obj)} {obj}")
return self
def extend(self: T_Message,
obj: Union[T_Message, Iterable[T_MessageSegment]]) -> T_Message:
def extend(self: TM, obj: Union[TM, Iterable[TMS]]) -> TM:
"""
:说明:
@ -413,10 +402,10 @@ class Message(List[T_MessageSegment], abc.ABC):
self.append(segment)
return self
def copy(self: T_Message) -> T_Message:
def copy(self: TM) -> TM:
return deepcopy(self)
def extract_plain_text(self) -> str:
def extract_plain_text(self: "Message[MessageSegment]") -> str:
"""
:说明: