♻️ rewrite driver request and response class

This commit is contained in:
yanyongyu
2021-12-17 23:20:19 +08:00
parent c0f321116a
commit ec9e159ef6
4 changed files with 296 additions and 159 deletions

View File

@ -7,22 +7,13 @@
import abc
import asyncio
from dataclasses import field, dataclass
from typing import (
TYPE_CHECKING,
Any,
Set,
Dict,
Type,
Union,
Callable,
Optional,
Awaitable,
)
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable
from nonebot.log import logger
from nonebot.utils import escape_tag
from nonebot.config import Env, Config
from ._model import URL, Request, Response, WebSocket
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
if TYPE_CHECKING:
@ -213,11 +204,11 @@ class ForwardDriver(Driver):
"""
@abc.abstractmethod
async def request(self, setup: "HTTPRequest") -> Any:
async def request(self, setup: "Request") -> Any:
raise NotImplementedError
@abc.abstractmethod
async def websocket(self, setup: "HTTPConnection") -> Any:
async def websocket(self, setup: "Request") -> Any:
raise NotImplementedError
@ -247,153 +238,16 @@ class ReverseDriver(Driver):
raise NotImplementedError
# TODO: repack dataclass
@dataclass
class HTTPConnection(abc.ABC):
http_version: str
"""One of ``"1.0"``, ``"1.1"`` or ``"2"``."""
scheme: str
"""URL scheme portion (likely ``"http"`` or ``"https"``)."""
path: str
"""
HTTP request target excluding any query string,
with percent-encoded sequences and UTF-8 byte sequences
decoded into characters.
"""
query_string: bytes = b""
""" URL portion after the ``?``, percent-encoded."""
headers: Dict[str, str] = field(default_factory=dict)
"""A dict of name-value pairs,
where name is the header name, and value is the header value.
Order of header values must be preserved from the original HTTP request;
order of header names is not important.
Header names must be lowercased.
"""
@property
@abc.abstractmethod
def type(self) -> str:
"""Connection type."""
raise NotImplementedError
@dataclass
class HTTPRequest(HTTPConnection):
"""HTTP 请求封装。参考 `asgi http scope`_。
.. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
"""
method: str = "GET"
"""The HTTP method name, uppercased."""
body: bytes = b""
"""Body of the request.
Optional; if missing defaults to ``b""``.
"""
@property
def type(self) -> str:
"""Always ``http``"""
return "http"
@dataclass
class HTTPResponse:
"""HTTP 响应封装。参考 `asgi http scope`_。
.. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
"""
status: int
"""HTTP status code."""
body: Optional[bytes] = None
"""HTTP body content.
Optional; if missing defaults to ``None``.
"""
headers: Dict[str, str] = field(default_factory=dict)
"""A dict of name-value pairs,
where name is the header name, and value is the header value.
Order must be preserved in the HTTP response.
Header names must be lowercased.
Optional; if missing defaults to an empty dict.
"""
@property
def type(self) -> str:
"""Always ``http``"""
return "http"
@dataclass
class WebSocket(HTTPConnection, abc.ABC):
"""WebSocket 连接封装。参考 `asgi websocket scope`_。
.. _asgi websocket scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope
"""
@property
def type(self) -> str:
"""Always ``websocket``"""
return "websocket"
@property
@abc.abstractmethod
def closed(self) -> bool:
"""
:类型: ``bool``
:说明: 连接是否已经关闭
"""
raise NotImplementedError
@abc.abstractmethod
async def accept(self):
"""接受 WebSocket 连接请求"""
raise NotImplementedError
@abc.abstractmethod
async def close(self, code: int):
"""关闭 WebSocket 连接请求"""
raise NotImplementedError
@abc.abstractmethod
async def receive(self) -> str:
"""接收一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def receive_bytes(self) -> bytes:
"""接收一条 WebSocket binary 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, data: str):
"""发送一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send_bytes(self, data: bytes):
"""发送一条 WebSocket binary 信息"""
raise NotImplementedError
@dataclass
class HTTPServerSetup:
path: str
path: URL # path should not be absolute, check it by URL.is_absolute() == False
method: str
handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]]
name: str
handle_func: Callable[[Request], Awaitable[Response]]
@dataclass
class WebSocketServerSetup:
path: str
path: URL # path should not be absolute, check it by URL.is_absolute() == False
name: str
handle_func: Callable[[WebSocket], Awaitable[Any]]