From cbe6eee868ce771d1f2d159f568525926851b03d Mon Sep 17 00:00:00 2001 From: StarHeart Date: Tue, 31 Mar 2026 20:56:12 +0800 Subject: [PATCH] =?UTF-8?q?:bug:=20Fix:=20=E4=BF=AE=E6=AD=A3=20http/websoc?= =?UTF-8?q?ket=20client=20timeout=20(#3923)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- nonebot/drivers/__init__.py | 2 + nonebot/drivers/aiohttp.py | 112 ++++++++++++------ nonebot/drivers/httpx.py | 77 ++++++++----- nonebot/drivers/websockets.py | 37 +++++- nonebot/internal/driver/__init__.py | 1 + nonebot/internal/driver/abstract.py | 10 +- nonebot/internal/driver/model.py | 16 ++- nonebot/utils.py | 55 +++++++++ tests/test_driver.py | 172 ++++++++++++++++++++++++++++ tests/test_utils.py | 33 ++++++ 10 files changed, 443 insertions(+), 72 deletions(-) diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index c7e6e82d..b54f72cc 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -9,6 +9,7 @@ FrontMatter: description: nonebot.drivers 模块 """ +from nonebot.internal.driver import DEFAULT_TIMEOUT as DEFAULT_TIMEOUT from nonebot.internal.driver import URL as URL from nonebot.internal.driver import ASGIMixin as ASGIMixin from nonebot.internal.driver import Cookies as Cookies @@ -31,6 +32,7 @@ from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup from nonebot.internal.driver import combine_driver as combine_driver __autodoc__ = { + "DEFAULT_TIMEOUT": True, "URL": True, "Cookies": True, "Request": True, diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index cb9aa810..f558e335 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -38,6 +38,7 @@ from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers.none import Driver as NoneDriver from nonebot.exception import WebSocketClosed from nonebot.internal.driver import ( + DEFAULT_TIMEOUT, Cookies, CookieTypes, HeaderTypes, @@ -45,6 +46,7 @@ from nonebot.internal.driver import ( Timeout, TimeoutTypes, ) +from nonebot.utils import UNSET, UnsetType, exclude_unset try: import aiohttp @@ -63,7 +65,7 @@ class Session(HTTPClientSession): headers: HeaderTypes = None, cookies: CookieTypes = None, version: str | HTTPVersion = HTTPVersion.H11, - timeout: TimeoutTypes = None, + timeout: TimeoutTypes | UnsetType = UNSET, proxy: str | None = None, ): self._client: aiohttp.ClientSession | None = None @@ -85,15 +87,32 @@ class Session(HTTPClientSession): else: raise RuntimeError(f"Unsupported HTTP version: {version}") + _timeout = None if isinstance(timeout, Timeout): - self._timeout = aiohttp.ClientTimeout( - total=timeout.total, - connect=timeout.connect, - sock_read=timeout.read, + timeout_kwargs: dict[str, float | None] = exclude_unset( + { + "total": timeout.total, + "connect": timeout.connect, + "sock_read": timeout.read, + } ) - else: - self._timeout = aiohttp.ClientTimeout(timeout) + if timeout_kwargs: + _timeout = aiohttp.ClientTimeout(**timeout_kwargs) # type: ignore + elif timeout is not UNSET: + _timeout = aiohttp.ClientTimeout(connect=timeout, sock_read=timeout) + if _timeout is None: + _timeout = aiohttp.ClientTimeout( + **exclude_unset( + { + "total": DEFAULT_TIMEOUT.total, + "connect": DEFAULT_TIMEOUT.connect, + "sock_read": DEFAULT_TIMEOUT.read, + } + ) + ) + + self._timeout = _timeout self._proxy = proxy @property @@ -102,6 +121,25 @@ class Session(HTTPClientSession): raise RuntimeError("Session is not initialized") return self._client + def _get_timeout(self, timeout: TimeoutTypes | UnsetType) -> aiohttp.ClientTimeout: + _timeout = None + if isinstance(timeout, Timeout): + timeout_kwargs: dict[str, float | None] = exclude_unset( + { + "total": timeout.total, + "connect": timeout.connect, + "sock_read": timeout.read, + } + ) + if timeout_kwargs: + _timeout = aiohttp.ClientTimeout(**timeout_kwargs) # type: ignore + elif timeout is not UNSET: + _timeout = aiohttp.ClientTimeout(connect=timeout, sock_read=timeout) + + if _timeout is None: + return self._timeout + return _timeout + @override async def request(self, setup: Request) -> Response: if self._params: @@ -121,15 +159,6 @@ class Session(HTTPClientSession): if cookie.value is not None ) - if isinstance(setup.timeout, Timeout): - timeout = aiohttp.ClientTimeout( - total=setup.timeout.total, - connect=setup.timeout.connect, - sock_read=setup.timeout.read, - ) - else: - timeout = aiohttp.ClientTimeout(setup.timeout) - async with await self.client.request( setup.method, url, @@ -138,7 +167,7 @@ class Session(HTTPClientSession): cookies=cookies, headers=setup.headers, proxy=setup.proxy or self._proxy, - timeout=timeout, + timeout=self._get_timeout(setup.timeout), ) as response: return Response( response.status, @@ -171,15 +200,6 @@ class Session(HTTPClientSession): if cookie.value is not None ) - if isinstance(setup.timeout, Timeout): - timeout = aiohttp.ClientTimeout( - total=setup.timeout.total, - connect=setup.timeout.connect, - sock_read=setup.timeout.read, - ) - else: - timeout = aiohttp.ClientTimeout(setup.timeout) - async with self.client.request( setup.method, url, @@ -188,7 +208,7 @@ class Session(HTTPClientSession): cookies=cookies, headers=setup.headers, proxy=setup.proxy or self._proxy, - timeout=timeout, + timeout=self._get_timeout(setup.timeout), ) as response: response_headers = response.headers.copy() # aiohttp does not guarantee fixed-size chunks; re-chunk to exact size @@ -270,13 +290,39 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin): else: raise RuntimeError(f"Unsupported HTTP version: {setup.version}") + timeout = None if isinstance(setup.timeout, Timeout): - timeout = aiohttp.ClientWSTimeout( - ws_receive=setup.timeout.read, # type: ignore - ws_close=setup.timeout.total, # type: ignore + timeout_kwargs: dict[str, float | None] = exclude_unset( + { + "ws_receive": setup.timeout.read, + "ws_close": ( + setup.timeout.total + if setup.timeout.close is UNSET + else setup.timeout.close + ), + } + ) + if timeout_kwargs: + timeout = aiohttp.ClientWSTimeout(**timeout_kwargs) + elif setup.timeout is not UNSET: + timeout = aiohttp.ClientWSTimeout( + ws_receive=setup.timeout, # type: ignore + ws_close=setup.timeout, # type: ignore + ) + + if timeout is None: + timeout = aiohttp.ClientWSTimeout( + **exclude_unset( + { + "ws_receive": DEFAULT_TIMEOUT.read, + "ws_close": ( + DEFAULT_TIMEOUT.total + if DEFAULT_TIMEOUT.close is UNSET + else DEFAULT_TIMEOUT.close + ), + } + ) ) - else: - timeout = aiohttp.ClientWSTimeout(ws_close=setup.timeout or 10.0) # type: ignore async with aiohttp.ClientSession(version=version, trust_env=True) as session: async with session.ws_connect( @@ -295,7 +341,7 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin): headers: HeaderTypes = None, cookies: CookieTypes = None, version: str | HTTPVersion = HTTPVersion.H11, - timeout: TimeoutTypes = None, + timeout: TimeoutTypes | UnsetType = UNSET, proxy: str | None = None, ) -> Session: return Session( diff --git a/nonebot/drivers/httpx.py b/nonebot/drivers/httpx.py index 70bec595..a8541ad0 100644 --- a/nonebot/drivers/httpx.py +++ b/nonebot/drivers/httpx.py @@ -34,6 +34,7 @@ from nonebot.drivers import ( ) from nonebot.drivers.none import Driver as NoneDriver from nonebot.internal.driver import ( + DEFAULT_TIMEOUT, Cookies, CookieTypes, HeaderTypes, @@ -41,6 +42,7 @@ from nonebot.internal.driver import ( Timeout, TimeoutTypes, ) +from nonebot.utils import UNSET, UnsetType, exclude_unset try: import httpx @@ -59,7 +61,7 @@ class Session(HTTPClientSession): headers: HeaderTypes = None, cookies: CookieTypes = None, version: str | HTTPVersion = HTTPVersion.H11, - timeout: TimeoutTypes = None, + timeout: TimeoutTypes | UnsetType = UNSET, proxy: str | None = None, ): self._client: httpx.AsyncClient | None = None @@ -73,15 +75,34 @@ class Session(HTTPClientSession): self._cookies = Cookies(cookies) self._version = HTTPVersion(version) + _timeout = None if isinstance(timeout, Timeout): - self._timeout = httpx.Timeout( - timeout=timeout.total, - connect=timeout.connect, - read=timeout.read, + avg_timeout = timeout.total and timeout.total / 4 + timeout_kwargs: dict[str, float | None] = exclude_unset( + { + "timeout": avg_timeout, + "connect": timeout.connect, + "read": timeout.read, + } ) - else: - self._timeout = httpx.Timeout(timeout) + if timeout_kwargs: + _timeout = httpx.Timeout(**timeout_kwargs) + elif timeout is not UNSET: + _timeout = httpx.Timeout(timeout) + if _timeout is None: + avg_timeout = DEFAULT_TIMEOUT.total and DEFAULT_TIMEOUT.total / 4 + _timeout = httpx.Timeout( + **exclude_unset( + { + "timeout": avg_timeout, + "connect": DEFAULT_TIMEOUT.connect, + "read": DEFAULT_TIMEOUT.read, + } + ) + ) + + self._timeout = _timeout self._proxy = proxy @property @@ -90,17 +111,28 @@ class Session(HTTPClientSession): raise RuntimeError("Session is not initialized") return self._client + def _get_timeout(self, timeout: TimeoutTypes | UnsetType) -> httpx.Timeout: + _timeout = None + if isinstance(timeout, Timeout): + avg_timeout = timeout.total and timeout.total / 4 + timeout_kwargs: dict[str, float | None] = exclude_unset( + { + "timeout": avg_timeout, + "connect": timeout.connect, + "read": timeout.read, + } + ) + if timeout_kwargs: + _timeout = httpx.Timeout(**timeout_kwargs) + elif timeout is not UNSET: + _timeout = httpx.Timeout(timeout) + + if _timeout is None: + return self._timeout + return _timeout + @override async def request(self, setup: Request) -> Response: - if isinstance(setup.timeout, Timeout): - timeout = httpx.Timeout( - timeout=setup.timeout.total, - connect=setup.timeout.connect, - read=setup.timeout.read, - ) - else: - timeout = httpx.Timeout(setup.timeout) - response = await self.client.request( setup.method, str(setup.url), @@ -112,7 +144,7 @@ class Session(HTTPClientSession): params=setup.url.raw_query_string, headers=tuple(setup.headers.items()), cookies=setup.cookies.jar, - timeout=timeout, + timeout=self._get_timeout(setup.timeout), ) return Response( response.status_code, @@ -128,15 +160,6 @@ class Session(HTTPClientSession): *, chunk_size: int = 1024, ) -> AsyncGenerator[Response, None]: - if isinstance(setup.timeout, Timeout): - timeout = httpx.Timeout( - timeout=setup.timeout.total, - connect=setup.timeout.connect, - read=setup.timeout.read, - ) - else: - timeout = httpx.Timeout(setup.timeout) - async with self.client.stream( setup.method, str(setup.url), @@ -148,7 +171,7 @@ class Session(HTTPClientSession): params=setup.url.raw_query_string, headers=tuple(setup.headers.items()), cookies=setup.cookies.jar, - timeout=timeout, + timeout=self._get_timeout(setup.timeout), ) as response: response_headers = response.headers.multi_items() async for chunk in response.aiter_bytes(chunk_size=chunk_size): diff --git a/nonebot/drivers/websockets.py b/nonebot/drivers/websockets.py index 3e6aa07b..325c0c34 100644 --- a/nonebot/drivers/websockets.py +++ b/nonebot/drivers/websockets.py @@ -25,11 +25,18 @@ from types import CoroutineType from typing import TYPE_CHECKING, Any, TypeVar from typing_extensions import ParamSpec, override -from nonebot.drivers import Request, Timeout, WebSocketClientMixin, combine_driver +from nonebot.drivers import ( + DEFAULT_TIMEOUT, + Request, + Timeout, + WebSocketClientMixin, + combine_driver, +) from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers.none import Driver as NoneDriver from nonebot.exception import WebSocketClosed from nonebot.log import LoguruHandler +from nonebot.utils import UNSET, exclude_unset try: from websockets import ClientConnection, ConnectionClosed, connect @@ -70,16 +77,36 @@ class Mixin(WebSocketClientMixin): @override @asynccontextmanager async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]: + timeout_kwargs: dict[str, float | None] = {} if isinstance(setup.timeout, Timeout): - timeout = setup.timeout.total or setup.timeout.connect or setup.timeout.read - else: - timeout = setup.timeout + open_timeout = ( + setup.timeout.connect or setup.timeout.read or setup.timeout.total + ) + timeout_kwargs = exclude_unset( + {"open_timeout": open_timeout, "close_timeout": setup.timeout.close} + ) + elif setup.timeout is not UNSET: + timeout_kwargs = { + "open_timeout": setup.timeout, + "close_timeout": setup.timeout, + } + + if not timeout_kwargs: + open_timeout = ( + DEFAULT_TIMEOUT.connect or DEFAULT_TIMEOUT.read or DEFAULT_TIMEOUT.total + ) + timeout_kwargs = exclude_unset( + { + "open_timeout": open_timeout, + "close_timeout": DEFAULT_TIMEOUT.close, + } + ) connection = connect( str(setup.url), additional_headers={**setup.headers, **setup.cookies.as_header(setup)}, proxy=setup.proxy if setup.proxy is not None else True, - open_timeout=timeout, + **timeout_kwargs, # type: ignore ) async with connection as ws: yield WebSocket(request=setup, websocket=ws) diff --git a/nonebot/internal/driver/__init__.py b/nonebot/internal/driver/__init__.py index e4b3f042..e250173e 100644 --- a/nonebot/internal/driver/__init__.py +++ b/nonebot/internal/driver/__init__.py @@ -9,6 +9,7 @@ from .abstract import ReverseDriver as ReverseDriver from .abstract import ReverseMixin as ReverseMixin from .abstract import WebSocketClientMixin as WebSocketClientMixin from .combine import combine_driver as combine_driver +from .model import DEFAULT_TIMEOUT as DEFAULT_TIMEOUT from .model import URL as URL from .model import ContentTypes as ContentTypes from .model import Cookies as Cookies diff --git a/nonebot/internal/driver/abstract.py b/nonebot/internal/driver/abstract.py index 7ead40a7..d68d1265 100644 --- a/nonebot/internal/driver/abstract.py +++ b/nonebot/internal/driver/abstract.py @@ -19,7 +19,13 @@ from nonebot.typing import ( T_BotDisconnectionHook, T_DependencyCache, ) -from nonebot.utils import escape_tag, flatten_exception_group, run_coro_with_catch +from nonebot.utils import ( + UNSET, + UnsetType, + escape_tag, + flatten_exception_group, + run_coro_with_catch, +) from ._lifespan import LIFESPAN_FUNC, Lifespan from .model import ( @@ -246,7 +252,7 @@ class HTTPClientSession(abc.ABC): headers: HeaderTypes = None, cookies: CookieTypes = None, version: str | HTTPVersion = HTTPVersion.H11, - timeout: TimeoutTypes = None, + timeout: TimeoutTypes | UnsetType = UNSET, proxy: str | None = None, ): raise NotImplementedError diff --git a/nonebot/internal/driver/model.py b/nonebot/internal/driver/model.py index 169d589d..53df0990 100644 --- a/nonebot/internal/driver/model.py +++ b/nonebot/internal/driver/model.py @@ -9,14 +9,20 @@ import urllib.request from multidict import CIMultiDict from yarl import URL as URL +from nonebot.utils import UNSET, UnsetType + @dataclass class Timeout: """Request 超时配置。""" - total: float | None = None - connect: float | None = None - read: float | None = None + total: float | None | UnsetType = UNSET + connect: float | None | UnsetType = UNSET + read: float | None | UnsetType = UNSET + close: float | None | UnsetType = UNSET + + +DEFAULT_TIMEOUT = Timeout(total=None, connect=5.0, read=30.0, close=10.0) RawURL: TypeAlias = tuple[bytes, bytes, int | None, bytes] @@ -68,7 +74,7 @@ class Request: json: Any = None, files: FilesTypes = None, version: str | HTTPVersion = HTTPVersion.H11, - timeout: TimeoutTypes = None, + timeout: TimeoutTypes | UnsetType = UNSET, proxy: str | None = None, ): # method @@ -80,7 +86,7 @@ class Request: # http version self.version: HTTPVersion = HTTPVersion(version) # timeout - self.timeout: TimeoutTypes = timeout + self.timeout: TimeoutTypes | UnsetType = timeout # proxy self.proxy: str | None = proxy diff --git a/nonebot/utils.py b/nonebot/utils.py index 95750b2c..872d92a2 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -19,6 +19,7 @@ from collections.abc import ( import contextlib from contextlib import AbstractContextManager, asynccontextmanager import dataclasses +from enum import Enum from functools import partial, wraps import importlib import inspect @@ -27,8 +28,12 @@ from pathlib import Path import re from typing import ( Any, + Final, Generic, + Literal, + TypeAlias, TypeVar, + final, get_args, get_origin, overload, @@ -49,6 +54,8 @@ from nonebot.typing import ( type_has_args, ) +from .compat import custom_validation + P = ParamSpec("P") R = TypeVar("R") T = TypeVar("T") @@ -57,6 +64,54 @@ V = TypeVar("V") E = TypeVar("E", bound=BaseException) +@final +@custom_validation +class Unset(Enum): + _UNSET = "" + + def __repr__(self) -> str: + return "" + + def __str__(self) -> str: + return self.__repr__() + + def __bool__(self) -> Literal[False]: + return False + + def __copy__(self): + return self._UNSET + + def __deepcopy__(self, memo: dict[int, Any]): + return self._UNSET + + @classmethod + def __get_validators__(cls): + yield cls._validate + + @classmethod + def _validate(cls, value: Any): + if value is not cls._UNSET: + raise ValueError(f"{value!r} is not UNSET") + return value + + +UnsetType: TypeAlias = Literal[Unset._UNSET] + +UNSET: Final[UnsetType] = Unset._UNSET + + +def exclude_unset(data: Any) -> Any: + if isinstance(data, dict): + return data.__class__( + (k, exclude_unset(v)) for k, v in data.items() if v is not UNSET + ) + elif isinstance(data, list): + return data.__class__(exclude_unset(i) for i in data if i is not UNSET) + elif data is UNSET: + return None + return data + + def escape_tag(s: str) -> str: """用于记录带颜色日志时转义 `` 类型特殊标签 diff --git a/tests/test_driver.py b/tests/test_driver.py index 1b7a2a33..3dd954b7 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -26,6 +26,7 @@ from nonebot.drivers.aiohttp import Session as AiohttpSession from nonebot.drivers.aiohttp import WebSocket as AiohttpWebSocket from nonebot.exception import WebSocketClosed from nonebot.params import Depends +from nonebot.utils import UNSET from utils import FakeAdapter @@ -706,6 +707,177 @@ async def test_aiohttp_websocket_close_frame(msg_type: str) -> None: await ws.receive() +def test_timeout_unset_vs_none(): + # default: all fields are UNSET + t = Timeout() + assert t.total is UNSET + assert t.connect is UNSET + assert t.read is UNSET + assert t.close is UNSET + + # explicitly set to None + t = Timeout(close=None) + assert t.close is None + assert t.close is not UNSET + + # explicitly set to a value + t = Timeout(total=5.0, close=None) + assert t.total == 5.0 + assert t.close is None + assert t.connect is UNSET + assert t.read is UNSET + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "driver", + [ + pytest.param("nonebot.drivers.httpx:Driver", id="httpx"), + pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"), + ], + indirect=True, +) +async def test_http_client_timeout(driver: Driver, server_url: URL): + """HTTP requests work with fully unset, partial, and None timeout fields.""" + assert isinstance(driver, HTTPClientMixin) + + # timeout not set, default timeout should apply + request = Request("POST", server_url, content="test") + response = await driver.request(request) + assert response.status_code == 200 + async for resp in driver.stream_request(request, chunk_size=1024): + assert resp.status_code == 200 + + # timeout is float or none + request = Request("POST", server_url, content="test", timeout=10.0) + response = await driver.request(request) + assert response.status_code == 200 + async for resp in driver.stream_request(request, chunk_size=1024): + assert resp.status_code == 200 + + # all fields unset, default timeout should apply + request = Request("POST", server_url, content="test", timeout=Timeout()) + response = await driver.request(request) + assert response.status_code == 200 + async for resp in driver.stream_request(request, chunk_size=1024): + assert resp.status_code == 200 + + # only total set + request = Request("POST", server_url, content="test", timeout=Timeout(total=10.0)) + response = await driver.request(request) + assert response.status_code == 200 + async for resp in driver.stream_request(request, chunk_size=1024): + assert resp.status_code == 200 + + # explicit None (no timeout) + request = Request( + "POST", + server_url, + content="test", + timeout=Timeout(total=None, connect=None, read=None), + ) + response = await driver.request(request) + assert response.status_code == 200 + async for resp in driver.stream_request(request, chunk_size=1024): + assert resp.status_code == 200 + + # session with timeout not set + session = driver.get_session() + async with session: + request = Request("POST", server_url, content="test") + response = await session.request(request) + assert response.status_code == 200 + + # session with float or none timeout + session = driver.get_session(timeout=10.0) + async with session: + request = Request("POST", server_url, content="test") + response = await session.request(request) + assert response.status_code == 200 + + # session with fully unset timeout + session = driver.get_session(timeout=Timeout()) + async with session: + request = Request("POST", server_url, content="test") + response = await session.request(request) + assert response.status_code == 200 + + # session with timeout + session = driver.get_session(timeout=Timeout(total=10.0, connect=5.0, read=5.0)) + async with session: + request = Request("POST", server_url, content="test") + response = await session.request(request) + assert response.status_code == 200 + + # session with timeout override + session = driver.get_session(timeout=Timeout(total=10.0)) + async with session: + request = Request( + "POST", server_url, content="test", timeout=Timeout(total=20.0) + ) + response = await session.request(request) + assert response.status_code == 200 + + # session with timeout float override + session = driver.get_session(timeout=Timeout(total=10.0)) + async with session: + request = Request("POST", server_url, content="test", timeout=20.0) + response = await session.request(request) + assert response.status_code == 200 + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "driver", + [ + pytest.param("nonebot.drivers.websockets:Driver", id="websockets"), + pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"), + ], + indirect=True, +) +async def test_websocket_client_timeout(driver: Driver, server_url: URL): + """WebSocket connections work with fully unset, partial, and None timeout fields.""" + assert isinstance(driver, WebSocketClientMixin) + + ws_url = server_url.with_scheme("ws") + + # timeout not set, default timeout should apply + request = Request("GET", ws_url) + async with driver.websocket(request) as ws: + await ws.send("quit") + with pytest.raises(WebSocketClosed): + await ws.receive() + + await anyio.sleep(1) + + # timeout is float or none + request = Request("GET", ws_url, timeout=10.0) + async with driver.websocket(request) as ws: + await ws.send("quit") + with pytest.raises(WebSocketClosed): + await ws.receive() + + await anyio.sleep(1) + + # all fields unset, default timeout should apply + request = Request("GET", ws_url, timeout=Timeout()) + async with driver.websocket(request) as ws: + await ws.send("quit") + with pytest.raises(WebSocketClosed): + await ws.receive() + + await anyio.sleep(1) + + # close explicitly set to None (no close timeout) + request = Request("GET", ws_url, timeout=Timeout(close=None)) + async with driver.websocket(request) as ws: + await ws.send("quit") + with pytest.raises(WebSocketClosed): + await ws.receive() + + await anyio.sleep(1) + + @pytest.mark.parametrize( ("driver", "driver_type"), [ diff --git a/tests/test_utils.py b/tests/test_utils.py index 9636d8d6..d7aed538 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,19 @@ +import copy import json +import pickle from typing import ClassVar, Dict, List, Literal, TypeVar, Union # noqa: UP035 +from pydantic import ValidationError +import pytest + +from nonebot.compat import type_validate_python from nonebot.utils import ( + UNSET, DataclassEncoder, + Unset, + UnsetType, escape_tag, + exclude_unset, generic_check_issubclass, is_async_gen_callable, is_coroutine_callable, @@ -12,6 +22,29 @@ from nonebot.utils import ( from utils import FakeMessage, FakeMessageSegment +def test_unset(): + assert isinstance(UNSET, Unset) + assert bool(UNSET) is False + assert copy.copy(UNSET) is UNSET + assert copy.deepcopy(UNSET) is UNSET + assert pickle.loads(pickle.dumps(UNSET)) is UNSET + assert type_validate_python(UnsetType, UNSET) is UNSET + + with pytest.raises(ValidationError): + type_validate_python(UnsetType, 123) + + +def test_exclude_unset(): + assert exclude_unset({"a": 1, "b": UNSET, "c": None, "d": {"x": UNSET}}) == { + "a": 1, + "c": None, + "d": {}, + } + assert exclude_unset([1, UNSET, None, {"x": UNSET}]) == [1, None, {}] + assert exclude_unset(UNSET) is None + assert exclude_unset(123) == 123 + + def test_loguru_escape_tag(): assert escape_tag("red") == r"\red\" assert escape_tag("white") == r"\white\"