mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-09-06 03:56:45 +00:00
♻️ rewrite quart driver
This commit is contained in:
@ -27,9 +27,9 @@ from nonebot.config import Env
|
|||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.utils import escape_tag
|
from nonebot.utils import escape_tag
|
||||||
from nonebot.config import Config as NoneBotConfig
|
from nonebot.config import Config as NoneBotConfig
|
||||||
from nonebot.drivers import Request as GenericRequest
|
from nonebot.drivers import Request as BaseRequest
|
||||||
|
from nonebot.drivers import Response as BaseResponse
|
||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
from nonebot.drivers import Response as GenericResponse
|
|
||||||
from nonebot.drivers import (
|
from nonebot.drivers import (
|
||||||
HTTPVersion,
|
HTTPVersion,
|
||||||
ForwardDriver,
|
ForwardDriver,
|
||||||
@ -247,7 +247,7 @@ class Driver(ReverseDriver):
|
|||||||
request: Request,
|
request: Request,
|
||||||
setup: HTTPServerSetup,
|
setup: HTTPServerSetup,
|
||||||
):
|
):
|
||||||
http_request = GenericRequest(
|
http_request = BaseRequest(
|
||||||
request.method,
|
request.method,
|
||||||
str(request.url),
|
str(request.url),
|
||||||
headers=request.headers.items(),
|
headers=request.headers.items(),
|
||||||
@ -260,7 +260,7 @@ class Driver(ReverseDriver):
|
|||||||
return Response(response.content, response.status_code, dict(response.headers))
|
return Response(response.content, response.status_code, dict(response.headers))
|
||||||
|
|
||||||
async def _handle_ws(self, websocket: WebSocket, setup: WebSocketServerSetup):
|
async def _handle_ws(self, websocket: WebSocket, setup: WebSocketServerSetup):
|
||||||
request = GenericRequest(
|
request = BaseRequest(
|
||||||
"GET",
|
"GET",
|
||||||
str(websocket.url),
|
str(websocket.url),
|
||||||
headers=websocket.headers.items(),
|
headers=websocket.headers.items(),
|
||||||
@ -293,7 +293,7 @@ class FullDriver(ForwardDriver, Driver):
|
|||||||
return "fastapi_full"
|
return "fastapi_full"
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
@overrides(ForwardDriver)
|
||||||
async def request(self, setup: "GenericRequest") -> Any:
|
async def request(self, setup: "BaseRequest") -> Any:
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
http2=setup.version == HTTPVersion.H2, follow_redirects=True
|
http2=setup.version == HTTPVersion.H2, follow_redirects=True
|
||||||
) as client:
|
) as client:
|
||||||
@ -304,7 +304,7 @@ class FullDriver(ForwardDriver, Driver):
|
|||||||
headers=tuple(setup.headers.items()),
|
headers=tuple(setup.headers.items()),
|
||||||
timeout=30.0,
|
timeout=30.0,
|
||||||
)
|
)
|
||||||
return GenericResponse(
|
return BaseResponse(
|
||||||
response.status_code,
|
response.status_code,
|
||||||
headers=response.headers,
|
headers=response.headers,
|
||||||
content=response.content,
|
content=response.content,
|
||||||
@ -312,31 +312,24 @@ class FullDriver(ForwardDriver, Driver):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
@overrides(ForwardDriver)
|
||||||
async def websocket(self, setup: "GenericRequest") -> Any:
|
async def websocket(self, setup: "BaseRequest") -> Any:
|
||||||
ws = await Connect(str(setup.url), extra_headers=setup.headers.items())
|
ws = await Connect(str(setup.url), extra_headers=setup.headers.items())
|
||||||
return WebSocketsWS(request=setup, websocket=ws)
|
return WebSocketsWS(request=setup, websocket=ws)
|
||||||
|
|
||||||
|
|
||||||
class WebSocketsWS(BaseWebSocket):
|
class WebSocketsWS(BaseWebSocket):
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
def __init__(self, *, request: GenericRequest, websocket: WebSocketClientProtocol):
|
def __init__(self, *, request: BaseRequest, websocket: WebSocketClientProtocol):
|
||||||
super().__init__(request=request)
|
super().__init__(request=request)
|
||||||
self.websocket = websocket
|
self.websocket = websocket
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
def closed(self) -> bool:
|
def closed(self) -> bool:
|
||||||
# if isinstance(self.websocket, WebSocket):
|
|
||||||
# return (
|
|
||||||
# self.websocket.client_state == WebSocketState.DISCONNECTED
|
|
||||||
# or self.websocket.application_state == WebSocketState.DISCONNECTED
|
|
||||||
# )
|
|
||||||
return self.websocket.closed
|
return self.websocket.closed
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def accept(self):
|
async def accept(self):
|
||||||
# if isinstance(self.websocket, WebSocket):
|
|
||||||
# await self.websocket.accept()
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
@ -345,8 +338,6 @@ class WebSocketsWS(BaseWebSocket):
|
|||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def receive(self) -> str:
|
async def receive(self) -> str:
|
||||||
# if isinstance(self.websocket, WebSocket):
|
|
||||||
# return await self.websocket.receive_text()
|
|
||||||
msg = await self.websocket.recv()
|
msg = await self.websocket.recv()
|
||||||
if isinstance(msg, bytes):
|
if isinstance(msg, bytes):
|
||||||
raise TypeError("WebSocket received unexpected frame type: bytes")
|
raise TypeError("WebSocket received unexpected frame type: bytes")
|
||||||
@ -354,8 +345,6 @@ class WebSocketsWS(BaseWebSocket):
|
|||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def receive_bytes(self) -> bytes:
|
async def receive_bytes(self) -> bytes:
|
||||||
# if isinstance(self.websocket, WebSocket):
|
|
||||||
# return await self.websocket.receive_bytes()
|
|
||||||
msg = await self.websocket.recv()
|
msg = await self.websocket.recv()
|
||||||
if isinstance(msg, str):
|
if isinstance(msg, str):
|
||||||
raise TypeError("WebSocket received unexpected frame type: str")
|
raise TypeError("WebSocket received unexpected frame type: str")
|
||||||
@ -363,20 +352,16 @@ class WebSocketsWS(BaseWebSocket):
|
|||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send(self, data: str) -> None:
|
async def send(self, data: str) -> None:
|
||||||
# if isinstance(self.websocket, WebSocket):
|
|
||||||
# await self.websocket.send({"type": "websocket.send", "text": data})
|
|
||||||
await self.websocket.send(data)
|
await self.websocket.send(data)
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send_bytes(self, data: bytes) -> None:
|
async def send_bytes(self, data: bytes) -> None:
|
||||||
# if isinstance(self.websocket, WebSocket):
|
|
||||||
# await self.websocket.send({"type": "websocket.send", "bytes": data})
|
|
||||||
await self.websocket.send(data)
|
await self.websocket.send(data)
|
||||||
|
|
||||||
|
|
||||||
class FastAPIWebSocket(BaseWebSocket):
|
class FastAPIWebSocket(BaseWebSocket):
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
def __init__(self, *, request: GenericRequest, websocket: WebSocket):
|
def __init__(self, *, request: BaseRequest, websocket: WebSocket):
|
||||||
super().__init__(request=request)
|
super().__init__(request=request)
|
||||||
self.websocket = websocket
|
self.websocket = websocket
|
||||||
|
|
||||||
|
@ -8,8 +8,7 @@ Quart 驱动适配
|
|||||||
https://pgjones.gitlab.io/quart/index.html
|
https://pgjones.gitlab.io/quart/index.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
from functools import partial
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, TypeVar, Callable, Optional, Coroutine
|
from typing import List, TypeVar, Callable, Optional, Coroutine
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@ -20,8 +19,9 @@ from nonebot.log import logger
|
|||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.utils import escape_tag
|
from nonebot.utils import escape_tag
|
||||||
from nonebot.config import Config as NoneBotConfig
|
from nonebot.config import Config as NoneBotConfig
|
||||||
from nonebot.drivers import HTTPRequest, ReverseDriver
|
from nonebot.drivers import Request as BaseRequest
|
||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
|
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from quart import request as _request
|
from quart import request as _request
|
||||||
@ -98,11 +98,6 @@ class Config(BaseSettings):
|
|||||||
class Driver(ReverseDriver):
|
class Driver(ReverseDriver):
|
||||||
"""
|
"""
|
||||||
Quart 驱动框架
|
Quart 驱动框架
|
||||||
|
|
||||||
:上报地址:
|
|
||||||
|
|
||||||
* ``/{adapter name}/http``: HTTP POST 上报
|
|
||||||
* ``/{adapter name}/ws``: WebSocket 上报
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: Env, config: NoneBotConfig):
|
def __init__(self, env: Env, config: NoneBotConfig):
|
||||||
@ -111,12 +106,6 @@ class Driver(ReverseDriver):
|
|||||||
self.quart_config = Config(**config.dict())
|
self.quart_config = Config(**config.dict())
|
||||||
|
|
||||||
self._server_app = Quart(self.__class__.__qualname__)
|
self._server_app = Quart(self.__class__.__qualname__)
|
||||||
self._server_app.add_url_rule(
|
|
||||||
"/<adapter>/http", methods=["POST"], view_func=self._handle_http
|
|
||||||
)
|
|
||||||
self._server_app.add_websocket(
|
|
||||||
"/<adapter>/ws", view_func=self._handle_ws_reverse
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@overrides(ReverseDriver)
|
@overrides(ReverseDriver)
|
||||||
@ -142,6 +131,21 @@ class Driver(ReverseDriver):
|
|||||||
"""Quart 使用的 logger"""
|
"""Quart 使用的 logger"""
|
||||||
return self._server_app.logger
|
return self._server_app.logger
|
||||||
|
|
||||||
|
@overrides(ReverseDriver)
|
||||||
|
def setup_http_server(self, setup: HTTPServerSetup):
|
||||||
|
self._server_app.add_url_rule(
|
||||||
|
setup.path.path,
|
||||||
|
methods=[setup.method],
|
||||||
|
view_func=partial(self._handle_http, setup=setup),
|
||||||
|
)
|
||||||
|
|
||||||
|
@overrides(ReverseDriver)
|
||||||
|
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
|
||||||
|
self._server_app.add_websocket(
|
||||||
|
setup.path.path,
|
||||||
|
view_func=partial(self._handle_ws, setup=setup),
|
||||||
|
)
|
||||||
|
|
||||||
@overrides(ReverseDriver)
|
@overrides(ReverseDriver)
|
||||||
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
|
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
|
||||||
"""参考文档: `Startup and Shutdown`_
|
"""参考文档: `Startup and Shutdown`_
|
||||||
@ -199,128 +203,75 @@ class Driver(ReverseDriver):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_http(self, adapter: str):
|
async def _handle_http(self, setup: HTTPServerSetup) -> Response:
|
||||||
request: Request = _request
|
request: Request = _request
|
||||||
data: bytes = await request.get_data() # type: ignore
|
|
||||||
|
|
||||||
if adapter not in self._adapters:
|
http_request = BaseRequest(
|
||||||
logger.warning(
|
|
||||||
f"Unknown adapter {adapter}. " "Please register the adapter before use."
|
|
||||||
)
|
|
||||||
raise exceptions.NotFound()
|
|
||||||
|
|
||||||
BotClass = self._adapters[adapter]
|
|
||||||
http_request = HTTPRequest(
|
|
||||||
request.http_version,
|
|
||||||
request.scheme,
|
|
||||||
request.path,
|
|
||||||
request.query_string,
|
|
||||||
dict(request.headers),
|
|
||||||
request.method,
|
request.method,
|
||||||
data,
|
request.url,
|
||||||
|
headers=request.headers.items(),
|
||||||
|
cookies=list(request.cookies.items()),
|
||||||
|
content=await request.get_data(
|
||||||
|
cache=False, as_text=False, parse_form_data=False
|
||||||
|
),
|
||||||
|
version=request.http_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
self_id, response = await BotClass.check_permission(self, http_request)
|
response = await setup.handle_func(http_request)
|
||||||
|
|
||||||
if not self_id:
|
|
||||||
raise exceptions.Unauthorized(
|
|
||||||
description=(response and response.body or b"").decode()
|
|
||||||
)
|
|
||||||
if self_id in self._clients:
|
|
||||||
logger.warning(
|
|
||||||
"There's already a reverse websocket connection,"
|
|
||||||
"so the event may be handled twice."
|
|
||||||
)
|
|
||||||
bot = BotClass(self_id, http_request)
|
|
||||||
asyncio.create_task(bot.handle_message(data))
|
|
||||||
return Response(
|
return Response(
|
||||||
response and response.body or "", response and response.status or 200
|
response.content or "",
|
||||||
|
response.status_code or 200,
|
||||||
|
headers=dict(response.headers),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_ws_reverse(self, adapter: str):
|
async def _handle_ws(self, setup: WebSocketServerSetup) -> None:
|
||||||
websocket: QuartWebSocket = _websocket
|
websocket: QuartWebSocket = _websocket
|
||||||
ws = WebSocket(
|
|
||||||
websocket.http_version,
|
http_request = BaseRequest(
|
||||||
websocket.scheme,
|
websocket.method,
|
||||||
websocket.path,
|
websocket.url,
|
||||||
websocket.query_string,
|
headers=websocket.headers.items(),
|
||||||
dict(websocket.headers),
|
cookies=list(websocket.cookies.items()),
|
||||||
websocket,
|
version=websocket.http_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
if adapter not in self._adapters:
|
ws = WebSocket(request=http_request, websocket=websocket)
|
||||||
logger.warning(
|
|
||||||
f"Unknown adapter {adapter}. Please register the adapter before use."
|
|
||||||
)
|
|
||||||
raise exceptions.NotFound()
|
|
||||||
|
|
||||||
BotClass = self._adapters[adapter]
|
await setup.handle_func(ws)
|
||||||
self_id, response = await BotClass.check_permission(self, ws)
|
|
||||||
|
|
||||||
if not self_id:
|
|
||||||
raise exceptions.Unauthorized(
|
|
||||||
description=(response and response.body or b"").decode()
|
|
||||||
)
|
|
||||||
|
|
||||||
if self_id in self._clients:
|
|
||||||
logger.opt(colors=True).warning(
|
|
||||||
"There's already a websocket connection, "
|
|
||||||
f"<y>{escape_tag(adapter.upper())} Bot {escape_tag(self_id)}</y> ignored."
|
|
||||||
)
|
|
||||||
raise exceptions.Forbidden(description="Client already exists.")
|
|
||||||
|
|
||||||
bot = BotClass(self_id, ws)
|
|
||||||
await ws.accept()
|
|
||||||
logger.opt(colors=True).info(
|
|
||||||
f"WebSocket Connection from <y>{escape_tag(adapter.upper())} "
|
|
||||||
f"Bot {escape_tag(self_id)}</y> Accepted!"
|
|
||||||
)
|
|
||||||
self._bot_connect(bot)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while not ws.closed:
|
|
||||||
try:
|
|
||||||
data = await ws.receive()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.warning("WebSocket disconnected by peer.")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(exception=e).error(
|
|
||||||
"Error when receiving data from websocket."
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
asyncio.create_task(bot.handle_message(data.encode()))
|
|
||||||
finally:
|
|
||||||
self._bot_disconnect(bot)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WebSocket(BaseWebSocket):
|
class WebSocket(BaseWebSocket):
|
||||||
websocket: QuartWebSocket = None # type: ignore
|
def __init__(self, *, request: BaseRequest, websocket: QuartWebSocket):
|
||||||
|
super().__init__(request=request)
|
||||||
|
self.websocket = websocket
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
def closed(self):
|
def closed(self):
|
||||||
# FIXME
|
raise NotImplementedError
|
||||||
return False
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def accept(self):
|
async def accept(self):
|
||||||
await self.websocket.accept()
|
await self.websocket.accept()
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def close(self):
|
async def close(self, code: int = 1000):
|
||||||
# FIXME
|
await self.websocket.close(code)
|
||||||
pass
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def receive(self) -> str:
|
async def receive(self) -> str:
|
||||||
return await self.websocket.receive() # type: ignore
|
msg = await self.websocket.receive()
|
||||||
|
if isinstance(msg, bytes):
|
||||||
|
raise TypeError("WebSocket received unexpected frame type: bytes")
|
||||||
|
return msg
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def receive_bytes(self) -> bytes:
|
async def receive_bytes(self) -> bytes:
|
||||||
return await self.websocket.receive() # type: ignore
|
msg = await self.websocket.receive()
|
||||||
|
if isinstance(msg, str):
|
||||||
|
raise TypeError("WebSocket received unexpected frame type: str")
|
||||||
|
return msg
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send(self, data: str):
|
async def send(self, data: str):
|
||||||
|
@ -61,7 +61,6 @@ def Depends(
|
|||||||
|
|
||||||
* ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数。默认为参数的类型注释。
|
* ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数。默认为参数的类型注释。
|
||||||
* ``use_cache: bool = True``: 是否使用缓存。默认为 ``True``。
|
* ``use_cache: bool = True``: 是否使用缓存。默认为 ``True``。
|
||||||
* ``allow_types: Optional[List[Type[Param]]] = None``: 允许的参数类型。默认为 ``None``。
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user