🐛 Fix: 修正 http/websocket client timeout (#3923)

Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
StarHeart
2026-03-31 20:56:12 +08:00
committed by GitHub
parent cf8127ee4d
commit cbe6eee868
10 changed files with 443 additions and 72 deletions

View File

@@ -9,6 +9,7 @@ FrontMatter:
description: nonebot.drivers 模块 description: nonebot.drivers 模块
""" """
from nonebot.internal.driver import DEFAULT_TIMEOUT as DEFAULT_TIMEOUT
from nonebot.internal.driver import URL as URL from nonebot.internal.driver import URL as URL
from nonebot.internal.driver import ASGIMixin as ASGIMixin from nonebot.internal.driver import ASGIMixin as ASGIMixin
from nonebot.internal.driver import Cookies as Cookies from nonebot.internal.driver import Cookies as Cookies
@@ -31,6 +32,7 @@ from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup
from nonebot.internal.driver import combine_driver as combine_driver from nonebot.internal.driver import combine_driver as combine_driver
__autodoc__ = { __autodoc__ = {
"DEFAULT_TIMEOUT": True,
"URL": True, "URL": True,
"Cookies": True, "Cookies": True,
"Request": True, "Request": True,

View File

@@ -38,6 +38,7 @@ from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.none import Driver as NoneDriver from nonebot.drivers.none import Driver as NoneDriver
from nonebot.exception import WebSocketClosed from nonebot.exception import WebSocketClosed
from nonebot.internal.driver import ( from nonebot.internal.driver import (
DEFAULT_TIMEOUT,
Cookies, Cookies,
CookieTypes, CookieTypes,
HeaderTypes, HeaderTypes,
@@ -45,6 +46,7 @@ from nonebot.internal.driver import (
Timeout, Timeout,
TimeoutTypes, TimeoutTypes,
) )
from nonebot.utils import UNSET, UnsetType, exclude_unset
try: try:
import aiohttp import aiohttp
@@ -63,7 +65,7 @@ class Session(HTTPClientSession):
headers: HeaderTypes = None, headers: HeaderTypes = None,
cookies: CookieTypes = None, cookies: CookieTypes = None,
version: str | HTTPVersion = HTTPVersion.H11, version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes = None, timeout: TimeoutTypes | UnsetType = UNSET,
proxy: str | None = None, proxy: str | None = None,
): ):
self._client: aiohttp.ClientSession | None = None self._client: aiohttp.ClientSession | None = None
@@ -85,15 +87,32 @@ class Session(HTTPClientSession):
else: else:
raise RuntimeError(f"Unsupported HTTP version: {version}") raise RuntimeError(f"Unsupported HTTP version: {version}")
_timeout = None
if isinstance(timeout, Timeout): if isinstance(timeout, Timeout):
self._timeout = aiohttp.ClientTimeout( timeout_kwargs: dict[str, float | None] = exclude_unset(
total=timeout.total, {
connect=timeout.connect, "total": timeout.total,
sock_read=timeout.read, "connect": timeout.connect,
"sock_read": timeout.read,
}
) )
else: if timeout_kwargs:
self._timeout = aiohttp.ClientTimeout(timeout) _timeout = aiohttp.ClientTimeout(**timeout_kwargs) # type: ignore
elif timeout is not UNSET:
_timeout = aiohttp.ClientTimeout(connect=timeout, sock_read=timeout)
if _timeout is None:
_timeout = aiohttp.ClientTimeout(
**exclude_unset(
{
"total": DEFAULT_TIMEOUT.total,
"connect": DEFAULT_TIMEOUT.connect,
"sock_read": DEFAULT_TIMEOUT.read,
}
)
)
self._timeout = _timeout
self._proxy = proxy self._proxy = proxy
@property @property
@@ -102,6 +121,25 @@ class Session(HTTPClientSession):
raise RuntimeError("Session is not initialized") raise RuntimeError("Session is not initialized")
return self._client return self._client
def _get_timeout(self, timeout: TimeoutTypes | UnsetType) -> aiohttp.ClientTimeout:
_timeout = None
if isinstance(timeout, Timeout):
timeout_kwargs: dict[str, float | None] = exclude_unset(
{
"total": timeout.total,
"connect": timeout.connect,
"sock_read": timeout.read,
}
)
if timeout_kwargs:
_timeout = aiohttp.ClientTimeout(**timeout_kwargs) # type: ignore
elif timeout is not UNSET:
_timeout = aiohttp.ClientTimeout(connect=timeout, sock_read=timeout)
if _timeout is None:
return self._timeout
return _timeout
@override @override
async def request(self, setup: Request) -> Response: async def request(self, setup: Request) -> Response:
if self._params: if self._params:
@@ -121,15 +159,6 @@ class Session(HTTPClientSession):
if cookie.value is not None if cookie.value is not None
) )
if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientTimeout(
total=setup.timeout.total,
connect=setup.timeout.connect,
sock_read=setup.timeout.read,
)
else:
timeout = aiohttp.ClientTimeout(setup.timeout)
async with await self.client.request( async with await self.client.request(
setup.method, setup.method,
url, url,
@@ -138,7 +167,7 @@ class Session(HTTPClientSession):
cookies=cookies, cookies=cookies,
headers=setup.headers, headers=setup.headers,
proxy=setup.proxy or self._proxy, proxy=setup.proxy or self._proxy,
timeout=timeout, timeout=self._get_timeout(setup.timeout),
) as response: ) as response:
return Response( return Response(
response.status, response.status,
@@ -171,15 +200,6 @@ class Session(HTTPClientSession):
if cookie.value is not None if cookie.value is not None
) )
if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientTimeout(
total=setup.timeout.total,
connect=setup.timeout.connect,
sock_read=setup.timeout.read,
)
else:
timeout = aiohttp.ClientTimeout(setup.timeout)
async with self.client.request( async with self.client.request(
setup.method, setup.method,
url, url,
@@ -188,7 +208,7 @@ class Session(HTTPClientSession):
cookies=cookies, cookies=cookies,
headers=setup.headers, headers=setup.headers,
proxy=setup.proxy or self._proxy, proxy=setup.proxy or self._proxy,
timeout=timeout, timeout=self._get_timeout(setup.timeout),
) as response: ) as response:
response_headers = response.headers.copy() response_headers = response.headers.copy()
# aiohttp does not guarantee fixed-size chunks; re-chunk to exact size # aiohttp does not guarantee fixed-size chunks; re-chunk to exact size
@@ -270,13 +290,39 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
else: else:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}") raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
timeout = None
if isinstance(setup.timeout, Timeout): if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientWSTimeout( timeout_kwargs: dict[str, float | None] = exclude_unset(
ws_receive=setup.timeout.read, # type: ignore {
ws_close=setup.timeout.total, # type: ignore "ws_receive": setup.timeout.read,
"ws_close": (
setup.timeout.total
if setup.timeout.close is UNSET
else setup.timeout.close
),
}
)
if timeout_kwargs:
timeout = aiohttp.ClientWSTimeout(**timeout_kwargs)
elif setup.timeout is not UNSET:
timeout = aiohttp.ClientWSTimeout(
ws_receive=setup.timeout, # type: ignore
ws_close=setup.timeout, # type: ignore
)
if timeout is None:
timeout = aiohttp.ClientWSTimeout(
**exclude_unset(
{
"ws_receive": DEFAULT_TIMEOUT.read,
"ws_close": (
DEFAULT_TIMEOUT.total
if DEFAULT_TIMEOUT.close is UNSET
else DEFAULT_TIMEOUT.close
),
}
)
) )
else:
timeout = aiohttp.ClientWSTimeout(ws_close=setup.timeout or 10.0) # type: ignore
async with aiohttp.ClientSession(version=version, trust_env=True) as session: async with aiohttp.ClientSession(version=version, trust_env=True) as session:
async with session.ws_connect( async with session.ws_connect(
@@ -295,7 +341,7 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
headers: HeaderTypes = None, headers: HeaderTypes = None,
cookies: CookieTypes = None, cookies: CookieTypes = None,
version: str | HTTPVersion = HTTPVersion.H11, version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes = None, timeout: TimeoutTypes | UnsetType = UNSET,
proxy: str | None = None, proxy: str | None = None,
) -> Session: ) -> Session:
return Session( return Session(

View File

@@ -34,6 +34,7 @@ from nonebot.drivers import (
) )
from nonebot.drivers.none import Driver as NoneDriver from nonebot.drivers.none import Driver as NoneDriver
from nonebot.internal.driver import ( from nonebot.internal.driver import (
DEFAULT_TIMEOUT,
Cookies, Cookies,
CookieTypes, CookieTypes,
HeaderTypes, HeaderTypes,
@@ -41,6 +42,7 @@ from nonebot.internal.driver import (
Timeout, Timeout,
TimeoutTypes, TimeoutTypes,
) )
from nonebot.utils import UNSET, UnsetType, exclude_unset
try: try:
import httpx import httpx
@@ -59,7 +61,7 @@ class Session(HTTPClientSession):
headers: HeaderTypes = None, headers: HeaderTypes = None,
cookies: CookieTypes = None, cookies: CookieTypes = None,
version: str | HTTPVersion = HTTPVersion.H11, version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes = None, timeout: TimeoutTypes | UnsetType = UNSET,
proxy: str | None = None, proxy: str | None = None,
): ):
self._client: httpx.AsyncClient | None = None self._client: httpx.AsyncClient | None = None
@@ -73,15 +75,34 @@ class Session(HTTPClientSession):
self._cookies = Cookies(cookies) self._cookies = Cookies(cookies)
self._version = HTTPVersion(version) self._version = HTTPVersion(version)
_timeout = None
if isinstance(timeout, Timeout): if isinstance(timeout, Timeout):
self._timeout = httpx.Timeout( avg_timeout = timeout.total and timeout.total / 4
timeout=timeout.total, timeout_kwargs: dict[str, float | None] = exclude_unset(
connect=timeout.connect, {
read=timeout.read, "timeout": avg_timeout,
"connect": timeout.connect,
"read": timeout.read,
}
) )
else: if timeout_kwargs:
self._timeout = httpx.Timeout(timeout) _timeout = httpx.Timeout(**timeout_kwargs)
elif timeout is not UNSET:
_timeout = httpx.Timeout(timeout)
if _timeout is None:
avg_timeout = DEFAULT_TIMEOUT.total and DEFAULT_TIMEOUT.total / 4
_timeout = httpx.Timeout(
**exclude_unset(
{
"timeout": avg_timeout,
"connect": DEFAULT_TIMEOUT.connect,
"read": DEFAULT_TIMEOUT.read,
}
)
)
self._timeout = _timeout
self._proxy = proxy self._proxy = proxy
@property @property
@@ -90,17 +111,28 @@ class Session(HTTPClientSession):
raise RuntimeError("Session is not initialized") raise RuntimeError("Session is not initialized")
return self._client return self._client
def _get_timeout(self, timeout: TimeoutTypes | UnsetType) -> httpx.Timeout:
_timeout = None
if isinstance(timeout, Timeout):
avg_timeout = timeout.total and timeout.total / 4
timeout_kwargs: dict[str, float | None] = exclude_unset(
{
"timeout": avg_timeout,
"connect": timeout.connect,
"read": timeout.read,
}
)
if timeout_kwargs:
_timeout = httpx.Timeout(**timeout_kwargs)
elif timeout is not UNSET:
_timeout = httpx.Timeout(timeout)
if _timeout is None:
return self._timeout
return _timeout
@override @override
async def request(self, setup: Request) -> Response: async def request(self, setup: Request) -> Response:
if isinstance(setup.timeout, Timeout):
timeout = httpx.Timeout(
timeout=setup.timeout.total,
connect=setup.timeout.connect,
read=setup.timeout.read,
)
else:
timeout = httpx.Timeout(setup.timeout)
response = await self.client.request( response = await self.client.request(
setup.method, setup.method,
str(setup.url), str(setup.url),
@@ -112,7 +144,7 @@ class Session(HTTPClientSession):
params=setup.url.raw_query_string, params=setup.url.raw_query_string,
headers=tuple(setup.headers.items()), headers=tuple(setup.headers.items()),
cookies=setup.cookies.jar, cookies=setup.cookies.jar,
timeout=timeout, timeout=self._get_timeout(setup.timeout),
) )
return Response( return Response(
response.status_code, response.status_code,
@@ -128,15 +160,6 @@ class Session(HTTPClientSession):
*, *,
chunk_size: int = 1024, chunk_size: int = 1024,
) -> AsyncGenerator[Response, None]: ) -> AsyncGenerator[Response, None]:
if isinstance(setup.timeout, Timeout):
timeout = httpx.Timeout(
timeout=setup.timeout.total,
connect=setup.timeout.connect,
read=setup.timeout.read,
)
else:
timeout = httpx.Timeout(setup.timeout)
async with self.client.stream( async with self.client.stream(
setup.method, setup.method,
str(setup.url), str(setup.url),
@@ -148,7 +171,7 @@ class Session(HTTPClientSession):
params=setup.url.raw_query_string, params=setup.url.raw_query_string,
headers=tuple(setup.headers.items()), headers=tuple(setup.headers.items()),
cookies=setup.cookies.jar, cookies=setup.cookies.jar,
timeout=timeout, timeout=self._get_timeout(setup.timeout),
) as response: ) as response:
response_headers = response.headers.multi_items() response_headers = response.headers.multi_items()
async for chunk in response.aiter_bytes(chunk_size=chunk_size): async for chunk in response.aiter_bytes(chunk_size=chunk_size):

