mirror of
https://github.com/nonebot/nonebot2.git
synced 2026-04-14 20:47:18 +00:00
🐛 Fix: 修正 http/websocket client timeout (#3923)
Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -9,6 +9,7 @@ FrontMatter:
|
||||
description: nonebot.drivers 模块
|
||||
"""
|
||||
|
||||
from nonebot.internal.driver import DEFAULT_TIMEOUT as DEFAULT_TIMEOUT
|
||||
from nonebot.internal.driver import URL as URL
|
||||
from nonebot.internal.driver import ASGIMixin as ASGIMixin
|
||||
from nonebot.internal.driver import Cookies as Cookies
|
||||
@@ -31,6 +32,7 @@ from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup
|
||||
from nonebot.internal.driver import combine_driver as combine_driver
|
||||
|
||||
__autodoc__ = {
|
||||
"DEFAULT_TIMEOUT": True,
|
||||
"URL": True,
|
||||
"Cookies": True,
|
||||
"Request": True,
|
||||
|
||||
@@ -38,6 +38,7 @@ from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers.none import Driver as NoneDriver
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.internal.driver import (
|
||||
DEFAULT_TIMEOUT,
|
||||
Cookies,
|
||||
CookieTypes,
|
||||
HeaderTypes,
|
||||
@@ -45,6 +46,7 @@ from nonebot.internal.driver import (
|
||||
Timeout,
|
||||
TimeoutTypes,
|
||||
)
|
||||
from nonebot.utils import UNSET, UnsetType, exclude_unset
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
@@ -63,7 +65,7 @@ class Session(HTTPClientSession):
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
version: str | HTTPVersion = HTTPVersion.H11,
|
||||
timeout: TimeoutTypes = None,
|
||||
timeout: TimeoutTypes | UnsetType = UNSET,
|
||||
proxy: str | None = None,
|
||||
):
|
||||
self._client: aiohttp.ClientSession | None = None
|
||||
@@ -85,15 +87,32 @@ class Session(HTTPClientSession):
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported HTTP version: {version}")
|
||||
|
||||
_timeout = None
|
||||
if isinstance(timeout, Timeout):
|
||||
self._timeout = aiohttp.ClientTimeout(
|
||||
total=timeout.total,
|
||||
connect=timeout.connect,
|
||||
sock_read=timeout.read,
|
||||
timeout_kwargs: dict[str, float | None] = exclude_unset(
|
||||
{
|
||||
"total": timeout.total,
|
||||
"connect": timeout.connect,
|
||||
"sock_read": timeout.read,
|
||||
}
|
||||
)
|
||||
else:
|
||||
self._timeout = aiohttp.ClientTimeout(timeout)
|
||||
if timeout_kwargs:
|
||||
_timeout = aiohttp.ClientTimeout(**timeout_kwargs) # type: ignore
|
||||
elif timeout is not UNSET:
|
||||
_timeout = aiohttp.ClientTimeout(connect=timeout, sock_read=timeout)
|
||||
|
||||
if _timeout is None:
|
||||
_timeout = aiohttp.ClientTimeout(
|
||||
**exclude_unset(
|
||||
{
|
||||
"total": DEFAULT_TIMEOUT.total,
|
||||
"connect": DEFAULT_TIMEOUT.connect,
|
||||
"sock_read": DEFAULT_TIMEOUT.read,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
self._timeout = _timeout
|
||||
self._proxy = proxy
|
||||
|
||||
@property
|
||||
@@ -102,6 +121,25 @@ class Session(HTTPClientSession):
|
||||
raise RuntimeError("Session is not initialized")
|
||||
return self._client
|
||||
|
||||
def _get_timeout(self, timeout: TimeoutTypes | UnsetType) -> aiohttp.ClientTimeout:
|
||||
_timeout = None
|
||||
if isinstance(timeout, Timeout):
|
||||
timeout_kwargs: dict[str, float | None] = exclude_unset(
|
||||
{
|
||||
"total": timeout.total,
|
||||
"connect": timeout.connect,
|
||||
"sock_read": timeout.read,
|
||||
}
|
||||
)
|
||||
if timeout_kwargs:
|
||||
_timeout = aiohttp.ClientTimeout(**timeout_kwargs) # type: ignore
|
||||
elif timeout is not UNSET:
|
||||
_timeout = aiohttp.ClientTimeout(connect=timeout, sock_read=timeout)
|
||||
|
||||
if _timeout is None:
|
||||
return self._timeout
|
||||
return _timeout
|
||||
|
||||
@override
|
||||
async def request(self, setup: Request) -> Response:
|
||||
if self._params:
|
||||
@@ -121,15 +159,6 @@ class Session(HTTPClientSession):
|
||||
if cookie.value is not None
|
||||
)
|
||||
|
||||
if isinstance(setup.timeout, Timeout):
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=setup.timeout.total,
|
||||
connect=setup.timeout.connect,
|
||||
sock_read=setup.timeout.read,
|
||||
)
|
||||
else:
|
||||
timeout = aiohttp.ClientTimeout(setup.timeout)
|
||||
|
||||
async with await self.client.request(
|
||||
setup.method,
|
||||
url,
|
||||
@@ -138,7 +167,7 @@ class Session(HTTPClientSession):
|
||||
cookies=cookies,
|
||||
headers=setup.headers,
|
||||
proxy=setup.proxy or self._proxy,
|
||||
timeout=timeout,
|
||||
timeout=self._get_timeout(setup.timeout),
|
||||
) as response:
|
||||
return Response(
|
||||
response.status,
|
||||
@@ -171,15 +200,6 @@ class Session(HTTPClientSession):
|
||||
if cookie.value is not None
|
||||
)
|
||||
|
||||
if isinstance(setup.timeout, Timeout):
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=setup.timeout.total,
|
||||
connect=setup.timeout.connect,
|
||||
sock_read=setup.timeout.read,
|
||||
)
|
||||
else:
|
||||
timeout = aiohttp.ClientTimeout(setup.timeout)
|
||||
|
||||
async with self.client.request(
|
||||
setup.method,
|
||||
url,
|
||||
@@ -188,7 +208,7 @@ class Session(HTTPClientSession):
|
||||
cookies=cookies,
|
||||
headers=setup.headers,
|
||||
proxy=setup.proxy or self._proxy,
|
||||
timeout=timeout,
|
||||
timeout=self._get_timeout(setup.timeout),
|
||||
) as response:
|
||||
response_headers = response.headers.copy()
|
||||
# aiohttp does not guarantee fixed-size chunks; re-chunk to exact size
|
||||
@@ -270,13 +290,39 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
||||
|
||||
timeout = None
|
||||
if isinstance(setup.timeout, Timeout):
|
||||
timeout = aiohttp.ClientWSTimeout(
|
||||
ws_receive=setup.timeout.read, # type: ignore
|
||||
ws_close=setup.timeout.total, # type: ignore
|
||||
timeout_kwargs: dict[str, float | None] = exclude_unset(
|
||||
{
|
||||
"ws_receive": setup.timeout.read,
|
||||
"ws_close": (
|
||||
setup.timeout.total
|
||||
if setup.timeout.close is UNSET
|
||||
else setup.timeout.close
|
||||
),
|
||||
}
|
||||
)
|
||||
if timeout_kwargs:
|
||||
timeout = aiohttp.ClientWSTimeout(**timeout_kwargs)
|
||||
elif setup.timeout is not UNSET:
|
||||
timeout = aiohttp.ClientWSTimeout(
|
||||
ws_receive=setup.timeout, # type: ignore
|
||||
ws_close=setup.timeout, # type: ignore
|
||||
)
|
||||
|
||||
if timeout is None:
|
||||
timeout = aiohttp.ClientWSTimeout(
|
||||
**exclude_unset(
|
||||
{
|
||||
"ws_receive": DEFAULT_TIMEOUT.read,
|
||||
"ws_close": (
|
||||
DEFAULT_TIMEOUT.total
|
||||
if DEFAULT_TIMEOUT.close is UNSET
|
||||
else DEFAULT_TIMEOUT.close
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
else:
|
||||
timeout = aiohttp.ClientWSTimeout(ws_close=setup.timeout or 10.0) # type: ignore
|
||||
|
||||
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
|
||||
async with session.ws_connect(
|
||||
@@ -295,7 +341,7 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
version: str | HTTPVersion = HTTPVersion.H11,
|
||||
timeout: TimeoutTypes = None,
|
||||
timeout: TimeoutTypes | UnsetType = UNSET,
|
||||
proxy: str | None = None,
|
||||
) -> Session:
|
||||
return Session(
|
||||
|
||||
@@ -34,6 +34,7 @@ from nonebot.drivers import (
|
||||
)
|
||||
from nonebot.drivers.none import Driver as NoneDriver
|
||||
from nonebot.internal.driver import (
|
||||
DEFAULT_TIMEOUT,
|
||||
Cookies,
|
||||
CookieTypes,
|
||||
HeaderTypes,
|
||||
@@ -41,6 +42,7 @@ from nonebot.internal.driver import (
|
||||
Timeout,
|
||||
TimeoutTypes,
|
||||
)
|
||||
from nonebot.utils import UNSET, UnsetType, exclude_unset
|
||||
|
||||
try:
|
||||
import httpx
|
||||
@@ -59,7 +61,7 @@ class Session(HTTPClientSession):
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
version: str | HTTPVersion = HTTPVersion.H11,
|
||||
timeout: TimeoutTypes = None,
|
||||
timeout: TimeoutTypes | UnsetType = UNSET,
|
||||
proxy: str | None = None,
|
||||
):
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
@@ -73,15 +75,34 @@ class Session(HTTPClientSession):
|
||||
self._cookies = Cookies(cookies)
|
||||
self._version = HTTPVersion(version)
|
||||
|
||||
_timeout = None
|
||||
if isinstance(timeout, Timeout):
|
||||
self._timeout = httpx.Timeout(
|
||||
timeout=timeout.total,
|
||||
connect=timeout.connect,
|
||||
read=timeout.read,
|
||||
avg_timeout = timeout.total and timeout.total / 4
|
||||
timeout_kwargs: dict[str, float | None] = exclude_unset(
|
||||
{
|
||||
"timeout": avg_timeout,
|
||||
"connect": timeout.connect,
|
||||
"read": timeout.read,
|
||||
}
|
||||
)
|
||||
else:
|
||||
self._timeout = httpx.Timeout(timeout)
|
||||
if timeout_kwargs:
|
||||
_timeout = httpx.Timeout(**timeout_kwargs)
|
||||
elif timeout is not UNSET:
|
||||
_timeout = httpx.Timeout(timeout)
|
||||
|
||||
if _timeout is None:
|
||||
avg_timeout = DEFAULT_TIMEOUT.total and DEFAULT_TIMEOUT.total / 4
|
||||
_timeout = httpx.Timeout(
|
||||
**exclude_unset(
|
||||
{
|
||||
"timeout": avg_timeout,
|
||||
"connect": DEFAULT_TIMEOUT.connect,
|
||||
"read": DEFAULT_TIMEOUT.read,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
self._timeout = _timeout
|
||||
self._proxy = proxy
|
||||
|
||||
@property
|
||||
@@ -90,17 +111,28 @@ class Session(HTTPClientSession):
|
||||
raise RuntimeError("Session is not initialized")
|
||||
return self._client
|
||||
|
||||
def _get_timeout(self, timeout: TimeoutTypes | UnsetType) -> httpx.Timeout:
|
||||
_timeout = None
|
||||
if isinstance(timeout, Timeout):
|
||||
avg_timeout = timeout.total and timeout.total / 4
|
||||
timeout_kwargs: dict[str, float | None] = exclude_unset(
|
||||
{
|
||||
"timeout": avg_timeout,
|
||||
"connect": timeout.connect,
|
||||
"read": timeout.read,
|
||||
}
|
||||
)
|
||||
if timeout_kwargs:
|
||||
_timeout = httpx.Timeout(**timeout_kwargs)
|
||||
elif timeout is not UNSET:
|
||||
_timeout = httpx.Timeout(timeout)
|
||||
|
||||
if _timeout is None:
|
||||
return self._timeout
|
||||
return _timeout
|
||||
|
||||
@override
|
||||
async def request(self, setup: Request) -> Response:
|
||||
if isinstance(setup.timeout, Timeout):
|
||||
timeout = httpx.Timeout(
|
||||
timeout=setup.timeout.total,
|
||||
connect=setup.timeout.connect,
|
||||
read=setup.timeout.read,
|
||||
)
|
||||
else:
|
||||
timeout = httpx.Timeout(setup.timeout)
|
||||
|
||||
response = await self.client.request(
|
||||
setup.method,
|
||||
str(setup.url),
|
||||
@@ -112,7 +144,7 @@ class Session(HTTPClientSession):
|
||||
params=setup.url.raw_query_string,
|
||||
headers=tuple(setup.headers.items()),
|
||||
cookies=setup.cookies.jar,
|
||||
timeout=timeout,
|
||||
timeout=self._get_timeout(setup.timeout),
|
||||
)
|
||||
return Response(
|
||||
response.status_code,
|
||||
@@ -128,15 +160,6 @@ class Session(HTTPClientSession):
|
||||
*,
|
||||
chunk_size: int = 1024,
|
||||
) -> AsyncGenerator[Response, None]:
|
||||
if isinstance(setup.timeout, Timeout):
|
||||
timeout = httpx.Timeout(
|
||||
timeout=setup.timeout.total,
|
||||
connect=setup.timeout.connect,
|
||||
read=setup.timeout.read,
|
||||
)
|
||||
else:
|
||||
timeout = httpx.Timeout(setup.timeout)
|
||||
|
||||
async with self.client.stream(
|
||||
setup.method,
|
||||
str(setup.url),
|
||||
@@ -148,7 +171,7 @@ class Session(HTTPClientSession):
|
||||
params=setup.url.raw_query_string,
|
||||
headers=tuple(setup.headers.items()),
|
||||
cookies=setup.cookies.jar,
|
||||
timeout=timeout,
|
||||
timeout=self._get_timeout(setup.timeout),
|
||||
) as response:
|
||||
response_headers = response.headers.multi_items()
|
||||
async for chunk in response.aiter_bytes(chunk_size=chunk_size):
|
||||
|
||||
@@ -25,11 +25,18 @@ from types import CoroutineType
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
from typing_extensions import ParamSpec, override
|
||||
|
||||
from nonebot.drivers import Request, Timeout, WebSocketClientMixin, combine_driver
|
||||
from nonebot.drivers import (
|
||||
DEFAULT_TIMEOUT,
|
||||
Request,
|
||||
Timeout,
|
||||
WebSocketClientMixin,
|
||||
combine_driver,
|
||||
)
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers.none import Driver as NoneDriver
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.log import LoguruHandler
|
||||
from nonebot.utils import UNSET, exclude_unset
|
||||
|
||||
try:
|
||||
from websockets import ClientConnection, ConnectionClosed, connect
|
||||
@@ -70,16 +77,36 @@ class Mixin(WebSocketClientMixin):
|
||||
@override
|
||||
@asynccontextmanager
|
||||
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
|
||||
timeout_kwargs: dict[str, float | None] = {}
|
||||
if isinstance(setup.timeout, Timeout):
|
||||
timeout = setup.timeout.total or setup.timeout.connect or setup.timeout.read
|
||||
else:
|
||||
timeout = setup.timeout
|
||||
open_timeout = (
|
||||
setup.timeout.connect or setup.timeout.read or setup.timeout.total
|
||||
)
|
||||
timeout_kwargs = exclude_unset(
|
||||
{"open_timeout": open_timeout, "close_timeout": setup.timeout.close}
|
||||
)
|
||||
elif setup.timeout is not UNSET:
|
||||
timeout_kwargs = {
|
||||
"open_timeout": setup.timeout,
|
||||
"close_timeout": setup.timeout,
|
||||
}
|
||||
|
||||
if not timeout_kwargs:
|
||||
open_timeout = (
|
||||
DEFAULT_TIMEOUT.connect or DEFAULT_TIMEOUT.read or DEFAULT_TIMEOUT.total
|
||||
)
|
||||
timeout_kwargs = exclude_unset(
|
||||
{
|
||||
"open_timeout": open_timeout,
|
||||
"close_timeout": DEFAULT_TIMEOUT.close,
|
||||
}
|
||||
)
|
||||
|
||||
connection = connect(
|
||||
str(setup.url),
|
||||
additional_headers={**setup.headers, **setup.cookies.as_header(setup)},
|
||||
proxy=setup.proxy if setup.proxy is not None else True,
|
||||
open_timeout=timeout,
|
||||
**timeout_kwargs, # type: ignore
|
||||
)
|
||||
async with connection as ws:
|
||||
yield WebSocket(request=setup, websocket=ws)
|
||||
|
||||
@@ -9,6 +9,7 @@ from .abstract import ReverseDriver as ReverseDriver
|
||||
from .abstract import ReverseMixin as ReverseMixin
|
||||
from .abstract import WebSocketClientMixin as WebSocketClientMixin
|
||||
from .combine import combine_driver as combine_driver
|
||||
from .model import DEFAULT_TIMEOUT as DEFAULT_TIMEOUT
|
||||
from .model import URL as URL
|
||||
from .model import ContentTypes as ContentTypes
|
||||
from .model import Cookies as Cookies
|
||||
|
||||
@@ -19,7 +19,13 @@ from nonebot.typing import (
|
||||
T_BotDisconnectionHook,
|
||||
T_DependencyCache,
|
||||
)
|
||||
from nonebot.utils import escape_tag, flatten_exception_group, run_coro_with_catch
|
||||
from nonebot.utils import (
|
||||
UNSET,
|
||||
UnsetType,
|
||||
escape_tag,
|
||||
flatten_exception_group,
|
||||
run_coro_with_catch,
|
||||
)
|
||||
|
||||
from ._lifespan import LIFESPAN_FUNC, Lifespan
|
||||
from .model import (
|
||||
@@ -246,7 +252,7 @@ class HTTPClientSession(abc.ABC):
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
version: str | HTTPVersion = HTTPVersion.H11,
|
||||
timeout: TimeoutTypes = None,
|
||||
timeout: TimeoutTypes | UnsetType = UNSET,
|
||||
proxy: str | None = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -9,14 +9,20 @@ import urllib.request
|
||||
from multidict import CIMultiDict
|
||||
from yarl import URL as URL
|
||||
|
||||
from nonebot.utils import UNSET, UnsetType
|
||||
|
||||
|
||||
@dataclass
|
||||
class Timeout:
|
||||
"""Request 超时配置。"""
|
||||
|
||||
total: float | None = None
|
||||
connect: float | None = None
|
||||
read: float | None = None
|
||||
total: float | None | UnsetType = UNSET
|
||||
connect: float | None | UnsetType = UNSET
|
||||
read: float | None | UnsetType = UNSET
|
||||
close: float | None | UnsetType = UNSET
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = Timeout(total=None, connect=5.0, read=30.0, close=10.0)
|
||||
|
||||
|
||||
RawURL: TypeAlias = tuple[bytes, bytes, int | None, bytes]
|
||||
@@ -68,7 +74,7 @@ class Request:
|
||||
json: Any = None,
|
||||
files: FilesTypes = None,
|
||||
version: str | HTTPVersion = HTTPVersion.H11,
|
||||
timeout: TimeoutTypes = None,
|
||||
timeout: TimeoutTypes | UnsetType = UNSET,
|
||||
proxy: str | None = None,
|
||||
):
|
||||
# method
|
||||
@@ -80,7 +86,7 @@ class Request:
|
||||
# http version
|
||||
self.version: HTTPVersion = HTTPVersion(version)
|
||||
# timeout
|
||||
self.timeout: TimeoutTypes = timeout
|
||||
self.timeout: TimeoutTypes | UnsetType = timeout
|
||||
# proxy
|
||||
self.proxy: str | None = proxy
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from collections.abc import (
|
||||
import contextlib
|
||||
from contextlib import AbstractContextManager, asynccontextmanager
|
||||
import dataclasses
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
import importlib
|
||||
import inspect
|
||||
@@ -27,8 +28,12 @@ from pathlib import Path
|
||||
import re
|
||||
from typing import (
|
||||
Any,
|
||||
Final,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
final,
|
||||
get_args,
|
||||
get_origin,
|
||||
overload,
|
||||
@@ -49,6 +54,8 @@ from nonebot.typing import (
|
||||
type_has_args,
|
||||
)
|
||||
|
||||
from .compat import custom_validation
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
@@ -57,6 +64,54 @@ V = TypeVar("V")
|
||||
E = TypeVar("E", bound=BaseException)
|
||||
|
||||
|
||||
@final
|
||||
@custom_validation
|
||||
class Unset(Enum):
|
||||
_UNSET = "<UNSET>"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<UNSET>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
def __bool__(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
def __copy__(self):
|
||||
return self._UNSET
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any]):
|
||||
return self._UNSET
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls._validate
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, value: Any):
|
||||
if value is not cls._UNSET:
|
||||
raise ValueError(f"{value!r} is not UNSET")
|
||||
return value
|
||||
|
||||
|
||||
UnsetType: TypeAlias = Literal[Unset._UNSET]
|
||||
|
||||
UNSET: Final[UnsetType] = Unset._UNSET
|
||||
|
||||
|
||||
def exclude_unset(data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
return data.__class__(
|
||||
(k, exclude_unset(v)) for k, v in data.items() if v is not UNSET
|
||||
)
|
||||
elif isinstance(data, list):
|
||||
return data.__class__(exclude_unset(i) for i in data if i is not UNSET)
|
||||
elif data is UNSET:
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
def escape_tag(s: str) -> str:
|
||||
"""用于记录带颜色日志时转义 `<tag>` 类型特殊标签
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from nonebot.drivers.aiohttp import Session as AiohttpSession
|
||||
from nonebot.drivers.aiohttp import WebSocket as AiohttpWebSocket
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.params import Depends
|
||||
from nonebot.utils import UNSET
|
||||
from utils import FakeAdapter
|
||||
|
||||
|
||||
@@ -706,6 +707,177 @@ async def test_aiohttp_websocket_close_frame(msg_type: str) -> None:
|
||||
await ws.receive()
|
||||
|
||||
|
||||
def test_timeout_unset_vs_none():
|
||||
# default: all fields are UNSET
|
||||
t = Timeout()
|
||||
assert t.total is UNSET
|
||||
assert t.connect is UNSET
|
||||
assert t.read is UNSET
|
||||
assert t.close is UNSET
|
||||
|
||||
# explicitly set to None
|
||||
t = Timeout(close=None)
|
||||
assert t.close is None
|
||||
assert t.close is not UNSET
|
||||
|
||||
# explicitly set to a value
|
||||
t = Timeout(total=5.0, close=None)
|
||||
assert t.total == 5.0
|
||||
assert t.close is None
|
||||
assert t.connect is UNSET
|
||||
assert t.read is UNSET
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize(
|
||||
"driver",
|
||||
[
|
||||
pytest.param("nonebot.drivers.httpx:Driver", id="httpx"),
|
||||
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_http_client_timeout(driver: Driver, server_url: URL):
|
||||
"""HTTP requests work with fully unset, partial, and None timeout fields."""
|
||||
assert isinstance(driver, HTTPClientMixin)
|
||||
|
||||
# timeout not set, default timeout should apply
|
||||
request = Request("POST", server_url, content="test")
|
||||
response = await driver.request(request)
|
||||
assert response.status_code == 200
|
||||
async for resp in driver.stream_request(request, chunk_size=1024):
|
||||
assert resp.status_code == 200
|
||||
|
||||
# timeout is float or none
|
||||
request = Request("POST", server_url, content="test", timeout=10.0)
|
||||
response = await driver.request(request)
|
||||
assert response.status_code == 200
|
||||
async for resp in driver.stream_request(request, chunk_size=1024):
|
||||
assert resp.status_code == 200
|
||||
|
||||
# all fields unset, default timeout should apply
|
||||
request = Request("POST", server_url, content="test", timeout=Timeout())
|
||||
response = await driver.request(request)
|
||||
assert response.status_code == 200
|
||||
async for resp in driver.stream_request(request, chunk_size=1024):
|
||||
assert resp.status_code == 200
|
||||
|
||||
# only total set
|
||||
request = Request("POST", server_url, content="test", timeout=Timeout(total=10.0))
|
||||
response = await driver.request(request)
|
||||
assert response.status_code == 200
|
||||
async for resp in driver.stream_request(request, chunk_size=1024):
|
||||
assert resp.status_code == 200
|
||||
|
||||
# explicit None (no timeout)
|
||||
request = Request(
|
||||
"POST",
|
||||
server_url,
|
||||
content="test",
|
||||
timeout=Timeout(total=None, connect=None, read=None),
|
||||
)
|
||||
response = await driver.request(request)
|
||||
assert response.status_code == 200
|
||||
async for resp in driver.stream_request(request, chunk_size=1024):
|
||||
assert resp.status_code == 200
|
||||
|
||||
# session with timeout not set
|
||||
session = driver.get_session()
|
||||
async with session:
|
||||
request = Request("POST", server_url, content="test")
|
||||
response = await session.request(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
# session with float or none timeout
|
||||
session = driver.get_session(timeout=10.0)
|
||||
async with session:
|
||||
request = Request("POST", server_url, content="test")
|
||||
response = await session.request(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
# session with fully unset timeout
|
||||
session = driver.get_session(timeout=Timeout())
|
||||
async with session:
|
||||
request = Request("POST", server_url, content="test")
|
||||
response = await session.request(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
# session with timeout
|
||||
session = driver.get_session(timeout=Timeout(total=10.0, connect=5.0, read=5.0))
|
||||
async with session:
|
||||
request = Request("POST", server_url, content="test")
|
||||
response = await session.request(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
# session with timeout override
|
||||
session = driver.get_session(timeout=Timeout(total=10.0))
|
||||
async with session:
|
||||
request = Request(
|
||||
"POST", server_url, content="test", timeout=Timeout(total=20.0)
|
||||
)
|
||||
response = await session.request(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
# session with timeout float override
|
||||
session = driver.get_session(timeout=Timeout(total=10.0))
|
||||
async with session:
|
||||
request = Request("POST", server_url, content="test", timeout=20.0)
|
||||
response = await session.request(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize(
|
||||
"driver",
|
||||
[
|
||||
pytest.param("nonebot.drivers.websockets:Driver", id="websockets"),
|
||||
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_websocket_client_timeout(driver: Driver, server_url: URL):
|
||||
"""WebSocket connections work with fully unset, partial, and None timeout fields."""
|
||||
assert isinstance(driver, WebSocketClientMixin)
|
||||
|
||||
ws_url = server_url.with_scheme("ws")
|
||||
|
||||
# timeout not set, default timeout should apply
|
||||
request = Request("GET", ws_url)
|
||||
async with driver.websocket(request) as ws:
|
||||
await ws.send("quit")
|
||||
with pytest.raises(WebSocketClosed):
|
||||
await ws.receive()
|
||||
|
||||
await anyio.sleep(1)
|
||||
|
||||
# timeout is float or none
|
||||
request = Request("GET", ws_url, timeout=10.0)
|
||||
async with driver.websocket(request) as ws:
|
||||
await ws.send("quit")
|
||||
with pytest.raises(WebSocketClosed):
|
||||
await ws.receive()
|
||||
|
||||
await anyio.sleep(1)
|
||||
|
||||
# all fields unset, default timeout should apply
|
||||
request = Request("GET", ws_url, timeout=Timeout())
|
||||
async with driver.websocket(request) as ws:
|
||||
await ws.send("quit")
|
||||
with pytest.raises(WebSocketClosed):
|
||||
await ws.receive()
|
||||
|
||||
await anyio.sleep(1)
|
||||
|
||||
# close explicitly set to None (no close timeout)
|
||||
request = Request("GET", ws_url, timeout=Timeout(close=None))
|
||||
async with driver.websocket(request) as ws:
|
||||
await ws.send("quit")
|
||||
with pytest.raises(WebSocketClosed):
|
||||
await ws.receive()
|
||||
|
||||
await anyio.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("driver", "driver_type"),
|
||||
[
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
import copy
|
||||
import json
|
||||
import pickle
|
||||
from typing import ClassVar, Dict, List, Literal, TypeVar, Union # noqa: UP035
|
||||
|
||||
from pydantic import ValidationError
|
||||
import pytest
|
||||
|
||||
from nonebot.compat import type_validate_python
|
||||
from nonebot.utils import (
|
||||
UNSET,
|
||||
DataclassEncoder,
|
||||
Unset,
|
||||
UnsetType,
|
||||
escape_tag,
|
||||
exclude_unset,
|
||||
generic_check_issubclass,
|
||||
is_async_gen_callable,
|
||||
is_coroutine_callable,
|
||||
@@ -12,6 +22,29 @@ from nonebot.utils import (
|
||||
from utils import FakeMessage, FakeMessageSegment
|
||||
|
||||
|
||||
def test_unset():
|
||||
assert isinstance(UNSET, Unset)
|
||||
assert bool(UNSET) is False
|
||||
assert copy.copy(UNSET) is UNSET
|
||||
assert copy.deepcopy(UNSET) is UNSET
|
||||
assert pickle.loads(pickle.dumps(UNSET)) is UNSET
|
||||
assert type_validate_python(UnsetType, UNSET) is UNSET
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
type_validate_python(UnsetType, 123)
|
||||
|
||||
|
||||
def test_exclude_unset():
|
||||
assert exclude_unset({"a": 1, "b": UNSET, "c": None, "d": {"x": UNSET}}) == {
|
||||
"a": 1,
|
||||
"c": None,
|
||||
"d": {},
|
||||
}
|
||||
assert exclude_unset([1, UNSET, None, {"x": UNSET}]) == [1, None, {}]
|
||||
assert exclude_unset(UNSET) is None
|
||||
assert exclude_unset(123) == 123
|
||||
|
||||
|
||||
def test_loguru_escape_tag():
|
||||
assert escape_tag("<red>red</red>") == r"\<red>red\</red>"
|
||||
assert escape_tag("<fg #fff>white</fg #fff>") == r"\<fg #fff>white\</fg #fff>"
|
||||
|
||||
Reference in New Issue
Block a user