mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-26 04:26:39 +00:00 
			
		
		
		
	websocket api
This commit is contained in:
		| @@ -2,27 +2,33 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
|  | ||||
| import abc | ||||
| from functools import reduce | ||||
| from functools import reduce, partial | ||||
| from dataclasses import dataclass, field | ||||
|  | ||||
| from nonebot.config import Config | ||||
| from nonebot.typing import Dict, Union, Optional, Iterable, WebSocket | ||||
| from nonebot.typing import Driver, WebSocket | ||||
| from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable | ||||
|  | ||||
|  | ||||
| class BaseBot(abc.ABC): | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def __init__(self, | ||||
|                  driver: Driver, | ||||
|                  connection_type: str, | ||||
|                  config: Config, | ||||
|                  self_id: int, | ||||
|                  self_id: str, | ||||
|                  *, | ||||
|                  websocket: WebSocket = None): | ||||
|         self.driver = driver | ||||
|         self.connection_type = connection_type | ||||
|         self.config = config | ||||
|         self.self_id = self_id | ||||
|         self.websocket = websocket | ||||
|  | ||||
|     def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]: | ||||
|         return partial(self.call_api, name) | ||||
|  | ||||
|     @property | ||||
|     @abc.abstractmethod | ||||
|     def type(self) -> str: | ||||
| @@ -37,6 +43,7 @@ class BaseBot(abc.ABC): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| # TODO: improve event | ||||
| class BaseEvent(abc.ABC): | ||||
|  | ||||
|     def __init__(self, raw_event: dict): | ||||
|   | ||||
| @@ -2,14 +2,17 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
|  | ||||
| import re | ||||
| import sys | ||||
| import asyncio | ||||
|  | ||||
| import httpx | ||||
|  | ||||
| from nonebot.config import Config | ||||
| from nonebot.message import handle_event | ||||
| from nonebot.exception import ApiNotAvailable | ||||
| from nonebot.typing import overrides, Driver, WebSocket, NoReturn | ||||
| from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional | ||||
| from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable | ||||
| from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment | ||||
| from nonebot.typing import Union, Tuple, Iterable, Optional, overrides, WebSocket | ||||
|  | ||||
|  | ||||
| def escape(s: str, *, escape_comma: bool = True) -> str: | ||||
| @@ -38,18 +41,60 @@ def _b2s(b: bool) -> str: | ||||
|     return str(b).lower() | ||||
|  | ||||
|  | ||||
| def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any: | ||||
|     if isinstance(result, dict): | ||||
|         if result.get("status") == "failed": | ||||
|             raise ActionFailed(retcode=result.get("retcode")) | ||||
|         return result.get("data") | ||||
|  | ||||
|  | ||||
| class ResultStore: | ||||
|     _seq = 1 | ||||
|     _futures: Dict[int, asyncio.Future] = {} | ||||
|  | ||||
|     @classmethod | ||||
|     def get_seq(cls) -> int: | ||||
|         s = cls._seq | ||||
|         cls._seq = (cls._seq + 1) % sys.maxsize | ||||
|         return s | ||||
|  | ||||
|     @classmethod | ||||
|     def add_result(cls, result: Dict[str, Any]): | ||||
|         if isinstance(result.get("echo"), dict) and \ | ||||
|                 isinstance(result["echo"].get("seq"), int): | ||||
|             future = cls._futures.get(result["echo"]["seq"]) | ||||
|             if future: | ||||
|                 future.set_result(result) | ||||
|  | ||||
|     @classmethod | ||||
|     async def fetch(cls, seq: int, timeout: float) -> Dict[str, Any]: | ||||
|         future = asyncio.get_event_loop().create_future() | ||||
|         cls._futures[seq] = future | ||||
|         try: | ||||
|             return await asyncio.wait_for(future, timeout) | ||||
|         except asyncio.TimeoutError: | ||||
|             raise NetworkError("WebSocket API call timeout") | ||||
|         finally: | ||||
|             del cls._futures[seq] | ||||
|  | ||||
|  | ||||
| class Bot(BaseBot): | ||||
|  | ||||
|     def __init__(self, | ||||
|                  driver: Driver, | ||||
|                  connection_type: str, | ||||
|                  config: Config, | ||||
|                  self_id: int, | ||||
|                  self_id: str, | ||||
|                  *, | ||||
|                  websocket: WebSocket = None): | ||||
|         if connection_type not in ["http", "websocket"]: | ||||
|             raise ValueError("Unsupported connection type") | ||||
|  | ||||
|         super().__init__(connection_type, config, self_id, websocket=websocket) | ||||
|         super().__init__(driver, | ||||
|                          connection_type, | ||||
|                          config, | ||||
|                          self_id, | ||||
|                          websocket=websocket) | ||||
|  | ||||
|     @property | ||||
|     @overrides(BaseBot) | ||||
| @@ -61,16 +106,29 @@ class Bot(BaseBot): | ||||
|         if not message: | ||||
|             return | ||||
|  | ||||
|         # TODO: convert message into event | ||||
|         event = Event(message) | ||||
|  | ||||
|         await handle_event(self, event) | ||||
|  | ||||
|     @overrides(BaseBot) | ||||
|     async def call_api(self, api: str, data: dict): | ||||
|         # TODO: Call API | ||||
|     async def call_api(self, api: str, **data) -> Union[Any, NoReturn]: | ||||
|         if "self_id" in data: | ||||
|             self_id = str(data.pop("self_id")) | ||||
|             bot = self.driver.bots[self_id] | ||||
|             return await bot.call_api(api, **data) | ||||
|  | ||||
|         if self.type == "websocket": | ||||
|             pass | ||||
|             seq = ResultStore.get_seq() | ||||
|             await self.websocket.send({ | ||||
|                 "action": api, | ||||
|                 "params": data, | ||||
|                 "echo": { | ||||
|                     "seq": seq | ||||
|                 } | ||||
|             }) | ||||
|             return _handle_api_result(await ResultStore.fetch( | ||||
|                 seq, self.config.api_timeout)) | ||||
|  | ||||
|         elif self.type == "http": | ||||
|             api_root = self.config.api_root.get(self.self_id) | ||||
|             if not api_root: | ||||
| @@ -82,14 +140,19 @@ class Bot(BaseBot): | ||||
|             if self.config.access_token: | ||||
|                 headers["Authorization"] = "Bearer " + self.config.access_token | ||||
|  | ||||
|             async with httpx.AsyncClient() as client: | ||||
|                 response = await client.post(api_root + api) | ||||
|             try: | ||||
|                 async with httpx.AsyncClient(headers=headers) as client: | ||||
|                     response = await client.post(api_root + api, json=data) | ||||
|  | ||||
|             if 200 <= response.status_code < 300: | ||||
|                 # TODO: handle http api response | ||||
|                 return ... | ||||
|             raise httpx.HTTPError( | ||||
|                 "<HttpFailed {0.status_code} for url: {0.url}>", response) | ||||
|                 if 200 <= response.status_code < 300: | ||||
|                     result = response.json() | ||||
|                     return _handle_api_result(result) | ||||
|                 raise NetworkError(f"HTTP request received unexpected " | ||||
|                                    f"status code: {response.status_code}") | ||||
|             except httpx.InvalidURL: | ||||
|                 raise NetworkError("API root url invalid") | ||||
|             except httpx.HTTPError: | ||||
|                 raise NetworkError("HTTP request failed") | ||||
|  | ||||
|  | ||||
| class Event(BaseEvent): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user