View File

@@ -25,11 +25,18 @@ from types import CoroutineType
from typing import TYPE_CHECKING, Any, TypeVar from typing import TYPE_CHECKING, Any, TypeVar
from typing_extensions import ParamSpec, override from typing_extensions import ParamSpec, override
from nonebot.drivers import Request, Timeout, WebSocketClientMixin, combine_driver from nonebot.drivers import (
DEFAULT_TIMEOUT,
Request,
Timeout,
WebSocketClientMixin,
combine_driver,
)
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.none import Driver as NoneDriver from nonebot.drivers.none import Driver as NoneDriver
from nonebot.exception import WebSocketClosed from nonebot.exception import WebSocketClosed
from nonebot.log import LoguruHandler from nonebot.log import LoguruHandler
from nonebot.utils import UNSET, exclude_unset
try: try:
from websockets import ClientConnection, ConnectionClosed, connect from websockets import ClientConnection, ConnectionClosed, connect
@@ -70,16 +77,36 @@ class Mixin(WebSocketClientMixin):
@override @override
@asynccontextmanager @asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]: async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
timeout_kwargs: dict[str, float | None] = {}
if isinstance(setup.timeout, Timeout): if isinstance(setup.timeout, Timeout):
timeout = setup.timeout.total or setup.timeout.connect or setup.timeout.read open_timeout = (
else: setup.timeout.connect or setup.timeout.read or setup.timeout.total
timeout = setup.timeout )
timeout_kwargs = exclude_unset(
{"open_timeout": open_timeout, "close_timeout": setup.timeout.close}
)
elif setup.timeout is not UNSET:
timeout_kwargs = {
"open_timeout": setup.timeout,
"close_timeout": setup.timeout,
}
if not timeout_kwargs:
open_timeout = (
DEFAULT_TIMEOUT.connect or DEFAULT_TIMEOUT.read or DEFAULT_TIMEOUT.total
)
timeout_kwargs = exclude_unset(
{
"open_timeout": open_timeout,
"close_timeout": DEFAULT_TIMEOUT.close,
}
)
connection = connect( connection = connect(
str(setup.url), str(setup.url),
additional_headers={**setup.headers, **setup.cookies.as_header(setup)}, additional_headers={**setup.headers, **setup.cookies.as_header(setup)},
proxy=setup.proxy if setup.proxy is not None else True, proxy=setup.proxy if setup.proxy is not None else True,
open_timeout=timeout, **timeout_kwargs, # type: ignore
) )
async with connection as ws: async with connection as ws:
yield WebSocket(request=setup, websocket=ws) yield WebSocket(request=setup, websocket=ws)

