♻️ separate fastapi driver

This commit is contained in:
yanyongyu
2021-11-27 12:16:31 +08:00
parent 030237fe22
commit 37f55652d9
4 changed files with 99 additions and 57 deletions

View File

@ -134,9 +134,9 @@ class Config(BaseSettings):
extra = "ignore"
class Driver(ReverseDriver, ForwardDriver):
class Driver(ReverseDriver):
"""
FastAPI 驱动框架
FastAPI 驱动框架。包含反向 Server 功能。
:上报地址:
@ -147,13 +147,9 @@ class Driver(ReverseDriver, ForwardDriver):
"""
def __init__(self, env: Env, config: NoneBotConfig):
super().__init__(env, config)
super(Driver, self).__init__(env, config)
self.fastapi_config: Config = Config(**config.dict())
self.http_pollings: List[HTTPPOLLING_SETUP] = []
self.websockets: List[WEBSOCKET_SETUP] = []
self.shutdown: asyncio.Event = asyncio.Event()
self.connections: List[asyncio.Task] = []
self._server_app = FastAPI(
debug=config.debug,
@ -167,9 +163,6 @@ class Driver(ReverseDriver, ForwardDriver):
self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse)
self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse)
self.on_startup(self._run_forward)
self.on_shutdown(self._shutdown_forward)
@property
@overrides(ReverseDriver)
def type(self) -> str:
@ -204,32 +197,6 @@ class Driver(ReverseDriver, ForwardDriver):
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
return self.server_app.on_event("shutdown")(func)
@overrides(ForwardDriver)
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
"""
:说明:
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
:参数:
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
"""
self.http_pollings.append(setup)
@overrides(ForwardDriver)
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
"""
:说明:
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
:参数:
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
"""
self.websockets.append(setup)
@overrides(ReverseDriver)
def run(
self,
@ -273,18 +240,6 @@ class Driver(ReverseDriver, ForwardDriver):
**kwargs,
)
def _run_forward(self):
for setup in self.http_pollings:
self.connections.append(asyncio.create_task(self._http_loop(setup)))
for setup in self.websockets:
self.connections.append(asyncio.create_task(self._ws_loop(setup)))
def _shutdown_forward(self):
self.shutdown.set()
for task in self.connections:
if not task.done():
task.cancel()
async def _handle_http(self, adapter: str, request: Request):
data = await request.body()
@ -386,6 +341,73 @@ class Driver(ReverseDriver, ForwardDriver):
finally:
self._bot_disconnect(bot)
class FullDriver(ForwardDriver, Driver):
"""
完整的 FastAPI 驱动框架,包含正向 Client 支持和反向 Server 支持。
:使用方法:
.. code-block:: dotenv
DRIVER=nonebot.drivers.fastapi:FullDriver
"""
def __init__(self, env: Env, config: NoneBotConfig):
super(FullDriver, self).__init__(env, config)
self.http_pollings: List[HTTPPOLLING_SETUP] = []
self.websockets: List[WEBSOCKET_SETUP] = []
self.shutdown: asyncio.Event = asyncio.Event()
self.connections: List[asyncio.Task] = []
self.on_startup(self._run_forward)
self.on_shutdown(self._shutdown_forward)
@property
@overrides(ForwardDriver)
def type(self) -> str:
"""驱动名称: ``fastapi_full``"""
return "fastapi_full"
@overrides(ForwardDriver)
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
"""
:说明:
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
:参数:
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
"""
self.http_pollings.append(setup)
@overrides(ForwardDriver)
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
"""
:说明:
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
:参数:
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
"""
self.websockets.append(setup)
def _run_forward(self):
for setup in self.http_pollings:
self.connections.append(asyncio.create_task(self._http_loop(setup)))
for setup in self.websockets:
self.connections.append(asyncio.create_task(self._ws_loop(setup)))
def _shutdown_forward(self):
self.shutdown.set()
for task in self.connections:
if not task.done():
task.cancel()
async def _http_loop(self, setup: HTTPPOLLING_SETUP):
async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = httpx.URL(setup.url)