♻️ rewrite quart driver

This commit is contained in:
yanyongyu
2021-12-20 15:46:23 +08:00
parent c49059f9d3
commit ea8f7717b9
3 changed files with 64 additions and 129 deletions

View File

@ -27,9 +27,9 @@ from nonebot.config import Env
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.config import Config as NoneBotConfig from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as GenericRequest from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import Response as BaseResponse
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import Response as GenericResponse
from nonebot.drivers import ( from nonebot.drivers import (
HTTPVersion, HTTPVersion,
ForwardDriver, ForwardDriver,
@ -247,7 +247,7 @@ class Driver(ReverseDriver):
request: Request, request: Request,
setup: HTTPServerSetup, setup: HTTPServerSetup,
): ):
http_request = GenericRequest( http_request = BaseRequest(
request.method, request.method,
str(request.url), str(request.url),
headers=request.headers.items(), headers=request.headers.items(),
@ -260,7 +260,7 @@ class Driver(ReverseDriver):
return Response(response.content, response.status_code, dict(response.headers)) return Response(response.content, response.status_code, dict(response.headers))
async def _handle_ws(self, websocket: WebSocket, setup: WebSocketServerSetup): async def _handle_ws(self, websocket: WebSocket, setup: WebSocketServerSetup):
request = GenericRequest( request = BaseRequest(
"GET", "GET",
str(websocket.url), str(websocket.url),
headers=websocket.headers.items(), headers=websocket.headers.items(),
@ -293,7 +293,7 @@ class FullDriver(ForwardDriver, Driver):
return "fastapi_full" return "fastapi_full"
@overrides(ForwardDriver) @overrides(ForwardDriver)
async def request(self, setup: "GenericRequest") -> Any: async def request(self, setup: "BaseRequest") -> Any:
async with httpx.AsyncClient( async with httpx.AsyncClient(
http2=setup.version == HTTPVersion.H2, follow_redirects=True http2=setup.version == HTTPVersion.H2, follow_redirects=True
) as client: ) as client:
@ -304,7 +304,7 @@ class FullDriver(ForwardDriver, Driver):
headers=tuple(setup.headers.items()), headers=tuple(setup.headers.items()),
timeout=30.0, timeout=30.0,
) )
return GenericResponse( return BaseResponse(
response.status_code, response.status_code,
headers=response.headers, headers=response.headers,
content=response.content, content=response.content,
@ -312,31 +312,24 @@ class FullDriver(ForwardDriver, Driver):
) )
@overrides(ForwardDriver) @overrides(ForwardDriver)
async def websocket(self, setup: "GenericRequest") -> Any: async def websocket(self, setup: "BaseRequest") -> Any:
ws = await Connect(str(setup.url), extra_headers=setup.headers.items()) ws = await Connect(str(setup.url), extra_headers=setup.headers.items())
return WebSocketsWS(request=setup, websocket=ws) return WebSocketsWS(request=setup, websocket=ws)
class WebSocketsWS(BaseWebSocket): class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
def __init__(self, *, request: GenericRequest, websocket: WebSocketClientProtocol): def __init__(self, *, request: BaseRequest, websocket: WebSocketClientProtocol):
super().__init__(request=request) super().__init__(request=request)
self.websocket = websocket self.websocket = websocket
@property @property
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
def closed(self) -> bool: def closed(self) -> bool:
# if isinstance(self.websocket, WebSocket):
# return (
# self.websocket.client_state == WebSocketState.DISCONNECTED
# or self.websocket.application_state == WebSocketState.DISCONNECTED
# )
return self.websocket.closed return self.websocket.closed
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def accept(self): async def accept(self):
# if isinstance(self.websocket, WebSocket):
# await self.websocket.accept()
raise NotImplementedError raise NotImplementedError
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
@ -345,8 +338,6 @@ class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def receive(self) -> str: async def receive(self) -> str:
# if isinstance(self.websocket, WebSocket):
# return await self.websocket.receive_text()
msg = await self.websocket.recv() msg = await self.websocket.recv()
if isinstance(msg, bytes): if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes") raise TypeError("WebSocket received unexpected frame type: bytes")
@ -354,8 +345,6 @@ class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes: async def receive_bytes(self) -> bytes:
# if isinstance(self.websocket, WebSocket):
# return await self.websocket.receive_bytes()
msg = await self.websocket.recv() msg = await self.websocket.recv()
if isinstance(msg, str): if isinstance(msg, str):
raise TypeError("WebSocket received unexpected frame type: str") raise TypeError("WebSocket received unexpected frame type: str")
@ -363,20 +352,16 @@ class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send(self, data: str) -> None: async def send(self, data: str) -> None:
# if isinstance(self.websocket, WebSocket):
# await self.websocket.send({"type": "websocket.send", "text": data})
await self.websocket.send(data) await self.websocket.send(data)
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None: async def send_bytes(self, data: bytes) -> None:
# if isinstance(self.websocket, WebSocket):
# await self.websocket.send({"type": "websocket.send", "bytes": data})
await self.websocket.send(data) await self.websocket.send(data)
class FastAPIWebSocket(BaseWebSocket): class FastAPIWebSocket(BaseWebSocket):
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
def __init__(self, *, request: GenericRequest, websocket: WebSocket): def __init__(self, *, request: BaseRequest, websocket: WebSocket):
super().__init__(request=request) super().__init__(request=request)
self.websocket = websocket self.websocket = websocket

View File

@ -8,8 +8,7 @@ Quart 驱动适配
https://pgjones.gitlab.io/quart/index.html https://pgjones.gitlab.io/quart/index.html
""" """
import asyncio from functools import partial
from dataclasses import dataclass
from typing import List, TypeVar, Callable, Optional, Coroutine from typing import List, TypeVar, Callable, Optional, Coroutine
import uvicorn import uvicorn
@ -20,8 +19,9 @@ from nonebot.log import logger
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.config import Config as NoneBotConfig from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import HTTPRequest, ReverseDriver from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
try: try:
from quart import request as _request from quart import request as _request
@ -98,11 +98,6 @@ class Config(BaseSettings):
class Driver(ReverseDriver): class Driver(ReverseDriver):
""" """
Quart 驱动框架 Quart 驱动框架
:上报地址:
* ``/{adapter name}/http``: HTTP POST 上报
* ``/{adapter name}/ws``: WebSocket 上报
""" """
def __init__(self, env: Env, config: NoneBotConfig): def __init__(self, env: Env, config: NoneBotConfig):
@ -111,12 +106,6 @@ class Driver(ReverseDriver):
self.quart_config = Config(**config.dict()) self.quart_config = Config(**config.dict())
self._server_app = Quart(self.__class__.__qualname__) self._server_app = Quart(self.__class__.__qualname__)
self._server_app.add_url_rule(
"/<adapter>/http", methods=["POST"], view_func=self._handle_http
)
self._server_app.add_websocket(
"/<adapter>/ws", view_func=self._handle_ws_reverse
)
@property @property
@overrides(ReverseDriver) @overrides(ReverseDriver)
@ -142,6 +131,21 @@ class Driver(ReverseDriver):
"""Quart 使用的 logger""" """Quart 使用的 logger"""
return self._server_app.logger return self._server_app.logger
@overrides(ReverseDriver)
def setup_http_server(self, setup: HTTPServerSetup):
self._server_app.add_url_rule(
setup.path.path,
methods=[setup.method],
view_func=partial(self._handle_http, setup=setup),
)
@overrides(ReverseDriver)
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
self._server_app.add_websocket(
setup.path.path,
view_func=partial(self._handle_ws, setup=setup),
)
@overrides(ReverseDriver) @overrides(ReverseDriver)
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable: def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: `Startup and Shutdown`_ """参考文档: `Startup and Shutdown`_
@ -199,128 +203,75 @@ class Driver(ReverseDriver):
**kwargs, **kwargs,
) )
async def _handle_http(self, adapter: str): async def _handle_http(self, setup: HTTPServerSetup) -> Response:
request: Request = _request request: Request = _request
data: bytes = await request.get_data() # type: ignore
if adapter not in self._adapters: http_request = BaseRequest(
logger.warning(
f"Unknown adapter {adapter}. " "Please register the adapter before use."
)
raise exceptions.NotFound()
BotClass = self._adapters[adapter]
http_request = HTTPRequest(
request.http_version,
request.scheme,
request.path,
request.query_string,
dict(request.headers),
request.method, request.method,
data, request.url,
headers=request.headers.items(),
cookies=list(request.cookies.items()),
content=await request.get_data(
cache=False, as_text=False, parse_form_data=False
),
version=request.http_version,
) )
self_id, response = await BotClass.check_permission(self, http_request) response = await setup.handle_func(http_request)
if not self_id:
raise exceptions.Unauthorized(
description=(response and response.body or b"").decode()
)
if self_id in self._clients:
logger.warning(
"There's already a reverse websocket connection,"
"so the event may be handled twice."
)
bot = BotClass(self_id, http_request)
asyncio.create_task(bot.handle_message(data))
return Response( return Response(
response and response.body or "", response and response.status or 200 response.content or "",
response.status_code or 200,
headers=dict(response.headers),
) )
async def _handle_ws_reverse(self, adapter: str): async def _handle_ws(self, setup: WebSocketServerSetup) -> None:
websocket: QuartWebSocket = _websocket websocket: QuartWebSocket = _websocket
ws = WebSocket(
websocket.http_version, http_request = BaseRequest(
websocket.scheme, websocket.method,
websocket.path, websocket.url,
websocket.query_string, headers=websocket.headers.items(),
dict(websocket.headers), cookies=list(websocket.cookies.items()),
websocket, version=websocket.http_version,
) )
if adapter not in self._adapters: ws = WebSocket(request=http_request, websocket=websocket)
logger.warning(
f"Unknown adapter {adapter}. Please register the adapter before use."
)
raise exceptions.NotFound()
BotClass = self._adapters[adapter] await setup.handle_func(ws)
self_id, response = await BotClass.check_permission(self, ws)
if not self_id:
raise exceptions.Unauthorized(
description=(response and response.body or b"").decode()
)
if self_id in self._clients:
logger.opt(colors=True).warning(
"There's already a websocket connection, "
f"<y>{escape_tag(adapter.upper())} Bot {escape_tag(self_id)}</y> ignored."
)
raise exceptions.Forbidden(description="Client already exists.")
bot = BotClass(self_id, ws)
await ws.accept()
logger.opt(colors=True).info(
f"WebSocket Connection from <y>{escape_tag(adapter.upper())} "
f"Bot {escape_tag(self_id)}</y> Accepted!"
)
self._bot_connect(bot)
try:
while not ws.closed:
try:
data = await ws.receive()
except asyncio.CancelledError:
logger.warning("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket."
)
break
asyncio.create_task(bot.handle_message(data.encode()))
finally:
self._bot_disconnect(bot)
@dataclass
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):
websocket: QuartWebSocket = None # type: ignore def __init__(self, *, request: BaseRequest, websocket: QuartWebSocket):
super().__init__(request=request)
self.websocket = websocket
@property @property
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
def closed(self): def closed(self):
# FIXME raise NotImplementedError
return False
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def accept(self): async def accept(self):
await self.websocket.accept() await self.websocket.accept()
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def close(self): async def close(self, code: int = 1000):
# FIXME await self.websocket.close(code)
pass
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def receive(self) -> str: async def receive(self) -> str:
return await self.websocket.receive() # type: ignore msg = await self.websocket.receive()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
return msg
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes: async def receive_bytes(self) -> bytes:
return await self.websocket.receive() # type: ignore msg = await self.websocket.receive()
if isinstance(msg, str):
raise TypeError("WebSocket received unexpected frame type: str")
return msg
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send(self, data: str): async def send(self, data: str):

View File

@ -61,7 +61,6 @@ def Depends(
* ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数。默认为参数的类型注释。 * ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数。默认为参数的类型注释。
* ``use_cache: bool = True``: 是否使用缓存。默认为 ``True``。 * ``use_cache: bool = True``: 是否使用缓存。默认为 ``True``。
* ``allow_types: Optional[List[Type[Param]]] = None``: 允许的参数类型。默认为 ``None``。
.. code-block:: python .. code-block:: python