From c82ceefc8b503476fc41d1c5494a91209d2084af Mon Sep 17 00:00:00 2001 From: Mix Date: Sat, 30 Jan 2021 19:11:17 +0800 Subject: [PATCH] :rewind: revert call method to http post, add api handle --- nonebot/adapters/mirai/__init__.py | 4 +- nonebot/adapters/mirai/bot.py | 233 ++++++++------------------- nonebot/adapters/mirai/bot_ws.py | 220 +++++++++++++++++++++++++ nonebot/adapters/mirai/event/base.py | 1 + 4 files changed, 293 insertions(+), 165 deletions(-) create mode 100644 nonebot/adapters/mirai/bot_ws.py diff --git a/nonebot/adapters/mirai/__init__.py b/nonebot/adapters/mirai/__init__.py index c832d378..991f30fd 100644 --- a/nonebot/adapters/mirai/__init__.py +++ b/nonebot/adapters/mirai/__init__.py @@ -1 +1,3 @@ -from .bot import MiraiBot \ No newline at end of file +from .bot import MiraiBot +from .event import * +from .message import MessageChain, MessageSegment diff --git a/nonebot/adapters/mirai/bot.py b/nonebot/adapters/mirai/bot.py index 08f6ea4e..338dd144 100644 --- a/nonebot/adapters/mirai/bot.py +++ b/nonebot/adapters/mirai/bot.py @@ -1,128 +1,74 @@ -import asyncio -import json +from typing import Any, Dict, Optional, Tuple +from datetime import datetime, timedelta from ipaddress import IPv4Address -from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set, - TypeVar) import httpx -import websockets from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Event as BaseEvent from nonebot.config import Config -from nonebot.drivers import Driver -from nonebot.drivers import WebSocket as BaseWebSocket +from nonebot.drivers import Driver, WebSocket from nonebot.exception import RequestDenied from nonebot.log import logger from nonebot.message import handle_event from nonebot.typing import overrides from .config import Config as MiraiConfig -from .event import Event - -WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]] -WebsocketHandler_T = TypeVar('WebsocketHandler_T', - bound=WebsocketHandlerFunction) +from .event import Event, FriendMessage, TempMessage, GroupMessage -async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str, - qq: int) -> str: +class SessionManager: + sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {} + session_expiry: timedelta = timedelta(minutes=15) - async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]: - response = await client.request(method, path, **kwargs) + def __init__(self, session_key: str, client: httpx.AsyncClient): + self.session_key, self.client = session_key, client + + async def post(self, path: str, *, params: Optional[Dict[str, Any]] = None): + params = {**(params or {}), 'sessionKey': self.session_key} + response = await self.client.post(path, json=params) response.raise_for_status() return response.json() - about = await request('GET', path='/about') - logger.opt(colors=True).debug('Mirai API HTTP backend version: ' - f'{about["data"]["version"]}') - - status = await request('POST', path='/auth', json={'authKey': auth_key}) - assert status['code'] == 0 - session_key = status['session'] - - verify = await request('POST', - path='/verify', - json={ - 'sessionKey': session_key, - 'qq': qq - }) - assert verify['code'] == 0, verify['msg'] - return session_key - - -class WebSocket(BaseWebSocket): + @classmethod + async def new(cls, self_id: int, *, host: IPv4Address, port: int, + auth_key: str): + if self_id in cls.sessions: + manager = cls.get(self_id) + if manager is not None: + return manager + client = httpx.AsyncClient(base_url=f'http://{host}:{port}') + response = await client.post('/auth', json={'authKey': auth_key}) + response.raise_for_status() + auth = response.json() + assert auth['code'] == 0 + session_key = auth['session'] + response = await client.post('/verify', + json={ + 'sessionKey': session_key, + 'qq': self_id + }) + assert response.json()['code'] == 0 + cls.sessions[self_id] = session_key, datetime.now(), client + return cls(session_key, client) @classmethod - async def new(cls, *, host: IPv4Address, port: int, - session_key: str) -> "WebSocket": - listen_address = httpx.URL(f'ws://{host}:{port}/all', - params={'sessionKey': session_key}) - websocket = await websockets.connect(uri=str(listen_address)) - return cls(websocket) - - @overrides(BaseWebSocket) - def __init__(self, websocket: websockets.WebSocketClientProtocol): - self.event_handlers: Set[WebsocketHandlerFunction] = set() - super().__init__(websocket) - - @property - @overrides(BaseWebSocket) - def websocket(self) -> websockets.WebSocketClientProtocol: - return self._websocket - - @property - @overrides(BaseWebSocket) - def closed(self) -> bool: - return self.websocket.closed - - @overrides(BaseWebSocket) - async def send(self, data: Dict[str, Any]): - return await self.websocket.send(json.dumps(data)) - - @overrides(BaseWebSocket) - async def receive(self) -> Dict[str, Any]: - received = await self.websocket.recv() - return json.loads(received) - - async def _dispatcher(self): - while not self.closed: - try: - data = await self.receive() - except websockets.ConnectionClosedOK: - logger.debug(f'Websocket connection {self.websocket} closed') - break - except websockets.ConnectionClosedError: - logger.exception(f'Websocket connection {self.websocket} ' - 'connection closed abnormally:') - break - except json.JSONDecodeError as e: - logger.exception(f'Websocket client listened {self.websocket} ' - f'failed to decode data: {e}') - continue - asyncio.gather(*map(lambda f: f(data), self.event_handlers), - return_exceptions=True) - - @overrides(BaseWebSocket) - async def accept(self): - asyncio.create_task(self._dispatcher()) - - @overrides(BaseWebSocket) - async def close(self): - await self.websocket.close() - - def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T: - self.event_handlers.add(callable) - return callable + def get(cls, self_id: int): + key, time, client = cls.sessions[self_id] + if datetime.now() - time > cls.session_expiry: + return None + return cls(key, client) class MiraiBot(BaseBot): - def __init__(self, connection_type: str, self_id: str, *, - websocket: WebSocket): + def __init__(self, + connection_type: str, + self_id: str, + *, + websocket: Optional[WebSocket] = None): super().__init__(connection_type, self_id, websocket=websocket) - websocket.handle(self.handle_message) - self.driver._bot_connect(self) + self.api = SessionManager.get(int(self_id)) @property @overrides(BaseBot) @@ -136,85 +82,44 @@ class MiraiBot(BaseBot): @classmethod @overrides(BaseBot) async def check_permission(cls, driver: "Driver", connection_type: str, - headers: dict, body: Optional[dict]) -> NoReturn: - raise RequestDenied( - status_code=501, - reason=f'Connection {connection_type} not implented') + headers: dict, body: Optional[dict]) -> str: + if connection_type == 'ws': + raise RequestDenied( + status_code=501, + reason='Websocket connection is not implemented') + self_id: Optional[str] = headers.get('bot') + if self_id is None: + raise RequestDenied(status_code=400, + reason='Header `Bot` is required.') + self_id = str(self_id).strip() + await SessionManager.new( + int(self_id), + host=cls.mirai_config.host, # type: ignore + port=cls.mirai_config.port, #type: ignore + auth_key=cls.mirai_config.auth_key) # type: ignore + return self_id @classmethod @overrides(BaseBot) - def register(cls, driver: "Driver", config: "Config", qq: int): + def register(cls, driver: "Driver", config: "Config"): cls.mirai_config = MiraiConfig(**config.dict()) - cls.active = True assert cls.mirai_config.auth_key is not None assert cls.mirai_config.host is not None assert cls.mirai_config.port is not None super().register(driver, config) - async def _bot_connection(): - async with httpx.AsyncClient( - base_url= - f'http://{cls.mirai_config.host}:{cls.mirai_config.port}' - ) as client: - session_key = await _ws_authorization( - client, - auth_key=cls.mirai_config.auth_key, # type: ignore - qq=qq) # type: ignore - - websocket = await WebSocket.new( - host=cls.mirai_config.host, # type: ignore - port=cls.mirai_config.port, # type: ignore - session_key=session_key) - bot = cls(connection_type='forward_ws', - self_id=str(qq), - websocket=websocket) - websocket.handle(bot.handle_message) - driver._clients[str(qq)] = bot - await websocket.accept() - - async def _connection_ensure(): - if str(qq) not in driver._clients: - await _bot_connection() - elif not driver._clients[str(qq)].alive: - driver._clients.pop(str(qq), None) - await _bot_connection() - - @driver.on_startup - async def _startup(): - - async def _checker(): - while cls.active: - try: - await _connection_ensure() - except Exception as e: - logger.opt(colors=True).warning( - 'Failed to create mirai connection to ' - f'{qq}, reason: {e}. ' - 'Will retry after 3 seconds') - await asyncio.sleep(3) - - asyncio.create_task(_checker()) - - @driver.on_shutdown - async def _shutdown(): - cls.active = False - bot = driver._clients.pop(str(qq), None) - if bot is None: - return - await bot.websocket.close() #type:ignore - @overrides(BaseBot) async def handle_message(self, message: dict): - event = Event.new(message) - await handle_event(self, event) + await handle_event(bot=self, + event=Event.new({ + **message, + 'self_id': self.self_id, + })) @overrides(BaseBot) async def call_api(self, api: str, **data): - return super().call_api(api, **data) + return await self.api.post('/' + api, params=data) @overrides(BaseBot) async def send(self, event: "BaseEvent", message: str, **kwargs): - return super().send(event, message, **kwargs) - - def __del__(self): - self.driver._bot_disconnect(self) + pass diff --git a/nonebot/adapters/mirai/bot_ws.py b/nonebot/adapters/mirai/bot_ws.py new file mode 100644 index 00000000..d9803c47 --- /dev/null +++ b/nonebot/adapters/mirai/bot_ws.py @@ -0,0 +1,220 @@ +import asyncio +import json +from ipaddress import IPv4Address +from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set, + TypeVar) + +import httpx +import websockets + +from nonebot.adapters import Bot as BaseBot +from nonebot.adapters import Event as BaseEvent +from nonebot.config import Config +from nonebot.drivers import Driver +from nonebot.drivers import WebSocket as BaseWebSocket +from nonebot.exception import RequestDenied +from nonebot.log import logger +from nonebot.message import handle_event +from nonebot.typing import overrides + +from .config import Config as MiraiConfig +from .event import Event + +WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]] +WebsocketHandler_T = TypeVar('WebsocketHandler_T', + bound=WebsocketHandlerFunction) + + +async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str, + qq: int) -> str: + + async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]: + response = await client.request(method, path, **kwargs) + response.raise_for_status() + return response.json() + + about = await request('GET', path='/about') + logger.opt(colors=True).debug('Mirai API HTTP backend version: ' + f'{about["data"]["version"]}') + + status = await request('POST', path='/auth', json={'authKey': auth_key}) + assert status['code'] == 0 + session_key = status['session'] + + verify = await request('POST', + path='/verify', + json={ + 'sessionKey': session_key, + 'qq': qq + }) + assert verify['code'] == 0, verify['msg'] + return session_key + + +class WebSocket(BaseWebSocket): + + @classmethod + async def new(cls, *, host: IPv4Address, port: int, + session_key: str) -> "WebSocket": + listen_address = httpx.URL(f'ws://{host}:{port}/all', + params={'sessionKey': session_key}) + websocket = await websockets.connect(uri=str(listen_address)) + return cls(websocket) + + @overrides(BaseWebSocket) + def __init__(self, websocket: websockets.WebSocketClientProtocol): + self.event_handlers: Set[WebsocketHandlerFunction] = set() + super().__init__(websocket) + + @property + @overrides(BaseWebSocket) + def websocket(self) -> websockets.WebSocketClientProtocol: + return self._websocket + + @property + @overrides(BaseWebSocket) + def closed(self) -> bool: + return self.websocket.closed + + @overrides(BaseWebSocket) + async def send(self, data: Dict[str, Any]): + return await self.websocket.send(json.dumps(data)) + + @overrides(BaseWebSocket) + async def receive(self) -> Dict[str, Any]: + received = await self.websocket.recv() + return json.loads(received) + + async def _dispatcher(self): + while not self.closed: + try: + data = await self.receive() + except websockets.ConnectionClosedOK: + logger.debug(f'Websocket connection {self.websocket} closed') + break + except websockets.ConnectionClosedError: + logger.exception(f'Websocket connection {self.websocket} ' + 'connection closed abnormally:') + break + except json.JSONDecodeError as e: + logger.exception(f'Websocket client listened {self.websocket} ' + f'failed to decode data: {e}') + continue + asyncio.gather(*map(lambda f: f(data), self.event_handlers), + return_exceptions=True) + + @overrides(BaseWebSocket) + async def accept(self): + asyncio.create_task(self._dispatcher()) + + @overrides(BaseWebSocket) + async def close(self): + await self.websocket.close() + + def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T: + self.event_handlers.add(callable) + return callable + + +class MiraiWebsocketBot(BaseBot): + + def __init__(self, connection_type: str, self_id: str, *, + websocket: WebSocket): + super().__init__(connection_type, self_id, websocket=websocket) + websocket.handle(self.handle_message) + self.driver._bot_connect(self) + + @property + @overrides(BaseBot) + def type(self) -> str: + return "mirai" + + @property + def alive(self) -> bool: + return not self.websocket.closed + + @classmethod + @overrides(BaseBot) + async def check_permission(cls, driver: "Driver", connection_type: str, + headers: dict, body: Optional[dict]) -> NoReturn: + raise RequestDenied( + status_code=501, + reason=f'Connection {connection_type} not implented') + + @classmethod + @overrides(BaseBot) + def register(cls, driver: "Driver", config: "Config", qq: int): + cls.mirai_config = MiraiConfig(**config.dict()) + cls.active = True + assert cls.mirai_config.auth_key is not None + assert cls.mirai_config.host is not None + assert cls.mirai_config.port is not None + super().register(driver, config) + + async def _bot_connection(): + async with httpx.AsyncClient( + base_url= + f'http://{cls.mirai_config.host}:{cls.mirai_config.port}' + ) as client: + session_key = await _ws_authorization( + client, + auth_key=cls.mirai_config.auth_key, # type: ignore + qq=qq) # type: ignore + + websocket = await WebSocket.new( + host=cls.mirai_config.host, # type: ignore + port=cls.mirai_config.port, # type: ignore + session_key=session_key) + bot = cls(connection_type='forward_ws', + self_id=str(qq), + websocket=websocket) + websocket.handle(bot.handle_message) + driver._clients[str(qq)] = bot + await websocket.accept() + + async def _connection_ensure(): + if str(qq) not in driver._clients: + await _bot_connection() + elif not driver._clients[str(qq)].alive: + driver._clients.pop(str(qq), None) + await _bot_connection() + + @driver.on_startup + async def _startup(): + + async def _checker(): + while cls.active: + try: + await _connection_ensure() + except Exception as e: + logger.opt(colors=True).warning( + 'Failed to create mirai connection to ' + f'{qq}, reason: {e}. ' + 'Will retry after 3 seconds') + await asyncio.sleep(3) + + asyncio.create_task(_checker()) + + @driver.on_shutdown + async def _shutdown(): + cls.active = False + bot = driver._clients.pop(str(qq), None) + if bot is None: + return + await bot.websocket.close() #type:ignore + + @overrides(BaseBot) + async def handle_message(self, message: dict): + event = Event.new(message) + await handle_event(self, event) + + @overrides(BaseBot) + async def call_api(self, api: str, **data): + return super().call_api(api, **data) + + @overrides(BaseBot) + async def send(self, event: "BaseEvent", message: str, **kwargs): + return super().send(event, message, **kwargs) + + def __del__(self): + self.driver._bot_disconnect(self) diff --git a/nonebot/adapters/mirai/event/base.py b/nonebot/adapters/mirai/event/base.py index 7a6cae39..6fbb30ff 100644 --- a/nonebot/adapters/mirai/event/base.py +++ b/nonebot/adapters/mirai/event/base.py @@ -37,6 +37,7 @@ class PrivateSenderInfo(BaseModel): class Event(BaseEvent): + self_id: int type: str @classmethod