mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-11-03 16:36:44 +00:00 
			
		
		
		
	💡 🚸 complete comments and optimize usage in mirai adapter
This commit is contained in:
		@@ -4,8 +4,18 @@ Mirai-API-HTTP 协议适配
 | 
			
		||||
 | 
			
		||||
协议详情请看: `mirai-api-http 文档`_ 
 | 
			
		||||
 | 
			
		||||
.. mirai-api-http 文档:
 | 
			
		||||
\:\:\: tip
 | 
			
		||||
该Adapter目前仍然处在早期实验性阶段, 并未经过充分测试
 | 
			
		||||
 | 
			
		||||
如果你在使用过程中遇到了任何问题, 请前往 `Issue页面`_ 为我们提供反馈
 | 
			
		||||
\:\:\:
 | 
			
		||||
 | 
			
		||||
.. _mirai-api-http 文档:
 | 
			
		||||
    https://github.com/project-mirai/mirai-api-http/tree/master/docs
 | 
			
		||||
 | 
			
		||||
.. _Issue页面
 | 
			
		||||
    https://github.com/nonebot/nonebot2/issues
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from .bot import MiraiBot
 | 
			
		||||
 
 | 
			
		||||
@@ -1,34 +1,23 @@
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
from functools import wraps
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
from ipaddress import IPv4Address
 | 
			
		||||
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
 | 
			
		||||
from typing import (Any, Dict, List, NoReturn, Optional, Tuple, Union)
 | 
			
		||||
 | 
			
		||||
import httpx
 | 
			
		||||
 | 
			
		||||
from nonebot.adapters import Bot as BaseBot
 | 
			
		||||
from nonebot.config import Config
 | 
			
		||||
from nonebot.drivers import Driver, WebSocket
 | 
			
		||||
from nonebot.exception import ActionFailed as BaseActionFailed
 | 
			
		||||
from nonebot.exception import RequestDenied
 | 
			
		||||
from nonebot.exception import ApiNotAvailable, RequestDenied
 | 
			
		||||
from nonebot.log import logger
 | 
			
		||||
from nonebot.message import handle_event
 | 
			
		||||
from nonebot.typing import overrides
 | 
			
		||||
from nonebot.utils import escape_tag
 | 
			
		||||
 | 
			
		||||
from .config import Config as MiraiConfig
 | 
			
		||||
from .event import Event, FriendMessage, GroupMessage, TempMessage
 | 
			
		||||
from .message import MessageChain, MessageSegment
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ActionFailed(BaseActionFailed):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, **kwargs):
 | 
			
		||||
        super().__init__('mirai')
 | 
			
		||||
        self.data = kwargs.copy()
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return self.__class__.__name__ + '(%s)' % ', '.join(
 | 
			
		||||
            map(lambda m: '%s=%r' % m, self.data.items()))
 | 
			
		||||
from .utils import catch_network_error, argument_validation
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SessionManager:
 | 
			
		||||
@@ -39,19 +28,11 @@ class SessionManager:
 | 
			
		||||
    def __init__(self, session_key: str, client: httpx.AsyncClient):
 | 
			
		||||
        self.session_key, self.client = session_key, client
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
 | 
			
		||||
        logger.opt(colors=True).debug(
 | 
			
		||||
            f'Mirai API returned data: <y>{escape_tag(str(data))}</y>')
 | 
			
		||||
        if isinstance(data, dict) and ('code' in data):
 | 
			
		||||
            if data['code'] != 0:
 | 
			
		||||
                raise ActionFailed(**data)
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    @catch_network_error
 | 
			
		||||
    async def post(self,
 | 
			
		||||
                   path: str,
 | 
			
		||||
                   *,
 | 
			
		||||
                   params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
 | 
			
		||||
                   params: Optional[Dict[str, Any]] = None) -> Any:
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
 | 
			
		||||
@@ -75,13 +56,13 @@ class SessionManager:
 | 
			
		||||
            timeout=3,
 | 
			
		||||
        )
 | 
			
		||||
        response.raise_for_status()
 | 
			
		||||
        return self._raise_code(response.json())
 | 
			
		||||
        return response.json()
 | 
			
		||||
 | 
			
		||||
    @catch_network_error
 | 
			
		||||
    async def request(self,
 | 
			
		||||
                      path: str,
 | 
			
		||||
                      *,
 | 
			
		||||
                      params: Optional[Dict[str,
 | 
			
		||||
                                            Any]] = None) -> Dict[str, Any]:
 | 
			
		||||
                      params: Optional[Dict[str, Any]] = None) -> Any:
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
 | 
			
		||||
