mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-28 00:31:14 +00:00
💥 change forward setup api
This commit is contained in:
@ -8,7 +8,7 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Set, Dict, Type, Optional, Callable, TYPE_CHECKING
|
||||
from typing import Any, Set, Dict, Type, Union, Optional, Callable, Awaitable, TYPE_CHECKING
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Env, Config
|
||||
@ -193,27 +193,40 @@ class Driver(abc.ABC):
|
||||
|
||||
|
||||
class ForwardDriver(Driver):
|
||||
"""
|
||||
Forward Driver 基类。将客户端框架封装,以满足适配器使用。
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup_http_polling(self,
|
||||
adapter: str,
|
||||
self_id: str,
|
||||
url: str,
|
||||
polling_interval: float = 3.,
|
||||
method: str = "GET",
|
||||
body: bytes = b"",
|
||||
headers: Dict[str, str] = {},
|
||||
http_version: str = "1.1") -> None:
|
||||
def setup_http_polling(
|
||||
self, setup: Union["HTTPPollingSetup",
|
||||
Callable[[], Awaitable["HTTPPollingSetup"]]]
|
||||
) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup_websocket(self,
|
||||
adapter: str,
|
||||
self_id: str,
|
||||
url: str,
|
||||
reconnect_interval: float = 3.,
|
||||
headers: Dict[str, str] = {},
|
||||
http_version: str = "1.1") -> None:
|
||||
def setup_websocket(
|
||||
self, setup: Union["WebSocketSetup",
|
||||
Callable[[], Awaitable["WebSocketSetup"]]]
|
||||
) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -369,3 +382,37 @@ class WebSocket(HTTPConnection, abc.ABC):
|
||||
async def send_bytes(self, data: bytes):
|
||||
"""发送一条 WebSocket binary 信息"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPPollingSetup:
|
||||
adapter: str
|
||||
"""协议适配器名称"""
|
||||
self_id: str
|
||||
"""机器人 ID"""
|
||||
url: str
|
||||
"""URL"""
|
||||
method: str
|
||||
"""HTTP method"""
|
||||
body: bytes
|
||||
"""HTTP body"""
|
||||
headers: Dict[str, str]
|
||||
"""HTTP headers"""
|
||||
http_version: str
|
||||
"""HTTP version"""
|
||||
poll_interval: float
|
||||
"""HTTP 轮询间隔"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketSetup:
|
||||
adapter: str
|
||||
"""协议适配器名称"""
|
||||
self_id: str
|
||||
"""机器人 ID"""
|
||||
url: str
|
||||
"""URL"""
|
||||
headers: Dict[str, str] = field(default_factory=dict)
|
||||
"""HTTP headers"""
|
||||
reconnect_interval: float = 3.
|
||||
"""WebSocket 重连间隔"""
|
||||
|
@ -1,11 +1,15 @@
|
||||
"""
|
||||
AIOHTTP 驱动适配
|
||||
================
|
||||
|
||||
本驱动仅支持客户端连接
|
||||
"""
|
||||
|
||||
import signal
|
||||
import asyncio
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Set, List, Dict, Optional, Callable, Awaitable
|
||||
from typing import Set, List, cast, Union, Optional, Callable, Awaitable
|
||||
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
@ -14,46 +18,31 @@ from nonebot.log import logger
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.drivers import ForwardDriver, HTTPRequest, WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import (ForwardDriver, HTTPPollingSetup, WebSocketSetup,
|
||||
HTTPRequest, WebSocket as BaseWebSocket)
|
||||
|
||||
STARTUP_FUNC = Callable[[], Awaitable[None]]
|
||||
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
|
||||
HTTPPOLLING_SETUP = Union[HTTPPollingSetup,
|
||||
Callable[[], Awaitable[HTTPPollingSetup]]]
|
||||
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
|
||||
HANDLED_SIGNALS = (
|
||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPPollingSetup:
|
||||
adapter: str
|
||||
self_id: str
|
||||
url: str
|
||||
method: str
|
||||
body: bytes
|
||||
headers: Dict[str, str]
|
||||
http_version: str
|
||||
poll_interval: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketSetup:
|
||||
adapter: str
|
||||
self_id: str
|
||||
url: str
|
||||
headers: Dict[str, str]
|
||||
http_version: str
|
||||
reconnect_interval: float
|
||||
|
||||
|
||||
class Driver(ForwardDriver):
|
||||
"""
|
||||
AIOHTTP 驱动框架
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env, config: Config):
|
||||
super().__init__(env, config)
|
||||
self.startup_funcs: Set[STARTUP_FUNC] = set()
|
||||
self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set()
|
||||
self.http_pollings: List[HTTPPollingSetup] = []
|
||||
self.websockets: List[WebSocketSetup] = []
|
||||
self.http_pollings: List[HTTPPOLLING_SETUP] = []
|
||||
self.websockets: List[WEBSOCKET_SETUP] = []
|
||||
self.connections: List[asyncio.Task] = []
|
||||
self.should_exit: asyncio.Event = asyncio.Event()
|
||||
self.force_exit: bool = False
|
||||
@ -67,46 +56,66 @@ class Driver(ForwardDriver):
|
||||
@property
|
||||
@overrides(ForwardDriver)
|
||||
def logger(self):
|
||||
"""aiohttp driver 使用的 logger"""
|
||||
return logger
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def on_startup(self, func: Callable) -> Callable:
|
||||
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个启动时执行的函数
|
||||
|
||||
:参数:
|
||||
|
||||
* ``func: Callable[[], Awaitable[None]]``
|
||||
"""
|
||||
self.startup_funcs.add(func)
|
||||
return func
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def on_shutdown(self, func: Callable) -> Callable:
|
||||
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个停止时执行的函数
|
||||
|
||||
:参数:
|
||||
|
||||
* ``func: Callable[[], Awaitable[None]]``
|
||||
"""
|
||||
self.shutdown_funcs.add(func)
|
||||
return func
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def setup_http_polling(self,
|
||||
adapter: str,
|
||||
self_id: str,
|
||||
url: str,
|
||||
polling_interval: float = 3.,
|
||||
method: str = "GET",
|
||||
body: bytes = b"",
|
||||
headers: Dict[str, str] = {},
|
||||
http_version: str = "1.1") -> None:
|
||||
self.http_pollings.append(
|
||||
HTTPPollingSetup(adapter, self_id, url, method, body, headers,
|
||||
http_version, polling_interval))
|
||||
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
|
||||
"""
|
||||
self.http_pollings.append(setup)
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def setup_websocket(self,
|
||||
adapter: str,
|
||||
self_id: str,
|
||||
url: str,
|
||||
reconnect_interval: float = 3.,
|
||||
headers: Dict[str, str] = {},
|
||||
http_version: str = "1.1") -> None:
|
||||
self.websockets.append(
|
||||
WebSocketSetup(adapter, self_id, url, headers, http_version,
|
||||
reconnect_interval))
|
||||
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
|
||||
"""
|
||||
self.websockets.append(setup)
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def run(self, *args, **kwargs):
|
||||
"""启动 aiohttp driver"""
|
||||
super().run(*args, **kwargs)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.serve())
|
||||
@ -197,59 +206,88 @@ class Driver(ForwardDriver):
|
||||
else:
|
||||
self.should_exit.set()
|
||||
|
||||
async def _http_loop(self, setup: HTTPPollingSetup):
|
||||
url = URL(setup.url)
|
||||
if not url.is_absolute() or not url.host:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
|
||||
return
|
||||
host = f"{url.host}:{url.port}" if url.port else url.host
|
||||
request = HTTPRequest(setup.http_version, url.scheme, url.path,
|
||||
url.raw_query_string.encode("latin-1"), {
|
||||
**setup.headers, "host": host
|
||||
}, setup.method, setup.body)
|
||||
async def _http_loop(self, setup: HTTPPOLLING_SETUP):
|
||||
|
||||
async def _build_request(
|
||||
setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
|
||||
url = URL(setup.url)
|
||||
if not url.is_absolute() or not url.host:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
|
||||
return
|
||||
host = f"{url.host}:{url.port}" if url.port else url.host
|
||||
return HTTPRequest(setup.http_version, url.scheme, url.path,
|
||||
url.raw_query_string.encode("latin-1"), {
|
||||
**setup.headers, "host": host
|
||||
}, setup.method, setup.body)
|
||||
|
||||
bot: Optional[Bot] = None
|
||||
request: Optional[HTTPRequest] = None
|
||||
setup_: Optional[HTTPPollingSetup] = None
|
||||
|
||||
BotClass = self._adapters[setup.adapter]
|
||||
bot = BotClass(setup.self_id, request)
|
||||
self._bot_connect(bot)
|
||||
logger.opt(colors=True).info(
|
||||
f"Start http polling for <y>{setup.adapter.upper()} "
|
||||
f"Bot {setup.self_id}</y>")
|
||||
|
||||
headers = request.headers
|
||||
timeout = aiohttp.ClientTimeout(30)
|
||||
version: aiohttp.HttpVersion
|
||||
if request.http_version == "1.0":
|
||||
version = aiohttp.HttpVersion10
|
||||
elif request.http_version == "1.1":
|
||||
version = aiohttp.HttpVersion11
|
||||
else:
|
||||
logger.opt(colors=True).error(
|
||||
"<r><bg #f8bbd0>Unsupported HTTP Version "
|
||||
f"{request.http_version}</bg #f8bbd0></r>")
|
||||
return
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=headers,
|
||||
timeout=timeout,
|
||||
version=version) as session:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while not self.should_exit.is_set():
|
||||
if not bot:
|
||||
if callable(setup):
|
||||
setup_ = await setup()
|
||||
else:
|
||||
setup_ = setup
|
||||
request = await _build_request(setup_)
|
||||
if not request:
|
||||
return
|
||||
|
||||
BotClass = self._adapters[setup.adapter]
|
||||
bot = BotClass(setup.self_id, request)
|
||||
self._bot_connect(bot)
|
||||
elif callable(setup):
|
||||
setup_ = await setup()
|
||||
request = await _build_request(setup_)
|
||||
if not request:
|
||||
await asyncio.sleep(setup_.poll_interval)
|
||||
continue
|
||||
bot.request = request
|
||||
|
||||
request = cast(HTTPRequest, request)
|
||||
setup_ = cast(HTTPPollingSetup, setup_)
|
||||
|
||||
headers = request.headers
|
||||
timeout = aiohttp.ClientTimeout(30)
|
||||
version: aiohttp.HttpVersion
|
||||
if request.http_version == "1.0":
|
||||
version = aiohttp.HttpVersion10
|
||||
elif request.http_version == "1.1":
|
||||
version = aiohttp.HttpVersion11
|
||||
else:
|
||||
logger.opt(colors=True).error(
|
||||
"<r><bg #f8bbd0>Unsupported HTTP Version "
|
||||
f"{request.http_version}</bg #f8bbd0></r>")
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Bot {setup.self_id} from adapter {setup.adapter} request {url}"
|
||||
f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
|
||||
)
|
||||
|
||||
try:
|
||||
async with session.request(
|
||||
request.method, url,
|
||||
data=request.body) as response:
|
||||
async with session.request(request.method,
|
||||
setup_.url,
|
||||
data=request.body,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
version=version) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.read()
|
||||
asyncio.create_task(bot.handle_message(data))
|
||||
except aiohttp.ClientResponseError as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f"<r><bg #f8bbd0>Error occurred while requesting {url}. "
|
||||
f"<r><bg #f8bbd0>Error occurred while requesting {setup_.url}. "
|
||||
"Try to reconnect...</bg #f8bbd0></r>")
|
||||
|
||||
await asyncio.sleep(setup.poll_interval)
|
||||
await asyncio.sleep(setup_.poll_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
@ -258,50 +296,48 @@ class Driver(ForwardDriver):
|
||||
"<r><bg #f8bbd0>Unexpected exception occurred "
|
||||
"while http polling</bg #f8bbd0></r>")
|
||||
finally:
|
||||
self._bot_disconnect(bot)
|
||||
|
||||
async def _ws_loop(self, setup: WebSocketSetup):
|
||||
url = URL(setup.url)
|
||||
if not url.is_absolute() or not url.host:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
|
||||
return
|
||||
host = f"{url.host}:{url.port}" if url.port else url.host
|
||||
|
||||
headers = {**setup.headers, "host": host}
|
||||
timeout = aiohttp.ClientTimeout(30)
|
||||
version: aiohttp.HttpVersion
|
||||
if setup.http_version == "1.0":
|
||||
version = aiohttp.HttpVersion10
|
||||
elif setup.http_version == "1.1":
|
||||
version = aiohttp.HttpVersion11
|
||||
else:
|
||||
logger.opt(colors=True).error(
|
||||
"<r><bg #f8bbd0>Unsupported HTTP Version "
|
||||
f"{setup.http_version}</bg #f8bbd0></r>")
|
||||
return
|
||||
if bot:
|
||||
self._bot_disconnect(bot)
|
||||
|
||||
async def _ws_loop(self, setup: WEBSOCKET_SETUP):
|
||||
bot: Optional[Bot] = None
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=headers,
|
||||
timeout=timeout,
|
||||
version=version) as session:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while True:
|
||||
if callable(setup):
|
||||
setup_ = await setup()
|
||||
else:
|
||||
setup_ = setup
|
||||
|
||||
url = URL(setup_.url)
|
||||
if not url.is_absolute() or not url.host:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>"
|
||||
)
|
||||
await asyncio.sleep(setup_.reconnect_interval)
|
||||
continue
|
||||
|
||||
host = f"{url.host}:{url.port}" if url.port else url.host
|
||||
headers = {**setup_.headers, "host": host}
|
||||
|
||||
logger.debug(
|
||||
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
|
||||
f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
|
||||
)
|
||||
try:
|
||||
async with session.ws_connect(url) as ws:
|
||||
async with session.ws_connect(url,
|
||||
headers=headers,
|
||||
timeout=30.) as ws:
|
||||
logger.opt(colors=True).info(
|
||||
f"WebSocket Connection to <y>{setup.adapter.upper()} "
|
||||
f"Bot {setup.self_id}</y> succeeded!")
|
||||
f"WebSocket Connection to <y>{setup_.adapter.upper()} "
|
||||
f"Bot {setup_.self_id}</y> succeeded!")
|
||||
request = WebSocket(
|
||||
setup.http_version, url.scheme, url.path,
|
||||
"1.1", url.scheme, url.path,
|
||||
url.raw_query_string.encode("latin-1"), headers,
|
||||
ws)
|
||||
|
||||
BotClass = self._adapters[setup.adapter]
|
||||
bot = BotClass(setup.self_id, request)
|
||||
BotClass = self._adapters[setup_.adapter]
|
||||
bot = BotClass(setup_.self_id, request)
|
||||
self._bot_connect(bot)
|
||||
while not self.should_exit.is_set():
|
||||
msg = await ws.receive()
|
||||
@ -330,7 +366,7 @@ class Driver(ForwardDriver):
|
||||
if bot:
|
||||
self._bot_disconnect(bot)
|
||||
bot = None
|
||||
await asyncio.sleep(setup.reconnect_interval)
|
||||
await asyncio.sleep(setup_.reconnect_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
@ -2,6 +2,8 @@
|
||||
FastAPI 驱动适配
|
||||
================
|
||||
|
||||
本驱动同时支持服务端以及客户端连接
|
||||
|
||||
后端使用方法请参考: `FastAPI 文档`_
|
||||
|
||||
.. _FastAPI 文档:
|
||||
@ -11,7 +13,7 @@ FastAPI 驱动适配
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Union, Optional, Callable
|
||||
from typing import List, cast, Union, Optional, Callable, Awaitable
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
@ -27,30 +29,13 @@ from nonebot.log import logger
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.config import Env, Config as NoneBotConfig
|
||||
from nonebot.drivers import ReverseDriver, ForwardDriver
|
||||
from nonebot.drivers import HTTPRequest, WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import (ReverseDriver, ForwardDriver, HTTPPollingSetup,
|
||||
WebSocketSetup, HTTPRequest, WebSocket as
|
||||
BaseWebSocket)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPPollingSetup:
|
||||
adapter: str
|
||||
self_id: str
|
||||
url: str
|
||||
method: str
|
||||
body: bytes
|
||||
headers: Dict[str, str]
|
||||
http_version: str
|
||||
poll_interval: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketSetup:
|
||||
adapter: str
|
||||
self_id: str
|
||||
url: str
|
||||
headers: Dict[str, str]
|
||||
http_version: str
|
||||
reconnect_interval: float
|
||||
HTTPPOLLING_SETUP = Union[HTTPPollingSetup,
|
||||
Callable[[], Awaitable[HTTPPollingSetup]]]
|
||||
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
@ -118,8 +103,8 @@ class Driver(ReverseDriver, ForwardDriver):
|
||||
super().__init__(env, config)
|
||||
|
||||
self.fastapi_config: Config = Config(**config.dict())
|
||||
self.http_pollings: List[HTTPPollingSetup] = []
|
||||
self.websockets: List[WebSocketSetup] = []
|
||||
self.http_pollings: List[HTTPPOLLING_SETUP] = []
|
||||
self.websockets: List[WEBSOCKET_SETUP] = []
|
||||
self.shutdown: asyncio.Event = asyncio.Event()
|
||||
self.connections: List[asyncio.Task] = []
|
||||
|
||||
@ -173,30 +158,30 @@ class Driver(ReverseDriver, ForwardDriver):
|
||||
return self.server_app.on_event("shutdown")(func)
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def setup_http_polling(self,
|
||||
adapter: str,
|
||||
self_id: str,
|
||||
url: str,
|
||||
polling_interval: float = 3.,
|
||||
method: str = "GET",
|
||||
body: bytes = b"",
|
||||
headers: Dict[str, str] = {},
|
||||
http_version: str = "1.1") -> None:
|
||||
self.http_pollings.append(
|
||||
HTTPPollingSetup(adapter, self_id, url, method, body, headers,
|
||||
http_version, polling_interval))
|
||||
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
|
||||
"""
|
||||
self.http_pollings.append(setup)
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def setup_websocket(self,
|
||||
adapter: str,
|
||||
self_id: str,
|
||||
url: str,
|
||||
reconnect_interval: float = 3.,
|
||||
headers: Dict[str, str] = {},
|
||||
http_version: str = "1.1") -> None:
|
||||
self.websockets.append(
|
||||
WebSocketSetup(adapter, self_id, url, headers, http_version,
|
||||
reconnect_interval))
|
||||
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
|
||||
"""
|
||||
self.websockets.append(setup)
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
def run(self,
|
||||
@ -336,50 +321,72 @@ class Driver(ReverseDriver, ForwardDriver):
|
||||
finally:
|
||||
self._bot_disconnect(bot)
|
||||
|
||||
async def _http_loop(self, setup: HTTPPollingSetup):
|
||||
url = httpx.URL(setup.url)
|
||||
if not url.netloc:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
|
||||
return
|
||||
request = HTTPRequest(
|
||||
setup.http_version, url.scheme, url.path, url.query, {
|
||||
**setup.headers, "host": url.netloc.decode("ascii")
|
||||
}, setup.method, setup.body)
|
||||
async def _http_loop(self, setup: HTTPPOLLING_SETUP):
|
||||
|
||||
async def _build_request(
|
||||
setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
|
||||
url = httpx.URL(setup.url)
|
||||
if not url.netloc:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
|
||||
return
|
||||
return HTTPRequest(
|
||||
setup.http_version, url.scheme, url.path, url.query, {
|
||||
**setup.headers, "host": url.netloc.decode("ascii")
|
||||
}, setup.method, setup.body)
|
||||
|
||||
bot: Optional[Bot] = None
|
||||
request: Optional[HTTPRequest] = None
|
||||
setup_: Optional[HTTPPollingSetup] = None
|
||||
|
||||
BotClass = self._adapters[setup.adapter]
|
||||
bot = BotClass(setup.self_id, request)
|
||||
self._bot_connect(bot)
|
||||
logger.opt(colors=True).info(
|
||||
f"Start http polling for <y>{setup.adapter.upper()} "
|
||||
f"Bot {setup.self_id}</y>")
|
||||
|
||||
headers = request.headers
|
||||
http2: bool = False
|
||||
if request.http_version == "2":
|
||||
http2 = True
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(headers=headers,
|
||||
timeout=30.,
|
||||
http2=http2) as session:
|
||||
async with httpx.AsyncClient(http2=True) as session:
|
||||
while not self.shutdown.is_set():
|
||||
if not bot:
|
||||
if callable(setup):
|
||||
setup_ = await setup()
|
||||
else:
|
||||
setup_ = setup
|
||||
request = await _build_request(setup_)
|
||||
if not request:
|
||||
return
|
||||
BotClass = self._adapters[setup.adapter]
|
||||
bot = BotClass(setup.self_id, request)
|
||||
self._bot_connect(bot)
|
||||
elif callable(setup):
|
||||
setup_ = await setup()
|
||||
request = await _build_request(setup_)
|
||||
if not request:
|
||||
await asyncio.sleep(setup_.poll_interval)
|
||||
continue
|
||||
bot.request = request
|
||||
|
||||
setup_ = cast(HTTPPollingSetup, setup_)
|
||||
request = cast(HTTPRequest, request)
|
||||
headers = request.headers
|
||||
|
||||
logger.debug(
|
||||
f"Bot {setup.self_id} from adapter {setup.adapter} request {url}"
|
||||
f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
|
||||
)
|
||||
try:
|
||||
response = await session.request(request.method,
|
||||
url,
|
||||
content=request.body)
|
||||
setup_.url,
|
||||
content=request.body,
|
||||
headers=headers,
|
||||
timeout=30.)
|
||||
response.raise_for_status()
|
||||
data = response.read()
|
||||
asyncio.create_task(bot.handle_message(data))
|
||||
except httpx.HTTPError as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f"<r><bg #f8bbd0>Error occurred while requesting {url}. "
|
||||
f"<r><bg #f8bbd0>Error occurred while requesting {setup_.url}. "
|
||||
"Try to reconnect...</bg #f8bbd0></r>")
|
||||
|
||||
await asyncio.sleep(setup.poll_interval)
|
||||
await asyncio.sleep(setup_.poll_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
@ -388,34 +395,41 @@ class Driver(ReverseDriver, ForwardDriver):
|
||||
"<r><bg #f8bbd0>Unexpected exception occurred "
|
||||
"while http polling</bg #f8bbd0></r>")
|
||||
finally:
|
||||
self._bot_disconnect(bot)
|
||||
|
||||
async def _ws_loop(self, setup: WebSocketSetup):
|
||||
url = httpx.URL(setup.url)
|
||||
if not url.netloc:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
|
||||
return
|
||||
|
||||
headers = {**setup.headers, "host": url.netloc.decode("ascii")}
|
||||
if bot:
|
||||
self._bot_disconnect(bot)
|
||||
|
||||
async def _ws_loop(self, setup: WEBSOCKET_SETUP):
|
||||
bot: Optional[Bot] = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
if callable(setup):
|
||||
setup_ = await setup()
|
||||
else:
|
||||
setup_ = setup
|
||||
|
||||
url = httpx.URL(setup_.url)
|
||||
if not url.netloc:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>"
|
||||
)
|
||||
return
|
||||
|
||||
headers = {**setup_.headers, "host": url.netloc.decode("ascii")}
|
||||
logger.debug(
|
||||
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
|
||||
f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
|
||||
)
|
||||
try:
|
||||
connection = Connect(setup.url)
|
||||
connection = Connect(setup_.url)
|
||||
async with connection as ws:
|
||||
logger.opt(colors=True).info(
|
||||
f"WebSocket Connection to <y>{setup.adapter.upper()} "
|
||||
f"Bot {setup.self_id}</y> succeeded!")
|
||||
request = WebSocket(setup.http_version, url.scheme,
|
||||
url.path, url.query, headers, ws)
|
||||
f"WebSocket Connection to <y>{setup_.adapter.upper()} "
|
||||
f"Bot {setup_.self_id}</y> succeeded!")
|
||||
request = WebSocket("1.1", url.scheme, url.path,
|
||||
url.query, headers, ws)
|
||||
|
||||
BotClass = self._adapters[setup.adapter]
|
||||
bot = BotClass(setup.self_id, request)
|
||||
BotClass = self._adapters[setup_.adapter]
|
||||
bot = BotClass(setup_.self_id, request)
|
||||
self._bot_connect(bot)
|
||||
while not self.shutdown.is_set():
|
||||
# use try except instead of "request.closed" because of queued message
|
||||
@ -434,7 +448,7 @@ class Driver(ReverseDriver, ForwardDriver):
|
||||
if bot:
|
||||
self._bot_disconnect(bot)
|
||||
bot = None
|
||||
await asyncio.sleep(setup.reconnect_interval)
|
||||
await asyncio.sleep(setup_.reconnect_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
Reference in New Issue
Block a user