support custom response

This commit is contained in:
StarHeartHunt
2021-06-10 21:52:20 +08:00
committed by yanyongyu
parent ca31ec5fe3
commit c0d78449be
25 changed files with 365 additions and 7542 deletions

View File

@ -9,24 +9,22 @@ Quart 驱动适配
"""
import asyncio
from json.decoder import JSONDecodeError
from typing import Any, Callable, Coroutine, Dict, Optional, Type, TypeVar
from typing import List, TypeVar, Callable, Coroutine, Optional
import uvicorn
from pydantic import BaseSettings
from nonebot.config import Config as NoneBotConfig
from nonebot.config import Env
from nonebot.drivers import ReverseDriver, WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.config import Env, Config as NoneBotConfig
from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket
try:
from quart import Quart, Request, Response
from quart import Websocket as QuartWebSocket
from quart import exceptions
from quart import request as _request
from quart import websocket as _websocket
from quart import Quart, Request, Response
from quart import Websocket as QuartWebSocket
except ImportError:
raise ValueError(
'Please install Quart by using `pip install nonebot2[quart]`')
@ -34,6 +32,25 @@ except ImportError:
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
class Config(BaseSettings):
"""
Quart 驱动框架设置
"""
quart_reload_dirs: List[str] = []
"""
:类型:
``List[str]``
:说明:
``debug`` 模式下重载监控文件夹列表,默认为 uvicorn 默认值
"""
class Config:
extra = "ignore"
class Driver(ReverseDriver):
"""
Quart 驱动框架
@ -48,18 +65,20 @@ class Driver(ReverseDriver):
def __init__(self, env: Env, config: NoneBotConfig):
super().__init__(env, config)
self.quart_config = Config(**config.dict())
self._server_app = Quart(self.__class__.__qualname__)
self._server_app.add_url_rule('/<adapter>/http',
methods=['POST'],
self._server_app.add_url_rule("/<adapter>/http",
methods=["POST"],
view_func=self._handle_http)
self._server_app.add_websocket('/<adapter>/ws',
self._server_app.add_websocket("/<adapter>/ws",
view_func=self._handle_ws_reverse)
@property
@overrides(ReverseDriver)
def type(self) -> str:
"""驱动名称: ``quart``"""
return 'quart'
return "quart"
@property
@overrides(ReverseDriver)
@ -76,17 +95,21 @@ class Driver(ReverseDriver):
@property
@overrides(ReverseDriver)
def logger(self):
"""fastapi 使用的 logger"""
"""Quart 使用的 logger"""
return self._server_app.logger
@overrides(ReverseDriver)
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: `Startup and Shutdown <https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html>`_"""
"""参考文档: `Startup and Shutdown`_
.. _Startup and Shutdown:
https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html
"""
return self.server_app.before_serving(func) # type: ignore
@overrides(ReverseDriver)
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: `Startup and Shutdown <https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html>`_"""
"""参考文档: `Startup and Shutdown`_"""
return self.server_app.after_serving(func) # type: ignore
@overrides(ReverseDriver)
@ -121,6 +144,7 @@ class Driver(ReverseDriver):
host=host or str(self.config.host),
port=port or self.config.port,
reload=bool(app) and self.config.debug,
reload_dirs=self.quart_config.quart_reload_dirs or None,
debug=self.config.debug,
log_config=LOGGING_CONFIG,
**kwargs)
@ -128,11 +152,7 @@ class Driver(ReverseDriver):
@overrides(ReverseDriver)
async def _handle_http(self, adapter: str):
request: Request = _request
try:
data: Dict[str, Any] = await request.get_json()
except Exception as e:
raise exceptions.BadRequest()
data: bytes = await request.get_data() # type: ignore
if adapter not in self._adapters:
logger.warning(f'Unknown adapter {adapter}. '
@ -140,25 +160,32 @@ class Driver(ReverseDriver):
raise exceptions.NotFound()
BotClass = self._adapters[adapter]
headers = {k: v for k, v in request.headers.items(lower=True)}
http_request = HTTPRequest(request.http_version, request.scheme,
request.path, request.query_string,
dict(request.headers), request.method, data)
try:
self_id = await BotClass.check_permission(self, 'http', headers,
data)
except RequestDenied as e:
raise exceptions.HTTPException(status_code=e.status_code,
description=e.reason,
name='Request Denied')
self_id, response = await BotClass.check_permission(self, http_request)
if not self_id:
raise exceptions.HTTPException(
response and response.status or 401,
description=(response and response.body or b"").decode(),
name="Request Denied")
if self_id in self._clients:
logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.")
bot = BotClass('http', self_id)
bot = BotClass(self_id, http_request)
asyncio.create_task(bot.handle_message(data))
return Response('', 204)
return Response(response and response.body or "",
response and response.status or 200)
@overrides(ReverseDriver)
async def _handle_ws_reverse(self, adapter: str):
websocket: QuartWebSocket = _websocket
ws = WebSocket(websocket.http_version, websocket.scheme,
websocket.path, websocket.query_string,
dict(websocket.headers), websocket)
if adapter not in self._adapters:
logger.warning(
f'Unknown adapter {adapter}. Please register the adapter before use.'
@ -166,19 +193,23 @@ class Driver(ReverseDriver):
raise exceptions.NotFound()
BotClass = self._adapters[adapter]
headers = {k: v for k, v in websocket.headers.items(lower=True)}
try:
self_id = await BotClass.check_permission(self, 'websocket',
headers, None)
except RequestDenied as e:
raise exceptions.HTTPException(status_code=e.status_code,
description=e.reason,
name='Request Denied')
self_id, response = await BotClass.check_permission(self, ws)
if not self_id:
raise exceptions.HTTPException(
response and response.status or 401,
description=(response and response.body or b"").decode(),
name="Request Denied")
if self_id in self._clients:
logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.")
ws = WebSocket(websocket)
bot = BotClass('websocket', self_id, websocket=ws)
logger.opt(colors=True).warning(
"There's already a reverse websocket connection, "
f"<y>{adapter.upper()} Bot {self_id}</y> ignored.")
raise exceptions.HTTPException(403,
description="Client already exists",
name="Request Denied")
bot = BotClass(self_id, ws)
await ws.accept()
logger.opt(colors=True).info(
f"WebSocket Connection from <y>{adapter.upper()} "
@ -187,52 +218,51 @@ class Driver(ReverseDriver):
try:
while not ws.closed:
data = await ws.receive()
if data is None:
continue
asyncio.create_task(bot.handle_message(data))
try:
data = await ws.receive()
except asyncio.CancelledError:
logger.warning("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket.")
break
asyncio.create_task(bot.handle_message(data.encode()))
finally:
self._bot_disconnect(bot)
class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, websocket: QuartWebSocket):
super().__init__(websocket)
self._closed = False
@property
@overrides(BaseWebSocket)
def websocket(self) -> QuartWebSocket:
return self._websocket
websocket: QuartWebSocket = None # type: ignore
@property
@overrides(BaseWebSocket)
def closed(self):
return self._closed
# FIXME
return False
@overrides(BaseWebSocket)
async def accept(self):
await self.websocket.accept()
self._closed = False
@overrides(BaseWebSocket)
async def close(self):
self._closed = True
# FIXME
pass
@overrides(BaseWebSocket)
async def receive(self) -> Optional[Dict[str, Any]]:
data: Optional[Dict[str, Any]] = None
try:
data = await self.websocket.receive_json()
except JSONDecodeError:
logger.warning('Received an invalid json message.')
except asyncio.CancelledError:
self._closed = True
logger.warning('WebSocket disconnected by peer.')
return data
async def receive(self) -> str:
return await self.websocket.receive() # type: ignore
@overrides(BaseWebSocket)
async def send(self, data: dict):
await self.websocket.send_json(data)
async def receive_bytes(self) -> bytes:
return await self.websocket.receive() # type: ignore
@overrides(BaseWebSocket)
async def send(self, data: str):
await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes):
await self.websocket.send(data)