🐛 Fix: 修正 http/websocket client timeout (#3923)

Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
StarHeart
2026-03-31 20:56:12 +08:00
committed by GitHub
parent cf8127ee4d
commit cbe6eee868
10 changed files with 443 additions and 72 deletions

View File

@@ -26,6 +26,7 @@ from nonebot.drivers.aiohttp import Session as AiohttpSession
from nonebot.drivers.aiohttp import WebSocket as AiohttpWebSocket
from nonebot.exception import WebSocketClosed
from nonebot.params import Depends
from nonebot.utils import UNSET
from utils import FakeAdapter
@@ -706,6 +707,177 @@ async def test_aiohttp_websocket_close_frame(msg_type: str) -> None:
await ws.receive()
def test_timeout_unset_vs_none():
# default: all fields are UNSET
t = Timeout()
assert t.total is UNSET
assert t.connect is UNSET
assert t.read is UNSET
assert t.close is UNSET
# explicitly set to None
t = Timeout(close=None)
assert t.close is None
assert t.close is not UNSET
# explicitly set to a value
t = Timeout(total=5.0, close=None)
assert t.total == 5.0
assert t.close is None
assert t.connect is UNSET
assert t.read is UNSET
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.httpx:Driver", id="httpx"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_http_client_timeout(driver: Driver, server_url: URL):
"""HTTP requests work with fully unset, partial, and None timeout fields."""
assert isinstance(driver, HTTPClientMixin)
# timeout not set, default timeout should apply
request = Request("POST", server_url, content="test")
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# timeout is float or none
request = Request("POST", server_url, content="test", timeout=10.0)
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# all fields unset, default timeout should apply
request = Request("POST", server_url, content="test", timeout=Timeout())
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# only total set
request = Request("POST", server_url, content="test", timeout=Timeout(total=10.0))
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# explicit None (no timeout)
request = Request(
"POST",
server_url,
content="test",
timeout=Timeout(total=None, connect=None, read=None),
)
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# session with timeout not set
session = driver.get_session()
async with session:
request = Request("POST", server_url, content="test")
response = await session.request(request)
assert response.status_code == 200
# session with float or none timeout
session = driver.get_session(timeout=10.0)
async with session:
request = Request("POST", server_url, content="test")
response = await session.request(request)
assert response.status_code == 200
# session with fully unset timeout
session = driver.get_session(timeout=Timeout())
async with session:
request = Request("POST", server_url, content="test")
response = await session.request(request)
assert response.status_code == 200
# session with timeout
session = driver.get_session(timeout=Timeout(total=10.0, connect=5.0, read=5.0))
async with session:
request = Request("POST", server_url, content="test")
response = await session.request(request)
assert response.status_code == 200
# session with timeout override
session = driver.get_session(timeout=Timeout(total=10.0))
async with session:
request = Request(
"POST", server_url, content="test", timeout=Timeout(total=20.0)
)
response = await session.request(request)
assert response.status_code == 200
# session with timeout float override
session = driver.get_session(timeout=Timeout(total=10.0))
async with session:
request = Request("POST", server_url, content="test", timeout=20.0)
response = await session.request(request)
assert response.status_code == 200
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.websockets:Driver", id="websockets"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_websocket_client_timeout(driver: Driver, server_url: URL):
"""WebSocket connections work with fully unset, partial, and None timeout fields."""
assert isinstance(driver, WebSocketClientMixin)
ws_url = server_url.with_scheme("ws")
# timeout not set, default timeout should apply
request = Request("GET", ws_url)
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
# timeout is float or none
request = Request("GET", ws_url, timeout=10.0)
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
# all fields unset, default timeout should apply
request = Request("GET", ws_url, timeout=Timeout())
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
# close explicitly set to None (no close timeout)
request = Request("GET", ws_url, timeout=Timeout(close=None))
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
@pytest.mark.parametrize(
("driver", "driver_type"),
[

View File

@@ -1,9 +1,19 @@
import copy
import json
import pickle
from typing import ClassVar, Dict, List, Literal, TypeVar, Union # noqa: UP035
from pydantic import ValidationError
import pytest
from nonebot.compat import type_validate_python
from nonebot.utils import (
UNSET,
DataclassEncoder,
Unset,
UnsetType,
escape_tag,
exclude_unset,
generic_check_issubclass,
is_async_gen_callable,
is_coroutine_callable,
@@ -12,6 +22,29 @@ from nonebot.utils import (
from utils import FakeMessage, FakeMessageSegment
def test_unset():
assert isinstance(UNSET, Unset)
assert bool(UNSET) is False
assert copy.copy(UNSET) is UNSET
assert copy.deepcopy(UNSET) is UNSET
assert pickle.loads(pickle.dumps(UNSET)) is UNSET
assert type_validate_python(UnsetType, UNSET) is UNSET
with pytest.raises(ValidationError):
type_validate_python(UnsetType, 123)
def test_exclude_unset():
assert exclude_unset({"a": 1, "b": UNSET, "c": None, "d": {"x": UNSET}}) == {
"a": 1,
"c": None,
"d": {},
}
assert exclude_unset([1, UNSET, None, {"x": UNSET}]) == [1, None, {}]
assert exclude_unset(UNSET) is None
assert exclude_unset(123) == 123
def test_loguru_escape_tag():
assert escape_tag("<red>red</red>") == r"\<red>red\</red>"
assert escape_tag("<fg #fff>white</fg #fff>") == r"\<fg #fff>white\</fg #fff>"