@@ -91,10 +72,6 @@ class SessionManager:
 | 
			
		||||
 | 
			
		||||
          * ``path: str``: 对应API路径
 | 
			
		||||
          * ``params: Optional[Dict[str, Any]]``: 请求参数 (无需sessionKey)
 | 
			
		||||
 | 
			
		||||
        :返回:
 | 
			
		||||
 | 
			
		||||
          - ``Dict[str, Any]``: API 返回值
 | 
			
		||||
        """
 | 
			
		||||
        response = await self.client.get(
 | 
			
		||||
            path,
 | 
			
		||||
@@ -105,25 +82,34 @@ class SessionManager:
 | 
			
		||||
            timeout=3,
 | 
			
		||||
        )
 | 
			
		||||
        response.raise_for_status()
 | 
			
		||||
        return self._raise_code(response.json())
 | 
			
		||||
        return response.json()
 | 
			
		||||
 | 
			
		||||
    async def upload(self, path: str, *, type: str,
 | 
			
		||||
                     file: Tuple[str, BytesIO]) -> Dict[str, Any]:
 | 
			
		||||
    @catch_network_error
 | 
			
		||||
    async def upload(self, path: str, *, params: Dict[str, Any]) -> Any:
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
 | 
			
		||||
        file_type, file_io = file
 | 
			
		||||
        response = await self.client.post(path,
 | 
			
		||||
                                          data={
 | 
			
		||||
                                              'sessionKey': self.session_key,
 | 
			
		||||
                                              'type': type
 | 
			
		||||
                                          },
 | 
			
		||||
                                          files={file_type: file_io},
 | 
			
		||||
                                          timeout=6)
 | 
			
		||||
            以表单(``multipart/form-data``)形式主动提交API请求
 | 
			
		||||
 | 
			
		||||
        :参数:
 | 
			
		||||
 | 
			
		||||
            * ``path: str``: 对应API路径
 | 
			
		||||
            * ``params: Dict[str, Any]``: 请求参数 (无需sessionKey)
 | 
			
		||||
        """
 | 
			
		||||
        files = {k: v for k, v in params.items() if isinstance(v, BytesIO)}
 | 
			
		||||
        form = {k: v for k, v in params.items() if k not in files}
 | 
			
		||||
        response = await self.client.post(
 | 
			
		||||
            path,
 | 
			
		||||
            data=form,
 | 
			
		||||
            files=files,
 | 
			
		||||
            timeout=6,
 | 
			
		||||
        )
 | 
			
		||||
        response.raise_for_status()
 | 
			
		||||
        return self._raise_code(response.json())
 | 
			
		||||
        return response.json()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def new(cls, self_id: int, *, host: IPv4Address, port: int,
 | 
			
		||||
                  auth_key: str):
 | 
			
		||||
                  auth_key: str) -> "SessionManager":
 | 
			
		||||
        session = cls.get(self_id)
 | 
			
		||||
        if session is not None:
 | 
			
		||||
            return session
 | 
			
		||||
@@ -145,7 +131,9 @@ class SessionManager:
 | 
			
		||||
        return cls(session_key, client)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get(cls, self_id: int, check_expire: bool = True):
 | 
			
		||||
    def get(cls,
 | 
			
		||||
            self_id: int,
 | 
			
		||||
            check_expire: bool = True) -> Optional["SessionManager"]:
 | 
			
		||||
        if self_id not in cls.sessions:
 | 
			
		||||
            return None
 | 
			
		||||
        key, time, client = cls.sessions[self_id]
 | 
			
		||||
@@ -157,6 +145,13 @@ class SessionManager:
 | 
			
		||||
