🐛 Fix: 修正 http/websocket client timeout (#3923)

Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
StarHeart
2026-03-31 20:56:12 +08:00
committed by GitHub
parent cf8127ee4d
commit cbe6eee868
10 changed files with 443 additions and 72 deletions

View File

@@ -9,6 +9,7 @@ FrontMatter:
description: nonebot.drivers 模块
"""
from nonebot.internal.driver import DEFAULT_TIMEOUT as DEFAULT_TIMEOUT
from nonebot.internal.driver import URL as URL
from nonebot.internal.driver import ASGIMixin as ASGIMixin
from nonebot.internal.driver import Cookies as Cookies
@@ -31,6 +32,7 @@ from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup
from nonebot.internal.driver import combine_driver as combine_driver
__autodoc__ = {
"DEFAULT_TIMEOUT": True,
"URL": True,
"Cookies": True,
"Request": True,

View File

@@ -38,6 +38,7 @@ from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.exception import WebSocketClosed
from nonebot.internal.driver import (
DEFAULT_TIMEOUT,
Cookies,
CookieTypes,
HeaderTypes,
@@ -45,6 +46,7 @@ from nonebot.internal.driver import (
Timeout,
TimeoutTypes,
)
from nonebot.utils import UNSET, UnsetType, exclude_unset
try:
import aiohttp
@@ -63,7 +65,7 @@ class Session(HTTPClientSession):
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes = None,
timeout: TimeoutTypes | UnsetType = UNSET,
proxy: str | None = None,
):
self._client: aiohttp.ClientSession | None = None
@@ -85,15 +87,32 @@ class Session(HTTPClientSession):
else:
raise RuntimeError(f"Unsupported HTTP version: {version}")
_timeout = None
if isinstance(timeout, Timeout):
self._timeout = aiohttp.ClientTimeout(
total=timeout.total,
connect=timeout.connect,
sock_read=timeout.read,
timeout_kwargs: dict[str, float | None] = exclude_unset(
{
"total": timeout.total,
"connect": timeout.connect,
"sock_read": timeout.read,
}
)
else:
self._timeout = aiohttp.ClientTimeout(timeout)
if timeout_kwargs:
_timeout = aiohttp.ClientTimeout(**timeout_kwargs) # type: ignore
elif timeout is not UNSET:
_timeout = aiohttp.ClientTimeout(connect=timeout, sock_read=timeout)
if _timeout is None:
_timeout = aiohttp.ClientTimeout(
**exclude_unset(
{
"total": DEFAULT_TIMEOUT.total,
"connect": DEFAULT_TIMEOUT.connect,
"sock_read": DEFAULT_TIMEOUT.read,
}
)
)
self._timeout = _timeout
self._proxy = proxy
@property
@@ -102,6 +121,25 @@ class Session(HTTPClientSession):
raise RuntimeError("Session is not initialized")
return self._client
def _get_timeout(self, timeout: TimeoutTypes | UnsetType) -> aiohttp.ClientTimeout:
_timeout = None
if isinstance(timeout, Timeout):
timeout_kwargs: dict[str, float | None] = exclude_unset(
{
"total": timeout.total,
"connect": timeout.connect,
"sock_read": timeout.read,
}
)
if timeout_kwargs:
_timeout = aiohttp.ClientTimeout(**timeout_kwargs) # type: ignore
elif timeout is not UNSET:
_timeout = aiohttp.ClientTimeout(connect=timeout, sock_read=timeout)
if _timeout is None:
return self._timeout
return _timeout
@override
async def request(self, setup: Request) -> Response:
if self._params:
@@ -121,15 +159,6 @@ class Session(HTTPClientSession):
if cookie.value is not None
)
if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientTimeout(
total=setup.timeout.total,
connect=setup.timeout.connect,
sock_read=setup.timeout.read,
)
else:
timeout = aiohttp.ClientTimeout(setup.timeout)
async with await self.client.request(
setup.method,
url,
@@ -138,7 +167,7 @@ class Session(HTTPClientSession):
cookies=cookies,
headers=setup.headers,
proxy=setup.proxy or self._proxy,
timeout=timeout,
timeout=self._get_timeout(setup.timeout),
) as response:
return Response(
response.status,
@@ -171,15 +200,6 @@ class Session(HTTPClientSession):
if cookie.value is not None
)
if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientTimeout(
total=setup.timeout.total,
connect=setup.timeout.connect,
sock_read=setup.timeout.read,
)
else:
timeout = aiohttp.ClientTimeout(setup.timeout)
async with self.client.request(
setup.method,
url,
@@ -188,7 +208,7 @@ class Session(HTTPClientSession):
cookies=cookies,
headers=setup.headers,
proxy=setup.proxy or self._proxy,
timeout=timeout,
timeout=self._get_timeout(setup.timeout),
) as response:
response_headers = response.headers.copy()
# aiohttp does not guarantee fixed-size chunks; re-chunk to exact size
@@ -270,13 +290,39 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
else:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
timeout = None
if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientWSTimeout(
ws_receive=setup.timeout.read, # type: ignore
ws_close=setup.timeout.total, # type: ignore
timeout_kwargs: dict[str, float | None] = exclude_unset(
{
"ws_receive": setup.timeout.read,
"ws_close": (
setup.timeout.total
if setup.timeout.close is UNSET
else setup.timeout.close
),
}
)
if timeout_kwargs:
timeout = aiohttp.ClientWSTimeout(**timeout_kwargs)
elif setup.timeout is not UNSET:
timeout = aiohttp.ClientWSTimeout(
ws_receive=setup.timeout, # type: ignore
ws_close=setup.timeout, # type: ignore
)
if timeout is None:
timeout = aiohttp.ClientWSTimeout(
**exclude_unset(
{
"ws_receive": DEFAULT_TIMEOUT.read,
"ws_close": (
DEFAULT_TIMEOUT.total
if DEFAULT_TIMEOUT.close is UNSET
else DEFAULT_TIMEOUT.close
),
}
)
)
else:
timeout = aiohttp.ClientWSTimeout(ws_close=setup.timeout or 10.0) # type: ignore
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
async with session.ws_connect(
@@ -295,7 +341,7 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes = None,
timeout: TimeoutTypes | UnsetType = UNSET,
proxy: str | None = None,
) -> Session:
return Session(

View File

@@ -34,6 +34,7 @@ from nonebot.drivers import (
)
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.internal.driver import (
DEFAULT_TIMEOUT,
Cookies,
CookieTypes,
HeaderTypes,
@@ -41,6 +42,7 @@ from nonebot.internal.driver import (
Timeout,
TimeoutTypes,
)
from nonebot.utils import UNSET, UnsetType, exclude_unset
try:
import httpx
@@ -59,7 +61,7 @@ class Session(HTTPClientSession):
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes = None,
timeout: TimeoutTypes | UnsetType = UNSET,
proxy: str | None = None,
):
self._client: httpx.AsyncClient | None = None
@@ -73,15 +75,34 @@ class Session(HTTPClientSession):
self._cookies = Cookies(cookies)
self._version = HTTPVersion(version)
_timeout = None
if isinstance(timeout, Timeout):
self._timeout = httpx.Timeout(
timeout=timeout.total,
connect=timeout.connect,
read=timeout.read,
avg_timeout = timeout.total and timeout.total / 4
timeout_kwargs: dict[str, float | None] = exclude_unset(
{
"timeout": avg_timeout,
"connect": timeout.connect,
"read": timeout.read,
}
)
else:
self._timeout = httpx.Timeout(timeout)
if timeout_kwargs:
_timeout = httpx.Timeout(**timeout_kwargs)
elif timeout is not UNSET:
_timeout = httpx.Timeout(timeout)
if _timeout is None:
avg_timeout = DEFAULT_TIMEOUT.total and DEFAULT_TIMEOUT.total / 4
_timeout = httpx.Timeout(
**exclude_unset(
{
"timeout": avg_timeout,
"connect": DEFAULT_TIMEOUT.connect,
"read": DEFAULT_TIMEOUT.read,
}
)
)
self._timeout = _timeout
self._proxy = proxy
@property
@@ -90,17 +111,28 @@ class Session(HTTPClientSession):
raise RuntimeError("Session is not initialized")
return self._client
def _get_timeout(self, timeout: TimeoutTypes | UnsetType) -> httpx.Timeout:
_timeout = None
if isinstance(timeout, Timeout):
avg_timeout = timeout.total and timeout.total / 4
timeout_kwargs: dict[str, float | None] = exclude_unset(
{
"timeout": avg_timeout,
"connect": timeout.connect,
"read": timeout.read,
}
)
if timeout_kwargs:
_timeout = httpx.Timeout(**timeout_kwargs)
elif timeout is not UNSET:
_timeout = httpx.Timeout(timeout)
if _timeout is None:
return self._timeout
return _timeout
@override
async def request(self, setup: Request) -> Response:
if isinstance(setup.timeout, Timeout):
timeout = httpx.Timeout(
timeout=setup.timeout.total,
connect=setup.timeout.connect,
read=setup.timeout.read,
)
else:
timeout = httpx.Timeout(setup.timeout)
response = await self.client.request(
setup.method,
str(setup.url),
@@ -112,7 +144,7 @@ class Session(HTTPClientSession):
params=setup.url.raw_query_string,
headers=tuple(setup.headers.items()),
cookies=setup.cookies.jar,
timeout=timeout,
timeout=self._get_timeout(setup.timeout),
)
return Response(
response.status_code,
@@ -128,15 +160,6 @@ class Session(HTTPClientSession):
*,
chunk_size: int = 1024,
) -> AsyncGenerator[Response, None]:
if isinstance(setup.timeout, Timeout):
timeout = httpx.Timeout(
timeout=setup.timeout.total,
connect=setup.timeout.connect,
read=setup.timeout.read,
)
else:
timeout = httpx.Timeout(setup.timeout)
async with self.client.stream(
setup.method,
str(setup.url),
@@ -148,7 +171,7 @@ class Session(HTTPClientSession):
params=setup.url.raw_query_string,
headers=tuple(setup.headers.items()),
cookies=setup.cookies.jar,
timeout=timeout,
timeout=self._get_timeout(setup.timeout),
) as response:
response_headers = response.headers.multi_items()
async for chunk in response.aiter_bytes(chunk_size=chunk_size):

View File

@@ -25,11 +25,18 @@ from types import CoroutineType
from typing import TYPE_CHECKING, Any, TypeVar
from typing_extensions import ParamSpec, override
from nonebot.drivers import Request, Timeout, WebSocketClientMixin, combine_driver
from nonebot.drivers import (
DEFAULT_TIMEOUT,
Request,
Timeout,
WebSocketClientMixin,
combine_driver,
)
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.exception import WebSocketClosed
from nonebot.log import LoguruHandler
from nonebot.utils import UNSET, exclude_unset
try:
from websockets import ClientConnection, ConnectionClosed, connect
@@ -70,16 +77,36 @@ class Mixin(WebSocketClientMixin):
@override
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
timeout_kwargs: dict[str, float | None] = {}
if isinstance(setup.timeout, Timeout):
timeout = setup.timeout.total or setup.timeout.connect or setup.timeout.read
else:
timeout = setup.timeout
open_timeout = (
setup.timeout.connect or setup.timeout.read or setup.timeout.total
)
timeout_kwargs = exclude_unset(
{"open_timeout": open_timeout, "close_timeout": setup.timeout.close}
)
elif setup.timeout is not UNSET:
timeout_kwargs = {
"open_timeout": setup.timeout,
"close_timeout": setup.timeout,
}
if not timeout_kwargs:
open_timeout = (
DEFAULT_TIMEOUT.connect or DEFAULT_TIMEOUT.read or DEFAULT_TIMEOUT.total
)
timeout_kwargs = exclude_unset(
{
"open_timeout": open_timeout,
"close_timeout": DEFAULT_TIMEOUT.close,
}
)
connection = connect(
str(setup.url),
additional_headers={**setup.headers, **setup.cookies.as_header(setup)},
proxy=setup.proxy if setup.proxy is not None else True,
open_timeout=timeout,
**timeout_kwargs, # type: ignore
)
async with connection as ws:
yield WebSocket(request=setup, websocket=ws)

View File

@@ -9,6 +9,7 @@ from .abstract import ReverseDriver as ReverseDriver
from .abstract import ReverseMixin as ReverseMixin
from .abstract import WebSocketClientMixin as WebSocketClientMixin
from .combine import combine_driver as combine_driver
from .model import DEFAULT_TIMEOUT as DEFAULT_TIMEOUT
from .model import URL as URL
from .model import ContentTypes as ContentTypes
from .model import Cookies as Cookies

View File

@@ -19,7 +19,13 @@ from nonebot.typing import (
T_BotDisconnectionHook,
T_DependencyCache,
)
from nonebot.utils import escape_tag, flatten_exception_group, run_coro_with_catch
from nonebot.utils import (
UNSET,
UnsetType,
escape_tag,
flatten_exception_group,
run_coro_with_catch,
)
from ._lifespan import LIFESPAN_FUNC, Lifespan
from .model import (
@@ -246,7 +252,7 @@ class HTTPClientSession(abc.ABC):
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes = None,
timeout: TimeoutTypes | UnsetType = UNSET,
proxy: str | None = None,
):
raise NotImplementedError

View File

@@ -9,14 +9,20 @@ import urllib.request
from multidict import CIMultiDict
from yarl import URL as URL
from nonebot.utils import UNSET, UnsetType
@dataclass
class Timeout:
"""Request 超时配置。"""
total: float | None = None
connect: float | None = None
read: float | None = None
total: float | None | UnsetType = UNSET
connect: float | None | UnsetType = UNSET
read: float | None | UnsetType = UNSET
close: float | None | UnsetType = UNSET
DEFAULT_TIMEOUT = Timeout(total=None, connect=5.0, read=30.0, close=10.0)
RawURL: TypeAlias = tuple[bytes, bytes, int | None, bytes]
@@ -68,7 +74,7 @@ class Request:
json: Any = None,
files: FilesTypes = None,
version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes = None,
timeout: TimeoutTypes | UnsetType = UNSET,
proxy: str | None = None,
):
# method
@@ -80,7 +86,7 @@ class Request:
# http version
self.version: HTTPVersion = HTTPVersion(version)
# timeout
self.timeout: TimeoutTypes = timeout
self.timeout: TimeoutTypes | UnsetType = timeout
# proxy
self.proxy: str | None = proxy

View File

@@ -19,6 +19,7 @@ from collections.abc import (
import contextlib
from contextlib import AbstractContextManager, asynccontextmanager
import dataclasses
from enum import Enum
from functools import partial, wraps
import importlib
import inspect
@@ -27,8 +28,12 @@ from pathlib import Path
import re
from typing import (
Any,
Final,
Generic,
Literal,
TypeAlias,
TypeVar,
final,
get_args,
get_origin,
overload,
@@ -49,6 +54,8 @@ from nonebot.typing import (
type_has_args,
)
from .compat import custom_validation
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
@@ -57,6 +64,54 @@ V = TypeVar("V")
E = TypeVar("E", bound=BaseException)
@final
@custom_validation
class Unset(Enum):
_UNSET = "<UNSET>"
def __repr__(self) -> str:
return "<UNSET>"
def __str__(self) -> str:
return self.__repr__()
def __bool__(self) -> Literal[False]:
return False
def __copy__(self):
return self._UNSET
def __deepcopy__(self, memo: dict[int, Any]):
return self._UNSET
@classmethod
def __get_validators__(cls):
yield cls._validate
@classmethod
def _validate(cls, value: Any):
if value is not cls._UNSET:
raise ValueError(f"{value!r} is not UNSET")
return value
UnsetType: TypeAlias = Literal[Unset._UNSET]
UNSET: Final[UnsetType] = Unset._UNSET
def exclude_unset(data: Any) -> Any:
if isinstance(data, dict):
return data.__class__(
(k, exclude_unset(v)) for k, v in data.items() if v is not UNSET
)
elif isinstance(data, list):
return data.__class__(exclude_unset(i) for i in data if i is not UNSET)
elif data is UNSET:
return None
return data
def escape_tag(s: str) -> str:
"""用于记录带颜色日志时转义 `<tag>` 类型特殊标签

View File

@@ -26,6 +26,7 @@ from nonebot.drivers.aiohttp import Session as AiohttpSession
from nonebot.drivers.aiohttp import WebSocket as AiohttpWebSocket
from nonebot.exception import WebSocketClosed
from nonebot.params import Depends
from nonebot.utils import UNSET
from utils import FakeAdapter
@@ -706,6 +707,177 @@ async def test_aiohttp_websocket_close_frame(msg_type: str) -> None:
await ws.receive()
def test_timeout_unset_vs_none():
# default: all fields are UNSET
t = Timeout()
assert t.total is UNSET
assert t.connect is UNSET
assert t.read is UNSET
assert t.close is UNSET
# explicitly set to None
t = Timeout(close=None)
assert t.close is None
assert t.close is not UNSET
# explicitly set to a value
t = Timeout(total=5.0, close=None)
assert t.total == 5.0
assert t.close is None
assert t.connect is UNSET
assert t.read is UNSET
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.httpx:Driver", id="httpx"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_http_client_timeout(driver: Driver, server_url: URL):
"""HTTP requests work with fully unset, partial, and None timeout fields."""
assert isinstance(driver, HTTPClientMixin)
# timeout not set, default timeout should apply
request = Request("POST", server_url, content="test")
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# timeout is float or none
request = Request("POST", server_url, content="test", timeout=10.0)
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# all fields unset, default timeout should apply
request = Request("POST", server_url, content="test", timeout=Timeout())
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# only total set
request = Request("POST", server_url, content="test", timeout=Timeout(total=10.0))
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# explicit None (no timeout)
request = Request(
"POST",
server_url,
content="test",
timeout=Timeout(total=None, connect=None, read=None),
)
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# session with timeout not set
session = driver.get_session()
async with session:
request = Request("POST", server_url, content="test")
response = await session.request(request)
assert response.status_code == 200
# session with float or none timeout
session = driver.get_session(timeout=10.0)
async with session:
request = Request("POST", server_url, content="test")
response = await session.request(request)
assert response.status_code == 200
# session with fully unset timeout
session = driver.get_session(timeout=Timeout())
async with session:
request = Request("POST", server_url, content="test")
response = await session.request(request)
assert response.status_code == 200
# session with timeout
session = driver.get_session(timeout=Timeout(total=10.0, connect=5.0, read=5.0))
async with session:
request = Request("POST", server_url, content="test")
response = await session.request(request)
assert response.status_code == 200
# session with timeout override
session = driver.get_session(timeout=Timeout(total=10.0))
async with session:
request = Request(
"POST", server_url, content="test", timeout=Timeout(total=20.0)
)
response = await session.request(request)
assert response.status_code == 200
# session with timeout float override
session = driver.get_session(timeout=Timeout(total=10.0))
async with session:
request = Request("POST", server_url, content="test", timeout=20.0)
response = await session.request(request)
assert response.status_code == 200
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.websockets:Driver", id="websockets"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_websocket_client_timeout(driver: Driver, server_url: URL):
"""WebSocket connections work with fully unset, partial, and None timeout fields."""
assert isinstance(driver, WebSocketClientMixin)
ws_url = server_url.with_scheme("ws")
# timeout not set, default timeout should apply
request = Request("GET", ws_url)
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
# timeout is float or none
request = Request("GET", ws_url, timeout=10.0)
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
# all fields unset, default timeout should apply
request = Request("GET", ws_url, timeout=Timeout())
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
# close explicitly set to None (no close timeout)
request = Request("GET", ws_url, timeout=Timeout(close=None))
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
@pytest.mark.parametrize(
("driver", "driver_type"),
[

View File

@@ -1,9 +1,19 @@
import copy
import json
import pickle
from typing import ClassVar, Dict, List, Literal, TypeVar, Union # noqa: UP035
from pydantic import ValidationError
import pytest
from nonebot.compat import type_validate_python
from nonebot.utils import (
UNSET,
DataclassEncoder,
Unset,
UnsetType,
escape_tag,
exclude_unset,
generic_check_issubclass,
is_async_gen_callable,
is_coroutine_callable,
@@ -12,6 +22,29 @@ from nonebot.utils import (
from utils import FakeMessage, FakeMessageSegment
def test_unset():
assert isinstance(UNSET, Unset)
assert bool(UNSET) is False
assert copy.copy(UNSET) is UNSET
assert copy.deepcopy(UNSET) is UNSET
assert pickle.loads(pickle.dumps(UNSET)) is UNSET
assert type_validate_python(UnsetType, UNSET) is UNSET
with pytest.raises(ValidationError):
type_validate_python(UnsetType, 123)
def test_exclude_unset():
assert exclude_unset({"a": 1, "b": UNSET, "c": None, "d": {"x": UNSET}}) == {
"a": 1,
"c": None,
"d": {},
}
assert exclude_unset([1, UNSET, None, {"x": UNSET}]) == [1, None, {}]
assert exclude_unset(UNSET) is None
assert exclude_unset(123) == 123
def test_loguru_escape_tag():
assert escape_tag("<red>red</red>") == r"\<red>red\</red>"
assert escape_tag("<fg #fff>white</fg #fff>") == r"\<fg #fff>white\</fg #fff>"