mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-30 22:46:40 +00:00 
			
		
		
		
	add websocket class and coolq message segment
This commit is contained in:
		| @@ -21,22 +21,46 @@ class BaseBot(object): | |||||||
| class BaseMessageSegment(dict): | class BaseMessageSegment(dict): | ||||||
|  |  | ||||||
|     def __init__(self, |     def __init__(self, | ||||||
|                  d: Optional[Dict[str, Any]] = None, |  | ||||||
|                  *, |  | ||||||
|                  type_: Optional[str] = None, |                  type_: Optional[str] = None, | ||||||
|                  data: Optional[Dict[str, str]] = None): |                  data: Optional[Dict[str, str]] = None): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         if isinstance(d, dict) and d.get('type'): |         if type_: | ||||||
|             self.update(d) |  | ||||||
|         elif type_: |  | ||||||
|             self.type = type_ |             self.type = type_ | ||||||
|             self.data = data |             self.data = data | ||||||
|         else: |         else: | ||||||
|             raise ValueError('the "type" field cannot be None or empty') |             raise ValueError('The "type" field cannot be empty') | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def __getitem__(self, item): | ||||||
|  |         if item not in ("type", "data"): | ||||||
|  |             raise KeyError(f'Key "{item}" is not allowed') | ||||||
|  |         return super().__getitem__(item) | ||||||
|  |  | ||||||
|  |     def __setitem__(self, key, value): | ||||||
|  |         if key not in ("type", "data"): | ||||||
|  |             raise KeyError(f'Key "{key}" is not allowed') | ||||||
|  |         return super().__setitem__(key, value) | ||||||
|  |  | ||||||
|  |     # TODO: __eq__ __add__ | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def type(self) -> str: | ||||||
|  |         return self["type"] | ||||||
|  |  | ||||||
|  |     @type.setter | ||||||
|  |     def type(self, value: str): | ||||||
|  |         self["type"] = value | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def data(self) -> Dict[str, str]: | ||||||
|  |         return self["data"] | ||||||
|  |  | ||||||
|  |     @data.setter | ||||||
|  |     def data(self, data: Optional[Dict[str, str]]): | ||||||
|  |         self["data"] = data or {} | ||||||
|  |  | ||||||
|  |  | ||||||
| class BaseMessage(list): | class BaseMessage(list): | ||||||
|  |  | ||||||
|   | |||||||
| @@ -5,13 +5,40 @@ import httpx | |||||||
|  |  | ||||||
| from nonebot.event import Event | from nonebot.event import Event | ||||||
| from nonebot.config import Config | from nonebot.config import Config | ||||||
| from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment |  | ||||||
| from nonebot.message import handle_event | from nonebot.message import handle_event | ||||||
|  | from nonebot.drivers import BaseWebSocket | ||||||
|  | from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def escape(s: str, *, escape_comma: bool = True) -> str: | ||||||
|  |     """ | ||||||
|  |     对字符串进行 CQ 码转义。 | ||||||
|  |  | ||||||
|  |     ``escape_comma`` 参数控制是否转义逗号(``,``)。 | ||||||
|  |     """ | ||||||
|  |     s = s.replace("&", "&") \ | ||||||
|  |         .replace("[", "[") \ | ||||||
|  |         .replace("]", "]") | ||||||
|  |     if escape_comma: | ||||||
|  |         s = s.replace(",", ",") | ||||||
|  |     return s | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def unescape(s: str) -> str: | ||||||
|  |     """对字符串进行 CQ 码去转义。""" | ||||||
|  |     return s.replace(",", ",") \ | ||||||
|  |         .replace("[", "[") \ | ||||||
|  |         .replace("]", "]") \ | ||||||
|  |         .replace("&", "&") | ||||||
|  |  | ||||||
|  |  | ||||||
| class Bot(BaseBot): | class Bot(BaseBot): | ||||||
|  |  | ||||||
|     def __init__(self, type_: str, config: Config, *, websocket=None): |     def __init__(self, | ||||||
|  |                  type_: str, | ||||||
|  |                  config: Config, | ||||||
|  |                  *, | ||||||
|  |                  websocket: BaseWebSocket = None): | ||||||
|         if type_ not in ["http", "websocket"]: |         if type_ not in ["http", "websocket"]: | ||||||
|             raise ValueError("Unsupported connection type") |             raise ValueError("Unsupported connection type") | ||||||
|         self.type = type_ |         self.type = type_ | ||||||
| @@ -33,7 +60,32 @@ class Bot(BaseBot): | |||||||
|  |  | ||||||
|  |  | ||||||
| class MessageSegment(BaseMessageSegment): | class MessageSegment(BaseMessageSegment): | ||||||
|     pass |  | ||||||
|  |     def __str__(self): | ||||||
|  |         type_ = self.type | ||||||
|  |         data = self.data.copy() | ||||||
|  |  | ||||||
|  |         # process special types | ||||||
|  |         if type_ == "text": | ||||||
|  |             return escape(data.get("text", ""), escape_comma=False) | ||||||
|  |         elif type_ == "at_all": | ||||||
|  |             type_ = "at" | ||||||
|  |             data = {"qq": "all"} | ||||||
|  |  | ||||||
|  |         params = ",".join([f"{k}={escape(str(v))}" for k, v in data.items()]) | ||||||
|  |         return f"[CQ:{type_}{',' if params else ''}{params}]" | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def at(user_id: int) -> "MessageSegment": | ||||||
|  |         return MessageSegment("at", {"qq": str(user_id)}) | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def at_all() -> "MessageSegment": | ||||||
|  |         return MessageSegment("at_all") | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def dice() -> "MessageSegment": | ||||||
|  |         return MessageSegment(type_="dice") | ||||||
|  |  | ||||||
|  |  | ||||||
| class Message(BaseMessage): | class Message(BaseMessage): | ||||||
|   | |||||||
| @@ -39,3 +39,29 @@ class BaseDriver(object): | |||||||
|  |  | ||||||
|     async def _handle_http_api(self): |     async def _handle_http_api(self): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BaseWebSocket(object): | ||||||
|  |  | ||||||
|  |     def __init__(self, websocket): | ||||||
|  |         self._websocket = websocket | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def websocket(self): | ||||||
|  |         return self._websocket | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def closed(self): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     async def accept(self): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     async def close(self): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     async def receive(self) -> dict: | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     async def send(self, data: dict): | ||||||
|  |         raise NotImplementedError | ||||||
|   | |||||||
| @@ -7,12 +7,13 @@ from typing import Optional | |||||||
| from ipaddress import IPv4Address | from ipaddress import IPv4Address | ||||||
|  |  | ||||||
| import uvicorn | import uvicorn | ||||||
|  | from fastapi.security import OAuth2PasswordBearer | ||||||
| from starlette.websockets import WebSocketDisconnect | from starlette.websockets import WebSocketDisconnect | ||||||
| from fastapi import Body, FastAPI, WebSocket | from fastapi import Body, FastAPI, WebSocket as FastAPIWebSocket | ||||||
|  |  | ||||||
| from nonebot.log import logger | from nonebot.log import logger | ||||||
| from nonebot.config import Config | from nonebot.config import Config | ||||||
| from nonebot.drivers import BaseDriver | from nonebot.drivers import BaseDriver, BaseWebSocket | ||||||
| from nonebot.adapters.coolq import Bot as CoolQBot | from nonebot.adapters.coolq import Bot as CoolQBot | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -86,7 +87,11 @@ class Driver(BaseDriver): | |||||||
|                     log_config=LOGGING_CONFIG, |                     log_config=LOGGING_CONFIG, | ||||||
|                     **kwargs) |                     **kwargs) | ||||||
|  |  | ||||||
|     async def _handle_http(self, adapter: str, data: dict = Body(...)): |     async def _handle_http(self, | ||||||
|  |                            adapter: str, | ||||||
|  |                            data: dict = Body(...), | ||||||
|  |                            access_token: str = OAuth2PasswordBearer( | ||||||
|  |                                "/", auto_error=False)): | ||||||
|         # TODO: Check authorization |         # TODO: Check authorization | ||||||
|         logger.debug(f"Received message: {data}") |         logger.debug(f"Received message: {data}") | ||||||
|         if adapter == "coolq": |         if adapter == "coolq": | ||||||
| @@ -94,20 +99,57 @@ class Driver(BaseDriver): | |||||||
|             await bot.handle_message(data) |             await bot.handle_message(data) | ||||||
|         return {"status": 200, "message": "success"} |         return {"status": 200, "message": "success"} | ||||||
|  |  | ||||||
|     async def _handle_ws_reverse(self, adapter: str, websocket: WebSocket): |     async def _handle_ws_reverse(self, | ||||||
|  |                                  adapter: str, | ||||||
|  |                                  websocket: FastAPIWebSocket, | ||||||
|  |                                  access_token: str = OAuth2PasswordBearer( | ||||||
|  |                                      "/", auto_error=False)): | ||||||
|  |         websocket = WebSocket(websocket) | ||||||
|  |  | ||||||
|         # TODO: Check authorization |         # TODO: Check authorization | ||||||
|         await websocket.accept() |         await websocket.accept() | ||||||
|         while True: |  | ||||||
|             try: |         while not websocket.closed: | ||||||
|                 data = await websocket.receive_json() |             data = await websocket.receive() | ||||||
|             except json.decoder.JSONDecodeError as e: |  | ||||||
|                 logger.exception(e) |             if not data: | ||||||
|                 continue |                 continue | ||||||
|             except WebSocketDisconnect: |  | ||||||
|                 logger.error("WebSocket Disconnect") |  | ||||||
|                 return |  | ||||||
|  |  | ||||||
|             logger.debug(f"Received message: {data}") |             logger.debug(f"Received message: {data}") | ||||||
|             if adapter == "coolq": |             if adapter == "coolq": | ||||||
|                 bot = CoolQBot("websocket", self.config, websocket=websocket) |                 bot = CoolQBot("websocket", self.config, websocket=websocket) | ||||||
|                 await bot.handle_message(data) |                 await bot.handle_message(data) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class WebSocket(BaseWebSocket): | ||||||
|  |  | ||||||
|  |     def __init__(self, websocket: FastAPIWebSocket): | ||||||
|  |         super().__init__(websocket) | ||||||
|  |         self._closed = None | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def closed(self): | ||||||
|  |         return self._closed | ||||||
|  |  | ||||||
|  |     async def accept(self): | ||||||
|  |         await self.websocket.accept() | ||||||
|  |         self._closed = False | ||||||
|  |  | ||||||
|  |     async def close(self): | ||||||
|  |         await self.websocket.close() | ||||||
|  |         self._closed = True | ||||||
|  |  | ||||||
|  |     async def receive(self) -> Optional[dict]: | ||||||
|  |         data = None | ||||||
|  |         try: | ||||||
|  |             data = await self.websocket.receive_json() | ||||||
|  |         except ValueError: | ||||||
|  |             logger.debug("Received an invalid json message.") | ||||||
|  |         except WebSocketDisconnect: | ||||||
|  |             self._closed = True | ||||||
|  |             logger.error("WebSocket disconnected by peer.") | ||||||
|  |  | ||||||
|  |         return data | ||||||
|  |  | ||||||
|  |     async def send(self, data: dict) -> None: | ||||||
|  |         await self.websocket.send_json(data) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user