View File

@@ -9,6 +9,7 @@ from .abstract import ReverseDriver as ReverseDriver
from .abstract import ReverseMixin as ReverseMixin from .abstract import ReverseMixin as ReverseMixin
from .abstract import WebSocketClientMixin as WebSocketClientMixin from .abstract import WebSocketClientMixin as WebSocketClientMixin
from .combine import combine_driver as combine_driver from .combine import combine_driver as combine_driver
from .model import DEFAULT_TIMEOUT as DEFAULT_TIMEOUT
from .model import URL as URL from .model import URL as URL
from .model import ContentTypes as ContentTypes from .model import ContentTypes as ContentTypes
from .model import Cookies as Cookies from .model import Cookies as Cookies

View File

@@ -19,7 +19,13 @@ from nonebot.typing import (
T_BotDisconnectionHook, T_BotDisconnectionHook,
T_DependencyCache, T_DependencyCache,
) )
from nonebot.utils import escape_tag, flatten_exception_group, run_coro_with_catch from nonebot.utils import (
UNSET,
UnsetType,
escape_tag,
flatten_exception_group,
run_coro_with_catch,
)
from ._lifespan import LIFESPAN_FUNC, Lifespan from ._lifespan import LIFESPAN_FUNC, Lifespan
from .model import ( from .model import (
@@ -246,7 +252,7 @@ class HTTPClientSession(abc.ABC):
headers: HeaderTypes = None, headers: HeaderTypes = None,
cookies: CookieTypes = None, cookies: CookieTypes = None,
version: str | HTTPVersion = HTTPVersion.H11, version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes = None, timeout: TimeoutTypes | UnsetType = UNSET,
proxy: str | None = None, proxy: str | None = None,
): ):
raise NotImplementedError raise NotImplementedError

