🚸 add factory classmethods in MessageSegment at mirai adapter

This commit is contained in:
Mix
2021-01-30 21:51:51 +08:00
parent 95f27824ee
commit 73be9151b0
3 changed files with 118 additions and 20 deletions

View File

@ -10,6 +10,7 @@ from nonebot.adapters import Event as BaseEvent
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
from nonebot.exception import RequestDenied
from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides
@ -19,6 +20,17 @@ from .event import Event, FriendMessage, GroupMessage, TempMessage
from .message import MessageChain, MessageSegment
class ActionFailed(BaseActionFailed):
def __init__(self, code: int, message: str = ''):
super().__init__('mirai')
self.code = code
self.message = message
def __repr__(self):
return f"{self.__class__.__name__}(code={self.code}, message={self.message!r})"
class SessionManager:
sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {}
session_expiry: timedelta = timedelta(minutes=15)
@ -26,14 +38,22 @@ class SessionManager:
def __init__(self, session_key: str, client: httpx.AsyncClient):
self.session_key, self.client = session_key, client
@staticmethod
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
code = data.get('code', 0)
logger.debug(f'Mirai API returned data: {data}')
if code != 0:
raise ActionFailed(code, message=data['msg'])
return data
async def post(self,
path: str,
*,
params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
params = {**(params or {}), 'sessionKey': self.session_key}
response = await self.client.post(path, json=params)
response = await self.client.post(path, json=params, timeout=3)
response.raise_for_status()
return response.json()
return self._raise_code(response.json())
async def request(self,
path: str,
@ -44,9 +64,10 @@ class SessionManager:
params={
**(params or {}), 'sessionKey':
self.session_key
})
},
timeout=3)
response.raise_for_status()
return response.json()
return self._raise_code(response.json())
async def upload(self, path: str, *, type: str,
file: Tuple[str, BytesIO]) -> Dict[str, Any]:
@ -59,7 +80,7 @@ class SessionManager:
files={file_type: file_io},
timeout=6)
response.raise_for_status()
return response.json()
return self._raise_code(response.json())
@classmethod
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
@ -152,7 +173,7 @@ class MiraiBot(BaseBot):
raise NotImplementedError
@overrides(BaseBot)
async def __getattr__(self, key: str) -> NoReturn:
def __getattr__(self, key: str) -> NoReturn:
raise NotImplementedError
@overrides(BaseBot)
@ -165,8 +186,10 @@ class MiraiBot(BaseBot):
return await self.send_friend_message(target=event.sender.id,
message_chain=message)
elif isinstance(event, GroupMessage):
return await self.send_group_message(target=event.sender.group.id,
message_chain=message)
return await self.send_group_message(
group=event.sender.group.id,
message_chain=message if not at_sender else
(MessageSegment.at(target=event.sender.id) + message))
elif isinstance(event, TempMessage):
return await self.send_temp_message(qq=event.sender.id,
group=event.sender.group.id,
@ -191,12 +214,15 @@ class MiraiBot(BaseBot):
'messageChain': message_chain.export()
})
async def send_group_message(self, target: int,
message_chain: MessageChain):
async def send_group_message(self,
group: int,
message_chain: MessageChain,
quote: Optional[int] = None):
return await self.api.post('sendGroupMessage',
params={
'target': target,
'messageChain': message_chain.export()
'group': group,
'messageChain': message_chain.export(),
'quote': quote
})
async def recall(self, target: int):