class MiraiBot(BaseBot):
 | 
			
		||||
    """
 | 
			
		||||
    mirai-api-http 协议 Bot 适配。
 | 
			
		||||
 | 
			
		||||
    \:\:\: warning
 | 
			
		||||
    API中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名
 | 
			
		||||
 | 
			
		||||
    部分字段可能与文档在符号上不一致
 | 
			
		||||
    \:\:\:
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
@@ -207,9 +202,9 @@ class MiraiBot(BaseBot):
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    def register(cls, driver: "Driver", config: "Config"):
 | 
			
		||||
        cls.mirai_config = MiraiConfig(**config.dict())
 | 
			
		||||
        assert cls.mirai_config.auth_key is not None
 | 
			
		||||
        assert cls.mirai_config.host is not None
 | 
			
		||||
        assert cls.mirai_config.port is not None
 | 
			
		||||
        if (cls.mirai_config.auth_key and cls.mirai_config.host and
 | 
			
		||||
                cls.mirai_config.port) is None:
 | 
			
		||||
            raise ApiNotAvailable('mirai')
 | 
			
		||||
        super().register(driver, config)
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
@@ -222,7 +217,12 @@ class MiraiBot(BaseBot):
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    async def call_api(self, api: str, **data) -> NoReturn:
 | 
			
		||||
        """由于Mirai的HTTP API特殊性, 该API暂时无法实现"""
 | 
			
		||||
        """
 | 
			
		||||
        由于Mirai的HTTP API特殊性, 该API暂时无法实现
 | 
			
		||||
        \:\:\: tip
 | 
			
		||||
        你可以使用 ``MiraiBot.api`` 中提供的调用方法来代替
 | 
			
		||||
        \:\:\:
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
@@ -231,6 +231,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @overrides(BaseBot)
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def send(self,
 | 
			
		||||
                   event: Event,
 | 
			
		||||
                   message: Union[MessageChain, MessageSegment, str],
 | 
			
		||||
@@ -245,10 +246,6 @@ class MiraiBot(BaseBot):
 | 
			
		||||
          * ``event: Event``: Event对象
 | 
			
		||||
          * ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息
 | 
			
		||||
          * ``at_sender: bool``: 是否 @ 事件主题
 | 
			
		||||
 | 
			
		||||
        :返回:
 | 
			
		||||
 | 
			
		||||
          - ``Any``: API 调用返回数据
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(message, MessageSegment):
 | 
			
		||||
            message = MessageChain(message)
 | 
			
		||||
@@ -269,6 +266,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f'Unsupported event type {event!r}.')
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def send_friend_message(self, target: int,
 | 
			
		||||
                                  message_chain: MessageChain):
 | 
			
		||||
        """
 | 
			
		||||
@@ -280,10 +278,6 @@ class MiraiBot(BaseBot):
 | 
			
		||||
 | 
			
		||||
          * ``target: int``: 发送消息目标好友的 QQ 号
 | 
			
		||||
          * ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
 | 
			
		||||
 | 
			
		||||
        :返回:
 | 
			
		||||
 | 
			
		||||
          - ``Any``: API 调用返回数据
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.post('sendFriendMessage',
 | 
			
		||||
                                   params={
 | 
			
		||||
@@ -291,6 +285,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
                                       'messageChain': message_chain.export()
 | 
			
		||||
                                   })
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def send_temp_message(self, qq: int, group: int,
 | 
			
		||||
                                message_chain: MessageChain):
 | 
			
		||||
        """
 | 
			
		||||