View File

@@ -9,14 +9,20 @@ import urllib.request
from multidict import CIMultiDict from multidict import CIMultiDict
from yarl import URL as URL from yarl import URL as URL
from nonebot.utils import UNSET, UnsetType
@dataclass @dataclass
class Timeout: class Timeout:
"""Request 超时配置。""" """Request 超时配置。"""
total: float | None = None total: float | None | UnsetType = UNSET
connect: float | None = None connect: float | None | UnsetType = UNSET
read: float | None = None read: float | None | UnsetType = UNSET
close: float | None | UnsetType = UNSET
DEFAULT_TIMEOUT = Timeout(total=None, connect=5.0, read=30.0, close=10.0)
RawURL: TypeAlias = tuple[bytes, bytes, int | None, bytes] RawURL: TypeAlias = tuple[bytes, bytes, int | None, bytes]
@@ -68,7 +74,7 @@ class Request:
json: Any = None, json: Any = None,
files: FilesTypes = None, files: FilesTypes = None,
version: str | HTTPVersion = HTTPVersion.H11, version: str | HTTPVersion = HTTPVersion.H11,
timeout: TimeoutTypes = None, timeout: TimeoutTypes | UnsetType = UNSET,
proxy: str | None = None, proxy: str | None = None,
): ):
# method # method
@@ -80,7 +86,7 @@ class Request:
# http version # http version
self.version: HTTPVersion = HTTPVersion(version) self.version: HTTPVersion = HTTPVersion(version)
# timeout # timeout
self.timeout: TimeoutTypes = timeout self.timeout: TimeoutTypes | UnsetType = timeout
# proxy # proxy
self.proxy: str | None = proxy self.proxy: str | None = proxy

