🚧 add support of reverse post and forward ws for mirai adapter

This commit is contained in:
Mix
2021-01-31 16:02:59 +08:00
parent 73be9151b0
commit 3f56da9245
5 changed files with 57 additions and 90 deletions

View File

@ -7,50 +7,21 @@ from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set,
import httpx
import websockets
from nonebot.adapters import Bot as BaseBot
from nonebot.adapters import Event as BaseEvent
from nonebot.config import Config
from nonebot.drivers import Driver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides
from .bot import MiraiBot, SessionManager
from .config import Config as MiraiConfig
from .event import Event
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
bound=WebsocketHandlerFunction)
async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str,
qq: int) -> str:
async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]:
response = await client.request(method, path, **kwargs)
response.raise_for_status()
return response.json()
about = await request('GET', path='/about')
logger.opt(colors=True).debug('Mirai API HTTP backend version: '
f'<g><b>{about["data"]["version"]}</b></g>')
status = await request('POST', path='/auth', json={'authKey': auth_key})
assert status['code'] == 0
session_key = status['session']
verify = await request('POST',
path='/verify',
json={
'sessionKey': session_key,
'qq': qq
})
assert verify['code'] == 0, verify['msg']
return session_key
class WebSocket(BaseWebSocket):
@classmethod
@ -59,6 +30,7 @@ class WebSocket(BaseWebSocket):
listen_address = httpx.URL(f'ws://{host}:{port}/all',
params={'sessionKey': session_key})
websocket = await websockets.connect(uri=str(listen_address))
await (await websocket.ping())
return cls(websocket)
@overrides(BaseWebSocket)
@ -116,25 +88,24 @@ class WebSocket(BaseWebSocket):
return callable
class MiraiWebsocketBot(BaseBot):
class MiraiWebsocketBot(MiraiBot):
@overrides(MiraiBot)
def __init__(self, connection_type: str, self_id: str, *,
websocket: WebSocket):
super().__init__(connection_type, self_id, websocket=websocket)
websocket.handle(self.handle_message)
self.driver._bot_connect(self)
@property
@overrides(BaseBot)
@overrides(MiraiBot)
def type(self) -> str:
return "mirai"
return "mirai-ws"
@property
def alive(self) -> bool:
return not self.websocket.closed
@classmethod
@overrides(BaseBot)
@overrides(MiraiBot)
async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[dict]) -> NoReturn:
raise RequestDenied(
@ -142,7 +113,7 @@ class MiraiWebsocketBot(BaseBot):
reason=f'Connection {connection_type} not implented')
@classmethod
@overrides(BaseBot)
@overrides(MiraiBot)
def register(cls, driver: "Driver", config: "Config", qq: int):
cls.mirai_config = MiraiConfig(**config.dict())
cls.active = True
@ -152,32 +123,33 @@ class MiraiWebsocketBot(BaseBot):
super().register(driver, config)
async def _bot_connection():
async with httpx.AsyncClient(
base_url=
f'http://{cls.mirai_config.host}:{cls.mirai_config.port}'
) as client:
session_key = await _ws_authorization(
client,
auth_key=cls.mirai_config.auth_key, # type: ignore
qq=qq) # type: ignore
session: SessionManager = await SessionManager.new(
qq,
host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, # type: ignore
auth_key=cls.mirai_config.auth_key # type: ignore
)
websocket = await WebSocket.new(
host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, # type: ignore
session_key=session_key)
session_key=session.session_key)
bot = cls(connection_type='forward_ws',
self_id=str(qq),
websocket=websocket)
websocket.handle(bot.handle_message)
driver._clients[str(qq)] = bot
await websocket.accept()
return bot
async def _connection_ensure():
if str(qq) not in driver._clients:
await _bot_connection()
elif not driver._clients[str(qq)].alive:
driver._clients.pop(str(qq), None)
await _bot_connection()
self_id = str(qq)
if self_id not in driver._clients:
bot = await _bot_connection()
driver._bot_connect(bot)
else:
bot = driver._clients[self_id]
if not bot.alive:
driver._bot_disconnect(bot)
return
@driver.on_startup
async def _startup():
@ -202,19 +174,3 @@ class MiraiWebsocketBot(BaseBot):
if bot is None:
return
await bot.websocket.close() #type:ignore
@overrides(BaseBot)
async def handle_message(self, message: dict):
event = Event.new(message)
await handle_event(self, event)
@overrides(BaseBot)
async def call_api(self, api: str, **data):
return super().call_api(api, **data)
@overrides(BaseBot)
async def send(self, event: "BaseEvent", message: str, **kwargs):
return super().send(event, message, **kwargs)
def __del__(self):
self.driver._bot_disconnect(self)