♻️ rewrite adapter abc class

This commit is contained in:
yanyongyu
2021-12-06 22:19:05 +08:00
parent 180aaadda9
commit d80c02ae46
7 changed files with 172 additions and 437 deletions

View File

@ -12,38 +12,35 @@ FastAPI 驱动适配
import asyncio
import logging
from functools import partial
from dataclasses import dataclass
from typing import List, Union, TypeVar, Callable, Optional, Awaitable, cast
from typing import Any, List, Union, Callable, Optional, Awaitable
import httpx
import uvicorn
from pydantic import BaseSettings
from fastapi.responses import Response
from websockets.exceptions import ConnectionClosed
from fastapi import FastAPI, Request, HTTPException, status
from starlette.websockets import WebSocketState
from fastapi import Depends, FastAPI, Request, status
from starlette.websockets import WebSocket as FastAPIWebSocket
from starlette.websockets import WebSocketState, WebSocketDisconnect
from websockets.legacy.client import Connect, WebSocketClientProtocol
from nonebot.config import Env
from nonebot.log import logger
from nonebot.adapters import Bot
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.drivers import WebSocket
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import (
HTTPRequest,
HTTPResponse,
ForwardDriver,
ReverseDriver,
WebSocketSetup,
HTTPPollingSetup,
HTTPConnection,
HTTPServerSetup,
WebSocketServerSetup,
)
S = TypeVar("S", bound=Union[HTTPPollingSetup, WebSocketSetup])
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
class Config(BaseSettings):
"""
@ -136,16 +133,7 @@ class Config(BaseSettings):
class Driver(ReverseDriver):
"""
FastAPI 驱动框架。包含反向 Server 功能。
:上报地址:
* ``/{adapter name}/``: HTTP POST 上报
* ``/{adapter name}/http/``: HTTP POST 上报
* ``/{adapter name}/ws``: WebSocket 上报
* ``/{adapter name}/ws/``: WebSocket 上报
"""
"""FastAPI 驱动框架。包含反向 Server 功能。"""
def __init__(self, env: Env, config: NoneBotConfig):
super(Driver, self).__init__(env, config)
@ -159,11 +147,6 @@ class Driver(ReverseDriver):
redoc_url=self.fastapi_config.fastapi_redoc_url,
)
self._server_app.post("/{adapter}/")(self._handle_http)
self._server_app.post("/{adapter}/http")(self._handle_http)
self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse)
self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse)
@property
@overrides(ReverseDriver)
def type(self) -> str:
@ -188,6 +171,30 @@ class Driver(ReverseDriver):
"""fastapi 使用的 logger"""
return logging.getLogger("fastapi")
@overrides(ReverseDriver)
def setup_http_server(self, setup: HTTPServerSetup):
def _get_handle_func():
return setup.handle_func
self._server_app.add_api_route(
setup.path,
partial(self._handle_http, handle_func=Depends(_get_handle_func)),
methods=[setup.method],
)
@overrides(ReverseDriver)
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
def _get_handle_func():
return setup.handle_func
self._server_app.add_api_websocket_route(
setup.path,
partial(
self._handle_ws,
handle_func=Depends(_get_handle_func),
),
)
@overrides(ReverseDriver)
def on_startup(self, func: Callable) -> Callable:
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
@ -241,19 +248,11 @@ class Driver(ReverseDriver):
**kwargs,
)
async def _handle_http(self, adapter: str, request: Request):
data = await request.body()
if adapter not in self._adapters:
logger.warning(
f"Unknown adapter {adapter}. Please register the adapter before use."
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="adapter not found"
)
# 创建 Bot 对象
BotClass = self._adapters[adapter]
async def _handle_http(
self,
request: Request,
handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]],
):
http_request = HTTPRequest(
request.scope["http_version"],
request.url.scheme,
@ -261,28 +260,17 @@ class Driver(ReverseDriver):
request.scope["query_string"],
dict(request.headers),
request.method,
data,
await request.body(),
)
x_self_id, response = await BotClass.check_permission(self, http_request)
if not x_self_id:
raise HTTPException(
response and response.status or 401,
response and response.body and response.body.decode("utf-8"),
)
response = await handle_func(http_request)
return Response(response.body, response.status, response.headers)
if x_self_id in self._clients:
logger.warning(
"There's already a reverse websocket connection,"
"so the event may be handled twice."
)
bot = BotClass(x_self_id, http_request)
asyncio.create_task(bot.handle_message(data))
return Response(response and response.body, response and response.status or 200)
async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket):
async def _handle_ws(
self,
websocket: FastAPIWebSocket,
handle_func: Callable[[WebSocket], Awaitable[Any]],
):
ws = WebSocket(
websocket.scope.get("http_version", "1.1"),
websocket.url.scheme,
@ -292,55 +280,7 @@ class Driver(ReverseDriver):
websocket,
)
if adapter not in self._adapters:
logger.warning(
f"Unknown adapter {adapter}. Please register the adapter before use."
)
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
# Create Bot Object
BotClass = self._adapters[adapter]
self_id, _ = await BotClass.check_permission(self, ws)
if not self_id:
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
if self_id in self._clients:
logger.opt(colors=True).warning(
"There's already a websocket connection, "
f"<y>{escape_tag(adapter.upper())} Bot {escape_tag(self_id)}</y> ignored."
)
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
bot = BotClass(self_id, ws)
await ws.accept()
logger.opt(colors=True).info(
f"WebSocket Connection from <y>{escape_tag(adapter.upper())} "
f"Bot {escape_tag(self_id)}</y> Accepted!"
)
self._bot_connect(bot)
try:
while not ws.closed:
try:
data = await ws.receive()
except WebSocketDisconnect:
logger.error("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket."
)
break
asyncio.create_task(bot.handle_message(data.encode()))
finally:
self._bot_disconnect(bot)
await handle_func(ws)
class FullDriver(ForwardDriver, Driver):
@ -354,17 +294,6 @@ class FullDriver(ForwardDriver, Driver):
DRIVER=nonebot.drivers.fastapi:FullDriver
"""
def __init__(self, env: Env, config: NoneBotConfig):
super(FullDriver, self).__init__(env, config)
self.http_pollings: List[HTTPPOLLING_SETUP] = []
self.websockets: List[WEBSOCKET_SETUP] = []
self.shutdown: asyncio.Event = asyncio.Event()
self.connections: List[asyncio.Task] = []
self.on_startup(self._run_forward)
self.on_shutdown(self._shutdown_forward)
@property
@overrides(ForwardDriver)
def type(self) -> str:
@ -372,217 +301,25 @@ class FullDriver(ForwardDriver, Driver):
return "fastapi_full"
@overrides(ForwardDriver)
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
"""
:说明:
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
:参数:
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
"""
self.http_pollings.append(setup)
async def request(self, setup: "HTTPRequest") -> Any:
async with httpx.AsyncClient(
http2=setup.http_version == "2", follow_redirects=True
) as client:
response = await client.request(
setup.method,
setup.url,
content=setup.body,
headers=setup.headers,
timeout=30.0,
)
return HTTPResponse(
response.status_code, response.content, response.headers
)
@overrides(ForwardDriver)
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
"""
:说明:
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
:参数:
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
"""
self.websockets.append(setup)
def _run_forward(self):
for setup in self.http_pollings:
self.connections.append(asyncio.create_task(self._http_loop(setup)))
for setup in self.websockets:
self.connections.append(asyncio.create_task(self._ws_loop(setup)))
def _shutdown_forward(self):
self.shutdown.set()
for task in self.connections:
if not task.done():
task.cancel()
async def _prepare_setup(
self, setup: Union[S, Callable[[], Awaitable[S]]]
) -> Optional[S]:
try:
if callable(setup):
return await setup()
else:
return setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
return
def _build_http_request(self, setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
)
return
return HTTPRequest(
setup.http_version,
url.scheme,
url.path,
url.query,
setup.headers,
setup.method,
setup.body,
)
async def _http_loop(self, _setup: HTTPPOLLING_SETUP):
http2: bool = False
bot: Optional[Bot] = None
request: Optional[HTTPRequest] = None
client: Optional[httpx.AsyncClient] = None
# FIXME: seperate const values from setup (self_id, adapter)
# logger.opt(colors=True).info(
# f"Start http polling for <y>{escape_tag(_setup.adapter.upper())} "
# f"Bot {escape_tag(_setup.self_id)}</y>"
# )
try:
while not self.shutdown.is_set():
setup = await self._prepare_setup(_setup)
if not setup:
await asyncio.sleep(3)
continue
request = self._build_http_request(setup)
if not request:
await asyncio.sleep(setup.poll_interval)
continue
if not client:
client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
elif http2 != (setup.http_version == "2"):
await client.aclose()
client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
http2 = setup.http_version == "2"
if not bot:
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
else:
bot.request = request
logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} request {setup.url}"
)
try:
response = await client.request(
request.method,
setup.url,
content=request.body,
headers=request.headers,
timeout=30.0,
)
response.raise_for_status()
data = response.read()
asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup.url)}. "
"Try to reconnect...</bg #f8bbd0></r>"
)
await asyncio.sleep(setup.poll_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while http polling</bg #f8bbd0></r>"
)
finally:
if bot:
self._bot_disconnect(bot)
if client:
await client.aclose()
async def _ws_loop(self, _setup: WEBSOCKET_SETUP):
bot: Optional[Bot] = None
try:
while True:
setup = await self._prepare_setup(_setup)
if not setup:
await asyncio.sleep(3)
continue
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
)
return
logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
)
try:
connection = Connect(setup.url, extra_headers=setup.headers)
async with connection as ws:
logger.opt(colors=True).info(
f"WebSocket Connection to <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup.self_id)}</y> succeeded!"
)
request = WebSocket(
"1.1", url.scheme, url.path, url.query, setup.headers, ws
)
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
while not self.shutdown.is_set():
# use try except instead of "request.closed" because of queued message
try:
msg = await request.receive_bytes()
asyncio.create_task(bot.handle_message(msg))
except ConnectionClosed:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>WebSocket connection closed. "
"Try to reconnect...</bg #f8bbd0></r>"
)
break
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error while connecting to {url}. "
"Try to reconnect...</bg #f8bbd0></r>"
)
finally:
if bot:
self._bot_disconnect(bot)
bot = None
if not setup.reconnect:
logger.info(f"WebSocket reconnect disabled for bot {setup.self_id}")
break
await asyncio.sleep(setup.reconnect_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while websocket loop</bg #f8bbd0></r>"
)
async def websocket(self, setup: "HTTPConnection") -> Any:
ws = await Connect(setup.url, extra_headers=setup.headers)
return WebSocket("1.1", url.scheme, url.path, url.query, setup.headers, ws)
@dataclass