mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-31 15:06:42 +00:00 
			
		
		
		
	🚧 finish forward websocket receive
This commit is contained in:
		| @@ -1,22 +1,89 @@ | |||||||
| from pprint import pprint | import asyncio | ||||||
| from typing import Optional | import json | ||||||
|  | from ipaddress import IPv4Address | ||||||
|  | from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set, | ||||||
|  |                     TypeVar) | ||||||
|  |  | ||||||
|  | import httpx | ||||||
|  | import websockets | ||||||
|  |  | ||||||
| from nonebot.adapters import Bot as BaseBot | from nonebot.adapters import Bot as BaseBot | ||||||
| from nonebot.adapters import Event as BaseEvent | from nonebot.adapters import Event as BaseEvent | ||||||
| from nonebot.drivers import Driver, WebSocket | from nonebot.drivers import Driver | ||||||
|  | from nonebot.drivers import WebSocket as BaseWebSocket | ||||||
|  | from nonebot.exception import RequestDenied | ||||||
|  | from nonebot.log import logger | ||||||
| from nonebot.message import handle_event | from nonebot.message import handle_event | ||||||
| from nonebot.typing import overrides | from nonebot.typing import overrides | ||||||
|  |  | ||||||
|  | from .config import Config | ||||||
| from .event import Event | from .event import Event | ||||||
|  |  | ||||||
|  | WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]] | ||||||
|  | WebsocketHandler_T = TypeVar('WebsocketHandler_T', | ||||||
|  |                              bound=WebsocketHandlerFunction) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class WebSocket(BaseWebSocket): | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     async def new(cls, *, host: IPv4Address, port: int, | ||||||
|  |                   session_key: str) -> "WebSocket": | ||||||
|  |         listen_address = httpx.URL(f'ws://{host}:{port}/all', | ||||||
|  |                                    params={'sessionKey': session_key}) | ||||||
|  |         websocket = await websockets.connect(uri=str(listen_address)) | ||||||
|  |         return cls(websocket) | ||||||
|  |  | ||||||
|  |     @overrides(BaseWebSocket) | ||||||
|  |     def __init__(self, websocket: websockets.WebSocketClientProtocol): | ||||||
|  |         self.event_handlers: Set[WebsocketHandlerFunction] = set() | ||||||
|  |         super().__init__(websocket) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     @overrides(BaseWebSocket) | ||||||
|  |     def websocket(self) -> websockets.WebSocketClientProtocol: | ||||||
|  |         return self._websocket | ||||||
|  |  | ||||||
|  |     @overrides(BaseWebSocket) | ||||||
|  |     async def send(self, data: Dict[str, Any]): | ||||||
|  |         return await self.websocket.send(json.dumps(data)) | ||||||
|  |  | ||||||
|  |     @overrides(BaseWebSocket) | ||||||
|  |     async def receive(self) -> Dict[str, Any]: | ||||||
|  |         received = await self.websocket.recv() | ||||||
|  |         return json.loads(received) | ||||||
|  |  | ||||||
|  |     async def _dispatcher(self): | ||||||
|  |         while not self.websocket.closed: | ||||||
|  |             try: | ||||||
|  |                 data = await self.receive() | ||||||
|  |             except websockets.ConnectionClosedOK: | ||||||
|  |                 break | ||||||
|  |             except Exception as e: | ||||||
|  |                 logger.exception(f'Websocket client listened {self.websocket} ' | ||||||
|  |                                  f'failed to receive data: {e}') | ||||||
|  |                 continue | ||||||
|  |             asyncio.ensure_future( | ||||||
|  |                 asyncio.gather(*map(lambda f: f(data), self.event_handlers), | ||||||
|  |                                return_exceptions=True)) | ||||||
|  |  | ||||||
|  |     @overrides(BaseWebSocket) | ||||||
|  |     async def accept(self): | ||||||
|  |         asyncio.ensure_future(self._dispatcher()) | ||||||
|  |  | ||||||
|  |     @overrides(BaseWebSocket) | ||||||
|  |     async def close(self): | ||||||
|  |         await self.websocket.close() | ||||||
|  |  | ||||||
|  |     def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T: | ||||||
|  |         self.event_handlers.add(callable) | ||||||
|  |         return callable | ||||||
|  |  | ||||||
|  |  | ||||||
| class MiraiBot(BaseBot): | class MiraiBot(BaseBot): | ||||||
|  |  | ||||||
|     def __init__(self, |     def __init__(self, connection_type: str, self_id: str, *, | ||||||
|                  connection_type: str, |                  websocket: WebSocket): | ||||||
|                  self_id: str, |  | ||||||
|                  *, |  | ||||||
|                  websocket: Optional["WebSocket"] = None): |  | ||||||
|         super().__init__(connection_type, self_id, websocket=websocket) |         super().__init__(connection_type, self_id, websocket=websocket) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
| @@ -27,8 +94,55 @@ class MiraiBot(BaseBot): | |||||||
|     @classmethod |     @classmethod | ||||||
|     @overrides(BaseBot) |     @overrides(BaseBot) | ||||||
|     async def check_permission(cls, driver: "Driver", connection_type: str, |     async def check_permission(cls, driver: "Driver", connection_type: str, | ||||||
|                                headers: dict, body: Optional[dict]) -> str: |                                headers: dict, body: Optional[dict]) -> NoReturn: | ||||||
|         return '' |         raise RequestDenied( | ||||||
|  |             status_code=501, | ||||||
|  |             reason=f'Connection {connection_type} not implented') | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     @overrides(BaseBot) | ||||||
|  |     def register(cls, driver: "Driver", config: "Config", qq: int): | ||||||
|  |         config = Config.parse_obj(config.dict()) | ||||||
|  |         assert config.auth_key and config.host and config.port, f'Current config {config!r} is invalid' | ||||||
|  |  | ||||||
|  |         super().register(driver, config)  # type: ignore | ||||||
|  |  | ||||||
|  |         @driver.on_startup | ||||||
|  |         async def _startup(): | ||||||
|  |             async with httpx.AsyncClient( | ||||||
|  |                     base_url=f'http://{config.host}:{config.port}') as client: | ||||||
|  |                 response = await client.get('/about') | ||||||
|  |                 info = response.json() | ||||||
|  |                 logger.debug(f'Mirai API returned info: {info}') | ||||||
|  |                 response = await client.post('/auth', | ||||||
|  |                                              json={'authKey': config.auth_key}) | ||||||
|  |                 status = response.json() | ||||||
|  |                 assert status['code'] == 0 | ||||||
|  |                 session_key = status['session'] | ||||||
|  |                 response = await client.post('/verify', | ||||||
|  |                                              json={ | ||||||
|  |                                                  'sessionKey': session_key, | ||||||
|  |                                                  'qq': qq | ||||||
|  |                                              }) | ||||||
|  |                 assert response.json()['code'] == 0 | ||||||
|  |  | ||||||
|  |             websocket = await WebSocket.new( | ||||||
|  |                 host=config.host,  # type: ignore | ||||||
|  |                 port=config.port,  # type: ignore | ||||||
|  |                 session_key=session_key) | ||||||
|  |             bot = cls(connection_type='forward_ws', | ||||||
|  |                       self_id=str(qq), | ||||||
|  |                       websocket=websocket) | ||||||
|  |             websocket.handle(bot.handle_message) | ||||||
|  |             driver._clients[str(qq)] = bot | ||||||
|  |             await websocket.accept() | ||||||
|  |  | ||||||
|  |         @driver.on_shutdown | ||||||
|  |         async def _shutdown(): | ||||||
|  |             bot = driver._clients.pop(str(qq), None) | ||||||
|  |             if bot is None: | ||||||
|  |                 return | ||||||
|  |             await bot.websocket.close()  #type:ignore | ||||||
|  |  | ||||||
|     @overrides(BaseBot) |     @overrides(BaseBot) | ||||||
|     async def handle_message(self, message: dict): |     async def handle_message(self, message: dict): | ||||||
|   | |||||||
							
								
								
									
										13
									
								
								nonebot/adapters/mirai/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								nonebot/adapters/mirai/config.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | |||||||
