♻️ rewrite quart driver

This commit is contained in:
yanyongyu
2021-12-20 15:46:23 +08:00
parent c49059f9d3
commit ea8f7717b9
3 changed files with 64 additions and 129 deletions

View File

@ -27,9 +27,9 @@ from nonebot.config import Env
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as GenericRequest
from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import Response as BaseResponse
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import Response as GenericResponse
from nonebot.drivers import (
HTTPVersion,
ForwardDriver,
@ -247,7 +247,7 @@ class Driver(ReverseDriver):
request: Request,
setup: HTTPServerSetup,
):
http_request = GenericRequest(
http_request = BaseRequest(
request.method,
str(request.url),
headers=request.headers.items(),
@ -260,7 +260,7 @@ class Driver(ReverseDriver):
return Response(response.content, response.status_code, dict(response.headers))
async def _handle_ws(self, websocket: WebSocket, setup: WebSocketServerSetup):
request = GenericRequest(
request = BaseRequest(
"GET",
str(websocket.url),
headers=websocket.headers.items(),
@ -293,7 +293,7 @@ class FullDriver(ForwardDriver, Driver):
return "fastapi_full"
@overrides(ForwardDriver)
async def request(self, setup: "GenericRequest") -> Any:
async def request(self, setup: "BaseRequest") -> Any:
async with httpx.AsyncClient(
http2=setup.version == HTTPVersion.H2, follow_redirects=True
) as client:
@ -304,7 +304,7 @@ class FullDriver(ForwardDriver, Driver):
headers=tuple(setup.headers.items()),
timeout=30.0,
)
return GenericResponse(
return BaseResponse(
response.status_code,
headers=response.headers,
content=response.content,
@ -312,31 +312,24 @@ class FullDriver(ForwardDriver, Driver):
)
@overrides(ForwardDriver)
async def websocket(self, setup: "GenericRequest") -> Any:
async def websocket(self, setup: "BaseRequest") -> Any:
ws = await Connect(str(setup.url), extra_headers=setup.headers.items())
return WebSocketsWS(request=setup, websocket=ws)
class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, *, request: GenericRequest, websocket: WebSocketClientProtocol):
def __init__(self, *, request: BaseRequest, websocket: WebSocketClientProtocol):
super().__init__(request=request)
self.websocket = websocket
@property
@overrides(BaseWebSocket)
def closed(self) -> bool:
# if isinstance(self.websocket, WebSocket):
# return (
# self.websocket.client_state == WebSocketState.DISCONNECTED
# or self.websocket.application_state == WebSocketState.DISCONNECTED
# )
return self.websocket.closed
@overrides(BaseWebSocket)
async def accept(self):
# if isinstance(self.websocket, WebSocket):
# await self.websocket.accept()
raise NotImplementedError
@overrides(BaseWebSocket)
@ -345,8 +338,6 @@ class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket)
async def receive(self) -> str:
# if isinstance(self.websocket, WebSocket):
# return await self.websocket.receive_text()
msg = await self.websocket.recv()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
@ -354,8 +345,6 @@ class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
# if isinstance(self.websocket, WebSocket):
# return await self.websocket.receive_bytes()
msg = await self.websocket.recv()
if isinstance(msg, str):
raise TypeError("WebSocket received unexpected frame type: str")
@ -363,20 +352,16 @@ class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
# if isinstance(self.websocket, WebSocket):
# await self.websocket.send({"type": "websocket.send", "text": data})
await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
# if isinstance(self.websocket, WebSocket):
# await self.websocket.send({"type": "websocket.send", "bytes": data})
await self.websocket.send(data)
class FastAPIWebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, *, request: GenericRequest, websocket: WebSocket):
def __init__(self, *, request: BaseRequest, websocket: WebSocket):
super().__init__(request=request)
self.websocket = websocket