support custom response

This commit is contained in:
StarHeartHunt
2021-06-10 21:52:20 +08:00
committed by yanyongyu
parent ca31ec5fe3
commit c0d78449be
25 changed files with 365 additions and 7542 deletions

View File

@ -8,23 +8,22 @@ FastAPI 驱动适配
https://fastapi.tiangolo.com/
"""
import json
import asyncio
import logging
from dataclasses import dataclass
from typing import List, Optional, Callable
import uvicorn
from pydantic import BaseSettings
from fastapi.responses import Response
from fastapi import status, Request, FastAPI, HTTPException
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
from starlette.websockets import (WebSocketState, WebSocketDisconnect, WebSocket
as FastAPIWebSocket)
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.utils import DataclassEncoder
from nonebot.exception import RequestDenied
from nonebot.config import Env, Config as NoneBotConfig
from nonebot.drivers import ReverseDriver, WebSocket as BaseWebSocket
from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket
class Config(BaseSettings):
@ -179,11 +178,6 @@ class Driver(ReverseDriver):
@overrides(ReverseDriver)
async def _handle_http(self, adapter: str, request: Request):
data = await request.body()
data_dict = json.loads(data.decode())
if not isinstance(data_dict, dict):
logger.warning("Data received is invalid")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
if adapter not in self._adapters:
logger.warning(
@ -194,27 +188,34 @@ class Driver(ReverseDriver):
# 创建 Bot 对象
BotClass = self._adapters[adapter]
headers = dict(request.headers)
try:
x_self_id = await BotClass.check_permission(self, "http", headers,
data)
except RequestDenied as e:
raise HTTPException(status_code=e.status_code,
detail=e.reason) from None
http_request = HTTPRequest(request.scope["http_version"],
request.url.scheme, request.url.path,
request.scope["query_string"],
dict(request.headers), request.method, data)
x_self_id, response = await BotClass.check_permission(
self, http_request)
if not x_self_id:
raise HTTPException(response and response.status or 401,
response.body)
if x_self_id in self._clients:
logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.")
bot = BotClass("http", x_self_id)
bot = BotClass(x_self_id, http_request)
asyncio.create_task(bot.handle_message(data_dict))
return Response("", 204)
asyncio.create_task(bot.handle_message(data))
return Response(response and response.body,
response and response.status or 200)
@overrides(ReverseDriver)
async def _handle_ws_reverse(self, adapter: str,
websocket: FastAPIWebSocket):
ws = WebSocket(websocket)
ws = WebSocket(websocket.scope.get("http_version",
"1.1"), websocket.url.scheme,
websocket.url.path, websocket.scope["query_string"],
dict(websocket.headers), websocket)
if adapter not in self._adapters:
logger.warning(
@ -225,11 +226,9 @@ class Driver(ReverseDriver):
# Create Bot Object
BotClass = self._adapters[adapter]
headers = dict(websocket.headers)
try:
x_self_id = await BotClass.check_permission(self, "websocket",
headers, None)
except RequestDenied:
x_self_id, _ = await BotClass.check_permission(self, ws)
if not x_self_id:
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
@ -240,7 +239,7 @@ class Driver(ReverseDriver):
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
bot = BotClass("websocket", x_self_id, websocket=ws)
bot = BotClass(x_self_id, ws)
await ws.accept()
logger.opt(colors=True).info(
@ -251,54 +250,51 @@ class Driver(ReverseDriver):
try:
while not ws.closed:
data = await ws.receive()
try:
data = await ws.receive()
except WebSocketDisconnect:
logger.error("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket.")
break
if not data:
continue
asyncio.create_task(bot.handle_message(data))
asyncio.create_task(bot.handle_message(data.encode()))
finally:
self._bot_disconnect(bot)
@dataclass
class WebSocket(BaseWebSocket):
def __init__(self, websocket: FastAPIWebSocket):
super().__init__(websocket)
self._closed = False
websocket: FastAPIWebSocket = None # type: ignore
@property
@overrides(BaseWebSocket)
def closed(self):
return self._closed
return (self.websocket.client_state == WebSocketState.DISCONNECTED or
self.websocket.application_state == WebSocketState.DISCONNECTED)
@overrides(BaseWebSocket)
async def accept(self):
await self.websocket.accept()
self._closed = False
@overrides(BaseWebSocket)
async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE):
await self.websocket.close(code=code)
self._closed = True
@overrides(BaseWebSocket)
async def receive(self) -> Optional[dict]:
data = None
try:
data = await self.websocket.receive_json()
if not isinstance(data, dict):
data = None
raise ValueError
except ValueError:
logger.warning("Received an invalid json message.")
except WebSocketDisconnect:
self._closed = True
logger.error("WebSocket disconnected by peer.")
return data
async def receive(self) -> str:
return await self.websocket.receive_text()
@overrides(BaseWebSocket)
async def send(self, data: dict) -> None:
text = json.dumps(data, cls=DataclassEncoder)
await self.websocket.send({"type": "websocket.send", "text": text})
async def receive_bytes(self) -> bytes:
return await self.websocket.receive_bytes()
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
await self.websocket.send({"type": "websocket.send", "text": data})
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send({"type": "websocket.send", "bytes": data})