fastapi driver support forward connect

This commit is contained in:
yanyongyu 2021-07-23 00:51:19 +08:00
parent 54a7e4808c
commit cf747f954c
5 changed files with 244 additions and 33 deletions

View File

@ -212,6 +212,9 @@ class Driver(ForwardDriver):
BotClass = self._adapters[setup.adapter] BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request) bot = BotClass(setup.self_id, request)
self._bot_connect(bot) self._bot_connect(bot)
logger.opt(colors=True).info(
f"Start http polling for <y>{setup.adapter.upper()} "
f"Bot {setup.self_id}</y>")
headers = request.headers headers = request.headers
timeout = aiohttp.ClientTimeout(30) timeout = aiohttp.ClientTimeout(30)
@ -289,11 +292,13 @@ class Driver(ForwardDriver):
) )
try: try:
async with session.ws_connect(url) as ws: async with session.ws_connect(url) as ws:
logger.opt(colors=True).info(
f"WebSocket Connection to <y>{setup.adapter.upper()} "
f"Bot {setup.self_id}</y> succeeded!")
request = WebSocket( request = WebSocket(
setup.http_version, url.scheme, url.path, setup.http_version, url.scheme, url.path,
url.raw_query_string.encode("latin-1"), { url.raw_query_string.encode("latin-1"), headers,
**setup.headers, "host": host ws)
}, ws)
BotClass = self._adapters[setup.adapter] BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request) bot = BotClass(setup.self_id, request)

View File

