mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-11-04 08:56:42 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			250 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			250 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import json
 | 
						|
import itertools
 | 
						|
 | 
						|
from dataclasses import dataclass
 | 
						|
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Mapping, Iterable
 | 
						|
 | 
						|
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
 | 
						|
from nonebot.typing import overrides
 | 
						|
 | 
						|
 | 
						|
class MessageSegment(BaseMessageSegment["Message"]):
 | 
						|
    """
 | 
						|
    飞书 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。
 | 
						|
    """
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    @overrides(BaseMessageSegment)
 | 
						|
    def get_message_class(cls) -> Type["Message"]:
 | 
						|
        return Message
 | 
						|
 | 
						|
    @property
 | 
						|
    def segment_text(self) -> dict:
 | 
						|
        return {
 | 
						|
            "image": "[图片]",
 | 
						|
            "file": "[文件]",
 | 
						|
            "audio": "[音频]",
 | 
						|
            "media": "[视频]",
 | 
						|
            "sticker": "[表情包]",
 | 
						|
            "interactive": "[消息卡片]",
 | 
						|
            "hongbao": "[红包]",
 | 
						|
            "share_calendar_event": "[日程卡片]",
 | 
						|
            "share_chat": "[群名片]",
 | 
						|
            "share_user": "[个人名片]",
 | 
						|
            "system": "[系统消息]",
 | 
						|
            "location": "[位置]",
 | 
						|
            "video_chat": "[视频通话]"
 | 
						|
        }
 | 
						|
 | 
						|
    def __str__(self) -> str:
 | 
						|
        if self.type in ["text", "hongbao", "a"]:
 | 
						|
            return str(self.data["text"])
 | 
						|
        elif self.type == "at":
 | 
						|
            return str(f"@{self.data['user_name']}")
 | 
						|
        else:
 | 
						|
            return self.segment_text.get(self.type, "")
 | 
						|
 | 
						|
    @overrides(BaseMessageSegment)
 | 
						|
    def __add__(self, other) -> "Message":
 | 
						|
        return Message(self) + (MessageSegment.text(other) if isinstance(
 | 
						|
            other, str) else other)
 | 
						|
 | 
						|
    @overrides(BaseMessageSegment)
 | 
						|
    def __radd__(self, other) -> "Message":
 | 
						|
        return (MessageSegment.text(other)
 | 
						|
                if isinstance(other, str) else Message(other)) + self
 | 
						|
 | 
						|
    @overrides(BaseMessageSegment)
 | 
						|
    def is_text(self) -> bool:
 | 
						|
        return self.type == "text"
 | 
						|
 | 
						|
    #接收消息
 | 
						|
    @staticmethod
 | 
						|
    def at(user_id: str) -> "MessageSegment":
 | 
						|
        return MessageSegment("at", {"user_id": user_id})
 | 
						|
 | 
						|
    #发送消息
 | 
						|
    @staticmethod
 | 
						|
    def text(text: str) -> "MessageSegment":
 | 
						|
        return MessageSegment("text", {"text": text})
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def post(title: str, content: list) -> "MessageSegment":
 | 
						|
        return MessageSegment("post", {"title": title, "content": content})
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def image(image_key: str) -> "MessageSegment":
 | 
						|
        return MessageSegment("image", {"image_key": image_key})
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def interactive(title: str, elements: list) -> "MessageSegment":
 | 
						|
        return MessageSegment("interactive", {
 | 
						|
            "title": title,
 | 
						|
            "elements": elements
 | 
						|
        })
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def share_chat(chat_id: str) -> "MessageSegment":
 | 
						|
        return MessageSegment("share_chat", {"chat_id": chat_id})
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def share_user(user_id: str) -> "MessageSegment":
 | 
						|
        return MessageSegment("share_user", {"user_id": user_id})
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def audio(file_key: str, duration: int) -> "MessageSegment":
 | 
						|
        return MessageSegment("audio", {
 | 
						|
            "file_key": file_key,
 | 
						|
            "duration": duration
 | 
						|
        })
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def media(file_key: str, image_key: str, file_name: str,
 | 
						|
              duration: int) -> "MessageSegment":
 | 
						|
        return MessageSegment(
 | 
						|
            "media", {
 | 
						|
                "file_key": file_key,
 | 
						|
                "image_key": image_key,
 | 
						|
                "file_name": file_name,
 | 
						|
                "duration": duration
 | 
						|
            })
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def file(file_key: str, file_name: str) -> "MessageSegment":
 | 
						|
        return MessageSegment("file", {
 | 
						|
            "file_key": file_key,
 | 
						|
            "file_name": file_name
 | 
						|
        })
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def sticker(file_key) -> "MessageSegment":
 | 
						|
        return MessageSegment("sticker", {"file_key": file_key})
 | 
						|
 | 
						|
 | 
						|
