mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-27 00:01:27 +00:00
websocket api
This commit is contained in:
@ -2,27 +2,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import abc
|
||||
from functools import reduce
|
||||
from functools import reduce, partial
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from nonebot.config import Config
|
||||
from nonebot.typing import Dict, Union, Optional, Iterable, WebSocket
|
||||
from nonebot.typing import Driver, WebSocket
|
||||
from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable
|
||||
|
||||
|
||||
class BaseBot(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self,
|
||||
driver: Driver,
|
||||
connection_type: str,
|
||||
config: Config,
|
||||
self_id: int,
|
||||
self_id: str,
|
||||
*,
|
||||
websocket: WebSocket = None):
|
||||
self.driver = driver
|
||||
self.connection_type = connection_type
|
||||
self.config = config
|
||||
self.self_id = self_id
|
||||
self.websocket = websocket
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]:
|
||||
return partial(self.call_api, name)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def type(self) -> str:
|
||||
@ -37,6 +43,7 @@ class BaseBot(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: improve event
|
||||
class BaseEvent(abc.ABC):
|
||||
|
||||
def __init__(self, raw_event: dict):
|
||||
|
@ -2,14 +2,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import re
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
|
||||
from nonebot.config import Config
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.exception import ApiNotAvailable
|
||||
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
|
||||
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
|
||||
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
|
||||
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
|
||||
from nonebot.typing import Union, Tuple, Iterable, Optional, overrides, WebSocket
|
||||
|
||||
|
||||
def escape(s: str, *, escape_comma: bool = True) -> str:
|
||||
@ -38,18 +41,60 @@ def _b2s(b: bool) -> str:
|
||||
return str(b).lower()
|
||||
|
||||
|
||||
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
|
||||
if isinstance(result, dict):
|
||||
if result.get("status") == "failed":
|
||||
raise ActionFailed(retcode=result.get("retcode"))
|
||||
return result.get("data")
|
||||
|
||||
|
||||
class ResultStore:
|
||||
_seq = 1
|
||||
_futures: Dict[int, asyncio.Future] = {}
|
||||
|
||||
@classmethod
|
||||
def get_seq(cls) -> int:
|
||||
s = cls._seq
|
||||
cls._seq = (cls._seq + 1) % sys.maxsize
|
||||
return s
|
||||
|
||||
@classmethod
|
||||
def add_result(cls, result: Dict[str, Any]):
|
||||
if isinstance(result.get("echo"), dict) and \
|
||||
isinstance(result["echo"].get("seq"), int):
|
||||
future = cls._futures.get(result["echo"]["seq"])
|
||||
if future:
|
||||
future.set_result(result)
|
||||
|
||||
@classmethod
|
||||
async def fetch(cls, seq: int, timeout: float) -> Dict[str, Any]:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
cls._futures[seq] = future
|
||||
try:
|
||||
return await asyncio.wait_for(future, timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise NetworkError("WebSocket API call timeout")
|
||||
finally:
|
||||
del cls._futures[seq]
|
||||
|
||||
|
||||
class Bot(BaseBot):
|
||||
|
||||
def __init__(self,
|
||||
driver: Driver,
|
||||
connection_type: str,
|
||||
config: Config,
|
||||
self_id: int,
|
||||
self_id: str,
|
||||
*,
|
||||
websocket: WebSocket = None):
|
||||
if connection_type not in ["http", "websocket"]:
|
||||
raise ValueError("Unsupported connection type")
|
||||
|
||||
super().__init__(connection_type, config, self_id, websocket=websocket)
|
||||
super().__init__(driver,
|
||||
connection_type,
|
||||
config,
|
||||
self_id,
|
||||
websocket=websocket)
|
||||
|
||||
@property
|
||||
@overrides(BaseBot)
|
||||
@ -61,16 +106,29 @@ class Bot(BaseBot):
|
||||
if not message:
|
||||
return
|
||||
|
||||
# TODO: convert message into event
|
||||
event = Event(message)
|
||||
|
||||
await handle_event(self, event)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def call_api(self, api: str, data: dict):
|
||||
# TODO: Call API
|
||||
async def call_api(self, api: str, **data) -> Union[Any, NoReturn]:
|
||||
if "self_id" in data:
|
||||
self_id = str(data.pop("self_id"))
|
||||
bot = self.driver.bots[self_id]
|
||||
return await bot.call_api(api, **data)
|
||||
|
||||
if self.type == "websocket":
|
||||
pass
|
||||
seq = ResultStore.get_seq()
|
||||
await self.websocket.send({
|
||||
"action": api,
|
||||
"params": data,
|
||||
"echo": {
|
||||
"seq": seq
|
||||
}
|
||||
})
|
||||
return _handle_api_result(await ResultStore.fetch(
|
||||
seq, self.config.api_timeout))
|
||||
|
||||
elif self.type == "http":
|
||||
api_root = self.config.api_root.get(self.self_id)
|
||||
if not api_root:
|
||||
@ -82,14 +140,19 @@ class Bot(BaseBot):
|
||||
if self.config.access_token:
|
||||
headers["Authorization"] = "Bearer " + self.config.access_token
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(api_root + api)
|
||||
try:
|
||||
async with httpx.AsyncClient(headers=headers) as client:
|
||||
response = await client.post(api_root + api, json=data)
|
||||
|
||||
if 200 <= response.status_code < 300:
|
||||
# TODO: handle http api response
|
||||
return ...
|
||||
raise httpx.HTTPError(
|
||||
"<HttpFailed {0.status_code} for url: {0.url}>", response)
|
||||
if 200 <= response.status_code < 300:
|
||||
result = response.json()
|
||||
return _handle_api_result(result)
|
||||
raise NetworkError(f"HTTP request received unexpected "
|
||||
f"status code: {response.status_code}")
|
||||
except httpx.InvalidURL:
|
||||
raise NetworkError("API root url invalid")
|
||||
except httpx.HTTPError:
|
||||
raise NetworkError("HTTP request failed")
|
||||
|
||||
|
||||
class Event(BaseEvent):
|
||||
|
Reference in New Issue
Block a user