websocket api

This commit is contained in:
yanyongyu
2020-08-13 15:23:04 +08:00
parent 0e73d4ce20
commit e7f9b2c229
10 changed files with 141 additions and 35 deletions

View File

@ -2,27 +2,33 @@
# -*- coding: utf-8 -*-
import abc
from functools import reduce
from functools import reduce, partial
from dataclasses import dataclass, field
from nonebot.config import Config
from nonebot.typing import Dict, Union, Optional, Iterable, WebSocket
from nonebot.typing import Driver, WebSocket
from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable
class BaseBot(abc.ABC):
@abc.abstractmethod
def __init__(self,
driver: Driver,
connection_type: str,
config: Config,
self_id: int,
self_id: str,
*,
websocket: WebSocket = None):
self.driver = driver
self.connection_type = connection_type
self.config = config
self.self_id = self_id
self.websocket = websocket
def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]:
return partial(self.call_api, name)
@property
@abc.abstractmethod
def type(self) -> str:
@ -37,6 +43,7 @@ class BaseBot(abc.ABC):
raise NotImplementedError
# TODO: improve event
class BaseEvent(abc.ABC):
def __init__(self, raw_event: dict):

View File

@ -2,14 +2,17 @@
# -*- coding: utf-8 -*-
import re
import sys
import asyncio
import httpx
from nonebot.config import Config
from nonebot.message import handle_event
from nonebot.exception import ApiNotAvailable
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
from nonebot.typing import Union, Tuple, Iterable, Optional, overrides, WebSocket
def escape(s: str, *, escape_comma: bool = True) -> str:
@ -38,18 +41,60 @@ def _b2s(b: bool) -> str:
return str(b).lower()
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
if isinstance(result, dict):
if result.get("status") == "failed":
raise ActionFailed(retcode=result.get("retcode"))
return result.get("data")
class ResultStore:
_seq = 1
_futures: Dict[int, asyncio.Future] = {}
@classmethod
def get_seq(cls) -> int:
s = cls._seq
cls._seq = (cls._seq + 1) % sys.maxsize
return s
@classmethod
def add_result(cls, result: Dict[str, Any]):
if isinstance(result.get("echo"), dict) and \
isinstance(result["echo"].get("seq"), int):
future = cls._futures.get(result["echo"]["seq"])
if future:
future.set_result(result)
@classmethod
async def fetch(cls, seq: int, timeout: float) -> Dict[str, Any]:
future = asyncio.get_event_loop().create_future()
cls._futures[seq] = future
try:
return await asyncio.wait_for(future, timeout)
except asyncio.TimeoutError:
raise NetworkError("WebSocket API call timeout")
finally:
del cls._futures[seq]
class Bot(BaseBot):
def __init__(self,
driver: Driver,
connection_type: str,
config: Config,
self_id: int,
self_id: str,
*,
websocket: WebSocket = None):
if connection_type not in ["http", "websocket"]:
raise ValueError("Unsupported connection type")
super().__init__(connection_type, config, self_id, websocket=websocket)
super().__init__(driver,
connection_type,
config,
self_id,
websocket=websocket)
@property
@overrides(BaseBot)
@ -61,16 +106,29 @@ class Bot(BaseBot):
if not message:
return
# TODO: convert message into event
event = Event(message)
await handle_event(self, event)
@overrides(BaseBot)
async def call_api(self, api: str, data: dict):
# TODO: Call API
async def call_api(self, api: str, **data) -> Union[Any, NoReturn]:
if "self_id" in data:
self_id = str(data.pop("self_id"))
bot = self.driver.bots[self_id]
return await bot.call_api(api, **data)
if self.type == "websocket":
pass
seq = ResultStore.get_seq()
await self.websocket.send({
"action": api,
"params": data,
"echo": {
"seq": seq
}
})
return _handle_api_result(await ResultStore.fetch(
seq, self.config.api_timeout))
elif self.type == "http":
api_root = self.config.api_root.get(self.self_id)
if not api_root:
@ -82,14 +140,19 @@ class Bot(BaseBot):
if self.config.access_token:
headers["Authorization"] = "Bearer " + self.config.access_token
async with httpx.AsyncClient() as client:
response = await client.post(api_root + api)
try:
async with httpx.AsyncClient(headers=headers) as client:
response = await client.post(api_root + api, json=data)
if 200 <= response.status_code < 300:
# TODO: handle http api response
return ...
raise httpx.HTTPError(
"<HttpFailed {0.status_code} for url: {0.url}>", response)
if 200 <= response.status_code < 300:
result = response.json()
return _handle_api_result(result)
raise NetworkError(f"HTTP request received unexpected "
f"status code: {response.status_code}")
except httpx.InvalidURL:
raise NetworkError("API root url invalid")
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
class Event(BaseEvent):