mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-17 19:40:44 +00:00
💥 change forward setup api
This commit is contained in:
@ -2,6 +2,8 @@
|
||||
FastAPI 驱动适配
|
||||
================
|
||||
|
||||
本驱动同时支持服务端以及客户端连接
|
||||
|
||||
后端使用方法请参考: `FastAPI 文档`_
|
||||
|
||||
.. _FastAPI 文档:
|
||||
@ -11,7 +13,7 @@ FastAPI 驱动适配
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Union, Optional, Callable
|
||||
from typing import List, cast, Union, Optional, Callable, Awaitable
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
@ -27,30 +29,13 @@ from nonebot.log import logger
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.config import Env, Config as NoneBotConfig
|
||||
from nonebot.drivers import ReverseDriver, ForwardDriver
|
||||
from nonebot.drivers import HTTPRequest, WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import (ReverseDriver, ForwardDriver, HTTPPollingSetup,
|
||||
WebSocketSetup, HTTPRequest, WebSocket as
|
||||
BaseWebSocket)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPPollingSetup:
|
||||
adapter: str
|
||||
self_id: str
|
||||
url: str
|
||||
method: str
|
||||
body: bytes
|
||||
headers: Dict[str, str]
|
||||
http_version: str
|
||||
poll_interval: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketSetup:
|
||||
adapter: str
|
||||
self_id: str
|
||||
url: str
|
||||
headers: Dict[str, str]
|
||||
http_version: str
|
||||
reconnect_interval: float
|
||||
HTTPPOLLING_SETUP = Union[HTTPPollingSetup,
|
||||
Callable[[], Awaitable[HTTPPollingSetup]]]
|
||||
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
@ -118,8 +103,8 @@ class Driver(ReverseDriver, ForwardDriver):
|
||||
super().__init__(env, config)
|
||||
|
||||
self.fastapi_config: Config = Config(**config.dict())
|
||||
self.http_pollings: List[HTTPPollingSetup] = []
|
||||
self.websockets: List[WebSocketSetup] = []
|
||||
self.http_pollings: List[HTTPPOLLING_SETUP] = []
|
||||
self.websockets: List[WEBSOCKET_SETUP] = []
|
||||
self.shutdown: asyncio.Event = asyncio.Event()
|
||||
self.connections: List[asyncio.Task] = []
|
||||
|
||||
@ -173,30 +158,30 @@ class Driver(ReverseDriver, ForwardDriver):
|
||||
return self.server_app.on_event("shutdown")(func)
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def setup_http_polling(self,
|
||||
adapter: str,
|
||||
self_id: str,
|
||||
url: str,
|
||||
polling_interval: float = 3.,
|
||||
method: str = "GET",
|
||||
body: bytes = b"",
|
||||
headers: Dict[str, str] = {},
|
||||
http_version: str = "1.1") -> None:
|
||||
self.http_pollings.append(
|
||||
HTTPPollingSetup(adapter, self_id, url, method, body, headers,
|
||||
http_version, polling_interval))
|
||||
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
|
||||
"""
|
||||
self.http_pollings.append(setup)
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def setup_websocket(self,
|
||||
adapter: str,
|
||||
self_id: str,
|
||||
url: str,
|
||||
reconnect_interval: float = 3.,
|
||||
headers: Dict[str, str] = {},
|
||||
http_version: str = "1.1") -> None:
|
||||
self.websockets.append(
|
||||
WebSocketSetup(adapter, self_id, url, headers, http_version,
|
||||
reconnect_interval))
|
||||
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
|
||||
"""
|
||||
self.websockets.append(setup)
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
def run(self,
|
||||
@ -336,50 +321,72 @@ class Driver(ReverseDriver, ForwardDriver):
|
||||
finally:
|
||||
self._bot_disconnect(bot)
|
||||
|
||||
async def _http_loop(self, setup: HTTPPollingSetup):
|
||||
url = httpx.URL(setup.url)
|
||||
if not url.netloc:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
|
||||
return
|
||||
request = HTTPRequest(
|
||||
setup.http_version, url.scheme, url.path, url.query, {
|
||||
**setup.headers, "host": url.netloc.decode("ascii")
|
||||
}, setup.method, setup.body)
|
||||
async def _http_loop(self, setup: HTTPPOLLING_SETUP):
|
||||
|
||||
async def _build_request(
|
||||
setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
|
||||
url = httpx.URL(setup.url)
|
||||
if not url.netloc:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
|
||||
return
|
||||
return HTTPRequest(
|
||||
setup.http_version, url.scheme, url.path, url.query, {
|
||||
**setup.headers, "host": url.netloc.decode("ascii")
|
||||
}, setup.method, setup.body)
|
||||
|
||||
bot: Optional[Bot] = None
|
||||
request: Optional[HTTPRequest] = None
|
||||
setup_: Optional[HTTPPollingSetup] = None
|
||||
|
||||
BotClass = self._adapters[setup.adapter]
|
||||
bot = BotClass(setup.self_id, request)
|
||||
self._bot_connect(bot)
|
||||
logger.opt(colors=True).info(
|
||||
f"Start http polling for <y>{setup.adapter.upper()} "
|
||||
f"Bot {setup.self_id}</y>")
|
||||
|
||||
headers = request.headers
|
||||
http2: bool = False
|
||||
if request.http_version == "2":
|
||||
http2 = True
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(headers=headers,
|
||||
timeout=30.,
|
||||
http2=http2) as session:
|
||||
async with httpx.AsyncClient(http2=True) as session:
|
||||
while not self.shutdown.is_set():
|
||||
if not bot:
|
||||
if callable(setup):
|
||||
setup_ = await setup()
|
||||
else:
|
||||
setup_ = setup
|
||||
request = await _build_request(setup_)
|
||||
if not request:
|
||||
return
|
||||
BotClass = self._adapters[setup.adapter]
|
||||
bot = BotClass(setup.self_id, request)
|
||||
self._bot_connect(bot)
|
||||
elif callable(setup):
|
||||
setup_ = await setup()
|
||||
request = await _build_request(setup_)
|
||||
if not request:
|
||||
await asyncio.sleep(setup_.poll_interval)
|
||||
continue
|
||||
bot.request = request
|
||||
|
||||
setup_ = cast(HTTPPollingSetup, setup_)
|
||||
request = cast(HTTPRequest, request)
|
||||
headers = request.headers
|
||||
|
||||
logger.debug(
|
||||
f"Bot {setup.self_id} from adapter {setup.adapter} request {url}"
|
||||
f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
|
||||
)
|
||||
try:
|
||||
response = await session.request(request.method,
|
||||
url,
|
||||
content=request.body)
|
||||
setup_.url,
|
||||
content=request.body,
|
||||
headers=headers,
|
||||
timeout=30.)
|
||||
response.raise_for_status()
|
||||
data = response.read()
|
||||
asyncio.create_task(bot.handle_message(data))
|
||||
except httpx.HTTPError as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f"<r><bg #f8bbd0>Error occurred while requesting {url}. "
|
||||
f"<r><bg #f8bbd0>Error occurred while requesting {setup_.url}. "
|
||||
"Try to reconnect...</bg #f8bbd0></r>")
|
||||
|
||||
await asyncio.sleep(setup.poll_interval)
|
||||
await asyncio.sleep(setup_.poll_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
@ -388,34 +395,41 @@ class Driver(ReverseDriver, ForwardDriver):
|
||||
"<r><bg #f8bbd0>Unexpected exception occurred "
|
||||
"while http polling</bg #f8bbd0></r>")
|
||||
finally:
|
||||
self._bot_disconnect(bot)
|
||||
|
||||
async def _ws_loop(self, setup: WebSocketSetup):
|
||||
url = httpx.URL(setup.url)
|
||||
if not url.netloc:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
|
||||
return
|
||||
|
||||
headers = {**setup.headers, "host": url.netloc.decode("ascii")}
|
||||
if bot:
|
||||
self._bot_disconnect(bot)
|
||||
|
||||
async def _ws_loop(self, setup: WEBSOCKET_SETUP):
|
||||
bot: Optional[Bot] = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
if callable(setup):
|
||||
setup_ = await setup()
|
||||
else:
|
||||
setup_ = setup
|
||||
|
||||
url = httpx.URL(setup_.url)
|
||||
if not url.netloc:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>"
|
||||
)
|
||||
return
|
||||
|
||||
headers = {**setup_.headers, "host": url.netloc.decode("ascii")}
|
||||
logger.debug(
|
||||
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
|
||||
f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
|
||||
)
|
||||
try:
|
||||
connection = Connect(setup.url)
|
||||
connection = Connect(setup_.url)
|
||||
async with connection as ws:
|
||||
logger.opt(colors=True).info(
|
||||
f"WebSocket Connection to <y>{setup.adapter.upper()} "
|
||||
f"Bot {setup.self_id}</y> succeeded!")
|
||||
request = WebSocket(setup.http_version, url.scheme,
|
||||
url.path, url.query, headers, ws)
|
||||
f"WebSocket Connection to <y>{setup_.adapter.upper()} "
|
||||
f"Bot {setup_.self_id}</y> succeeded!")
|
||||
request = WebSocket("1.1", url.scheme, url.path,
|
||||
url.query, headers, ws)
|
||||
|
||||
BotClass = self._adapters[setup.adapter]
|
||||
bot = BotClass(setup.self_id, request)
|
||||
BotClass = self._adapters[setup_.adapter]
|
||||
bot = BotClass(setup_.self_id, request)
|
||||
self._bot_connect(bot)
|
||||
while not self.shutdown.is_set():
|
||||
# use try except instead of "request.closed" because of queued message
|
||||
@ -434,7 +448,7 @@ class Driver(ReverseDriver, ForwardDriver):
|
||||
if bot:
|
||||
self._bot_disconnect(bot)
|
||||
bot = None
|
||||
await asyncio.sleep(setup.reconnect_interval)
|
||||
await asyncio.sleep(setup_.reconnect_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
Reference in New Issue
Block a user