diff --git a/nonebot/adapters/mirai/__init__.py b/nonebot/adapters/mirai/__init__.py
index b209657e..75afaff8 100644
--- a/nonebot/adapters/mirai/__init__.py
+++ b/nonebot/adapters/mirai/__init__.py
@@ -4,8 +4,18 @@ Mirai-API-HTTP 协议适配
协议详情请看: `mirai-api-http 文档`_
-.. mirai-api-http 文档:
+\:\:\: tip
+该Adapter目前仍然处在早期实验性阶段, 并未经过充分测试
+
+如果你在使用过程中遇到了任何问题, 请前往 `Issue页面`_ 为我们提供反馈
+\:\:\:
+
+.. _mirai-api-http 文档:
https://github.com/project-mirai/mirai-api-http/tree/master/docs
+
+.. _Issue页面
+ https://github.com/nonebot/nonebot2/issues
+
"""
from .bot import MiraiBot
diff --git a/nonebot/adapters/mirai/bot.py b/nonebot/adapters/mirai/bot.py
index 74c4f602..ed0b9ae1 100644
--- a/nonebot/adapters/mirai/bot.py
+++ b/nonebot/adapters/mirai/bot.py
@@ -1,34 +1,23 @@
from datetime import datetime, timedelta
+from functools import wraps
from io import BytesIO
from ipaddress import IPv4Address
-from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
+from typing import (Any, Dict, List, NoReturn, Optional, Tuple, Union)
import httpx
from nonebot.adapters import Bot as BaseBot
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
-from nonebot.exception import ActionFailed as BaseActionFailed
-from nonebot.exception import RequestDenied
+from nonebot.exception import ApiNotAvailable, RequestDenied
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides
-from nonebot.utils import escape_tag
from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage
from .message import MessageChain, MessageSegment
-
-
-class ActionFailed(BaseActionFailed):
-
- def __init__(self, **kwargs):
- super().__init__('mirai')
- self.data = kwargs.copy()
-
- def __repr__(self):
- return self.__class__.__name__ + '(%s)' % ', '.join(
- map(lambda m: '%s=%r' % m, self.data.items()))
+from .utils import catch_network_error, argument_validation
class SessionManager:
@@ -39,19 +28,11 @@ class SessionManager:
def __init__(self, session_key: str, client: httpx.AsyncClient):
self.session_key, self.client = session_key, client
- @staticmethod
- def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
- logger.opt(colors=True).debug(
- f'Mirai API returned data: {escape_tag(str(data))}')
- if isinstance(data, dict) and ('code' in data):
- if data['code'] != 0:
- raise ActionFailed(**data)
- return data
-
+ @catch_network_error
async def post(self,
path: str,
*,
- params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
+ params: Optional[Dict[str, Any]] = None) -> Any:
"""
:说明:
@@ -75,13 +56,13 @@ class SessionManager:
timeout=3,
)
response.raise_for_status()
- return self._raise_code(response.json())
+ return response.json()
+ @catch_network_error
async def request(self,
path: str,
*,
- params: Optional[Dict[str,
- Any]] = None) -> Dict[str, Any]:
+ params: Optional[Dict[str, Any]] = None) -> Any:
"""
:说明:
@@ -91,10 +72,6 @@ class SessionManager:
* ``path: str``: 对应API路径
* ``params: Optional[Dict[str, Any]]``: 请求参数 (无需sessionKey)
-
- :返回:
-
- - ``Dict[str, Any]``: API 返回值
"""
response = await self.client.get(
path,
@@ -105,25 +82,34 @@ class SessionManager:
timeout=3,
)
response.raise_for_status()
- return self._raise_code(response.json())
+ return response.json()
- async def upload(self, path: str, *, type: str,
- file: Tuple[str, BytesIO]) -> Dict[str, Any]:
+ @catch_network_error
+ async def upload(self, path: str, *, params: Dict[str, Any]) -> Any:
+ """
+ :说明:
- file_type, file_io = file
- response = await self.client.post(path,
- data={
- 'sessionKey': self.session_key,
- 'type': type
- },
- files={file_type: file_io},
- timeout=6)
+ 以表单(``multipart/form-data``)形式主动提交API请求
+
+ :参数:
+
+ * ``path: str``: 对应API路径
+ * ``params: Dict[str, Any]``: 请求参数 (无需sessionKey)
+ """
+ files = {k: v for k, v in params.items() if isinstance(v, BytesIO)}
+ form = {k: v for k, v in params.items() if k not in files}
+ response = await self.client.post(
+ path,
+ data=form,
+ files=files,
+ timeout=6,
+ )
response.raise_for_status()
- return self._raise_code(response.json())
+ return response.json()
@classmethod
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
- auth_key: str):
+ auth_key: str) -> "SessionManager":
session = cls.get(self_id)
if session is not None:
return session
@@ -145,7 +131,9 @@ class SessionManager:
return cls(session_key, client)
@classmethod
- def get(cls, self_id: int, check_expire: bool = True):
+ def get(cls,
+ self_id: int,
+ check_expire: bool = True) -> Optional["SessionManager"]:
if self_id not in cls.sessions:
return None
key, time, client = cls.sessions[self_id]
@@ -157,6 +145,13 @@ class SessionManager:
class MiraiBot(BaseBot):
"""
mirai-api-http 协议 Bot 适配。
+
+ \:\:\: warning
+ API中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名
+
+ 部分字段可能与文档在符号上不一致
+ \:\:\:
+
"""
@overrides(BaseBot)
@@ -207,9 +202,9 @@ class MiraiBot(BaseBot):
@overrides(BaseBot)
def register(cls, driver: "Driver", config: "Config"):
cls.mirai_config = MiraiConfig(**config.dict())
- assert cls.mirai_config.auth_key is not None
- assert cls.mirai_config.host is not None
- assert cls.mirai_config.port is not None
+ if (cls.mirai_config.auth_key and cls.mirai_config.host and
+ cls.mirai_config.port) is None:
+ raise ApiNotAvailable('mirai')
super().register(driver, config)
@overrides(BaseBot)
@@ -222,7 +217,12 @@ class MiraiBot(BaseBot):
@overrides(BaseBot)
async def call_api(self, api: str, **data) -> NoReturn:
- """由于Mirai的HTTP API特殊性, 该API暂时无法实现"""
+ """
+ 由于Mirai的HTTP API特殊性, 该API暂时无法实现
+ \:\:\: tip
+ 你可以使用 ``MiraiBot.api`` 中提供的调用方法来代替
+ \:\:\:
+ """
raise NotImplementedError
@overrides(BaseBot)
@@ -231,6 +231,7 @@ class MiraiBot(BaseBot):
raise NotImplementedError
@overrides(BaseBot)
+ @argument_validation
async def send(self,
event: Event,
message: Union[MessageChain, MessageSegment, str],
@@ -245,10 +246,6 @@ class MiraiBot(BaseBot):
* ``event: Event``: Event对象
* ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息
* ``at_sender: bool``: 是否 @ 事件主题
-
- :返回:
-
- - ``Any``: API 调用返回数据
"""
if isinstance(message, MessageSegment):
message = MessageChain(message)
@@ -269,6 +266,7 @@ class MiraiBot(BaseBot):
else:
raise ValueError(f'Unsupported event type {event!r}.')
+ @argument_validation
async def send_friend_message(self, target: int,
message_chain: MessageChain):
"""
@@ -280,10 +278,6 @@ class MiraiBot(BaseBot):
* ``target: int``: 发送消息目标好友的 QQ 号
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
-
- :返回:
-
- - ``Any``: API 调用返回数据
"""
return await self.api.post('sendFriendMessage',
params={
@@ -291,6 +285,7 @@ class MiraiBot(BaseBot):
'messageChain': message_chain.export()
})
+ @argument_validation
async def send_temp_message(self, qq: int, group: int,
message_chain: MessageChain):
"""
@@ -303,10 +298,6 @@ class MiraiBot(BaseBot):
* ``qq: int``: 临时会话对象 QQ 号
* ``group: int``: 临时会话群号
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
-
- :返回:
-
- - ``Any``: API 调用返回数据
"""
return await self.api.post('sendTempMessage',
params={
@@ -315,6 +306,7 @@ class MiraiBot(BaseBot):
'messageChain': message_chain.export()
})
+ @argument_validation
async def send_group_message(self,
group: int,
message_chain: MessageChain,
@@ -329,10 +321,6 @@ class MiraiBot(BaseBot):
* ``group: int``: 发送消息目标群的群号
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
* ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复
-
- :返回:
-
- - ``Any``: API 调用返回数据
"""
return await self.api.post('sendGroupMessage',
params={
@@ -341,6 +329,7 @@ class MiraiBot(BaseBot):
'quote': quote
})
+ @argument_validation
async def recall(self, target: int):
"""
:说明:
@@ -350,13 +339,10 @@ class MiraiBot(BaseBot):
:参数:
* ``target: int``: 需要撤回的消息的message_id
-
- :返回:
-
- - ``Any``: API 调用返回数据
"""
return await self.api.post('recall', params={'target': target})
+ @argument_validation
async def send_image_message(self, target: int, qq: int, group: int,
urls: List[str]) -> List[str]:
"""
@@ -384,8 +370,9 @@ class MiraiBot(BaseBot):
'qq': qq,
'group': group,
'urls': urls
- }) # type: ignore
+ })
+ @argument_validation
async def upload_image(self, type: str, img: BytesIO):
"""
:说明:
@@ -396,15 +383,14 @@ class MiraiBot(BaseBot):
* ``type: str``: "friend" 或 "group" 或 "temp"
* ``img: BytesIO``: 图片的BytesIO对象
-
- :返回:
-
- - ``Any``: API 调用返回数据
"""
return await self.api.upload('uploadImage',
- type=type,
- file=('img', img))
+ params={
+ 'type': type,
+ 'img': img
+ })
+ @argument_validation
async def upload_voice(self, type: str, voice: BytesIO):
"""
:说明:
@@ -415,15 +401,14 @@ class MiraiBot(BaseBot):
* ``type: str``: 当前仅支持 "group"
* ``voice: BytesIO``: 语音的BytesIO对象
-
- :返回:
-
- - ``Any``: API 调用返回数据
"""
return await self.api.upload('uploadVoice',
- type=type,
- file=('voice', voice))
+ params={
+ 'type': type,
+ 'voice': voice
+ })
+ @argument_validation
async def fetch_message(self, count: int = 10):
"""
:说明:
@@ -437,6 +422,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.request('fetchMessage', params={'count': count})
+ @argument_validation
async def fetch_latest_message(self, count: int = 10):
"""
:说明:
@@ -451,6 +437,7 @@ class MiraiBot(BaseBot):
return await self.api.request('fetchLatestMessage',
params={'count': count})
+ @argument_validation
async def peek_message(self, count: int = 10):
"""
:说明:
@@ -464,6 +451,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.request('peekMessage', params={'count': count})
+ @argument_validation
async def peek_latest_message(self, count: int = 10):
"""
:说明:
@@ -478,6 +466,7 @@ class MiraiBot(BaseBot):
return await self.api.request('peekLatestMessage',
params={'count': count})
+ @argument_validation
async def messsage_from_id(self, id: int):
"""
:说明:
@@ -491,6 +480,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.request('messageFromId', params={'id': id})
+ @argument_validation
async def count_message(self):
"""
:说明:
@@ -499,6 +489,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.request('countMessage')
+ @argument_validation
async def friend_list(self) -> List[Dict[str, Any]]:
"""
:说明:
@@ -509,8 +500,9 @@ class MiraiBot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的好友列表数据
"""
- return await self.api.request('friendList') # type: ignore
+ return await self.api.request('friendList')
+ @argument_validation
async def group_list(self) -> List[Dict[str, Any]]:
"""
:说明:
@@ -521,8 +513,9 @@ class MiraiBot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群列表数据
"""
- return await self.api.request('groupList') # type: ignore
+ return await self.api.request('groupList')
+ @argument_validation
async def member_list(self, target: int) -> List[Dict[str, Any]]:
"""
:说明:
@@ -537,9 +530,9 @@ class MiraiBot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群成员列表数据
"""
- return await self.api.request('memberList',
- params={'target': target}) # type: ignore
+ return await self.api.request('memberList', params={'target': target})
+ @argument_validation
async def mute(self, target: int, member_id: int, time: int):
"""
:说明:
@@ -559,6 +552,7 @@ class MiraiBot(BaseBot):
'time': time
})
+ @argument_validation
async def unmute(self, target: int, member_id: int):
"""
:说明:
@@ -576,6 +570,7 @@ class MiraiBot(BaseBot):
'memberId': member_id
})
+ @argument_validation
async def kick(self, target: int, member_id: int, msg: str):
"""
:说明:
@@ -595,6 +590,7 @@ class MiraiBot(BaseBot):
'msg': msg
})
+ @argument_validation
async def quit(self, target: int):
"""
:说明:
@@ -607,6 +603,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.post('quit', params={'target': target})
+ @argument_validation
async def mute_all(self, target: int):
"""
:说明:
@@ -619,6 +616,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.post('muteAll', params={'target': target})
+ @argument_validation
async def unmute_all(self, target: int):
"""
:说明:
@@ -631,6 +629,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.post('unmuteAll', params={'target': target})
+ @argument_validation
async def group_config(self, target: int):
"""
:说明:
@@ -656,6 +655,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.request('groupConfig', params={'target': target})
+ @argument_validation
async def modify_group_config(self, target: int, config: Dict[str, Any]):
"""
:说明:
@@ -673,6 +673,7 @@ class MiraiBot(BaseBot):
'config': config
})
+ @argument_validation
async def member_info(self, target: int, member_id: int):
"""
:说明:
@@ -699,6 +700,7 @@ class MiraiBot(BaseBot):
'memberId': member_id
})
+ @argument_validation
async def modify_member_info(self, target: int, member_id: int,
info: Dict[str, Any]):
"""
diff --git a/nonebot/adapters/mirai/utils.py b/nonebot/adapters/mirai/utils.py
new file mode 100644
index 00000000..0a4b4a1b
--- /dev/null
+++ b/nonebot/adapters/mirai/utils.py
@@ -0,0 +1,89 @@
+from functools import wraps
+from typing import Callable, Coroutine, TypeVar
+
+import httpx
+from pydantic import ValidationError, validate_arguments, Extra
+
+import nonebot.exception as exception
+from nonebot.log import logger
+from nonebot.utils import escape_tag
+
+_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
+_AnyCallable = TypeVar("_AnyCallable", bound=Callable)
+
+
+class ActionFailed(exception.ActionFailed):
+ """
+ :说明:
+
+ API 请求成功返回数据,但 API 操作失败。
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__('mirai')
+ self.data = kwargs.copy()
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(%s)' % ', '.join(
+ map(lambda m: '%s=%r' % m, self.data.items()))
+
+
+class InvalidArgument(exception.AdapterException):
+ """
+ :说明:
+
+ 调用API的参数出错
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__('mirai')
+
+
+def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
+ """
+ :说明:
+
+ 捕捉函数抛出的httpx网络异常并释放``NetworkError``异常
+ 处理返回数据, 在code不为0时释放``ActionFailed``异常
+
+ \:\:\: warning
+ 此装饰器只支持使用了httpx的异步函数
+ \:\:\:
+ """
+
+ @wraps(function)
+ async def wrapper(*args, **kwargs):
+ try:
+ data = await function(*args, **kwargs)
+ except httpx.HTTPError:
+ raise exception.NetworkError('mirai')
+ logger.opt(colors=True).debug('Mirai API returned data: '
+ f'{escape_tag(str(data))}')
+ if isinstance(data, dict):
+ if data.get('code', 0) != 0:
+ raise ActionFailed(**data)
+ return data
+
+ return wrapper # type: ignore
+
+
+def argument_validation(function: _AnyCallable) -> _AnyCallable:
+ """
+ :说明:
+
+ 通过函数签名中的类型注解来对传入参数进行运行时校验
+ 会在参数出错时释放``InvalidArgument``异常
+ """
+ function = validate_arguments(config={
+ 'arbitrary_types_allowed': True,
+ 'extra': Extra.forbid
+ })(function)
+
+ @wraps(function)
+ def wrapper(*args, **kwargs):
+ try:
+ return function(*args, **kwargs)
+ except ValidationError:
+ raise InvalidArgument
+
+ return wrapper # type: ignore