@ -11,19 +11,46 @@ FastAPI 驱动适配
import asyncio import asyncio
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Callable from typing import List, Dict, Union, Optional, Callable
import httpx
import uvicorn import uvicorn
from pydantic import BaseSettings from pydantic import BaseSettings
from fastapi.responses import Response from fastapi.responses import Response
from websockets.exceptions import ConnectionClosed
from fastapi import status, Request, FastAPI, HTTPException from fastapi import status, Request, FastAPI, HTTPException
from websockets.legacy.client import Connect, WebSocketClientProtocol
from starlette.websockets import (WebSocketState, WebSocketDisconnect, WebSocket from starlette.websockets import (WebSocketState, WebSocketDisconnect, WebSocket
as FastAPIWebSocket) as FastAPIWebSocket)
from nonebot.log import logger from nonebot.log import logger
from nonebot.adapters import Bot
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.config import Env, Config as NoneBotConfig from nonebot.config import Env, Config as NoneBotConfig
from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket from nonebot.drivers import ReverseDriver, ForwardDriver
from nonebot.drivers import HTTPRequest, WebSocket as BaseWebSocket
@dataclass
class HTTPPollingSetup:
adapter: str
self_id: str
url: str
method: str
body: bytes
headers: Dict[str, str]
http_version: str
poll_interval: float
@dataclass
class WebSocketSetup:
adapter: str
self_id: str
url: str
headers: Dict[str, str]
http_version: str
reconnect_interval: float
class Config(BaseSettings): class Config(BaseSettings):
@ -75,7 +102,7 @@ class Config(BaseSettings):
extra = "ignore" extra = "ignore"
class Driver(ReverseDriver): class Driver(ReverseDriver, ForwardDriver):
""" """
FastAPI 驱动框架 FastAPI 驱动框架
@ -90,7 +117,11 @@ class Driver(ReverseDriver):
def __init__(self, env: Env, config: NoneBotConfig): def __init__(self, env: Env, config: NoneBotConfig):
super().__init__(env, config) super().__init__(env, config)
self.fastapi_config = Config(**config.dict()) self.fastapi_config: Config = Config(**config.dict())
self.http_pollings: List[HTTPPollingSetup] = []
self.websockets: List[WebSocketSetup] = []
self.shutdown: asyncio.Event = asyncio.Event()
self.connections: List[asyncio.Task] = []
self._server_app = FastAPI( self._server_app = FastAPI(
debug=config.debug, debug=config.debug,
@ -104,6 +135,9 @@ class Driver(ReverseDriver):
self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse) self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse)
self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse) self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse)
self.on_startup(self._run_forward)
self.on_shutdown(self._shutdown_forward)
@property @property
@overrides(ReverseDriver) @overrides(ReverseDriver)
def type(self) -> str: def type(self) -> str:
@ -138,6 +172,32 @@ class Driver(ReverseDriver):
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_""" """参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
return self.server_app.on_event("shutdown")(func) return self.server_app.on_event("shutdown")(func)
@overrides(ForwardDriver)
def setup_http_polling(self,
adapter: str,
self_id: str,
url: str,
polling_interval: float = 3.,
method: str = "GET",
body: bytes = b"",
headers: Dict[str, str] = {},
http_version: str = "1.1") -> None:
self.http_pollings.append(
HTTPPollingSetup(adapter, self_id, url, method, body, headers,
http_version, polling_interval))
@overrides(ForwardDriver)
def setup_websocket(self,
adapter: str,
self_id: str,
url: str,
reconnect_interval: float = 3.,
headers: Dict[str, str] = {},
http_version: str = "1.1") -> None:
self.websockets.append(
WebSocketSetup(adapter, self_id, url, headers, http_version,
reconnect_interval))
@overrides(ReverseDriver) @overrides(ReverseDriver)
def run(self, def run(self,
host: Optional[str] = None, host: Optional[str] = None,
@ -166,14 +226,27 @@ class Driver(ReverseDriver):
}, },
}, },
} }
uvicorn.run(app or self.server_app, uvicorn.run(
host=host or str(self.config.host), app or self.server_app, # type: ignore
port=port or self.config.port, host=host or str(self.config.host),
reload=bool(app) and self.config.debug, port=port or self.config.port,
reload_dirs=self.fastapi_config.fastapi_reload_dirs or None, reload=bool(app) and self.config.debug,
debug=self.config.debug, reload_dirs=self.fastapi_config.fastapi_reload_dirs or None,
log_config=LOGGING_CONFIG, debug=self.config.debug,
**kwargs) log_config=LOGGING_CONFIG,
**kwargs)
def _run_forward(self):
for setup in self.http_pollings:
self.connections.append(asyncio.create_task(self._http_loop(setup)))
for setup in self.websockets:
self.connections.append(asyncio.create_task(self._ws_loop(setup)))
def _shutdown_forward(self):
self.shutdown.set()
for task in self.connections:
if not task.done():
task.cancel()
async def _handle_http(self, adapter: str, request: Request): async def _handle_http(self, adapter: str, request: Request):
data = await request.body() data = await request.body()
@ -263,37 +336,166 @@ class Driver(ReverseDriver):
finally: finally:
self._bot_disconnect(bot) self._bot_disconnect(bot)
async def _http_loop(self, setup: HTTPPollingSetup):
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
return
request = HTTPRequest(
setup.http_version, url.scheme, url.path, url.query, {
**setup.headers, "host": url.netloc.decode("ascii")
}, setup.method, setup.body)
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
logger.opt(colors=True).info(
f"Start http polling for <y>{setup.adapter.upper()} "
f"Bot {setup.self_id}</y>")
headers = request.headers
http2: bool = False
if request.http_version == "2":
http2 = True
try:
async with httpx.AsyncClient(headers=headers,
timeout=30.,
http2=http2) as session:
while not self.shutdown.is_set():
logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} request {url}"
)
try:
response = await session.request(request.method,
url,
content=request.body)
response.raise_for_status()
data = response.read()
asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {url}. "
"Try to reconnect...</bg #f8bbd0></r>")
await asyncio.sleep(setup.poll_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while http polling</bg #f8bbd0></r>")
finally:
self._bot_disconnect(bot)
async def _ws_loop(self, setup: WebSocketSetup):
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
return
headers = {**setup.headers, "host": url.netloc.decode("ascii")}
bot: Optional[Bot] = None
try:
while True:
logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
)
try:
connection = Connect(setup.url)
async with connection as ws:
logger.opt(colors=True).info(
f"WebSocket Connection to <y>{setup.adapter.upper()} "
f"Bot {setup.self_id}</y> succeeded!")
request = WebSocket(setup.http_version, url.scheme,
url.path, url.query, headers, ws)
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
while not self.shutdown.is_set():
# use try except instead of "request.closed" because of queued message
try:
msg = await request.receive_bytes()
asyncio.create_task(bot.handle_message(msg))
except ConnectionClosed:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>WebSocket connection closed by peer. "
"Try to reconnect...</bg #f8bbd0></r>")
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error while connecting to {url}. "
"Try to reconnect...</bg #f8bbd0></r>")
finally:
if bot:
self._bot_disconnect(bot)
bot = None
await asyncio.sleep(setup.reconnect_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while websocket loop</bg #f8bbd0></r>")
@dataclass @dataclass
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):
websocket: FastAPIWebSocket = None # type: ignore websocket: Union[FastAPIWebSocket,
WebSocketClientProtocol] = None # type: ignore
@property @property
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
def closed(self): def closed(self) -> bool:
return (self.websocket.client_state == WebSocketState.DISCONNECTED or if isinstance(self.websocket, FastAPIWebSocket):
return (
self.websocket.client_state == WebSocketState.DISCONNECTED or
self.websocket.application_state == WebSocketState.DISCONNECTED) self.websocket.application_state == WebSocketState.DISCONNECTED)
else:
return self.websocket.closed
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def accept(self): async def accept(self):
await self.websocket.accept() if isinstance(self.websocket, FastAPIWebSocket):
await self.websocket.accept()
else:
raise NotImplementedError
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE):
await self.websocket.close(code=code) await self.websocket.close(code)
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def receive(self) -> str: async def receive(self) -> str:
return await self.websocket.receive_text() if isinstance(self.websocket, FastAPIWebSocket):
return await self.websocket.receive_text()
else:
msg = await self.websocket.recv()
return msg.decode("utf-8") if isinstance(msg, bytes) else msg
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes: async def receive_bytes(self) -> bytes:
return await self.websocket.receive_bytes() if isinstance(self.websocket, FastAPIWebSocket):
return await self.websocket.receive_bytes()
else:
msg = await self.websocket.recv()
return msg.encode("utf-8") if isinstance(msg, str) else msg
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send(self, data: str) -> None: async def send(self, data: str) -> None:
await self.websocket.send({"type": "websocket.send", "text": data}) if isinstance(self.websocket, FastAPIWebSocket):
await self.websocket.send({"type": "websocket.send", "text": data})
else:
await self.websocket.send(data)
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None: async def send_bytes(self, data: bytes) -> None:
await self.websocket.send({"type": "websocket.send", "bytes": data}) if isinstance(self.websocket, FastAPIWebSocket):
await self.websocket.send({"type": "websocket.send", "bytes": data})
else:
await self.websocket.send(data)

