Feature: WS 支持 ping interval/timeout 配置 (#3964)

Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
This commit is contained in:
StarHeart
2026-04-20 19:39:28 +08:00
committed by GitHub
parent 6b1c616860
commit 2b77b122af
4 changed files with 126 additions and 13 deletions

View File

@@ -46,6 +46,7 @@ from nonebot.internal.driver import (
Timeout, Timeout,
TimeoutTypes, TimeoutTypes,
) )
from nonebot.log import logger
from nonebot.utils import UNSET, UnsetType, exclude_unset from nonebot.utils import UNSET, UnsetType, exclude_unset
try: try:
@@ -324,6 +325,16 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
) )
) )
heartbeat = None
if setup.ping_interval is not UNSET:
heartbeat = setup.ping_interval
if isinstance(setup.timeout, Timeout) and setup.timeout.ping is not UNSET:
logger.warning(
"aiohttp driver does not expose a separate ping timeout; "
"the configured ping timeout will be ignored."
)
async with aiohttp.ClientSession(version=version, trust_env=True) as session: async with aiohttp.ClientSession(version=version, trust_env=True) as session:
async with session.ws_connect( async with session.ws_connect(
setup.url, setup.url,
@@ -331,6 +342,8 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
timeout=timeout, timeout=timeout,
headers=setup.headers, headers=setup.headers,
proxy=setup.proxy, proxy=setup.proxy,
autoping=heartbeat is not None,
heartbeat=heartbeat,
) as ws: ) as ws:
yield WebSocket(request=setup, session=session, websocket=ws) yield WebSocket(request=setup, session=session, websocket=ws)

View File