@@ -303,10 +298,6 @@ class MiraiBot(BaseBot):
 | 
			
		||||
          * ``qq: int``: 临时会话对象 QQ 号
 | 
			
		||||
          * ``group: int``: 临时会话群号
 | 
			
		||||
          * ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
 | 
			
		||||
 | 
			
		||||
        :返回:
 | 
			
		||||
 | 
			
		||||
          - ``Any``: API 调用返回数据
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.post('sendTempMessage',
 | 
			
		||||
                                   params={
 | 
			
		||||
@@ -315,6 +306,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
                                       'messageChain': message_chain.export()
 | 
			
		||||
                                   })
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def send_group_message(self,
 | 
			
		||||
                                 group: int,
 | 
			
		||||
                                 message_chain: MessageChain,
 | 
			
		||||
@@ -329,10 +321,6 @@ class MiraiBot(BaseBot):
 | 
			
		||||
          * ``group: int``: 发送消息目标群的群号
 | 
			
		||||
          * ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
 | 
			
		||||
          * ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复
 | 
			
		||||
 | 
			
		||||
        :返回:
 | 
			
		||||
 | 
			
		||||
          - ``Any``: API 调用返回数据
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.post('sendGroupMessage',
 | 
			
		||||
                                   params={
 | 
			
		||||
@@ -341,6 +329,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
                                       'quote': quote
 | 
			
		||||
                                   })
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def recall(self, target: int):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -350,13 +339,10 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        :参数:
 | 
			
		||||
 | 
			
		||||
          * ``target: int``: 需要撤回的消息的message_id
 | 
			
		||||
 | 
			
		||||
        :返回:
 | 
			
		||||
 | 
			
		||||
          - ``Any``: API 调用返回数据
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.post('recall', params={'target': target})
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def send_image_message(self, target: int, qq: int, group: int,
 | 
			
		||||
                                 urls: List[str]) -> List[str]:
 | 
			
		||||
        """
 | 
			
		||||
@@ -384,8 +370,9 @@ class MiraiBot(BaseBot):
 | 
			
		||||
                                       'qq': qq,
 | 
			
		||||
                                       'group': group,
 | 
			
		||||
                                       'urls': urls
 | 
			
		||||
                                   })  # type: ignore
 | 
			
		||||
                                   })
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def upload_image(self, type: str, img: BytesIO):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -396,15 +383,14 @@ class MiraiBot(BaseBot):
 | 
			
		||||
 | 
			
		||||
          * ``type: str``: "friend" 或 "group" 或 "temp"
 | 
			
		||||
          * ``img: BytesIO``: 图片的BytesIO对象
 | 
			
		||||
 | 
			
		||||
        :返回:
 | 
			
		||||
 | 
			
		||||
          - ``Any``: API 调用返回数据
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.upload('uploadImage',
 | 
			
		||||
                                     type=type,
 | 
			
		||||
                                     file=('img', img))
 | 
			
		||||
                                     params={
 | 
			
		||||
                                         'type': type,
 | 
			
		||||
                                         'img': img
 | 
			
		||||
                                     })
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def upload_voice(self, type: str, voice: BytesIO):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -415,15 +401,14 @@ class MiraiBot(BaseBot):
 | 
			
		||||
 | 
			
		||||
          * ``type: str``: 当前仅支持 "group"
 | 
			
		||||
          * ``voice: BytesIO``: 语音的BytesIO对象
 | 
			
		||||
 | 
			
		||||
        :返回:
 | 
			
		||||
 | 
			
		||||
          - ``Any``: API 调用返回数据
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.upload('uploadVoice',
 | 
			
		||||
                                     type=type,
 | 
			
		||||
                                     file=('voice', voice))
 | 
			
		||||
                                     params={
 | 
			
		||||
                                         'type': type,
 | 
			
		||||
                                         'voice': voice
 | 
			
		||||
                                     })
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def fetch_message(self, count: int = 10):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -437,6 +422,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.request('fetchMessage', params={'count': count})
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def fetch_latest_message(self, count: int = 10):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -451,6 +437,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        return await self.api.request('fetchLatestMessage',
 | 
			
		||||
                                      params={'count': count})
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def peek_message(self, count: int = 10):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -464,6 +451,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.request('peekMessage', params={'count': count})
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def peek_latest_message(self, count: int = 10):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -478,6 +466,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        return await self.api.request('peekLatestMessage',
 | 
			
		||||
                                      params={'count': count})
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def messsage_from_id(self, id: int):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -491,6 +480,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.request('messageFromId', params={'id': id})
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def count_message(self):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -499,6 +489,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.request('countMessage')
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def friend_list(self) -> List[Dict[str, Any]]:
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -509,8 +500,9 @@ class MiraiBot(BaseBot):
 | 
			
		||||
 | 
			
		||||
          - ``List[Dict[str, Any]]``: 返回的好友列表数据
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.request('friendList')  # type: ignore
 | 
			
		||||
        return await self.api.request('friendList')
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def group_list(self) -> List[Dict[str, Any]]:
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -521,8 +513,9 @@ class MiraiBot(BaseBot):
 | 
			
		||||
 | 
			
		||||
          - ``List[Dict[str, Any]]``: 返回的群列表数据
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.request('groupList')  # type: ignore
 | 
			
		||||
        return await self.api.request('groupList')
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def member_list(self, target: int) -> List[Dict[str, Any]]:
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -537,9 +530,9 @@ class MiraiBot(BaseBot):
 | 
			
		||||
 | 
			
		||||
          - ``List[Dict[str, Any]]``: 返回的群成员列表数据
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.request('memberList',
 | 
			
		||||
                                      params={'target': target})  # type: ignore
 | 
			
		||||
        return await self.api.request('memberList', params={'target': target})
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def mute(self, target: int, member_id: int, time: int):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -559,6 +552,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
                                       'time': time
 | 
			
		||||
                                   })
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def unmute(self, target: int, member_id: int):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -576,6 +570,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
                                       'memberId': member_id
 | 
			
		||||
                                   })
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def kick(self, target: int, member_id: int, msg: str):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -595,6 +590,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
                                       'msg': msg
 | 
			
		||||
                                   })
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def quit(self, target: int):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -607,6 +603,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.post('quit', params={'target': target})
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def mute_all(self, target: int):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -619,6 +616,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.post('muteAll', params={'target': target})
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def unmute_all(self, target: int):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -631,6 +629,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.post('unmuteAll', params={'target': target})
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def group_config(self, target: int):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -656,6 +655,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
        """
 | 
			
		||||
        return await self.api.request('groupConfig', params={'target': target})
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def modify_group_config(self, target: int, config: Dict[str, Any]):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -673,6 +673,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
                                       'config': config
 | 
			
		||||
                                   })
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def member_info(self, target: int, member_id: int):
 | 
			
		||||
        """
 | 
			
		||||
        :说明:
 | 
			
		||||
@@ -699,6 +700,7 @@ class MiraiBot(BaseBot):
 | 
			
		||||
                                          'memberId': member_id
 | 
			
		||||
                                      })
 | 
			
		||||
 | 
			
		||||
    @argument_validation
 | 
			
		||||
    async def modify_member_info(self, target: int, member_id: int,
 | 
			
		||||
                                 info: Dict[str, Any]):
 | 
			
		||||
        """
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										89
									
								
								nonebot/adapters/mirai/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								nonebot/adapters/mirai/utils.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,89 @@
 | 
			
		||||