View File

@ -140,14 +140,15 @@ class Driver(ReverseDriver):
}, },
}, },
} }
uvicorn.run(app or self.server_app, uvicorn.run(
host=host or str(self.config.host), app or self.server_app, # type: ignore
port=port or self.config.port, host=host or str(self.config.host),
reload=bool(app) and self.config.debug, port=port or self.config.port,
reload_dirs=self.quart_config.quart_reload_dirs or None, reload=bool(app) and self.config.debug,
debug=self.config.debug, reload_dirs=self.quart_config.quart_reload_dirs or None,
log_config=LOGGING_CONFIG, debug=self.config.debug,
**kwargs) log_config=LOGGING_CONFIG,
**kwargs)
async def _handle_http(self, adapter: str): async def _handle_http(self, adapter: str):
request: Request = _request request: Request = _request

View File

@ -14,6 +14,9 @@ sidebar: auto
- 修复 `type_updater` `permission_updater` 未传递的错误 - 修复 `type_updater` `permission_updater` 未传递的错误
- 修复 `type_updater` `permission_updater` 参数 `state` 错误 - 修复 `type_updater` `permission_updater` 参数 `state` 错误
- 修复使用 `state_factory` 后导致无法在 session 内传递 `state` - 修复使用 `state_factory` 后导致无法在 session 内传递 `state`
- 新增正向 Driver(Client) 支持
- 新增 `aiohttp` 正向 Driver
- `fastapi` Driver 新增正向支持
## v2.0.0a13.post1 ## v2.0.0a13.post1

View File

@ -1,4 +1,4 @@
DRIVER=nonebot.drivers.aiohttp:Driver DRIVER=nonebot.drivers.fastapi:Driver
HOST=0.0.0.0 HOST=0.0.0.0
PORT=2333 PORT=2333
DEBUG=true DEBUG=true