🚧 add support of reverse post and forward ws for mirai adapter

This commit is contained in:
Mix
2021-01-31 16:02:59 +08:00
parent 73be9151b0
commit 3f56da9245
5 changed files with 57 additions and 90 deletions

View File

@ -1,3 +1,4 @@
from .bot import MiraiBot from .bot import MiraiBot
from .bot_ws import MiraiWebsocketBot
from .event import * from .event import *
from .message import MessageChain, MessageSegment from .message import MessageChain, MessageSegment

View File

@ -1,19 +1,19 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from io import BytesIO from io import BytesIO
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Any, Dict, List, NoReturn, Optional, Tuple from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
import httpx import httpx
from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Bot as BaseBot
from nonebot.adapters import Event as BaseEvent
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket from nonebot.drivers import Driver, WebSocket
from nonebot.exception import RequestDenied
from nonebot.exception import ActionFailed as BaseActionFailed from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.exception import RequestDenied
from nonebot.log import logger from nonebot.log import logger
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import escape_tag
from .config import Config as MiraiConfig from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage from .event import Event, FriendMessage, GroupMessage, TempMessage
@ -41,7 +41,8 @@ class SessionManager:
@staticmethod @staticmethod
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]: def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
code = data.get('code', 0) code = data.get('code', 0)
logger.debug(f'Mirai API returned data: {data}') logger.opt(colors=True).debug('Mirai API returned data: '
f'<y>{escape_tag(str(data))}</y>')
if code != 0: if code != 0:
raise ActionFailed(code, message=data['msg']) raise ActionFailed(code, message=data['msg'])
return data return data
@ -85,10 +86,10 @@ class SessionManager:
@classmethod @classmethod
async def new(cls, self_id: int, *, host: IPv4Address, port: int, async def new(cls, self_id: int, *, host: IPv4Address, port: int,
auth_key: str): auth_key: str):
if self_id in cls.sessions: session = cls.get(self_id)
manager = cls.get(self_id) if session is not None:
if manager is not None: return session
return manager
client = httpx.AsyncClient(base_url=f'http://{host}:{port}') client = httpx.AsyncClient(base_url=f'http://{host}:{port}')
response = await client.post('/auth', json={'authKey': auth_key}) response = await client.post('/auth', json={'authKey': auth_key})
response.raise_for_status() response.raise_for_status()
@ -102,10 +103,13 @@ class SessionManager:
}) })
assert response.json()['code'] == 0 assert response.json()['code'] == 0
cls.sessions[self_id] = session_key, datetime.now(), client cls.sessions[self_id] = session_key, datetime.now(), client
return cls(session_key, client) return cls(session_key, client)
@classmethod @classmethod
def get(cls, self_id: int): def get(cls, self_id: int):
if self_id not in cls.sessions:
return None
key, time, client = cls.sessions[self_id] key, time, client = cls.sessions[self_id]
if datetime.now() - time > cls.session_expiry: if datetime.now() - time > cls.session_expiry:
return None return None
@ -114,6 +118,7 @@ class SessionManager:
class MiraiBot(BaseBot): class MiraiBot(BaseBot):
@overrides(BaseBot)
def __init__(self, def __init__(self,
connection_type: str, connection_type: str,
self_id: str, self_id: str,
@ -179,17 +184,20 @@ class MiraiBot(BaseBot):
@overrides(BaseBot) @overrides(BaseBot)
async def send(self, async def send(self,
event: Event, event: Event,
message: MessageChain, message: Union[MessageChain, MessageSegment, str],
at_sender: bool = False, at_sender: bool = False):
**kwargs): if isinstance(message, MessageSegment):
message = MessageChain(message)
elif isinstance(message, str):
message = MessageChain(MessageSegment.plain(message))
if isinstance(event, FriendMessage): if isinstance(event, FriendMessage):
return await self.send_friend_message(target=event.sender.id, return await self.send_friend_message(target=event.sender.id,
message_chain=message) message_chain=message)
elif isinstance(event, GroupMessage): elif isinstance(event, GroupMessage):
return await self.send_group_message( if at_sender:
group=event.sender.group.id, message = MessageSegment.at(event.sender.id) + message
message_chain=message if not at_sender else return await self.send_group_message(group=event.sender.group.id,
(MessageSegment.at(target=event.sender.id) + message)) message_chain=message)
elif isinstance(event, TempMessage): elif isinstance(event, TempMessage):
return await self.send_temp_message(qq=event.sender.id, return await self.send_temp_message(qq=event.sender.id,
group=event.sender.group.id, group=event.sender.group.id,

View File

@ -7,50 +7,21 @@ from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set,
import httpx import httpx
import websockets import websockets
from nonebot.adapters import Bot as BaseBot
from nonebot.adapters import Event as BaseEvent
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import Driver from nonebot.drivers import Driver
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied from nonebot.exception import RequestDenied
from nonebot.log import logger from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides from nonebot.typing import overrides
from .bot import MiraiBot, SessionManager
from .config import Config as MiraiConfig from .config import Config as MiraiConfig
from .event import Event
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]] WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
WebsocketHandler_T = TypeVar('WebsocketHandler_T', WebsocketHandler_T = TypeVar('WebsocketHandler_T',
bound=WebsocketHandlerFunction) bound=WebsocketHandlerFunction)
async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str,
qq: int) -> str:
async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]:
response = await client.request(method, path, **kwargs)
response.raise_for_status()
return response.json()
about = await request('GET', path='/about')
logger.opt(colors=True).debug('Mirai API HTTP backend version: '
f'<g><b>{about["data"]["version"]}</b></g>')
status = await request('POST', path='/auth', json={'authKey': auth_key})
assert status['code'] == 0
session_key = status['session']
verify = await request('POST',
path='/verify',
json={
'sessionKey': session_key,
'qq': qq
})
assert verify['code'] == 0, verify['msg']
return session_key
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):
@classmethod @classmethod
@ -59,6 +30,7 @@ class WebSocket(BaseWebSocket):
listen_address = httpx.URL(f'ws://{host}:{port}/all', listen_address = httpx.URL(f'ws://{host}:{port}/all',
params={'sessionKey': session_key}) params={'sessionKey': session_key})
websocket = await websockets.connect(uri=str(listen_address)) websocket = await websockets.connect(uri=str(listen_address))
await (await websocket.ping())
return cls(websocket) return cls(websocket)
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
@ -116,25 +88,24 @@ class WebSocket(BaseWebSocket):
return callable return callable
class MiraiWebsocketBot(BaseBot): class MiraiWebsocketBot(MiraiBot):
@overrides(MiraiBot)
def __init__(self, connection_type: str, self_id: str, *, def __init__(self, connection_type: str, self_id: str, *,
websocket: WebSocket): websocket: WebSocket):
super().__init__(connection_type, self_id, websocket=websocket) super().__init__(connection_type, self_id, websocket=websocket)
websocket.handle(self.handle_message)
self.driver._bot_connect(self)
@property @property
@overrides(BaseBot) @overrides(MiraiBot)
def type(self) -> str: def type(self) -> str:
return "mirai" return "mirai-ws"
@property @property
def alive(self) -> bool: def alive(self) -> bool:
return not self.websocket.closed return not self.websocket.closed
@classmethod @classmethod
@overrides(BaseBot) @overrides(MiraiBot)
async def check_permission(cls, driver: "Driver", connection_type: str, async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[dict]) -> NoReturn: headers: dict, body: Optional[dict]) -> NoReturn:
raise RequestDenied( raise RequestDenied(
@ -142,7 +113,7 @@ class MiraiWebsocketBot(BaseBot):
reason=f'Connection {connection_type} not implented') reason=f'Connection {connection_type} not implented')
@classmethod @classmethod
@overrides(BaseBot) @overrides(MiraiBot)
def register(cls, driver: "Driver", config: "Config", qq: int): def register(cls, driver: "Driver", config: "Config", qq: int):
cls.mirai_config = MiraiConfig(**config.dict()) cls.mirai_config = MiraiConfig(**config.dict())
cls.active = True cls.active = True
@ -152,32 +123,33 @@ class MiraiWebsocketBot(BaseBot):
super().register(driver, config) super().register(driver, config)
async def _bot_connection(): async def _bot_connection():
async with httpx.AsyncClient( session: SessionManager = await SessionManager.new(
base_url= qq,
f'http://{cls.mirai_config.host}:{cls.mirai_config.port}' host=cls.mirai_config.host, # type: ignore
) as client: port=cls.mirai_config.port, # type: ignore
session_key = await _ws_authorization( auth_key=cls.mirai_config.auth_key # type: ignore
client, )
auth_key=cls.mirai_config.auth_key, # type: ignore
qq=qq) # type: ignore
websocket = await WebSocket.new( websocket = await WebSocket.new(
host=cls.mirai_config.host, # type: ignore host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, # type: ignore port=cls.mirai_config.port, # type: ignore
session_key=session_key) session_key=session.session_key)
bot = cls(connection_type='forward_ws', bot = cls(connection_type='forward_ws',
self_id=str(qq), self_id=str(qq),
websocket=websocket) websocket=websocket)
websocket.handle(bot.handle_message) websocket.handle(bot.handle_message)
driver._clients[str(qq)] = bot
await websocket.accept() await websocket.accept()
return bot
async def _connection_ensure(): async def _connection_ensure():
if str(qq) not in driver._clients: self_id = str(qq)
await _bot_connection() if self_id not in driver._clients:
elif not driver._clients[str(qq)].alive: bot = await _bot_connection()
driver._clients.pop(str(qq), None) driver._bot_connect(bot)
await _bot_connection() else:
bot = driver._clients[self_id]
if not bot.alive:
driver._bot_disconnect(bot)
return
@driver.on_startup @driver.on_startup
async def _startup(): async def _startup():
@ -202,19 +174,3 @@ class MiraiWebsocketBot(BaseBot):
if bot is None: if bot is None:
return return
await bot.websocket.close() #type:ignore await bot.websocket.close() #type:ignore
@overrides(BaseBot)
async def handle_message(self, message: dict):
event = Event.new(message)
await handle_event(self, event)
@overrides(BaseBot)
async def call_api(self, api: str, **data):
return super().call_api(api, **data)
@overrides(BaseBot)
async def send(self, event: "BaseEvent", message: str, **kwargs):
return super().send(event, message, **kwargs)
def __del__(self):
self.driver._bot_disconnect(self)

