mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-11-03 16:36:44 +00:00 
			
		
		
		
	🚸 add factory classmethods in MessageSegment at mirai adapter
This commit is contained in:
		@@ -10,6 +10,7 @@ from nonebot.adapters import Event as BaseEvent
 | 
			
		||||
from nonebot.config import Config
 | 
			
		||||
from nonebot.drivers import Driver, WebSocket
 | 
			
		||||
from nonebot.exception import RequestDenied
 | 
			
		||||
from nonebot.exception import ActionFailed as BaseActionFailed
 | 
			
		||||
from nonebot.log import logger
 | 
			
		||||
from nonebot.message import handle_event
 | 
			
		||||
from nonebot.typing import overrides
 | 
			
		||||
@@ -19,6 +20,17 @@ from .event import Event, FriendMessage, GroupMessage, TempMessage
 | 
			
		||||
from .message import MessageChain, MessageSegment
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ActionFailed(BaseActionFailed):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, code: int, message: str = ''):
 | 
			
		||||
        super().__init__('mirai')
 | 
			
		||||
        self.code = code
 | 
			
		||||
        self.message = message
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"{self.__class__.__name__}(code={self.code}, message={self.message!r})"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SessionManager:
 | 
			
		||||
    sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {}
 | 
			
		||||
    session_expiry: timedelta = timedelta(minutes=15)
 | 
			
		||||
@@ -26,14 +38,22 @@ class SessionManager:
 | 
			
		||||
    def __init__(self, session_key: str, client: httpx.AsyncClient):
 | 
			
		||||
        self.session_key, self.client = session_key, client
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
 | 
			
		||||
        code = data.get('code', 0)
 | 
			
		||||
        logger.debug(f'Mirai API returned data: {data}')
 | 
			
		||||
        if code != 0:
 | 
			
		||||
            raise ActionFailed(code, message=data['msg'])
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    async def post(self,
 | 
			
		||||
                   path: str,
 | 
			
		||||
                   *,
 | 
			
		||||
                   params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
 | 
			
		||||
        params = {**(params or {}), 'sessionKey': self.session_key}
 | 
			
		||||
        response = await self.client.post(path, json=params)
 | 
			
		||||
        response = await self.client.post(path, json=params, timeout=3)
 | 
			
		||||
        response.raise_for_status()
 | 
			
		||||
        return response.json()
 | 
			
		||||
        return self._raise_code(response.json())
 | 
			
		||||
 | 
			
		||||
    async def request(self,
 | 
			
		||||
                      path: str,
 | 
			
		||||
@@ -44,9 +64,10 @@ class SessionManager:
 | 
			
		||||
                                         params={
 | 
			
		||||
                                             **(params or {}), 'sessionKey':
 | 
			
		||||
                                                 self.session_key
 | 
			
		||||
                                         })
 | 
			
		||||
                                         },
 | 
			
		||||
                                         timeout=3)
 | 
			
		||||
        response.raise_for_status()
 | 
			
		||||
        return response.json()
 | 
			
		||||
        return self._raise_code(response.json())
 | 
			
		||||
 | 
			
		||||
    async def upload(self, path: str, *, type: str,
 | 
			
		||||
                     file: Tuple[str, BytesIO]) -> Dict[str, Any]:
 | 
			
		||||
@@ -59,7 +80,7 @@ class SessionManager:
 | 
			
		||||
                                          files={file_type: file_io},
 | 
			
		||||
                                          timeout=6)
 | 
			
		||||
        response.raise_for_status()
 | 
			
		||||
        return response.json()
 | 
			
		||||
        return self._raise_code(response.json())
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def new(cls, self_id: int, *, host: IPv4Address, port: int,
 | 
			
		||||
@@ -152,7 +173,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    async def __getattr__(self, key: str) -> NoReturn:
 | 
			
		||||
    def __getattr__(self, key: str) -> NoReturn:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
@@ -165,8 +186,10 @@ class MiraiBot(BaseBot):
 | 
			
		||||
            return await self.send_friend_message(target=event.sender.id,
 | 
			
		||||
                                                  message_chain=message)
 | 
			
		||||
        elif isinstance(event, GroupMessage):
 | 
			
		||||
            return await self.send_group_message(target=event.sender.group.id,
 | 
			
		||||
                                                 message_chain=message)
 | 
			
		||||
            return await self.send_group_message(
 | 
			
		||||
                group=event.sender.group.id,
 | 
			
		||||
                message_chain=message if not at_sender else
 | 
			
		||||
                (MessageSegment.at(target=event.sender.id) + message))
 | 
			
		||||
        elif isinstance(event, TempMessage):
 | 
			
		||||
            return await self.send_temp_message(qq=event.sender.id,
 | 
			
		||||
                                                group=event.sender.group.id,
 | 
			
		||||
@@ -191,12 +214,15 @@ class MiraiBot(BaseBot):
 | 
			
		||||
                                       'messageChain': message_chain.export()
 | 
			
		||||
                                   })
 | 
			
		||||
 | 
			
		||||
    async def send_group_message(self, target: int,
 | 
			
		||||
                                 message_chain: MessageChain):
 | 
			
		||||
    async def send_group_message(self,
 | 
			
		||||
                                 group: int,
 | 
			
		||||
                                 message_chain: MessageChain,
 | 
			
		||||
                                 quote: Optional[int] = None):
 | 
			
		||||
        return await self.api.post('sendGroupMessage',
 | 
			
		||||
                                   params={
 | 
			
		||||
                                       'target': target,
 | 
			
		||||
                                       'messageChain': message_chain.export()
 | 
			
		||||
                                       'group': group,
 | 
			
		||||
                                       'messageChain': message_chain.export(),
 | 
			
		||||
                                       'quote': quote
 | 
			
		||||
                                   })
 | 
			
		||||
 | 
			
		||||
    async def recall(self, target: int):
 | 
			
		||||
 
 | 
			
		||||
@@ -18,7 +18,7 @@ class MessageEvent(Event):
 | 
			
		||||
 | 
			
		||||
    @overrides(Event)
 | 
			
		||||
    def get_plaintext(self) -> str:
 | 
			
		||||
        return self.message_chain.__str__()
 | 
			
		||||
        return self.message_chain.extract_plain_text()
 | 
			
		||||
 | 
			
		||||
    @overrides(Event)
 | 
			
		||||
    def get_user_id(self) -> str:
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,5 @@
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Any, Dict, Iterable, List, Union
 | 
			
		||||
from typing import Any, Dict, Iterable, List, Optional, Union
 | 
			
		||||
 | 
			
		||||
from pydantic import validate_arguments
 | 
			
		||||
 | 
			
		||||
@@ -31,7 +31,8 @@ class MessageSegment(BaseMessageSegment):
 | 
			
		||||
    @overrides(BaseMessageSegment)
 | 
			
		||||
    @validate_arguments
 | 
			
		||||
    def __init__(self, type: MessageType, **data):
 | 
			
		||||
        super().__init__(type=type, data=data)
 | 
			
		||||
        super().__init__(type=type,
 | 
			
		||||
                         data={k: v for k, v in data.items() if v is not None})
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseMessageSegment)
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
@@ -60,6 +61,79 @@ class MessageSegment(BaseMessageSegment):
 | 
			
		||||
    def as_dict(self) -> Dict[str, Any]:
 | 
			
		||||
        return {'type': self.type.value, **self.data}
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def source(cls, id: int, time: int):
 | 
			
		||||
        return cls(type=MessageType.SOURCE, id=id, time=time)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def quote(cls, id: int, group_id: int, sender_id: int, target_id: int,
 | 
			
		||||
              origin: "MessageChain"):
 | 
			
		||||
        return cls(type=MessageType.QUOTE,
 | 
			
		||||
                   id=id,
 | 
			
		||||
                   groupId=group_id,
 | 
			
		||||
                   senderId=sender_id,
 | 
			
		||||
                   targetId=target_id,
 | 
			
		||||
                   origin=origin.export())
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def at(cls, target: int):
 | 
			
		||||
        return cls(type=MessageType.AT, target=target)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def at_all(cls):
 | 
			
		||||
        return cls(type=MessageType.AT_ALL)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def face(cls, face_id: Optional[int] = None, name: Optional[str] = None):
 | 
			
		||||
        return cls(type=MessageType.FACE, faceId=face_id, name=name)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def plain(cls, text: str):
 | 
			
		||||
        return cls(type=MessageType.PLAIN, text=text)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def image(cls,
 | 
			
		||||
              image_id: Optional[str] = None,
 | 
			
		||||
              url: Optional[str] = None,
 | 
			
		||||
              path: Optional[str] = None):
 | 
			
		||||
        return cls(type=MessageType.IMAGE, imageId=image_id, url=url, path=path)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def flash_image(cls,
 | 
			
		||||
                    image_id: Optional[str] = None,
 | 
			
		||||
                    url: Optional[str] = None,
 | 
			
		||||
                    path: Optional[str] = None):
 | 
			
		||||
        return cls(type=MessageType.FLASH_IMAGE,
 | 
			
		||||
                   imageId=image_id,
 | 
			
		||||
                   url=url,
 | 
			
		||||
                   path=path)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def voice(cls,
 | 
			
		||||
              voice_id: Optional[str] = None,
 | 
			
		||||
              url: Optional[str] = None,
 | 
			
		||||
              path: Optional[str] = None):
 | 
			
		||||
        return cls(type=MessageType.FLASH_IMAGE,
 | 
			
		||||
                   imageId=voice_id,
 | 
			
		||||
                   url=url,
 | 
			
		||||
                   path=path)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def xml(cls, xml: str):
 | 
			
		||||
        return cls(type=MessageType.XML, xml=xml)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def json(cls, json: str):
 | 
			
		||||
        return cls(type=MessageType.JSON, json=json)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def app(cls, content: str):
 | 
			
		||||
        return cls(type=MessageType.APP, content=content)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def poke(cls, name: str):
 | 
			
		||||
        return cls(type=MessageType.POKE, name=name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MessageChain(BaseMessage):
 | 
			
		||||
 | 
			
		||||
@@ -90,11 +164,9 @@ class MessageChain(BaseMessage):
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def export(self) -> List[Dict[str, Any]]:
 | 
			
		||||
        chain: List[Dict[str, Any]] = []
 | 
			
		||||
        for segment in self.copy():
 | 
			
		||||
            segment: MessageSegment
 | 
			
		||||
            chain.append({'type': segment.type.value, **segment.data})
 | 
			
		||||
        return chain
 | 
			
		||||
        return [
 | 
			
		||||
            *map(lambda segment: segment.as_dict(), self.copy())  # type: ignore
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return f'<{self.__class__.__name__} {[*self.copy()]}>'
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user