mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-30 22:46:40 +00:00 
			
		
		
		
	⚡ improve radd support for messagesegment
This commit is contained in:
		| @@ -9,7 +9,7 @@ import abc | |||||||
| from typing_extensions import Literal | from typing_extensions import Literal | ||||||
| from functools import reduce, partial | from functools import reduce, partial | ||||||
| from dataclasses import dataclass, field | from dataclasses import dataclass, field | ||||||
| from typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable, TYPE_CHECKING | from typing import Any, Dict, Union, TypeVar, Optional, Callable, Iterable, Awaitable, TYPE_CHECKING | ||||||
|  |  | ||||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||||
|  |  | ||||||
| @@ -267,6 +267,10 @@ class Event(abc.ABC, BaseModel): | |||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |  | ||||||
|  | T_Message = TypeVar("T_Message", bound="Message") | ||||||
|  | T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment") | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
| class MessageSegment(abc.ABC): | class MessageSegment(abc.ABC): | ||||||
|     """消息段基类""" |     """消息段基类""" | ||||||
| @@ -282,19 +286,34 @@ class MessageSegment(abc.ABC): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def __str__(self) -> str: |     def __str__(self: T_MessageSegment) -> str: | ||||||
|         """该消息段所代表的 str,在命令匹配部分使用""" |         """该消息段所代表的 str,在命令匹配部分使用""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def __add__(self, other) -> "Message": |     def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment, | ||||||
|  |                                                      T_Message]) -> "T_Message": | ||||||
|         """你需要在这里实现不同消息段的合并: |         """你需要在这里实现不同消息段的合并: | ||||||
|         比如: |         比如: | ||||||
|             if isinstance(other, str): |             if isinstance(other, str): | ||||||
|                 ... |                 ... | ||||||
|             elif isinstance(other, MessageSegment): |             elif isinstance(other, MessageSegment): | ||||||
|                 ... |                 ... | ||||||
|         注意:不能返回 self,需要返回一个新生成的对象 |         注意:需要返回一个新生成的对象 | ||||||
|  |         """ | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     def __radd__( | ||||||
|  |         self: T_MessageSegment, other: Union[str, dict, list, T_MessageSegment, | ||||||
|  |                                              T_Message]) -> "T_Message": | ||||||
|  |         """你需要在这里实现不同消息段的合并: | ||||||
|  |         比如: | ||||||
|  |             if isinstance(other, str): | ||||||
|  |                 ... | ||||||
|  |             elif isinstance(other, MessageSegment): | ||||||
|  |                 ... | ||||||
|  |         注意:需要返回一个新生成的对象 | ||||||
|         """ |         """ | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
| @@ -316,17 +335,17 @@ class Message(list, abc.ABC): | |||||||
|     """消息数组""" |     """消息数组""" | ||||||
|  |  | ||||||
|     def __init__(self, |     def __init__(self, | ||||||
|                  message: Union[str, dict, list, BaseModel, MessageSegment, |                  message: Union[str, dict, list, T_MessageSegment, | ||||||
|                                 "Message"] = None, |                                 T_Message] = None, | ||||||
|                  *args, |                  *args, | ||||||
|                  **kwargs): |                  **kwargs): | ||||||
|         """ |         """ | ||||||
|         :参数: |         :参数: | ||||||
|  |  | ||||||
|           * ``message: Union[str, dict, list, BaseModel, MessageSegment, Message]``: 消息内容 |           * ``message: Union[str, dict, list, MessageSegment, Message]``: 消息内容 | ||||||
|         """ |         """ | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|         if isinstance(message, (str, dict, list, BaseModel)): |         if isinstance(message, (str, dict, list)): | ||||||
|             self.extend(self._construct(message)) |             self.extend(self._construct(message)) | ||||||
|         elif isinstance(message, Message): |         elif isinstance(message, Message): | ||||||
|             self.extend(message) |             self.extend(message) | ||||||
| @@ -347,11 +366,12 @@ class Message(list, abc.ABC): | |||||||
|     @staticmethod |     @staticmethod | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def _construct( |     def _construct( | ||||||
|             msg: Union[str, dict, list, BaseModel]) -> Iterable[MessageSegment]: |             msg: Union[str, dict, list, | ||||||
|  |                        BaseModel]) -> Iterable[T_MessageSegment]: | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def __add__(self, other: Union[str, MessageSegment, |     def __add__(self: T_Message, other: Union[str, T_MessageSegment, | ||||||
|                                    "Message"]) -> "Message": |                                               T_Message]) -> T_Message: | ||||||
|         result = self.__class__(self) |         result = self.__class__(self) | ||||||
|         if isinstance(other, str): |         if isinstance(other, str): | ||||||
|             result.extend(self._construct(other)) |             result.extend(self._construct(other)) | ||||||
| @@ -361,11 +381,12 @@ class Message(list, abc.ABC): | |||||||
|             result.extend(other) |             result.extend(other) | ||||||
|         return result |         return result | ||||||
|  |  | ||||||
|     def __radd__(self, other: Union[str, MessageSegment, "Message"]): |     def __radd__(self: T_Message, other: Union[str, T_MessageSegment, | ||||||
|  |                                                T_Message]): | ||||||
|         result = self.__class__(other) |         result = self.__class__(other) | ||||||
|         return result.__add__(self) |         return result.__add__(self) | ||||||
|  |  | ||||||
|     def append(self, obj: Union[str, MessageSegment]) -> "Message": |     def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message: | ||||||
|         """ |         """ | ||||||
|         :说明: |         :说明: | ||||||
|  |  | ||||||
| @@ -383,8 +404,8 @@ class Message(list, abc.ABC): | |||||||
|             raise ValueError(f"Unexpected type: {type(obj)} {obj}") |             raise ValueError(f"Unexpected type: {type(obj)} {obj}") | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     def extend(self, obj: Union["Message", |     def extend(self: T_Message, | ||||||
|                                 Iterable[MessageSegment]]) -> "Message": |                obj: Union[T_Message, Iterable[T_MessageSegment]]) -> T_Message: | ||||||
|         """ |         """ | ||||||
|         :说明: |         :说明: | ||||||
|  |  | ||||||
| @@ -398,7 +419,7 @@ class Message(list, abc.ABC): | |||||||
|             self.append(segment) |             self.append(segment) | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     def reduce(self) -> None: |     def reduce(self: T_Message) -> None: | ||||||
|         """ |         """ | ||||||
|         :说明: |         :说明: | ||||||
|  |  | ||||||
| @@ -413,14 +434,14 @@ class Message(list, abc.ABC): | |||||||
|             else: |             else: | ||||||
|                 index += 1 |                 index += 1 | ||||||
|  |  | ||||||
|     def extract_plain_text(self) -> str: |     def extract_plain_text(self: T_Message) -> str: | ||||||
|         """ |         """ | ||||||
|         :说明: |         :说明: | ||||||
|  |  | ||||||
|           提取消息内纯文本消息 |           提取消息内纯文本消息 | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         def _concat(x: str, y: MessageSegment) -> str: |         def _concat(x: str, y: T_MessageSegment) -> str: | ||||||
|             return f"{x} {y}" if y.is_text() else x |             return f"{x} {y}" if y.is_text() else x | ||||||
|  |  | ||||||
|         plain_text = reduce(_concat, self, "") |         plain_text = reduce(_concat, self, "") | ||||||
|   | |||||||
| @@ -35,6 +35,10 @@ class MessageSegment(BaseMessageSegment): | |||||||
|     def __add__(self, other) -> "Message": |     def __add__(self, other) -> "Message": | ||||||
|         return Message(self) + other |         return Message(self) + other | ||||||
|  |  | ||||||
|  |     @overrides(BaseMessageSegment) | ||||||
|  |     def __radd__(self, other) -> "Message": | ||||||
|  |         return Message(other) + self | ||||||
|  |  | ||||||
|     @overrides(BaseMessageSegment) |     @overrides(BaseMessageSegment) | ||||||
|     def is_text(self) -> bool: |     def is_text(self) -> bool: | ||||||
|         return self.type == "text" |         return self.type == "text" | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user