View File

@@ -19,6 +19,7 @@ from collections.abc import (
import contextlib import contextlib
from contextlib import AbstractContextManager, asynccontextmanager from contextlib import AbstractContextManager, asynccontextmanager
import dataclasses import dataclasses
from enum import Enum
from functools import partial, wraps from functools import partial, wraps
import importlib import importlib
import inspect import inspect
@@ -27,8 +28,12 @@ from pathlib import Path
import re import re
from typing import ( from typing import (
Any, Any,
Final,
Generic, Generic,
Literal,
TypeAlias,
TypeVar, TypeVar,
final,
get_args, get_args,
get_origin, get_origin,
overload, overload,
@@ -49,6 +54,8 @@ from nonebot.typing import (
type_has_args, type_has_args,
) )
from .compat import custom_validation
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
T = TypeVar("T") T = TypeVar("T")
@@ -57,6 +64,54 @@ V = TypeVar("V")
E = TypeVar("E", bound=BaseException) E = TypeVar("E", bound=BaseException)
@final
@custom_validation
class Unset(Enum):
_UNSET = "<UNSET>"
def __repr__(self) -> str:
return "<UNSET>"
def __str__(self) -> str:
return self.__repr__()
def __bool__(self) -> Literal[False]:
return False
def __copy__(self):
return self._UNSET
def __deepcopy__(self, memo: dict[int, Any]):
return self._UNSET
@classmethod
def __get_validators__(cls):
yield cls._validate
@classmethod
def _validate(cls, value: Any):
if value is not cls._UNSET:
raise ValueError(f"{value!r} is not UNSET")
return value
UnsetType: TypeAlias = Literal[Unset._UNSET]
UNSET: Final[UnsetType] = Unset._UNSET
def exclude_unset(data: Any) -> Any:
if isinstance(data, dict):
return data.__class__(
(k, exclude_unset(v)) for k, v in data.items() if v is not UNSET
)
elif isinstance(data, list):
return data.__class__(exclude_unset(i) for i in data if i is not UNSET)
elif data is UNSET:
return None
return data
def escape_tag(s: str) -> str: def escape_tag(s: str) -> str:
"""用于记录带颜色日志时转义 `<tag>` 类型特殊标签 """用于记录带颜色日志时转义 `<tag>` 类型特殊标签

View File

@@ -26,6 +26,7 @@ from nonebot.drivers.aiohttp import Session as AiohttpSession
from nonebot.drivers.aiohttp import WebSocket as AiohttpWebSocket from nonebot.drivers.aiohttp import WebSocket as AiohttpWebSocket
from nonebot.exception import WebSocketClosed from nonebot.exception import WebSocketClosed
from nonebot.params import Depends from nonebot.params import Depends
from nonebot.utils import UNSET
from utils import FakeAdapter from utils import FakeAdapter
@@ -706,6 +707,177 @@ async def test_aiohttp_websocket_close_frame(msg_type: str) -> None:
await ws.receive() await ws.receive()
def test_timeout_unset_vs_none():
# default: all fields are UNSET
t = Timeout()
assert t.total is UNSET
assert t.connect is UNSET
assert t.read is UNSET
assert t.close is UNSET
# explicitly set to None
t = Timeout(close=None)
assert t.close is None
assert t.close is not UNSET
# explicitly set to a value
t = Timeout(total=5.0, close=None)
assert t.total == 5.0
assert t.close is None
assert t.connect is UNSET
assert t.read is UNSET
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.httpx:Driver", id="httpx"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_http_client_timeout(driver: Driver, server_url: URL):
"""HTTP requests work with fully unset, partial, and None timeout fields."""
assert isinstance(driver, HTTPClientMixin)
# timeout not set, default timeout should apply
request = Request("POST", server_url, content="test")
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# timeout is float or none
request = Request("POST", server_url, content="test", timeout=10.0)
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# all fields unset, default timeout should apply
request = Request("POST", server_url, content="test", timeout=Timeout())
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# only total set
request = Request("POST", server_url, content="test", timeout=Timeout(total=10.0))
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# explicit None (no timeout)
request = Request(
"POST",
server_url,
content="test",
timeout=Timeout(total=None, connect=None, read=None),
)
response = await driver.request(request)
assert response.status_code == 200
async for resp in driver.stream_request(request, chunk_size=1024):
assert resp.status_code == 200
# session with timeout not set
session = driver.get_session()
async with session:
request = Request("POST", server_url, content="test")
response = await session.request(request)
assert response.status_code == 200
# session with float or none timeout
session = driver.get_session(timeout=10.0)
async with session:
request = Request("POST", server_url, content="test")
response = await session.request(request)
assert response.status_code == 200
# session with fully unset timeout
session = driver.get_session(timeout=Timeout())
async with session:
request = Request("POST", server_url, content="test")
response = await session.request(request)
assert response.status_code == 200
# session with timeout
session = driver.get_session(timeout=Timeout(total=10.0, connect=5.0, read=5.0))
async with session:
request = Request("POST", server_url, content="test")
response = await session.request(request)
assert response.status_code == 200
# session with timeout override
session = driver.get_session(timeout=Timeout(total=10.0))
async with session:
request = Request(
"POST", server_url, content="test", timeout=Timeout(total=20.0)
)
response = await session.request(request)
assert response.status_code == 200
# session with timeout float override
session = driver.get_session(timeout=Timeout(total=10.0))
async with session:
request = Request("POST", server_url, content="test", timeout=20.0)
response = await session.request(request)
assert response.status_code == 200
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.websockets:Driver", id="websockets"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_websocket_client_timeout(driver: Driver, server_url: URL):
"""WebSocket connections work with fully unset, partial, and None timeout fields."""
assert isinstance(driver, WebSocketClientMixin)
ws_url = server_url.with_scheme("ws")
# timeout not set, default timeout should apply
request = Request("GET", ws_url)
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
# timeout is float or none
request = Request("GET", ws_url, timeout=10.0)
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
# all fields unset, default timeout should apply
request = Request("GET", ws_url, timeout=Timeout())
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
# close explicitly set to None (no close timeout)
request = Request("GET", ws_url, timeout=Timeout(close=None))
async with driver.websocket(request) as ws:
await ws.send("quit")
with pytest.raises(WebSocketClosed):
await ws.receive()
await anyio.sleep(1)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("driver", "driver_type"), ("driver", "driver_type"),
[ [

View File

@@ -1,9 +1,19 @@
import copy
import json import json
import pickle
from typing import ClassVar, Dict, List, Literal, TypeVar, Union # noqa: UP035 from typing import ClassVar, Dict, List, Literal, TypeVar, Union # noqa: UP035
from pydantic import ValidationError
import pytest
from nonebot.compat import type_validate_python
from nonebot.utils import ( from nonebot.utils import (
UNSET,
DataclassEncoder, DataclassEncoder,
Unset,
UnsetType,
escape_tag, escape_tag,
exclude_unset,
generic_check_issubclass, generic_check_issubclass,
is_async_gen_callable, is_async_gen_callable,
is_coroutine_callable, is_coroutine_callable,
@@ -12,6 +22,29 @@ from nonebot.utils import (
from utils import FakeMessage, FakeMessageSegment from utils import FakeMessage, FakeMessageSegment
def test_unset():
assert isinstance(UNSET, Unset)
assert bool(UNSET) is False
assert copy.copy(UNSET) is UNSET
assert copy.deepcopy(UNSET) is UNSET
assert pickle.loads(pickle.dumps(UNSET)) is UNSET
assert type_validate_python(UnsetType, UNSET) is UNSET
with pytest.raises(ValidationError):
type_validate_python(UnsetType, 123)
def test_exclude_unset():
assert exclude_unset({"a": 1, "b": UNSET, "c": None, "d": {"x": UNSET}}) == {
"a": 1,
"c": None,
"d": {},
}
assert exclude_unset([1, UNSET, None, {"x": UNSET}]) == [1, None, {}]
assert exclude_unset(UNSET) is None
assert exclude_unset(123) == 123
def test_loguru_escape_tag(): def test_loguru_escape_tag():
assert escape_tag("<red>red</red>") == r"\<red>red\</red>" assert escape_tag("<red>red</red>") == r"\<red>red\</red>"
assert escape_tag("<fg #fff>white</fg #fff>") == r"\<fg #fff>white\</fg #fff>" assert escape_tag("<fg #fff>white</fg #fff>") == r"\<fg #fff>white\</fg #fff>"