From 3dbd927a2a58438781e766328bf52282da144121 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Wed, 15 Jul 2020 20:39:59 +0800 Subject: [PATCH] add websocket class and coolq message segment --- nonebot/adapters/__init__.py | 36 ++++++++++++++++---- nonebot/adapters/coolq.py | 58 +++++++++++++++++++++++++++++-- nonebot/drivers/__init__.py | 26 ++++++++++++++ nonebot/drivers/fastapi.py | 66 +++++++++++++++++++++++++++++------- 4 files changed, 165 insertions(+), 21 deletions(-) diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 30d6a788..c12d96bb 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -21,22 +21,46 @@ class BaseBot(object): class BaseMessageSegment(dict): def __init__(self, - d: Optional[Dict[str, Any]] = None, - *, type_: Optional[str] = None, data: Optional[Dict[str, str]] = None): super().__init__() - if isinstance(d, dict) and d.get('type'): - self.update(d) - elif type_: + if type_: self.type = type_ self.data = data else: - raise ValueError('the "type" field cannot be None or empty') + raise ValueError('The "type" field cannot be empty') def __str__(self): raise NotImplementedError + def __getitem__(self, item): + if item not in ("type", "data"): + raise KeyError(f'Key "{item}" is not allowed') + return super().__getitem__(item) + + def __setitem__(self, key, value): + if key not in ("type", "data"): + raise KeyError(f'Key "{key}" is not allowed') + return super().__setitem__(key, value) + + # TODO: __eq__ __add__ + + @property + def type(self) -> str: + return self["type"] + + @type.setter + def type(self, value: str): + self["type"] = value + + @property + def data(self) -> Dict[str, str]: + return self["data"] + + @data.setter + def data(self, data: Optional[Dict[str, str]]): + self["data"] = data or {} + class BaseMessage(list): diff --git a/nonebot/adapters/coolq.py b/nonebot/adapters/coolq.py index 8e7630c1..f31feaab 100644 --- a/nonebot/adapters/coolq.py +++ b/nonebot/adapters/coolq.py @@ -5,13 +5,40 @@ import httpx from nonebot.event import Event from nonebot.config import Config -from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment from nonebot.message import handle_event +from nonebot.drivers import BaseWebSocket +from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment + + +def escape(s: str, *, escape_comma: bool = True) -> str: + """ + 对字符串进行 CQ 码转义。 + + ``escape_comma`` 参数控制是否转义逗号(``,``)。 + """ + s = s.replace("&", "&") \ + .replace("[", "[") \ + .replace("]", "]") + if escape_comma: + s = s.replace(",", ",") + return s + + +def unescape(s: str) -> str: + """对字符串进行 CQ 码去转义。""" + return s.replace(",", ",") \ + .replace("[", "[") \ + .replace("]", "]") \ + .replace("&", "&") class Bot(BaseBot): - def __init__(self, type_: str, config: Config, *, websocket=None): + def __init__(self, + type_: str, + config: Config, + *, + websocket: BaseWebSocket = None): if type_ not in ["http", "websocket"]: raise ValueError("Unsupported connection type") self.type = type_ @@ -33,7 +60,32 @@ class Bot(BaseBot): class MessageSegment(BaseMessageSegment): - pass + + def __str__(self): + type_ = self.type + data = self.data.copy() + + # process special types + if type_ == "text": + return escape(data.get("text", ""), escape_comma=False) + elif type_ == "at_all": + type_ = "at" + data = {"qq": "all"} + + params = ",".join([f"{k}={escape(str(v))}" for k, v in data.items()]) + return f"[CQ:{type_}{',' if params else ''}{params}]" + + @staticmethod + def at(user_id: int) -> "MessageSegment": + return MessageSegment("at", {"qq": str(user_id)}) + + @staticmethod + def at_all() -> "MessageSegment": + return MessageSegment("at_all") + + @staticmethod + def dice() -> "MessageSegment": + return MessageSegment(type_="dice") class Message(BaseMessage): diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index b8f6c207..49de914c 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -39,3 +39,29 @@ class BaseDriver(object): async def _handle_http_api(self): raise NotImplementedError + + +class BaseWebSocket(object): + + def __init__(self, websocket): + self._websocket = websocket + + @property + def websocket(self): + return self._websocket + + @property + def closed(self): + raise NotImplementedError + + async def accept(self): + raise NotImplementedError + + async def close(self): + raise NotImplementedError + + async def receive(self) -> dict: + raise NotImplementedError + + async def send(self, data: dict): + raise NotImplementedError diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index a30fc87e..32ab8693 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -7,12 +7,13 @@ from typing import Optional from ipaddress import IPv4Address import uvicorn +from fastapi.security import OAuth2PasswordBearer from starlette.websockets import WebSocketDisconnect -from fastapi import Body, FastAPI, WebSocket +from fastapi import Body, FastAPI, WebSocket as FastAPIWebSocket from nonebot.log import logger from nonebot.config import Config -from nonebot.drivers import BaseDriver +from nonebot.drivers import BaseDriver, BaseWebSocket from nonebot.adapters.coolq import Bot as CoolQBot @@ -86,7 +87,11 @@ class Driver(BaseDriver): log_config=LOGGING_CONFIG, **kwargs) - async def _handle_http(self, adapter: str, data: dict = Body(...)): + async def _handle_http(self, + adapter: str, + data: dict = Body(...), + access_token: str = OAuth2PasswordBearer( + "/", auto_error=False)): # TODO: Check authorization logger.debug(f"Received message: {data}") if adapter == "coolq": @@ -94,20 +99,57 @@ class Driver(BaseDriver): await bot.handle_message(data) return {"status": 200, "message": "success"} - async def _handle_ws_reverse(self, adapter: str, websocket: WebSocket): + async def _handle_ws_reverse(self, + adapter: str, + websocket: FastAPIWebSocket, + access_token: str = OAuth2PasswordBearer( + "/", auto_error=False)): + websocket = WebSocket(websocket) + # TODO: Check authorization await websocket.accept() - while True: - try: - data = await websocket.receive_json() - except json.decoder.JSONDecodeError as e: - logger.exception(e) + + while not websocket.closed: + data = await websocket.receive() + + if not data: continue - except WebSocketDisconnect: - logger.error("WebSocket Disconnect") - return logger.debug(f"Received message: {data}") if adapter == "coolq": bot = CoolQBot("websocket", self.config, websocket=websocket) await bot.handle_message(data) + + +class WebSocket(BaseWebSocket): + + def __init__(self, websocket: FastAPIWebSocket): + super().__init__(websocket) + self._closed = None + + @property + def closed(self): + return self._closed + + async def accept(self): + await self.websocket.accept() + self._closed = False + + async def close(self): + await self.websocket.close() + self._closed = True + + async def receive(self) -> Optional[dict]: + data = None + try: + data = await self.websocket.receive_json() + except ValueError: + logger.debug("Received an invalid json message.") + except WebSocketDisconnect: + self._closed = True + logger.error("WebSocket disconnected by peer.") + + return data + + async def send(self, data: dict) -> None: + await self.websocket.send_json(data)