mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-17 19:40:44 +00:00
♻️ rewrite quart driver
This commit is contained in:
@ -27,9 +27,9 @@ from nonebot.config import Env
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.config import Config as NoneBotConfig
|
||||
from nonebot.drivers import Request as GenericRequest
|
||||
from nonebot.drivers import Request as BaseRequest
|
||||
from nonebot.drivers import Response as BaseResponse
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import Response as GenericResponse
|
||||
from nonebot.drivers import (
|
||||
HTTPVersion,
|
||||
ForwardDriver,
|
||||
@ -247,7 +247,7 @@ class Driver(ReverseDriver):
|
||||
request: Request,
|
||||
setup: HTTPServerSetup,
|
||||
):
|
||||
http_request = GenericRequest(
|
||||
http_request = BaseRequest(
|
||||
request.method,
|
||||
str(request.url),
|
||||
headers=request.headers.items(),
|
||||
@ -260,7 +260,7 @@ class Driver(ReverseDriver):
|
||||
return Response(response.content, response.status_code, dict(response.headers))
|
||||
|
||||
async def _handle_ws(self, websocket: WebSocket, setup: WebSocketServerSetup):
|
||||
request = GenericRequest(
|
||||
request = BaseRequest(
|
||||
"GET",
|
||||
str(websocket.url),
|
||||
headers=websocket.headers.items(),
|
||||
@ -293,7 +293,7 @@ class FullDriver(ForwardDriver, Driver):
|
||||
return "fastapi_full"
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
async def request(self, setup: "GenericRequest") -> Any:
|
||||
async def request(self, setup: "BaseRequest") -> Any:
|
||||
async with httpx.AsyncClient(
|
||||
http2=setup.version == HTTPVersion.H2, follow_redirects=True
|
||||
) as client:
|
||||
@ -304,7 +304,7 @@ class FullDriver(ForwardDriver, Driver):
|
||||
headers=tuple(setup.headers.items()),
|
||||
timeout=30.0,
|
||||
)
|
||||
return GenericResponse(
|
||||
return BaseResponse(
|
||||
response.status_code,
|
||||
headers=response.headers,
|
||||
content=response.content,
|
||||
@ -312,31 +312,24 @@ class FullDriver(ForwardDriver, Driver):
|
||||
)
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
async def websocket(self, setup: "GenericRequest") -> Any:
|
||||
async def websocket(self, setup: "BaseRequest") -> Any:
|
||||
ws = await Connect(str(setup.url), extra_headers=setup.headers.items())
|
||||
return WebSocketsWS(request=setup, websocket=ws)
|
||||
|
||||
|
||||
class WebSocketsWS(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
def __init__(self, *, request: GenericRequest, websocket: WebSocketClientProtocol):
|
||||
def __init__(self, *, request: BaseRequest, websocket: WebSocketClientProtocol):
|
||||
super().__init__(request=request)
|
||||
self.websocket = websocket
|
||||
|
||||
@property
|
||||
@overrides(BaseWebSocket)
|
||||
def closed(self) -> bool:
|
||||
# if isinstance(self.websocket, WebSocket):
|
||||
# return (
|
||||
# self.websocket.client_state == WebSocketState.DISCONNECTED
|
||||
# or self.websocket.application_state == WebSocketState.DISCONNECTED
|
||||
# )
|
||||
return self.websocket.closed
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def accept(self):
|
||||
# if isinstance(self.websocket, WebSocket):
|
||||
# await self.websocket.accept()
|
||||
raise NotImplementedError
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@ -345,8 +338,6 @@ class WebSocketsWS(BaseWebSocket):
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive(self) -> str:
|
||||
# if isinstance(self.websocket, WebSocket):
|
||||
# return await self.websocket.receive_text()
|
||||
msg = await self.websocket.recv()
|
||||
if isinstance(msg, bytes):
|
||||
raise TypeError("WebSocket received unexpected frame type: bytes")
|
||||
@ -354,8 +345,6 @@ class WebSocketsWS(BaseWebSocket):
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive_bytes(self) -> bytes:
|
||||
# if isinstance(self.websocket, WebSocket):
|
||||
# return await self.websocket.receive_bytes()
|
||||
msg = await self.websocket.recv()
|
||||
if isinstance(msg, str):
|
||||
raise TypeError("WebSocket received unexpected frame type: str")
|
||||
@ -363,20 +352,16 @@ class WebSocketsWS(BaseWebSocket):
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send(self, data: str) -> None:
|
||||
# if isinstance(self.websocket, WebSocket):
|
||||
# await self.websocket.send({"type": "websocket.send", "text": data})
|
||||
await self.websocket.send(data)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
# if isinstance(self.websocket, WebSocket):
|
||||
# await self.websocket.send({"type": "websocket.send", "bytes": data})
|
||||
await self.websocket.send(data)
|
||||
|
||||
|
||||
class FastAPIWebSocket(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
def __init__(self, *, request: GenericRequest, websocket: WebSocket):
|
||||
def __init__(self, *, request: BaseRequest, websocket: WebSocket):
|
||||
super().__init__(request=request)
|
||||
self.websocket = websocket
|
||||
|
||||
|
Reference in New Issue
Block a user