@@ -36,7 +36,7 @@ from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.none import Driver as NoneDriver from nonebot.drivers.none import Driver as NoneDriver
from nonebot.exception import WebSocketClosed from nonebot.exception import WebSocketClosed
from nonebot.log import LoguruHandler from nonebot.log import LoguruHandler
from nonebot.utils import UNSET, exclude_unset from nonebot.utils import UNSET, UnsetType, exclude_unset
try: try:
from websockets import ClientConnection, ConnectionClosed, connect from websockets import ClientConnection, ConnectionClosed, connect
@@ -77,14 +77,17 @@ class Mixin(WebSocketClientMixin):
@override @override
@asynccontextmanager @asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]: async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
timeout_kwargs: dict[str, float | None] = {} timeout_kwargs: dict[str, float | None | UnsetType] = {}
if isinstance(setup.timeout, Timeout): if isinstance(setup.timeout, Timeout):
open_timeout = ( open_timeout = (
setup.timeout.connect or setup.timeout.read or setup.timeout.total setup.timeout.connect or setup.timeout.read or setup.timeout.total
) )
timeout_kwargs = exclude_unset( timeout_kwargs = {
{"open_timeout": open_timeout, "close_timeout": setup.timeout.close} "open_timeout": open_timeout,
) "close_timeout": setup.timeout.close,
"ping_timeout": setup.timeout.ping,
}
elif setup.timeout is not UNSET: elif setup.timeout is not UNSET:
timeout_kwargs = { timeout_kwargs = {
"open_timeout": setup.timeout, "open_timeout": setup.timeout,
@@ -95,18 +98,24 @@ class Mixin(WebSocketClientMixin):
open_timeout = ( open_timeout = (
DEFAULT_TIMEOUT.connect or DEFAULT_TIMEOUT.read or DEFAULT_TIMEOUT.total DEFAULT_TIMEOUT.connect or DEFAULT_TIMEOUT.read or DEFAULT_TIMEOUT.total
) )
timeout_kwargs = exclude_unset( timeout_kwargs = {
{ "open_timeout": open_timeout,
"open_timeout": open_timeout, "close_timeout": DEFAULT_TIMEOUT.close,
"close_timeout": DEFAULT_TIMEOUT.close, "ping_timeout": DEFAULT_TIMEOUT.ping,
} }
)
kwargs = exclude_unset(
{
**timeout_kwargs,
"ping_interval": setup.ping_interval,
}
)
connection = connect( connection = connect(
str(setup.url), str(setup.url),
additional_headers={**setup.headers, **setup.cookies.as_header(setup)}, additional_headers={**setup.headers, **setup.cookies.as_header(setup)},
proxy=setup.proxy if setup.proxy is not None else True, proxy=setup.proxy if setup.proxy is not None else True,
**timeout_kwargs, # type: ignore **kwargs, # type: ignore
) )
async with connection as ws: async with connection as ws:
yield WebSocket(request=setup, websocket=ws) yield WebSocket(request=setup, websocket=ws)

View File

@@ -20,9 +20,10 @@ class Timeout:
connect: float | None | UnsetType = UNSET connect: float | None | UnsetType = UNSET
read: float | None | UnsetType = UNSET read: float | None | UnsetType = UNSET
close: float | None | UnsetType = UNSET close: float | None | UnsetType = UNSET
ping: float | None | UnsetType = UNSET
DEFAULT_TIMEOUT = Timeout(total=None, connect=5.0, read=30.0, close=10.0) DEFAULT_TIMEOUT = Timeout(total=None, connect=5.0, read=30.0, close=10.0, ping=20.0)
RawURL: TypeAlias = tuple[bytes, bytes, int | None, bytes] RawURL: TypeAlias = tuple[bytes, bytes, int | None, bytes]
@@ -52,6 +53,7 @@ FileTypes: TypeAlias = (
) )
FilesTypes: TypeAlias = dict[str, FileTypes] | list[tuple[str, FileTypes]] | None FilesTypes: TypeAlias = dict[str, FileTypes] | list[tuple[str, FileTypes]] | None
TimeoutTypes: TypeAlias = float | Timeout | None TimeoutTypes: TypeAlias = float | Timeout | None
PingIntervalTypes: TypeAlias = float | None
class HTTPVersion(Enum): class HTTPVersion(Enum):
@@ -76,6 +78,7 @@ class Request:
version: str | HTTPVersion = HTTPVersion.H11, version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes | UnsetType = UNSET, timeout: TimeoutTypes | UnsetType = UNSET,
proxy: str | None = None, proxy: str | None = None,
ping_interval: PingIntervalTypes | UnsetType = UNSET,
): ):
# method # method
self.method: str = ( self.method: str = (
@@ -89,6 +92,8 @@ class Request:
self.timeout: TimeoutTypes | UnsetType = timeout self.timeout: TimeoutTypes | UnsetType = timeout
# proxy # proxy
self.proxy: str | None = proxy self.proxy: str | None = proxy
# ping interval
self.ping_interval: PingIntervalTypes | UnsetType = ping_interval
# url # url
if isinstance(url, tuple): if isinstance(url, tuple):

View File

@@ -878,6 +878,92 @@ async def test_websocket_client_timeout(driver: Driver, server_url: URL):
await anyio.sleep(1) await anyio.sleep(1)
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.websockets:Driver", id="websockets"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_websocket_client_ping_timeout(driver: Driver, server_url: URL):
"""WebSocket connections work with different ping_timeout settings."""
assert isinstance(driver, WebSocketClientMixin)
ws_url = server_url.with_scheme("ws")
# ping timeout not set (UNSET), falls back to DEFAULT_TIMEOUT.ping
request = Request("GET", ws_url, timeout=Timeout())
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
# ping timeout explicitly set to None (disable ping timeout)
request = Request("GET", ws_url, timeout=Timeout(ping=None))
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
# ping timeout set to a float value
request = Request("GET", ws_url, timeout=Timeout(ping=20.0))
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.websockets:Driver", id="websockets"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_websocket_client_ping_interval(driver: Driver, server_url: URL):
"""WebSocket connections work with different ping_interval settings."""
assert isinstance(driver, WebSocketClientMixin)
ws_url = server_url.with_scheme("ws")
# ping_interval not set (UNSET), default behavior
request = Request("GET", ws_url)
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
# ping_interval explicitly set to None (disable ping)
request = Request("GET", ws_url, ping_interval=None)
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
# ping_interval set to a float value
request = Request("GET", ws_url, ping_interval=20.0)
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("driver", "driver_type"), ("driver", "driver_type"),
[ [