class Message(BaseMessage[MessageSegment]):
 | 
						|
    """
 | 
						|
    飞书 协议 Message 适配。
 | 
						|
    """
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    @overrides(BaseMessage)
 | 
						|
    def get_segment_class(cls) -> Type[MessageSegment]:
 | 
						|
        return MessageSegment
 | 
						|
 | 
						|
    @overrides(BaseMessage)
 | 
						|
    def __add__(self, other: Union[str, Mapping,
 | 
						|
                                   Iterable[Mapping]]) -> "Message":
 | 
						|
        return super(Message, self).__add__(
 | 
						|
            MessageSegment.text(other) if isinstance(other, str) else other)
 | 
						|
 | 
						|
    @overrides(BaseMessage)
 | 
						|
    def __radd__(self, other: Union[str, Mapping,
 | 
						|
                                    Iterable[Mapping]]) -> "Message":
 | 
						|
        return super(Message, self).__radd__(
 | 
						|
            MessageSegment.text(other) if isinstance(other, str) else other)
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    @overrides(BaseMessage)
 | 
						|
    def _construct(
 | 
						|
        msg: Union[str, Mapping,
 | 
						|
                   Iterable[Mapping]]) -> Iterable[MessageSegment]:
 | 
						|
        if isinstance(msg, Mapping):
 | 
						|
            yield MessageSegment(msg["type"], msg.get("data") or {})
 | 
						|
            return
 | 
						|
        elif isinstance(msg, str):
 | 
						|
            yield MessageSegment.text(msg)
 | 
						|
        elif isinstance(msg, Iterable):
 | 
						|
            for seg in msg:
 | 
						|
                if isinstance(seg, MessageSegment):
 | 
						|
                    yield seg
 | 
						|
                else:
 | 
						|
                    yield MessageSegment(seg["type"], seg.get("data") or {})
 | 
						|
 | 
						|
    def _merge(self) -> "Message":
 | 
						|
        i: int
 | 
						|
        seg: MessageSegment
 | 
						|
        msg: List[MessageSegment] = []
 | 
						|
        for i, seg in enumerate(self):
 | 
						|
            if seg.type == "text" and i != 0 and msg[-1].type == "text":
 | 
						|
                msg[-1] = MessageSegment(
 | 
						|
                    "text", {"text": msg[-1].data["text"] + seg.data["text"]})
 | 
						|
            else:
 | 
						|
                msg.append(seg)
 | 
						|
        return Message(msg)
 | 
						|
 | 
						|
    @overrides(BaseMessage)
 | 
						|
    def extract_plain_text(self) -> str:
 | 
						|
        return "".join(seg.data["text"] for seg in self if seg.is_text())
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class MessageSerializer:
 | 
						|
    """
 | 
						|
    飞书 协议 Message 序列化器。
 | 
						|
    """
 | 
						|
    message: Message
 | 
						|
 | 
						|
    def serialize(self) -> Tuple[str, str]:
 | 
						|
        segments = list(self.message)
 | 
						|
        last_segment_type: str = ""
 | 
						|
        if len(segments) > 1:
 | 
						|
            msg = {"title": "", "content": [[]]}
 | 
						|
            for segment in segments:
 | 
						|
                if segment == "image":
 | 
						|
                    if last_segment_type != "image":
 | 
						|
                        msg["content"].append([])
 | 
						|
                else:
 | 
						|
                    if last_segment_type == "image":
 | 
						|
                        msg["content"].append([])
 | 
						|
                msg["content"][-1].append({
 | 
						|
                    "tag": segment.type if segment.type != "image" else "img",
 | 
						|
                    **segment.data
 | 
						|
                })
 | 
						|
                last_segment_type = segment.type
 | 
						|
            return "post", json.dumps({"zh_cn": {**msg}})
 | 
						|
 | 
						|
        else:
 | 
						|
            return self.message[0].type, json.dumps(self.message[0].data)
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class MessageDeserializer:
 | 
						|
    """
 | 
						|
    飞书 协议 Message 反序列化器。
 | 
						|
    """
 | 
						|
    type: str
 | 
						|
    data: Dict[str, Any]
 | 
						|
    mentions: Optional[List[dict]]
 | 
						|
 | 
						|
    def deserialize(self) -> Message:
 | 
						|
        dict_mention = {}
 | 
						|
        if self.mentions:
 | 
						|
            for mention in self.mentions:
 | 
						|
                dict_mention[mention["key"]] = mention
 | 
						|
 | 
						|
        if self.type == "post":
 | 
						|
            msg = Message()
 | 
						|
            if self.data["title"] != "":
 | 
						|
                msg += MessageSegment("text", {'text': self.data["title"]})
 | 
						|
 | 
						|
            for seg in itertools.chain(*self.data["content"]):
 | 
						|
                tag = seg.pop("tag")
 | 
						|
                if tag == "at":
 | 
						|
                    seg["user_name"] = dict_mention[seg["user_id"]]["name"]
 | 
						|
                    seg["user_id"] = dict_mention[
 | 
						|
                        seg["user_id"]]["id"]["open_id"]
 | 
						|
 | 
						|
                msg += MessageSegment(tag if tag != "img" else "image", seg)
 | 
						|
 | 
						|
            return msg._merge()
 | 
						|
        elif self.type == "text":
 | 
						|
            for key, mention in dict_mention.items():
 | 
						|
                self.data["text"] = self.data["text"].replace(
 | 
						|
                    key, f"@{mention['name']}")
 | 
						|
            self.data["mentions"] = dict_mention
 | 
						|
 | 
						|
            return Message(MessageSegment(self.type, self.data))
 | 
						|
 | 
						|
        else:
 | 
						|
            return Message(MessageSegment(self.type, self.data))
 |