from functools import wraps
 | 
			
		||||
from typing import Callable, Coroutine, TypeVar
 | 
			
		||||
 | 
			
		||||
import httpx
 | 
			
		||||
from pydantic import ValidationError, validate_arguments, Extra
 | 
			
		||||
 | 
			
		||||
import nonebot.exception as exception
 | 
			
		||||
from nonebot.log import logger
 | 
			
		||||
from nonebot.utils import escape_tag
 | 
			
		||||
 | 
			
		||||
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
 | 
			
		||||
_AnyCallable = TypeVar("_AnyCallable", bound=Callable)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ActionFailed(exception.ActionFailed):
 | 
			
		||||
    """
 | 
			
		||||
    :说明:
 | 
			
		||||
 | 
			
		||||
      API 请求成功返回数据,但 API 操作失败。
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, **kwargs):
 | 
			
		||||
        super().__init__('mirai')
 | 
			
		||||
        self.data = kwargs.copy()
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return self.__class__.__name__ + '(%s)' % ', '.join(
 | 
			
		||||
            map(lambda m: '%s=%r' % m, self.data.items()))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InvalidArgument(exception.AdapterException):
 | 
			
		||||
    """
 | 
			
		||||
    :说明:
 | 
			
		||||
    
 | 
			
		||||
      调用API的参数出错
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, **kwargs):
 | 
			
		||||
        super().__init__('mirai')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
 | 
			
		||||
    """
 | 
			
		||||
    :说明:
 | 
			
		||||
 | 
			
		||||
      捕捉函数抛出的httpx网络异常并释放``NetworkError``异常
 | 
			
		||||
      处理返回数据, 在code不为0时释放``ActionFailed``异常
 | 
			
		||||
 | 
			
		||||
    \:\:\: warning
 | 
			
		||||
    此装饰器只支持使用了httpx的异步函数
 | 
			
		||||
    \:\:\:
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @wraps(function)
 | 
			
		||||
    async def wrapper(*args, **kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            data = await function(*args, **kwargs)
 | 
			
		||||
        except httpx.HTTPError:
 | 
			
		||||
            raise exception.NetworkError('mirai')
 | 
			
		||||
        logger.opt(colors=True).debug('<b>Mirai API returned data:</b> '
 | 
			
		||||
                                      f'<y>{escape_tag(str(data))}</y>')
 | 
			
		||||
        if isinstance(data, dict):
 | 
			
		||||
            if data.get('code', 0) != 0:
 | 
			
		||||
                raise ActionFailed(**data)
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    return wrapper  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def argument_validation(function: _AnyCallable) -> _AnyCallable:
 | 
			
		||||
    """
 | 
			
		||||
    :说明:
 | 
			
		||||
 | 
			
		||||
      通过函数签名中的类型注解来对传入参数进行运行时校验
 | 
			
		||||
      会在参数出错时释放``InvalidArgument``异常
 | 
			
		||||
    """
 | 
			
		||||
    function = validate_arguments(config={
 | 
			
		||||
        'arbitrary_types_allowed': True,
 | 
			
		||||
        'extra': Extra.forbid
 | 
			
		||||
    })(function)
 | 
			
		||||
 | 
			
		||||
    @wraps(function)
 | 
			
		||||
    def wrapper(*args, **kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            return function(*args, **kwargs)
 | 
			
		||||
        except ValidationError:
 | 
			
		||||
            raise InvalidArgument
 | 
			
		||||
 | 
			
		||||
    return wrapper  # type: ignore
 | 
			
		||||
		Reference in New Issue
	
	Block a user