mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-28 08:41:29 +00:00
🚧 add support of reverse post and forward ws for mirai adapter
This commit is contained in:
@ -1,19 +1,19 @@
|
||||
from datetime import datetime, timedelta
|
||||
from io import BytesIO
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, NoReturn, Optional, Tuple
|
||||
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from nonebot.adapters import Bot as BaseBot
|
||||
from nonebot.adapters import Event as BaseEvent
|
||||
from nonebot.config import Config
|
||||
from nonebot.drivers import Driver, WebSocket
|
||||
from nonebot.exception import RequestDenied
|
||||
from nonebot.exception import ActionFailed as BaseActionFailed
|
||||
from nonebot.exception import RequestDenied
|
||||
from nonebot.log import logger
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
|
||||
from .config import Config as MiraiConfig
|
||||
from .event import Event, FriendMessage, GroupMessage, TempMessage
|
||||
@ -41,7 +41,8 @@ class SessionManager:
|
||||
@staticmethod
|
||||
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
code = data.get('code', 0)
|
||||
logger.debug(f'Mirai API returned data: {data}')
|
||||
logger.opt(colors=True).debug('Mirai API returned data: '
|
||||
f'<y>{escape_tag(str(data))}</y>')
|
||||
if code != 0:
|
||||
raise ActionFailed(code, message=data['msg'])
|
||||
return data
|
||||
@ -85,10 +86,10 @@ class SessionManager:
|
||||
@classmethod
|
||||
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
|
||||
auth_key: str):
|
||||
if self_id in cls.sessions:
|
||||
manager = cls.get(self_id)
|
||||
if manager is not None:
|
||||
return manager
|
||||
session = cls.get(self_id)
|
||||
if session is not None:
|
||||
return session
|
||||
|
||||
client = httpx.AsyncClient(base_url=f'http://{host}:{port}')
|
||||
response = await client.post('/auth', json={'authKey': auth_key})
|
||||
response.raise_for_status()
|
||||
@ -102,10 +103,13 @@ class SessionManager:
|
||||
})
|
||||
assert response.json()['code'] == 0
|
||||
cls.sessions[self_id] = session_key, datetime.now(), client
|
||||
|
||||
return cls(session_key, client)
|
||||
|
||||
@classmethod
|
||||
def get(cls, self_id: int):
|
||||
if self_id not in cls.sessions:
|
||||
return None
|
||||
key, time, client = cls.sessions[self_id]
|
||||
if datetime.now() - time > cls.session_expiry:
|
||||
return None
|
||||
@ -114,6 +118,7 @@ class SessionManager:
|
||||
|
||||
class MiraiBot(BaseBot):
|
||||
|
||||
@overrides(BaseBot)
|
||||
def __init__(self,
|
||||
connection_type: str,
|
||||
self_id: str,
|
||||
@ -179,17 +184,20 @@ class MiraiBot(BaseBot):
|
||||
@overrides(BaseBot)
|
||||
async def send(self,
|
||||
event: Event,
|
||||
message: MessageChain,
|
||||
at_sender: bool = False,
|
||||
**kwargs):
|
||||
message: Union[MessageChain, MessageSegment, str],
|
||||
at_sender: bool = False):
|
||||
if isinstance(message, MessageSegment):
|
||||
message = MessageChain(message)
|
||||
elif isinstance(message, str):
|
||||
message = MessageChain(MessageSegment.plain(message))
|
||||
if isinstance(event, FriendMessage):
|
||||
return await self.send_friend_message(target=event.sender.id,
|
||||
message_chain=message)
|
||||
elif isinstance(event, GroupMessage):
|
||||
return await self.send_group_message(
|
||||
group=event.sender.group.id,
|
||||
message_chain=message if not at_sender else
|
||||
(MessageSegment.at(target=event.sender.id) + message))
|
||||
if at_sender:
|
||||
message = MessageSegment.at(event.sender.id) + message
|
||||
return await self.send_group_message(group=event.sender.group.id,
|
||||
message_chain=message)
|
||||
elif isinstance(event, TempMessage):
|
||||
return await self.send_temp_message(qq=event.sender.id,
|
||||
group=event.sender.group.id,
|
||||
|
Reference in New Issue
Block a user