View File

@ -86,7 +86,7 @@ class Event(BaseEvent):
@overrides(BaseEvent) @overrides(BaseEvent)
def get_event_description(self) -> str: def get_event_description(self) -> str:
return str(self.dict()) return str(self.normalize_dict())
@overrides(BaseEvent) @overrides(BaseEvent)
def get_message(self) -> BaseMessage: def get_message(self) -> BaseMessage:

View File

@ -135,10 +135,11 @@ class MessageSegment(BaseMessageSegment):
return cls(type=MessageType.POKE, name=name) return cls(type=MessageType.POKE, name=name)
class MessageChain(BaseMessage): class MessageChain(BaseMessage): #type:List[MessageSegment]
@overrides(BaseMessage) @overrides(BaseMessage)
def __init__(self, message: Union[List[Dict[str, Any]], MessageSegment], def __init__(self, message: Union[List[Dict[str, Any]],
Iterable[MessageSegment], MessageSegment],
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if isinstance(message, MessageSegment): if isinstance(message, MessageSegment):
@ -152,15 +153,16 @@ class MessageChain(BaseMessage):
@overrides(BaseMessage) @overrides(BaseMessage)
def _construct( def _construct(
self, message: Iterable[Union[Dict[str, Any], MessageSegment]] self, message: Union[List[Dict[str, Any]], Iterable[MessageSegment]]
) -> List[MessageSegment]: ) -> List[MessageSegment]:
if isinstance(message, str): if isinstance(message, str):
raise ValueError( raise ValueError(
"String operation is not supported in mirai adapter") "String operation is not supported in mirai adapter")
return [ return [
*map( *map(
lambda segment: segment if isinstance(segment, MessageSegment) lambda x: x
else MessageSegment(**segment), message) if isinstance(x, MessageSegment) else MessageSegment(**x),
message)
] ]
def export(self) -> List[Dict[str, Any]]: def export(self) -> List[Dict[str, Any]]: