diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 8eea6ea2..8efdf36e 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -27,9 +27,9 @@ from nonebot.config import Env from nonebot.typing import overrides from nonebot.utils import escape_tag from nonebot.config import Config as NoneBotConfig -from nonebot.drivers import Request as GenericRequest +from nonebot.drivers import Request as BaseRequest +from nonebot.drivers import Response as BaseResponse from nonebot.drivers import WebSocket as BaseWebSocket -from nonebot.drivers import Response as GenericResponse from nonebot.drivers import ( HTTPVersion, ForwardDriver, @@ -247,7 +247,7 @@ class Driver(ReverseDriver): request: Request, setup: HTTPServerSetup, ): - http_request = GenericRequest( + http_request = BaseRequest( request.method, str(request.url), headers=request.headers.items(), @@ -260,7 +260,7 @@ class Driver(ReverseDriver): return Response(response.content, response.status_code, dict(response.headers)) async def _handle_ws(self, websocket: WebSocket, setup: WebSocketServerSetup): - request = GenericRequest( + request = BaseRequest( "GET", str(websocket.url), headers=websocket.headers.items(), @@ -293,7 +293,7 @@ class FullDriver(ForwardDriver, Driver): return "fastapi_full" @overrides(ForwardDriver) - async def request(self, setup: "GenericRequest") -> Any: + async def request(self, setup: "BaseRequest") -> Any: async with httpx.AsyncClient( http2=setup.version == HTTPVersion.H2, follow_redirects=True ) as client: @@ -304,7 +304,7 @@ class FullDriver(ForwardDriver, Driver): headers=tuple(setup.headers.items()), timeout=30.0, ) - return GenericResponse( + return BaseResponse( response.status_code, headers=response.headers, content=response.content, @@ -312,31 +312,24 @@ class FullDriver(ForwardDriver, Driver): ) @overrides(ForwardDriver) - async def websocket(self, setup: "GenericRequest") -> Any: + async def websocket(self, setup: "BaseRequest") -> Any: ws = await Connect(str(setup.url), extra_headers=setup.headers.items()) return WebSocketsWS(request=setup, websocket=ws) class WebSocketsWS(BaseWebSocket): @overrides(BaseWebSocket) - def __init__(self, *, request: GenericRequest, websocket: WebSocketClientProtocol): + def __init__(self, *, request: BaseRequest, websocket: WebSocketClientProtocol): super().__init__(request=request) self.websocket = websocket @property @overrides(BaseWebSocket) def closed(self) -> bool: - # if isinstance(self.websocket, WebSocket): - # return ( - # self.websocket.client_state == WebSocketState.DISCONNECTED - # or self.websocket.application_state == WebSocketState.DISCONNECTED - # ) return self.websocket.closed @overrides(BaseWebSocket) async def accept(self): - # if isinstance(self.websocket, WebSocket): - # await self.websocket.accept() raise NotImplementedError @overrides(BaseWebSocket) @@ -345,8 +338,6 @@ class WebSocketsWS(BaseWebSocket): @overrides(BaseWebSocket) async def receive(self) -> str: - # if isinstance(self.websocket, WebSocket): - # return await self.websocket.receive_text() msg = await self.websocket.recv() if isinstance(msg, bytes): raise TypeError("WebSocket received unexpected frame type: bytes") @@ -354,8 +345,6 @@ class WebSocketsWS(BaseWebSocket): @overrides(BaseWebSocket) async def receive_bytes(self) -> bytes: - # if isinstance(self.websocket, WebSocket): - # return await self.websocket.receive_bytes() msg = await self.websocket.recv() if isinstance(msg, str): raise TypeError("WebSocket received unexpected frame type: str") @@ -363,20 +352,16 @@ class WebSocketsWS(BaseWebSocket): @overrides(BaseWebSocket) async def send(self, data: str) -> None: - # if isinstance(self.websocket, WebSocket): - # await self.websocket.send({"type": "websocket.send", "text": data}) await self.websocket.send(data) @overrides(BaseWebSocket) async def send_bytes(self, data: bytes) -> None: - # if isinstance(self.websocket, WebSocket): - # await self.websocket.send({"type": "websocket.send", "bytes": data}) await self.websocket.send(data) class FastAPIWebSocket(BaseWebSocket): @overrides(BaseWebSocket) - def __init__(self, *, request: GenericRequest, websocket: WebSocket): + def __init__(self, *, request: BaseRequest, websocket: WebSocket): super().__init__(request=request) self.websocket = websocket diff --git a/nonebot/drivers/quart.py b/nonebot/drivers/quart.py index 79055eb9..7527957e 100644 --- a/nonebot/drivers/quart.py +++ b/nonebot/drivers/quart.py @@ -8,8 +8,7 @@ Quart 驱动适配 https://pgjones.gitlab.io/quart/index.html """ -import asyncio -from dataclasses import dataclass +from functools import partial from typing import List, TypeVar, Callable, Optional, Coroutine import uvicorn @@ -20,8 +19,9 @@ from nonebot.log import logger from nonebot.typing import overrides from nonebot.utils import escape_tag from nonebot.config import Config as NoneBotConfig -from nonebot.drivers import HTTPRequest, ReverseDriver +from nonebot.drivers import Request as BaseRequest from nonebot.drivers import WebSocket as BaseWebSocket +from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup try: from quart import request as _request @@ -98,11 +98,6 @@ class Config(BaseSettings): class Driver(ReverseDriver): """ Quart 驱动框架 - - :上报地址: - - * ``/{adapter name}/http``: HTTP POST 上报 - * ``/{adapter name}/ws``: WebSocket 上报 """ def __init__(self, env: Env, config: NoneBotConfig): @@ -111,12 +106,6 @@ class Driver(ReverseDriver): self.quart_config = Config(**config.dict()) self._server_app = Quart(self.__class__.__qualname__) - self._server_app.add_url_rule( - "//http", methods=["POST"], view_func=self._handle_http - ) - self._server_app.add_websocket( - "//ws", view_func=self._handle_ws_reverse - ) @property @overrides(ReverseDriver) @@ -142,6 +131,21 @@ class Driver(ReverseDriver): """Quart 使用的 logger""" return self._server_app.logger + @overrides(ReverseDriver) + def setup_http_server(self, setup: HTTPServerSetup): + self._server_app.add_url_rule( + setup.path.path, + methods=[setup.method], + view_func=partial(self._handle_http, setup=setup), + ) + + @overrides(ReverseDriver) + def setup_websocket_server(self, setup: WebSocketServerSetup) -> None: + self._server_app.add_websocket( + setup.path.path, + view_func=partial(self._handle_ws, setup=setup), + ) + @overrides(ReverseDriver) def on_startup(self, func: _AsyncCallable) -> _AsyncCallable: """参考文档: `Startup and Shutdown`_ @@ -199,128 +203,75 @@ class Driver(ReverseDriver): **kwargs, ) - async def _handle_http(self, adapter: str): + async def _handle_http(self, setup: HTTPServerSetup) -> Response: request: Request = _request - data: bytes = await request.get_data() # type: ignore - if adapter not in self._adapters: - logger.warning( - f"Unknown adapter {adapter}. " "Please register the adapter before use." - ) - raise exceptions.NotFound() - - BotClass = self._adapters[adapter] - http_request = HTTPRequest( - request.http_version, - request.scheme, - request.path, - request.query_string, - dict(request.headers), + http_request = BaseRequest( request.method, - data, + request.url, + headers=request.headers.items(), + cookies=list(request.cookies.items()), + content=await request.get_data( + cache=False, as_text=False, parse_form_data=False + ), + version=request.http_version, ) - self_id, response = await BotClass.check_permission(self, http_request) + response = await setup.handle_func(http_request) - if not self_id: - raise exceptions.Unauthorized( - description=(response and response.body or b"").decode() - ) - if self_id in self._clients: - logger.warning( - "There's already a reverse websocket connection," - "so the event may be handled twice." - ) - bot = BotClass(self_id, http_request) - asyncio.create_task(bot.handle_message(data)) return Response( - response and response.body or "", response and response.status or 200 + response.content or "", + response.status_code or 200, + headers=dict(response.headers), ) - async def _handle_ws_reverse(self, adapter: str): + async def _handle_ws(self, setup: WebSocketServerSetup) -> None: websocket: QuartWebSocket = _websocket - ws = WebSocket( - websocket.http_version, - websocket.scheme, - websocket.path, - websocket.query_string, - dict(websocket.headers), - websocket, + + http_request = BaseRequest( + websocket.method, + websocket.url, + headers=websocket.headers.items(), + cookies=list(websocket.cookies.items()), + version=websocket.http_version, ) - if adapter not in self._adapters: - logger.warning( - f"Unknown adapter {adapter}. Please register the adapter before use." - ) - raise exceptions.NotFound() + ws = WebSocket(request=http_request, websocket=websocket) - BotClass = self._adapters[adapter] - self_id, response = await BotClass.check_permission(self, ws) - - if not self_id: - raise exceptions.Unauthorized( - description=(response and response.body or b"").decode() - ) - - if self_id in self._clients: - logger.opt(colors=True).warning( - "There's already a websocket connection, " - f"{escape_tag(adapter.upper())} Bot {escape_tag(self_id)} ignored." - ) - raise exceptions.Forbidden(description="Client already exists.") - - bot = BotClass(self_id, ws) - await ws.accept() - logger.opt(colors=True).info( - f"WebSocket Connection from {escape_tag(adapter.upper())} " - f"Bot {escape_tag(self_id)} Accepted!" - ) - self._bot_connect(bot) - - try: - while not ws.closed: - try: - data = await ws.receive() - except asyncio.CancelledError: - logger.warning("WebSocket disconnected by peer.") - break - except Exception as e: - logger.opt(exception=e).error( - "Error when receiving data from websocket." - ) - break - - asyncio.create_task(bot.handle_message(data.encode())) - finally: - self._bot_disconnect(bot) + await setup.handle_func(ws) -@dataclass class WebSocket(BaseWebSocket): - websocket: QuartWebSocket = None # type: ignore + def __init__(self, *, request: BaseRequest, websocket: QuartWebSocket): + super().__init__(request=request) + self.websocket = websocket @property @overrides(BaseWebSocket) def closed(self): - # FIXME - return False + raise NotImplementedError @overrides(BaseWebSocket) async def accept(self): await self.websocket.accept() @overrides(BaseWebSocket) - async def close(self): - # FIXME - pass + async def close(self, code: int = 1000): + await self.websocket.close(code) @overrides(BaseWebSocket) async def receive(self) -> str: - return await self.websocket.receive() # type: ignore + msg = await self.websocket.receive() + if isinstance(msg, bytes): + raise TypeError("WebSocket received unexpected frame type: bytes") + return msg @overrides(BaseWebSocket) async def receive_bytes(self) -> bytes: - return await self.websocket.receive() # type: ignore + msg = await self.websocket.receive() + if isinstance(msg, str): + raise TypeError("WebSocket received unexpected frame type: str") + return msg @overrides(BaseWebSocket) async def send(self, data: str): diff --git a/nonebot/params.py b/nonebot/params.py index a4a6c5f5..bab4ed7f 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -61,7 +61,6 @@ def Depends( * ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数。默认为参数的类型注释。 * ``use_cache: bool = True``: 是否使用缓存。默认为 ``True``。 - * ``allow_types: Optional[List[Type[Param]]] = None``: 允许的参数类型。默认为 ``None``。 .. code-block:: python