support custom response

This commit is contained in:
StarHeartHunt
2021-06-10 21:52:20 +08:00
committed by yanyongyu
parent ca31ec5fe3
commit c0d78449be
25 changed files with 365 additions and 7542 deletions

View File

@ -11,13 +11,14 @@ from copy import copy
from functools import reduce, partial
from typing_extensions import Protocol
from dataclasses import dataclass, field
from typing import (Any, Set, List, Dict, Union, TypeVar, Mapping, Optional,
Iterable, Awaitable, TYPE_CHECKING)
from typing import (Any, Set, List, Dict, Tuple, Union, TypeVar, Mapping,
Optional, Iterable, Awaitable, TYPE_CHECKING)
from pydantic import BaseModel
from nonebot.log import logger
from nonebot.utils import DataclassEncoder
from nonebot.drivers import HTTPConnection, HTTPResponse
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
if TYPE_CHECKING:
@ -51,12 +52,7 @@ class Bot(abc.ABC):
:说明: call_api 后执行的函数
"""
@abc.abstractmethod
def __init__(self,
connection_type: str,
self_id: str,
*,
websocket: Optional["WebSocket"] = None):
def __init__(self, self_id: str, request: HTTPConnection):
"""
:参数:
@ -64,12 +60,10 @@ class Bot(abc.ABC):
* ``self_id: str``: 机器人 ID
* ``websocket: Optional[WebSocket]``: Websocket 连接对象
"""
self.connection_type = connection_type
"""连接类型"""
self.self_id = self_id
self.self_id: str = self_id
"""机器人 ID"""
self.websocket = websocket
"""Websocket 连接对象"""
self.request: HTTPConnection = request
"""连接信息"""
def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name)
@ -92,8 +86,9 @@ class Bot(abc.ABC):
@classmethod
@abc.abstractmethod
async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[bytes]) -> str:
async def check_permission(
cls, driver: "Driver", request: HTTPConnection
) -> Tuple[Optional[str], Optional[HTTPResponse]]:
"""
:说明:
@ -108,7 +103,8 @@ class Bot(abc.ABC):
:返回:
- ``str``: 连接唯一标识符
- ``str``: 连接唯一标识符``None`` 代表连接不合法
- ``HTTPResponse``: HTTP 上报响应
:异常:
@ -117,7 +113,7 @@ class Bot(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
async def handle_message(self, message: dict):
async def handle_message(self, message: bytes):
"""
:说明:
@ -125,7 +121,7 @@ class Bot(abc.ABC):
:参数:
* ``message: dict``: 收到的上报消息
* ``message: bytes``: 收到的上报消息
"""
raise NotImplementedError

View File

@ -7,8 +7,8 @@
import abc
import asyncio
from typing import (Any, Set, List, Dict, Type, Tuple, Optional, Callable,
MutableMapping, TYPE_CHECKING)
from dataclasses import dataclass, field
from typing import Set, Dict, Type, Optional, Callable, TYPE_CHECKING
from nonebot.log import logger
from nonebot.config import Env, Config
@ -47,12 +47,12 @@ class Driver(abc.ABC):
* ``env: Env``: 包含环境信息的 Env 对象
* ``config: Config``: 包含配置信息的 Config 对象
"""
self.env = env.environment
self.env: str = env.environment
"""
:类型: ``str``
:说明: 环境名称
"""
self.config = config
self.config: Config = config
"""
:类型: ``Config``
:说明: 配置对象
@ -231,143 +231,101 @@ class ReverseDriver(Driver):
raise NotImplementedError
class HTTPRequest:
@dataclass
class HTTPConnection(abc.ABC):
http_version: str
"""One of `"1.0"`, `"1.1"` or `"2"`."""
scheme: str
"""URL scheme portion (likely `"http"` or `"https"`)."""
path: str
"""
HTTP request target excluding any query string,
with percent-encoded sequences and UTF-8 byte sequences
decoded into characters.
"""
query_string: bytes = b""
""" URL portion after the `?`, percent-encoded."""
headers: Dict[str, str] = field(default_factory=dict)
"""A dict of name-value pairs,
where name is the header name, and value is the header value.
Order of header values must be preserved from the original HTTP request;
order of header names is not important.
Header names must be lowercased.
"""
@property
@abc.abstractmethod
def type(self) -> str:
"""Connection type."""
raise NotImplementedError
@dataclass
class HTTPRequest(HTTPConnection):
"""HTTP 请求封装。参考 `asgi http scope`_。
.. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
"""
method: str = "GET"
"""The HTTP method name, uppercased."""
body: bytes = b""
"""Body of the request.
def __init__(self, scope: MutableMapping[str, Any]):
self._scope = scope
Optional; if missing defaults to b"".
"""
@property
def type(self) -> str:
"""Always `http`"""
"""Always ``http``"""
return "http"
@property
def scope(self) -> MutableMapping[str, Any]:
"""Raw scope from asgi.
The connection scope information, a dictionary that
contains at least a `type` key specifying the protocol that is incoming.
"""
return self._scope
@property
def http_version(self) -> str:
"""One of `"1.0"`, `"1.1"` or `"2"`."""
raise self.scope["http_version"]
@property
def method(self) -> str:
"""The HTTP method name, uppercased."""
raise self.scope["method"]
@property
def schema(self) -> str:
"""
URL scheme portion (likely `"http"` or `"https"`).
Optional (but must not be empty); default is `"http"`.
"""
raise self.scope["schema"]
@property
def path(self) -> str:
"""
HTTP request target excluding any query string,
with percent-encoded sequences and UTF-8 byte sequences
decoded into characters.
"""
return self.scope["path"]
@property
def query_string(self) -> bytes:
""" URL portion after the `?`, percent-encoded."""
return self.scope["query_string"]
@property
def headers(self) -> List[Tuple[bytes, bytes]]:
"""An iterable of [name, value] two-item iterables,
where name is the header name, and value is the header value.
Order of header values must be preserved from the original HTTP request;
order of header names is not important.
Duplicates are possible and must be preserved in the message as received.
Header names must be lowercased.
"""
return list(self.scope["headers"])
@property
def body(self) -> bytes:
"""Body of the request.
Optional; if missing defaults to b"".
If more_body is set, treat as start of body and concatenate on further chunks.
"""
return self.scope["body"]
@dataclass
class HTTPResponse:
"""HTTP 响应封装。参考 `asgi http scope`_。
.. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
"""
status: int
"""HTTP status code."""
body: Optional[bytes] = None
"""HTTP body content.
def __init__(self,
status: int,
headers: List[Tuple[bytes, bytes]] = [],
body: Optional[bytes] = None):
self.status: int = status
"""HTTP status code."""
self.headers: List[Tuple[bytes, bytes]] = headers
"""An iterable of [name, value] two-item iterables,
where name is the header name,
and value is the header value.
Optional; if missing defaults to ``None``.
"""
headers: Dict[str, str] = field(default_factory=dict)
"""A dict of name-value pairs,
where name is the header name, and value is the header value.
Order must be preserved in the HTTP response.
Order must be preserved in the HTTP response.
Header names must be lowercased.
Header names must be lowercased.
Optional; if missing defaults to an empty list.
"""
self.body: Optional[bytes] = body
"""HTTP body content.
Optional; if missing defaults to `None`.
"""
Optional; if missing defaults to an empty dict.
"""
@property
def type(self) -> str:
"""Always `http`"""
"""Always ``http``"""
return "http"
class WebSocket:
@dataclass
class WebSocket(HTTPConnection, abc.ABC):
"""WebSocket 连接封装。参考 `asgi websocket scope`_。
.. _asgi websocket scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope
"""
@abc.abstractmethod
def __init__(self, websocket):
"""
:参数:
* ``websocket: Any``: WebSocket 连接对象
"""
self._websocket = websocket
@property
def websocket(self):
"""WebSocket 连接对象"""
return self._websocket
def type(self) -> str:
"""Always ``websocket``"""
return "websocket"
@property
@abc.abstractmethod
@ -389,11 +347,21 @@ class WebSocket:
raise NotImplementedError
@abc.abstractmethod
async def receive(self) -> dict:
"""接收一条 WebSocket 信息"""
async def receive(self) -> str:
"""接收一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, data: dict):
"""发送一条 WebSocket 信息"""
async def receive_bytes(self) -> bytes:
"""接收一条 WebSocket binary 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, data: str):
"""发送一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send_bytes(self, data: bytes):
"""发送一条 WebSocket text 信息"""
raise NotImplementedError

View File

@ -8,23 +8,22 @@ FastAPI 驱动适配
https://fastapi.tiangolo.com/
"""
import json
import asyncio
import logging
from dataclasses import dataclass
from typing import List, Optional, Callable
import uvicorn
from pydantic import BaseSettings
from fastapi.responses import Response
from fastapi import status, Request, FastAPI, HTTPException
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
from starlette.websockets import (WebSocketState, WebSocketDisconnect, WebSocket
as FastAPIWebSocket)
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.utils import DataclassEncoder
from nonebot.exception import RequestDenied
from nonebot.config import Env, Config as NoneBotConfig
from nonebot.drivers import ReverseDriver, WebSocket as BaseWebSocket
from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket
class Config(BaseSettings):
@ -179,11 +178,6 @@ class Driver(ReverseDriver):
@overrides(ReverseDriver)
async def _handle_http(self, adapter: str, request: Request):
data = await request.body()
data_dict = json.loads(data.decode())
if not isinstance(data_dict, dict):
logger.warning("Data received is invalid")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
if adapter not in self._adapters:
logger.warning(
@ -194,27 +188,34 @@ class Driver(ReverseDriver):
# 创建 Bot 对象
BotClass = self._adapters[adapter]
headers = dict(request.headers)
try:
x_self_id = await BotClass.check_permission(self, "http", headers,
data)
except RequestDenied as e:
raise HTTPException(status_code=e.status_code,
detail=e.reason) from None
http_request = HTTPRequest(request.scope["http_version"],
request.url.scheme, request.url.path,
request.scope["query_string"],
dict(request.headers), request.method, data)
x_self_id, response = await BotClass.check_permission(
self, http_request)
if not x_self_id:
raise HTTPException(response and response.status or 401,
response.body)
if x_self_id in self._clients:
logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.")
bot = BotClass("http", x_self_id)
bot = BotClass(x_self_id, http_request)
asyncio.create_task(bot.handle_message(data_dict))
return Response("", 204)
asyncio.create_task(bot.handle_message(data))
return Response(response and response.body,
response and response.status or 200)
@overrides(ReverseDriver)
async def _handle_ws_reverse(self, adapter: str,
websocket: FastAPIWebSocket):
ws = WebSocket(websocket)
ws = WebSocket(websocket.scope.get("http_version",
"1.1"), websocket.url.scheme,
websocket.url.path, websocket.scope["query_string"],
dict(websocket.headers), websocket)
if adapter not in self._adapters:
logger.warning(
@ -225,11 +226,9 @@ class Driver(ReverseDriver):
# Create Bot Object
BotClass = self._adapters[adapter]
headers = dict(websocket.headers)
try:
x_self_id = await BotClass.check_permission(self, "websocket",
headers, None)
except RequestDenied:
x_self_id, _ = await BotClass.check_permission(self, ws)
if not x_self_id:
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
@ -240,7 +239,7 @@ class Driver(ReverseDriver):
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
bot = BotClass("websocket", x_self_id, websocket=ws)
bot = BotClass(x_self_id, ws)
await ws.accept()
logger.opt(colors=True).info(
@ -251,54 +250,51 @@ class Driver(ReverseDriver):
try:
while not ws.closed:
data = await ws.receive()
try:
data = await ws.receive()
except WebSocketDisconnect:
logger.error("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket.")
break
if not data:
continue
asyncio.create_task(bot.handle_message(data))
asyncio.create_task(bot.handle_message(data.encode()))
finally:
self._bot_disconnect(bot)
@dataclass
class WebSocket(BaseWebSocket):
def __init__(self, websocket: FastAPIWebSocket):
super().__init__(websocket)
self._closed = False
websocket: FastAPIWebSocket = None # type: ignore
@property
@overrides(BaseWebSocket)
def closed(self):
return self._closed
return (self.websocket.client_state == WebSocketState.DISCONNECTED or
self.websocket.application_state == WebSocketState.DISCONNECTED)
@overrides(BaseWebSocket)
async def accept(self):
await self.websocket.accept()
self._closed = False
@overrides(BaseWebSocket)
async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE):
await self.websocket.close(code=code)
self._closed = True
@overrides(BaseWebSocket)
async def receive(self) -> Optional[dict]:
data = None
try:
data = await self.websocket.receive_json()
if not isinstance(data, dict):
data = None
raise ValueError
except ValueError:
logger.warning("Received an invalid json message.")
except WebSocketDisconnect:
self._closed = True
logger.error("WebSocket disconnected by peer.")
return data
async def receive(self) -> str:
return await self.websocket.receive_text()
@overrides(BaseWebSocket)
async def send(self, data: dict) -> None:
text = json.dumps(data, cls=DataclassEncoder)
await self.websocket.send({"type": "websocket.send", "text": text})
async def receive_bytes(self) -> bytes:
return await self.websocket.receive_bytes()
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
await self.websocket.send({"type": "websocket.send", "text": data})
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send({"type": "websocket.send", "bytes": data})

View File

@ -9,24 +9,22 @@ Quart 驱动适配
"""
import asyncio
from json.decoder import JSONDecodeError
from typing import Any, Callable, Coroutine, Dict, Optional, Type, TypeVar
from typing import List, TypeVar, Callable, Coroutine, Optional
import uvicorn
from pydantic import BaseSettings
from nonebot.config import Config as NoneBotConfig
from nonebot.config import Env
from nonebot.drivers import ReverseDriver, WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.config import Env, Config as NoneBotConfig
from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket
try:
from quart import Quart, Request, Response
from quart import Websocket as QuartWebSocket
from quart import exceptions
from quart import request as _request
from quart import websocket as _websocket
from quart import Quart, Request, Response
from quart import Websocket as QuartWebSocket
except ImportError:
raise ValueError(
'Please install Quart by using `pip install nonebot2[quart]`')
@ -34,6 +32,25 @@ except ImportError:
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
class Config(BaseSettings):
"""
Quart 驱动框架设置
"""
quart_reload_dirs: List[str] = []
"""
:类型:
``List[str]``
:说明:
``debug`` 模式下重载监控文件夹列表,默认为 uvicorn 默认值
"""
class Config:
extra = "ignore"
class Driver(ReverseDriver):
"""
Quart 驱动框架
@ -48,18 +65,20 @@ class Driver(ReverseDriver):
def __init__(self, env: Env, config: NoneBotConfig):
super().__init__(env, config)
self.quart_config = Config(**config.dict())
self._server_app = Quart(self.__class__.__qualname__)
self._server_app.add_url_rule('/<adapter>/http',
methods=['POST'],
self._server_app.add_url_rule("/<adapter>/http",
methods=["POST"],
view_func=self._handle_http)
self._server_app.add_websocket('/<adapter>/ws',
self._server_app.add_websocket("/<adapter>/ws",
view_func=self._handle_ws_reverse)
@property
@overrides(ReverseDriver)
def type(self) -> str:
"""驱动名称: ``quart``"""
return 'quart'
return "quart"
@property
@overrides(ReverseDriver)
@ -76,17 +95,21 @@ class Driver(ReverseDriver):
@property
@overrides(ReverseDriver)
def logger(self):
"""fastapi 使用的 logger"""
"""Quart 使用的 logger"""
return self._server_app.logger
@overrides(ReverseDriver)
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: `Startup and Shutdown <https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html>`_"""
"""参考文档: `Startup and Shutdown`_
.. _Startup and Shutdown:
https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html
"""
return self.server_app.before_serving(func) # type: ignore
@overrides(ReverseDriver)
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: `Startup and Shutdown <https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html>`_"""
"""参考文档: `Startup and Shutdown`_"""
return self.server_app.after_serving(func) # type: ignore
@overrides(ReverseDriver)
@ -121,6 +144,7 @@ class Driver(ReverseDriver):
host=host or str(self.config.host),
port=port or self.config.port,
reload=bool(app) and self.config.debug,
reload_dirs=self.quart_config.quart_reload_dirs or None,
debug=self.config.debug,
log_config=LOGGING_CONFIG,
**kwargs)
@ -128,11 +152,7 @@ class Driver(ReverseDriver):
@overrides(ReverseDriver)
async def _handle_http(self, adapter: str):
request: Request = _request
try:
data: Dict[str, Any] = await request.get_json()
except Exception as e:
raise exceptions.BadRequest()
data: bytes = await request.get_data() # type: ignore
if adapter not in self._adapters:
logger.warning(f'Unknown adapter {adapter}. '
@ -140,25 +160,32 @@ class Driver(ReverseDriver):
raise exceptions.NotFound()
BotClass = self._adapters[adapter]
headers = {k: v for k, v in request.headers.items(lower=True)}
http_request = HTTPRequest(request.http_version, request.scheme,
request.path, request.query_string,
dict(request.headers), request.method, data)
try:
self_id = await BotClass.check_permission(self, 'http', headers,
data)
except RequestDenied as e:
raise exceptions.HTTPException(status_code=e.status_code,
description=e.reason,
name='Request Denied')
self_id, response = await BotClass.check_permission(self, http_request)
if not self_id:
raise exceptions.HTTPException(
response and response.status or 401,
description=(response and response.body or b"").decode(),
name="Request Denied")
if self_id in self._clients:
logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.")
bot = BotClass('http', self_id)
bot = BotClass(self_id, http_request)
asyncio.create_task(bot.handle_message(data))
return Response('', 204)
return Response(response and response.body or "",
response and response.status or 200)
@overrides(ReverseDriver)
async def _handle_ws_reverse(self, adapter: str):
websocket: QuartWebSocket = _websocket
ws = WebSocket(websocket.http_version, websocket.scheme,
websocket.path, websocket.query_string,
dict(websocket.headers), websocket)
if adapter not in self._adapters:
logger.warning(
f'Unknown adapter {adapter}. Please register the adapter before use.'
@ -166,19 +193,23 @@ class Driver(ReverseDriver):
raise exceptions.NotFound()
BotClass = self._adapters[adapter]
headers = {k: v for k, v in websocket.headers.items(lower=True)}
try:
self_id = await BotClass.check_permission(self, 'websocket',
headers, None)
except RequestDenied as e:
raise exceptions.HTTPException(status_code=e.status_code,
description=e.reason,
name='Request Denied')
self_id, response = await BotClass.check_permission(self, ws)
if not self_id:
raise exceptions.HTTPException(
response and response.status or 401,
description=(response and response.body or b"").decode(),
name="Request Denied")
if self_id in self._clients:
logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.")
ws = WebSocket(websocket)
bot = BotClass('websocket', self_id, websocket=ws)
logger.opt(colors=True).warning(
"There's already a reverse websocket connection, "
f"<y>{adapter.upper()} Bot {self_id}</y> ignored.")
raise exceptions.HTTPException(403,
description="Client already exists",
name="Request Denied")
bot = BotClass(self_id, ws)
await ws.accept()
logger.opt(colors=True).info(
f"WebSocket Connection from <y>{adapter.upper()} "
@ -187,52 +218,51 @@ class Driver(ReverseDriver):
try:
while not ws.closed:
data = await ws.receive()
if data is None:
continue
asyncio.create_task(bot.handle_message(data))
try:
data = await ws.receive()
except asyncio.CancelledError:
logger.warning("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket.")
break
asyncio.create_task(bot.handle_message(data.encode()))
finally:
self._bot_disconnect(bot)
class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, websocket: QuartWebSocket):
super().__init__(websocket)
self._closed = False
@property
@overrides(BaseWebSocket)
def websocket(self) -> QuartWebSocket:
return self._websocket
websocket: QuartWebSocket = None # type: ignore
@property
@overrides(BaseWebSocket)
def closed(self):
return self._closed
# FIXME
return False
@overrides(BaseWebSocket)
async def accept(self):
await self.websocket.accept()
self._closed = False
@overrides(BaseWebSocket)
async def close(self):
self._closed = True
# FIXME
pass
@overrides(BaseWebSocket)
async def receive(self) -> Optional[Dict[str, Any]]:
data: Optional[Dict[str, Any]] = None
try:
data = await self.websocket.receive_json()
except JSONDecodeError:
logger.warning('Received an invalid json message.')
except asyncio.CancelledError:
self._closed = True
logger.warning('WebSocket disconnected by peer.')
return data
async def receive(self) -> str:
return await self.websocket.receive() # type: ignore
@overrides(BaseWebSocket)
async def send(self, data: dict):
await self.websocket.send_json(data)
async def receive_bytes(self) -> bytes:
return await self.websocket.receive() # type: ignore
@overrides(BaseWebSocket)
async def send(self, data: str):
await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes):
await self.websocket.send(data)

View File

@ -115,29 +115,6 @@ class StopPropagation(NoneBotException):
pass
class RequestDenied(NoneBotException):
"""
:说明:
Bot 连接请求不合法。
:参数:
* ``status_code: int``: HTTP 状态码
* ``reason: str``: 拒绝原因
"""
def __init__(self, status_code: int, reason: str):
self.status_code = status_code
self.reason = reason
def __repr__(self):
return f"<RequestDenied, status_code={self.status_code}, reason={self.reason}>"
def __str__(self):
return self.__repr__()
class AdapterException(NoneBotException):
"""
:说明: