implement at parser

This commit is contained in:
StarHeartHunt
2021-07-08 22:30:39 +08:00
parent 15ab958a70
commit 67770ffa6f
3 changed files with 18 additions and 6 deletions

View File

@ -293,7 +293,7 @@ class Bot(BaseBot):
at_sender = at_sender and bool(event.get_user_id()) at_sender = at_sender and bool(event.get_user_id())
if at_sender and receive_id_type != "union_id": if at_sender and receive_id_type != "union_id":
msg += MessageSegment.at(event.get_user_id(), "StarHeart") + " " msg = MessageSegment.at(event.get_user_id(), "StarHeart") + " " + msg
msg_type, content = MessageSerializer(msg).serialize() msg_type, content = MessageSerializer(msg).serialize()

View File

@ -97,10 +97,9 @@ class EventMessage(BaseModel):
@root_validator(pre=True) @root_validator(pre=True)
def parse_message(cls, values: dict): def parse_message(cls, values: dict):
#TODO:解析mentions替换message的user_id传入deserializer
values["content"] = MessageDeserializer( values["content"] = MessageDeserializer(
values["message_type"], values["message_type"], json.loads(values["content"]),
json.loads(values["content"])).deserialize() values.get("mentions")).deserialize()
return values return values
@ -154,7 +153,7 @@ class MessageEvent(Event):
@overrides(Event) @overrides(Event)
def get_user_id(self) -> str: def get_user_id(self) -> str:
return self.event.sender.sender_id.user_id return self.event.sender.sender_id.open_id
@overrides(Event) @overrides(Event)
def get_session_id(self) -> str: def get_session_id(self) -> str:

View File

@ -2,7 +2,7 @@ import json
import itertools import itertools
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Type, Union, Mapping, Iterable from typing import Any, Dict, List, Optional, Tuple, Type, Union, Mapping, Iterable
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
from nonebot.typing import overrides from nonebot.typing import overrides
@ -214,15 +214,28 @@ class MessageDeserializer:
""" """
type: str type: str
data: Dict[str, Any] data: Dict[str, Any]
mentions: Optional[List[dict]]
def deserialize(self) -> Message: def deserialize(self) -> Message:
dict_mention = {}
if self.type == "post": if self.type == "post":
if self.mentions:
for mention in self.mentions:
dict_mention[mention["key"]] = mention
msg = Message() msg = Message()
if self.data["title"] != "": if self.data["title"] != "":
msg += MessageSegment("text", {'text': self.data["title"]}) msg += MessageSegment("text", {'text': self.data["title"]})
for seg in itertools.chain(*self.data["content"]): for seg in itertools.chain(*self.data["content"]):
tag = seg.pop("tag") tag = seg.pop("tag")
if tag == "at":
seg["user_name"] = dict_mention[seg["user_id"]]["name"]
seg["user_id"] = dict_mention[
seg["user_id"]]["id"]["open_id"]
msg += MessageSegment(tag if tag != "img" else "image", seg) msg += MessageSegment(tag if tag != "img" else "image", seg)
return msg._merge() return msg._merge()
else: else: