mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-28 08:41:29 +00:00
💡 🚸 complete comments and optimize usage in mirai adapter
This commit is contained in:
@ -1,34 +1,23 @@
|
||||
from datetime import datetime, timedelta
|
||||
from functools import wraps
|
||||
from io import BytesIO
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
|
||||
from typing import (Any, Dict, List, NoReturn, Optional, Tuple, Union)
|
||||
|
||||
import httpx
|
||||
|
||||
from nonebot.adapters import Bot as BaseBot
|
||||
from nonebot.config import Config
|
||||
from nonebot.drivers import Driver, WebSocket
|
||||
from nonebot.exception import ActionFailed as BaseActionFailed
|
||||
from nonebot.exception import RequestDenied
|
||||
from nonebot.exception import ApiNotAvailable, RequestDenied
|
||||
from nonebot.log import logger
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
|
||||
from .config import Config as MiraiConfig
|
||||
from .event import Event, FriendMessage, GroupMessage, TempMessage
|
||||
from .message import MessageChain, MessageSegment
|
||||
|
||||
|
||||
class ActionFailed(BaseActionFailed):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__('mirai')
|
||||
self.data = kwargs.copy()
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(%s)' % ', '.join(
|
||||
map(lambda m: '%s=%r' % m, self.data.items()))
|
||||
from .utils import catch_network_error, argument_validation
|
||||
|
||||
|
||||
class SessionManager:
|
||||
@ -39,19 +28,11 @@ class SessionManager:
|
||||
def __init__(self, session_key: str, client: httpx.AsyncClient):
|
||||
self.session_key, self.client = session_key, client
|
||||
|
||||
@staticmethod
|
||||
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
logger.opt(colors=True).debug(
|
||||
f'Mirai API returned data: <y>{escape_tag(str(data))}</y>')
|
||||
if isinstance(data, dict) and ('code' in data):
|
||||
if data['code'] != 0:
|
||||
raise ActionFailed(**data)
|
||||
return data
|
||||
|
||||
@catch_network_error
|
||||
async def post(self,
|
||||
path: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -75,13 +56,13 @@ class SessionManager:
|
||||
timeout=3,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return self._raise_code(response.json())
|
||||
return response.json()
|
||||
|
||||
@catch_network_error
|
||||
async def request(self,
|
||||
path: str,
|
||||
*,
|
||||
params: Optional[Dict[str,
|
||||
Any]] = None) -> Dict[str, Any]:
|
||||
params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -91,10 +72,6 @@ class SessionManager:
|
||||
|
||||
* ``path: str``: 对应API路径
|
||||
* ``params: Optional[Dict[str, Any]]``: 请求参数 (无需sessionKey)
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Dict[str, Any]``: API 返回值
|
||||
"""
|
||||
response = await self.client.get(
|
||||
path,
|
||||
@ -105,25 +82,34 @@ class SessionManager:
|
||||
timeout=3,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return self._raise_code(response.json())
|
||||
return response.json()
|
||||
|
||||
async def upload(self, path: str, *, type: str,
|
||||
file: Tuple[str, BytesIO]) -> Dict[str, Any]:
|
||||
@catch_network_error
|
||||
async def upload(self, path: str, *, params: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
file_type, file_io = file
|
||||
response = await self.client.post(path,
|
||||
data={
|
||||
'sessionKey': self.session_key,
|
||||
'type': type
|
||||
},
|
||||
files={file_type: file_io},
|
||||
timeout=6)
|
||||
以表单(``multipart/form-data``)形式主动提交API请求
|
||||
|
||||
:参数:
|
||||
|
||||
* ``path: str``: 对应API路径
|
||||
* ``params: Dict[str, Any]``: 请求参数 (无需sessionKey)
|
||||
"""
|
||||
files = {k: v for k, v in params.items() if isinstance(v, BytesIO)}
|
||||
form = {k: v for k, v in params.items() if k not in files}
|
||||
response = await self.client.post(
|
||||
path,
|
||||
data=form,
|
||||
files=files,
|
||||
timeout=6,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return self._raise_code(response.json())
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
|
||||
auth_key: str):
|
||||
auth_key: str) -> "SessionManager":
|
||||
session = cls.get(self_id)
|
||||
if session is not None:
|
||||
return session
|
||||
@ -145,7 +131,9 @@ class SessionManager:
|
||||
return cls(session_key, client)
|
||||
|
||||
@classmethod
|
||||
def get(cls, self_id: int, check_expire: bool = True):
|
||||
def get(cls,
|
||||
self_id: int,
|
||||
check_expire: bool = True) -> Optional["SessionManager"]:
|
||||
if self_id not in cls.sessions:
|
||||
return None
|
||||
key, time, client = cls.sessions[self_id]
|
||||
@ -157,6 +145,13 @@ class SessionManager:
|
||||
class MiraiBot(BaseBot):
|
||||
"""
|
||||
mirai-api-http 协议 Bot 适配。
|
||||
|
||||
\:\:\: warning
|
||||
API中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名
|
||||
|
||||
部分字段可能与文档在符号上不一致
|
||||
\:\:\:
|
||||
|
||||
"""
|
||||
|
||||
@overrides(BaseBot)
|
||||
@ -207,9 +202,9 @@ class MiraiBot(BaseBot):
|
||||
@overrides(BaseBot)
|
||||
def register(cls, driver: "Driver", config: "Config"):
|
||||
cls.mirai_config = MiraiConfig(**config.dict())
|
||||
assert cls.mirai_config.auth_key is not None
|
||||
assert cls.mirai_config.host is not None
|
||||
assert cls.mirai_config.port is not None
|
||||
if (cls.mirai_config.auth_key and cls.mirai_config.host and
|
||||
cls.mirai_config.port) is None:
|
||||
raise ApiNotAvailable('mirai')
|
||||
super().register(driver, config)
|
||||
|
||||
@overrides(BaseBot)
|
||||
@ -222,7 +217,12 @@ class MiraiBot(BaseBot):
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def call_api(self, api: str, **data) -> NoReturn:
|
||||
"""由于Mirai的HTTP API特殊性, 该API暂时无法实现"""
|
||||
"""
|
||||
由于Mirai的HTTP API特殊性, 该API暂时无法实现
|
||||
\:\:\: tip
|
||||
你可以使用 ``MiraiBot.api`` 中提供的调用方法来代替
|
||||
\:\:\:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@overrides(BaseBot)
|
||||
@ -231,6 +231,7 @@ class MiraiBot(BaseBot):
|
||||
raise NotImplementedError
|
||||
|
||||
@overrides(BaseBot)
|
||||
@argument_validation
|
||||
async def send(self,
|
||||
event: Event,
|
||||
message: Union[MessageChain, MessageSegment, str],
|
||||
@ -245,10 +246,6 @@ class MiraiBot(BaseBot):
|
||||
* ``event: Event``: Event对象
|
||||
* ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息
|
||||
* ``at_sender: bool``: 是否 @ 事件主题
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
if isinstance(message, MessageSegment):
|
||||
message = MessageChain(message)
|
||||
@ -269,6 +266,7 @@ class MiraiBot(BaseBot):
|
||||
else:
|
||||
raise ValueError(f'Unsupported event type {event!r}.')
|
||||
|
||||
@argument_validation
|
||||
async def send_friend_message(self, target: int,
|
||||
message_chain: MessageChain):
|
||||
"""
|
||||
@ -280,10 +278,6 @@ class MiraiBot(BaseBot):
|
||||
|
||||
* ``target: int``: 发送消息目标好友的 QQ 号
|
||||
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
return await self.api.post('sendFriendMessage',
|
||||
params={
|
||||
@ -291,6 +285,7 @@ class MiraiBot(BaseBot):
|
||||
'messageChain': message_chain.export()
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def send_temp_message(self, qq: int, group: int,
|
||||
message_chain: MessageChain):
|
||||
"""
|
||||
@ -303,10 +298,6 @@ class MiraiBot(BaseBot):
|
||||
* ``qq: int``: 临时会话对象 QQ 号
|
||||
* ``group: int``: 临时会话群号
|
||||
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
return await self.api.post('sendTempMessage',
|
||||
params={
|
||||
@ -315,6 +306,7 @@ class MiraiBot(BaseBot):
|
||||
'messageChain': message_chain.export()
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def send_group_message(self,
|
||||
group: int,
|
||||
message_chain: MessageChain,
|
||||
@ -329,10 +321,6 @@ class MiraiBot(BaseBot):
|
||||
* ``group: int``: 发送消息目标群的群号
|
||||
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
|
||||
* ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
return await self.api.post('sendGroupMessage',
|
||||
params={
|
||||
@ -341,6 +329,7 @@ class MiraiBot(BaseBot):
|
||||
'quote': quote
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def recall(self, target: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -350,13 +339,10 @@ class MiraiBot(BaseBot):
|
||||
:参数:
|
||||
|
||||
* ``target: int``: 需要撤回的消息的message_id
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
return await self.api.post('recall', params={'target': target})
|
||||
|
||||
@argument_validation
|
||||
async def send_image_message(self, target: int, qq: int, group: int,
|
||||
urls: List[str]) -> List[str]:
|
||||
"""
|
||||
@ -384,8 +370,9 @@ class MiraiBot(BaseBot):
|
||||
'qq': qq,
|
||||
'group': group,
|
||||
'urls': urls
|
||||
}) # type: ignore
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def upload_image(self, type: str, img: BytesIO):
|
||||
"""
|
||||
:说明:
|
||||
@ -396,15 +383,14 @@ class MiraiBot(BaseBot):
|
||||
|
||||
* ``type: str``: "friend" 或 "group" 或 "temp"
|
||||
* ``img: BytesIO``: 图片的BytesIO对象
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
return await self.api.upload('uploadImage',
|
||||
type=type,
|
||||
file=('img', img))
|
||||
params={
|
||||
'type': type,
|
||||
'img': img
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def upload_voice(self, type: str, voice: BytesIO):
|
||||
"""
|
||||
:说明:
|
||||
@ -415,15 +401,14 @@ class MiraiBot(BaseBot):
|
||||
|
||||
* ``type: str``: 当前仅支持 "group"
|
||||
* ``voice: BytesIO``: 语音的BytesIO对象
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
return await self.api.upload('uploadVoice',
|
||||
type=type,
|
||||
file=('voice', voice))
|
||||
params={
|
||||
'type': type,
|
||||
'voice': voice
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def fetch_message(self, count: int = 10):
|
||||
"""
|
||||
:说明:
|
||||
@ -437,6 +422,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.request('fetchMessage', params={'count': count})
|
||||
|
||||
@argument_validation
|
||||
async def fetch_latest_message(self, count: int = 10):
|
||||
"""
|
||||
:说明:
|
||||
@ -451,6 +437,7 @@ class MiraiBot(BaseBot):
|
||||
return await self.api.request('fetchLatestMessage',
|
||||
params={'count': count})
|
||||
|
||||
@argument_validation
|
||||
async def peek_message(self, count: int = 10):
|
||||
"""
|
||||
:说明:
|
||||
@ -464,6 +451,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.request('peekMessage', params={'count': count})
|
||||
|
||||
@argument_validation
|
||||
async def peek_latest_message(self, count: int = 10):
|
||||
"""
|
||||
:说明:
|
||||
@ -478,6 +466,7 @@ class MiraiBot(BaseBot):
|
||||
return await self.api.request('peekLatestMessage',
|
||||
params={'count': count})
|
||||
|
||||
@argument_validation
|
||||
async def messsage_from_id(self, id: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -491,6 +480,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.request('messageFromId', params={'id': id})
|
||||
|
||||
@argument_validation
|
||||
async def count_message(self):
|
||||
"""
|
||||
:说明:
|
||||
@ -499,6 +489,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.request('countMessage')
|
||||
|
||||
@argument_validation
|
||||
async def friend_list(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
:说明:
|
||||
@ -509,8 +500,9 @@ class MiraiBot(BaseBot):
|
||||
|
||||
- ``List[Dict[str, Any]]``: 返回的好友列表数据
|
||||
"""
|
||||
return await self.api.request('friendList') # type: ignore
|
||||
return await self.api.request('friendList')
|
||||
|
||||
@argument_validation
|
||||
async def group_list(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
:说明:
|
||||
@ -521,8 +513,9 @@ class MiraiBot(BaseBot):
|
||||
|
||||
- ``List[Dict[str, Any]]``: 返回的群列表数据
|
||||
"""
|
||||
return await self.api.request('groupList') # type: ignore
|
||||
return await self.api.request('groupList')
|
||||
|
||||
@argument_validation
|
||||
async def member_list(self, target: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
:说明:
|
||||
@ -537,9 +530,9 @@ class MiraiBot(BaseBot):
|
||||
|
||||
- ``List[Dict[str, Any]]``: 返回的群成员列表数据
|
||||
"""
|
||||
return await self.api.request('memberList',
|
||||
params={'target': target}) # type: ignore
|
||||
return await self.api.request('memberList', params={'target': target})
|
||||
|
||||
@argument_validation
|
||||
async def mute(self, target: int, member_id: int, time: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -559,6 +552,7 @@ class MiraiBot(BaseBot):
|
||||
'time': time
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def unmute(self, target: int, member_id: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -576,6 +570,7 @@ class MiraiBot(BaseBot):
|
||||
'memberId': member_id
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def kick(self, target: int, member_id: int, msg: str):
|
||||
"""
|
||||
:说明:
|
||||
@ -595,6 +590,7 @@ class MiraiBot(BaseBot):
|
||||
'msg': msg
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def quit(self, target: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -607,6 +603,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.post('quit', params={'target': target})
|
||||
|
||||
@argument_validation
|
||||
async def mute_all(self, target: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -619,6 +616,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.post('muteAll', params={'target': target})
|
||||
|
||||
@argument_validation
|
||||
async def unmute_all(self, target: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -631,6 +629,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.post('unmuteAll', params={'target': target})
|
||||
|
||||
@argument_validation
|
||||
async def group_config(self, target: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -656,6 +655,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.request('groupConfig', params={'target': target})
|
||||
|
||||
@argument_validation
|
||||
async def modify_group_config(self, target: int, config: Dict[str, Any]):
|
||||
"""
|
||||
:说明:
|
||||
@ -673,6 +673,7 @@ class MiraiBot(BaseBot):
|
||||
'config': config
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def member_info(self, target: int, member_id: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -699,6 +700,7 @@ class MiraiBot(BaseBot):
|
||||
'memberId': member_id
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def modify_member_info(self, target: int, member_id: int,
|
||||
info: Dict[str, Any]):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user