|  | from ipaddress import IPv4Address | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
|  | from pydantic import BaseModel, Extra, Field | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Config(BaseModel): | ||||||
|  |     auth_key: Optional[str] = Field(None, alias='mirai_auth_key') | ||||||
|  |     host: Optional[IPv4Address] = Field(None, alias='mirai_host') | ||||||
|  |     port: Optional[int] = Field(None, alias='mirai_port') | ||||||
|  |  | ||||||
|  |     class Config: | ||||||
|  |         extra = Extra.ignore | ||||||
| @@ -62,7 +62,7 @@ class Driver(abc.ABC): | |||||||
|         :说明: 已连接的 Bot |         :说明: 已连接的 Bot | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|     def register_adapter(self, name: str, adapter: Type["Bot"]): |     def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs): | ||||||
|         """ |         """ | ||||||
|         :说明: |         :说明: | ||||||
|  |  | ||||||
| @@ -74,7 +74,7 @@ class Driver(abc.ABC): | |||||||
|           * ``adapter: Type[Bot]``: 适配器 Class |           * ``adapter: Type[Bot]``: 适配器 Class | ||||||
|         """ |         """ | ||||||
|         self._adapters[name] = adapter |         self._adapters[name] = adapter | ||||||
|         adapter.register(self, self.config) |         adapter.register(self, self.config, **kwargs) | ||||||
|         logger.opt( |         logger.opt( | ||||||
|             colors=True).debug(f'Succeeded to load adapter "<y>{name}</y>"') |             colors=True).debug(f'Succeeded to load adapter "<y>{name}</y>"') | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user