mirror of
https://github.com/nonebot/nonebot2.git
synced 2026-04-15 04:57:25 +00:00
🐛 Fix: 修正 http/websocket client timeout (#3923)
Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -9,6 +9,7 @@ FrontMatter:
|
|||||||
description: nonebot.drivers 模块
|
description: nonebot.drivers 模块
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from nonebot.internal.driver import DEFAULT_TIMEOUT as DEFAULT_TIMEOUT
|
||||||
from nonebot.internal.driver import URL as URL
|
from nonebot.internal.driver import URL as URL
|
||||||
from nonebot.internal.driver import ASGIMixin as ASGIMixin
|
from nonebot.internal.driver import ASGIMixin as ASGIMixin
|
||||||
from nonebot.internal.driver import Cookies as Cookies
|
from nonebot.internal.driver import Cookies as Cookies
|
||||||
@@ -31,6 +32,7 @@ from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup
|
|||||||
from nonebot.internal.driver import combine_driver as combine_driver
|
from nonebot.internal.driver import combine_driver as combine_driver
|
||||||
|
|
||||||
__autodoc__ = {
|
__autodoc__ = {
|
||||||
|
"DEFAULT_TIMEOUT": True,
|
||||||
"URL": True,
|
"URL": True,
|
||||||
"Cookies": True,
|
"Cookies": True,
|
||||||
"Request": True,
|
"Request": True,
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from nonebot.drivers import WebSocket as BaseWebSocket
|
|||||||
from nonebot.drivers.none import Driver as NoneDriver
|
from nonebot.drivers.none import Driver as NoneDriver
|
||||||
from nonebot.exception import WebSocketClosed
|
from nonebot.exception import WebSocketClosed
|
||||||
from nonebot.internal.driver import (
|
from nonebot.internal.driver import (
|
||||||
|
DEFAULT_TIMEOUT,
|
||||||
Cookies,
|
Cookies,
|
||||||
CookieTypes,
|
CookieTypes,
|
||||||
HeaderTypes,
|
HeaderTypes,
|
||||||
@@ -45,6 +46,7 @@ from nonebot.internal.driver import (
|
|||||||
Timeout,
|
Timeout,
|
||||||
TimeoutTypes,
|
TimeoutTypes,
|
||||||
)
|
)
|
||||||
|
from nonebot.utils import UNSET, UnsetType, exclude_unset
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@@ -63,7 +65,7 @@ class Session(HTTPClientSession):
|
|||||||
headers: HeaderTypes = None,
|
headers: HeaderTypes = None,
|
||||||
cookies: CookieTypes = None,
|
cookies: CookieTypes = None,
|
||||||
version: str | HTTPVersion = HTTPVersion.H11,
|
version: str | HTTPVersion = HTTPVersion.H11,
|
||||||
timeout: TimeoutTypes = None,
|
timeout: TimeoutTypes | UnsetType = UNSET,
|
||||||
proxy: str | None = None,
|
proxy: str | None = None,
|
||||||
):
|
):
|
||||||
self._client: aiohttp.ClientSession | None = None
|
self._client: aiohttp.ClientSession | None = None
|
||||||
@@ -85,15 +87,32 @@ class Session(HTTPClientSession):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupported HTTP version: {version}")
|
raise RuntimeError(f"Unsupported HTTP version: {version}")
|
||||||
|
|
||||||
|
_timeout = None
|
||||||
if isinstance(timeout, Timeout):
|
if isinstance(timeout, Timeout):
|
||||||
self._timeout = aiohttp.ClientTimeout(
|
timeout_kwargs: dict[str, float | None] = exclude_unset(
|
||||||
total=timeout.total,
|
{
|
||||||
connect=timeout.connect,
|
"total": timeout.total,
|
||||||
sock_read=timeout.read,
|
"connect": timeout.connect,
|
||||||
|
"sock_read": timeout.read,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
else:
|
if timeout_kwargs:
|
||||||
self._timeout = aiohttp.ClientTimeout(timeout)
|
_timeout = aiohttp.ClientTimeout(**timeout_kwargs) # type: ignore
|
||||||
|
elif timeout is not UNSET:
|
||||||
|
_timeout = aiohttp.ClientTimeout(connect=timeout, sock_read=timeout)
|
||||||
|
|
||||||
|
if _timeout is None:
|
||||||
|
_timeout = aiohttp.ClientTimeout(
|
||||||
|
**exclude_unset(
|
||||||
|
{
|
||||||
|
"total": DEFAULT_TIMEOUT.total,
|
||||||
|
"connect": DEFAULT_TIMEOUT.connect,
|
||||||
|
"sock_read": DEFAULT_TIMEOUT.read,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._timeout = _timeout
|
||||||
self._proxy = proxy
|
self._proxy = proxy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -102,6 +121,25 @@ class Session(HTTPClientSession):
|
|||||||
raise RuntimeError("Session is not initialized")
|
raise RuntimeError("Session is not initialized")
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
|
def _get_timeout(self, timeout: TimeoutTypes | UnsetType) -> aiohttp.ClientTimeout:
|
||||||
|
_timeout = None
|
||||||
|
if isinstance(timeout, Timeout):
|
||||||
|
timeout_kwargs: dict[str, float | None] = exclude_unset(
|
||||||
|
{
|
||||||
|
"total": timeout.total,
|
||||||
|
"connect": timeout.connect,
|
||||||
|
"sock_read": timeout.read,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if timeout_kwargs:
|
||||||
|
_timeout = aiohttp.ClientTimeout(**timeout_kwargs) # type: ignore
|
||||||
|
elif timeout is not UNSET:
|
||||||
|
_timeout = aiohttp.ClientTimeout(connect=timeout, sock_read=timeout)
|
||||||
|
|
||||||
|
if _timeout is None:
|
||||||
|
return self._timeout
|
||||||
|
return _timeout
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def request(self, setup: Request) -> Response:
|
async def request(self, setup: Request) -> Response:
|
||||||
if self._params:
|
if self._params:
|
||||||
@@ -121,15 +159,6 @@ class Session(HTTPClientSession):
|
|||||||
if cookie.value is not None
|
if cookie.value is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(setup.timeout, Timeout):
|
|
||||||
timeout = aiohttp.ClientTimeout(
|
|
||||||
total=setup.timeout.total,
|
|
||||||
connect=setup.timeout.connect,
|
|
||||||
sock_read=setup.timeout.read,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
timeout = aiohttp.ClientTimeout(setup.timeout)
|
|
||||||
|
|
||||||
async with await self.client.request(
|
async with await self.client.request(
|
||||||
setup.method,
|
setup.method,
|
||||||
url,
|
url,
|
||||||
@@ -138,7 +167,7 @@ class Session(HTTPClientSession):
|
|||||||
cookies=cookies,
|
cookies=cookies,
|
||||||
headers=setup.headers,
|
headers=setup.headers,
|
||||||
proxy=setup.proxy or self._proxy,
|
proxy=setup.proxy or self._proxy,
|
||||||
timeout=timeout,
|
timeout=self._get_timeout(setup.timeout),
|
||||||
) as response:
|
) as response:
|
||||||
return Response(
|
return Response(
|
||||||
response.status,
|
response.status,
|
||||||
@@ -171,15 +200,6 @@ class Session(HTTPClientSession):
|
|||||||
if cookie.value is not None
|
if cookie.value is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(setup.timeout, Timeout):
|
|
||||||
timeout = aiohttp.ClientTimeout(
|
|
||||||
total=setup.timeout.total,
|
|
||||||
connect=setup.timeout.connect,
|
|
||||||
sock_read=setup.timeout.read,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
timeout = aiohttp.ClientTimeout(setup.timeout)
|
|
||||||
|
|
||||||
async with self.client.request(
|
async with self.client.request(
|
||||||
setup.method,
|
setup.method,
|
||||||
url,
|
url,
|
||||||
@@ -188,7 +208,7 @@ class Session(HTTPClientSession):
|
|||||||
cookies=cookies,
|
cookies=cookies,
|
||||||
headers=setup.headers,
|
headers=setup.headers,
|
||||||
proxy=setup.proxy or self._proxy,
|
proxy=setup.proxy or self._proxy,
|
||||||
timeout=timeout,
|
timeout=self._get_timeout(setup.timeout),
|
||||||
) as response:
|
) as response:
|
||||||
response_headers = response.headers.copy()
|
response_headers = response.headers.copy()
|
||||||
# aiohttp does not guarantee fixed-size chunks; re-chunk to exact size
|
# aiohttp does not guarantee fixed-size chunks; re-chunk to exact size
|
||||||
@@ -270,13 +290,39 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
||||||
|
|
||||||
|
timeout = None
|
||||||
if isinstance(setup.timeout, Timeout):
|
if isinstance(setup.timeout, Timeout):
|
||||||
timeout = aiohttp.ClientWSTimeout(
|
timeout_kwargs: dict[str, float | None] = exclude_unset(
|
||||||
ws_receive=setup.timeout.read, # type: ignore
|
{
|
||||||
ws_close=setup.timeout.total, # type: ignore
|
"ws_receive": setup.timeout.read,
|
||||||
|
"ws_close": (
|
||||||
|
setup.timeout.total
|
||||||
|
if setup.timeout.close is UNSET
|
||||||
|
else setup.timeout.close
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if timeout_kwargs:
|
||||||
|
timeout = aiohttp.ClientWSTimeout(**timeout_kwargs)
|
||||||
|
elif setup.timeout is not UNSET:
|
||||||
|
timeout = aiohttp.ClientWSTimeout(
|
||||||
|
ws_receive=setup.timeout, # type: ignore
|
||||||
|
ws_close=setup.timeout, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
if timeout is None:
|
||||||
|
timeout = aiohttp.ClientWSTimeout(
|
||||||
|
**exclude_unset(
|
||||||
|
{
|
||||||
|
"ws_receive": DEFAULT_TIMEOUT.read,
|
||||||
|
"ws_close": (
|
||||||
|
DEFAULT_TIMEOUT.total
|
||||||
|
if DEFAULT_TIMEOUT.close is UNSET
|
||||||
|
else DEFAULT_TIMEOUT.close
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
timeout = aiohttp.ClientWSTimeout(ws_close=setup.timeout or 10.0) # type: ignore
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
|
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
|
||||||
async with session.ws_connect(
|
async with session.ws_connect(
|
||||||
@@ -295,7 +341,7 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
|
|||||||
headers: HeaderTypes = None,
|
headers: HeaderTypes = None,
|
||||||
cookies: CookieTypes = None,
|
cookies: CookieTypes = None,
|
||||||
version: str | HTTPVersion = HTTPVersion.H11,
|
version: str | HTTPVersion = HTTPVersion.H11,
|
||||||
timeout: TimeoutTypes = None,
|
timeout: TimeoutTypes | UnsetType = UNSET,
|
||||||
proxy: str | None = None,
|
proxy: str | None = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
return Session(
|
return Session(
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from nonebot.drivers import (
|
|||||||
)
|
)
|
||||||
from nonebot.drivers.none import Driver as NoneDriver
|
from nonebot.drivers.none import Driver as NoneDriver
|
||||||
from nonebot.internal.driver import (
|
from nonebot.internal.driver import (
|
||||||
|
DEFAULT_TIMEOUT,
|
||||||
Cookies,
|
Cookies,
|
||||||
CookieTypes,
|
CookieTypes,
|
||||||
HeaderTypes,
|
HeaderTypes,
|
||||||
@@ -41,6 +42,7 @@ from nonebot.internal.driver import (
|
|||||||
Timeout,
|
Timeout,
|
||||||
TimeoutTypes,
|
TimeoutTypes,
|
||||||
)
|
)
|
||||||
|
from nonebot.utils import UNSET, UnsetType, exclude_unset
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import httpx
|
import httpx
|
||||||
@@ -59,7 +61,7 @@ class Session(HTTPClientSession):
|
|||||||
headers: HeaderTypes = None,
|
headers: HeaderTypes = None,
|
||||||
cookies: CookieTypes = None,
|
cookies: CookieTypes = None,
|
||||||
version: str | HTTPVersion = HTTPVersion.H11,
|
version: str | HTTPVersion = HTTPVersion.H11,
|
||||||
timeout: TimeoutTypes = None,
|
timeout: TimeoutTypes | UnsetType = UNSET,
|
||||||
proxy: str | None = None,
|
proxy: str | None = None,
|
||||||
):
|
):
|
||||||
self._client: httpx.AsyncClient | None = None
|
self._client: httpx.AsyncClient | None = None
|
||||||
@@ -73,15 +75,34 @@ class Session(HTTPClientSession):
|
|||||||
self._cookies = Cookies(cookies)
|
self._cookies = Cookies(cookies)
|
||||||
self._version = HTTPVersion(version)
|
self._version = HTTPVersion(version)
|
||||||
|
|
||||||
|
_timeout = None
|
||||||
if isinstance(timeout, Timeout):
|
if isinstance(timeout, Timeout):
|
||||||
self._timeout = httpx.Timeout(
|
avg_timeout = timeout.total and timeout.total / 4
|
||||||
timeout=timeout.total,
|
timeout_kwargs: dict[str, float | None] = exclude_unset(
|
||||||
connect=timeout.connect,
|
{
|
||||||
read=timeout.read,
|
"timeout": avg_timeout,
|
||||||
|
"connect": timeout.connect,
|
||||||
|
"read": timeout.read,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
else:
|
if timeout_kwargs:
|
||||||
self._timeout = httpx.Timeout(timeout)
|
_timeout = httpx.Timeout(**timeout_kwargs)
|
||||||
|
elif timeout is not UNSET:
|
||||||
|
_timeout = httpx.Timeout(timeout)
|
||||||
|
|
||||||
|
if _timeout is None:
|
||||||
|
avg_timeout = DEFAULT_TIMEOUT.total and DEFAULT_TIMEOUT.total / 4
|
||||||
|
_timeout = httpx.Timeout(
|
||||||
|
**exclude_unset(
|
||||||
|
{
|
||||||
|
"timeout": avg_timeout,
|
||||||
|
"connect": DEFAULT_TIMEOUT.connect,
|
||||||
|
"read": DEFAULT_TIMEOUT.read,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._timeout = _timeout
|
||||||
self._proxy = proxy
|
self._proxy = proxy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -90,17 +111,28 @@ class Session(HTTPClientSession):
|
|||||||
raise RuntimeError("Session is not initialized")
|
raise RuntimeError("Session is not initialized")
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
|
def _get_timeout(self, timeout: TimeoutTypes | UnsetType) -> httpx.Timeout:
|
||||||
|
_timeout = None
|
||||||
|
if isinstance(timeout, Timeout):
|
||||||
|
avg_timeout = timeout.total and timeout.total / 4
|
||||||
|
timeout_kwargs: dict[str, float | None] = exclude_unset(
|
||||||
|
{
|
||||||
|
"timeout": avg_timeout,
|
||||||
|
"connect": timeout.connect,
|
||||||
|
"read": timeout.read,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if timeout_kwargs:
|
||||||
|
_timeout = httpx.Timeout(**timeout_kwargs)
|
||||||
|
elif timeout is not UNSET:
|
||||||
|
_timeout = httpx.Timeout(timeout)
|
||||||
|
|
||||||
|
if _timeout is None:
|
||||||
|
return self._timeout
|
||||||
|
return _timeout
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def request(self, setup: Request) -> Response:
|
async def request(self, setup: Request) -> Response:
|
||||||
if isinstance(setup.timeout, Timeout):
|
|
||||||
timeout = httpx.Timeout(
|
|
||||||
timeout=setup.timeout.total,
|
|
||||||
connect=setup.timeout.connect,
|
|
||||||
read=setup.timeout.read,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
timeout = httpx.Timeout(setup.timeout)
|
|
||||||
|
|
||||||
response = await self.client.request(
|
response = await self.client.request(
|
||||||
setup.method,
|
setup.method,
|
||||||
str(setup.url),
|
str(setup.url),
|
||||||
@@ -112,7 +144,7 @@ class Session(HTTPClientSession):
|
|||||||
params=setup.url.raw_query_string,
|
params=setup.url.raw_query_string,
|
||||||
headers=tuple(setup.headers.items()),
|
headers=tuple(setup.headers.items()),
|
||||||
cookies=setup.cookies.jar,
|
cookies=setup.cookies.jar,
|
||||||
timeout=timeout,
|
timeout=self._get_timeout(setup.timeout),
|
||||||
)
|
)
|
||||||
return Response(
|
return Response(
|
||||||
response.status_code,
|
response.status_code,
|
||||||
@@ -128,15 +160,6 @@ class Session(HTTPClientSession):
|
|||||||
*,
|
*,
|
||||||
chunk_size: int = 1024,
|
chunk_size: int = 1024,
|
||||||
) -> AsyncGenerator[Response, None]:
|
) -> AsyncGenerator[Response, None]:
|
||||||
if isinstance(setup.timeout, Timeout):
|
|
||||||
timeout = httpx.Timeout(
|
|
||||||
timeout=setup.timeout.total,
|
|
||||||
connect=setup.timeout.connect,
|
|
||||||
read=setup.timeout.read,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
timeout = httpx.Timeout(setup.timeout)
|
|
||||||
|
|
||||||
async with self.client.stream(
|
async with self.client.stream(
|
||||||
setup.method,
|
setup.method,
|
||||||
str(setup.url),
|
str(setup.url),
|
||||||
@@ -148,7 +171,7 @@ class Session(HTTPClientSession):
|
|||||||
params=setup.url.raw_query_string,
|
params=setup.url.raw_query_string,
|
||||||
headers=tuple(setup.headers.items()),
|
headers=tuple(setup.headers.items()),
|
||||||
cookies=setup.cookies.jar,
|
cookies=setup.cookies.jar,
|
||||||
timeout=timeout,
|
timeout=self._get_timeout(setup.timeout),
|
||||||
) as response:
|
) as response:
|
||||||
response_headers = response.headers.multi_items()
|
response_headers = response.headers.multi_items()
|
||||||
async for chunk in response.aiter_bytes(chunk_size=chunk_size):
|
async for chunk in response.aiter_bytes(chunk_size=chunk_size):
|
||||||
|
|||||||
@@ -25,11 +25,18 @@ from types import CoroutineType
|
|||||||
from typing import TYPE_CHECKING, Any, TypeVar
|
from typing import TYPE_CHECKING, Any, TypeVar
|
||||||
from typing_extensions import ParamSpec, override
|
from typing_extensions import ParamSpec, override
|
||||||
|
|
||||||
from nonebot.drivers import Request, Timeout, WebSocketClientMixin, combine_driver
|
from nonebot.drivers import (
|
||||||
|
DEFAULT_TIMEOUT,
|
||||||
|
Request,
|
||||||
|
Timeout,
|
||||||
|
WebSocketClientMixin,
|
||||||
|
combine_driver,
|
||||||
|
)
|
||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
from nonebot.drivers.none import Driver as NoneDriver
|
from nonebot.drivers.none import Driver as NoneDriver
|
||||||
from nonebot.exception import WebSocketClosed
|
from nonebot.exception import WebSocketClosed
|
||||||
from nonebot.log import LoguruHandler
|
from nonebot.log import LoguruHandler
|
||||||
|
from nonebot.utils import UNSET, exclude_unset
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from websockets import ClientConnection, ConnectionClosed, connect
|
from websockets import ClientConnection, ConnectionClosed, connect
|
||||||
@@ -70,16 +77,36 @@ class Mixin(WebSocketClientMixin):
|
|||||||
@override
|
@override
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
|
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
|
||||||
|
timeout_kwargs: dict[str, float | None] = {}
|
||||||
if isinstance(setup.timeout, Timeout):
|
if isinstance(setup.timeout, Timeout):
|
||||||
timeout = setup.timeout.total or setup.timeout.connect or setup.timeout.read
|
open_timeout = (
|
||||||
else:
|
setup.timeout.connect or setup.timeout.read or setup.timeout.total
|
||||||
timeout = setup.timeout
|
)
|
||||||
|
timeout_kwargs = exclude_unset(
|
||||||
|
{"open_timeout": open_timeout, "close_timeout": setup.timeout.close}
|
||||||
|
)
|
||||||
|
elif setup.timeout is not UNSET:
|
||||||
|
timeout_kwargs = {
|
||||||
|
"open_timeout": setup.timeout,
|
||||||
|
"close_timeout": setup.timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not timeout_kwargs:
|
||||||
|
open_timeout = (
|
||||||
|
DEFAULT_TIMEOUT.connect or DEFAULT_TIMEOUT.read or DEFAULT_TIMEOUT.total
|
||||||
|
)
|
||||||
|
timeout_kwargs = exclude_unset(
|
||||||
|
{
|
||||||
|
"open_timeout": open_timeout,
|
||||||
|
"close_timeout": DEFAULT_TIMEOUT.close,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
connection = connect(
|
connection = connect(
|
||||||
str(setup.url),
|
str(setup.url),
|
||||||
additional_headers={**setup.headers, **setup.cookies.as_header(setup)},
|
additional_headers={**setup.headers, **setup.cookies.as_header(setup)},
|
||||||
proxy=setup.proxy if setup.proxy is not None else True,
|
proxy=setup.proxy if setup.proxy is not None else True,
|
||||||
open_timeout=timeout,
|
**timeout_kwargs, # type: ignore
|
||||||
)
|
)
|
||||||
async with connection as ws:
|
async with connection as ws:
|
||||||
yield WebSocket(request=setup, websocket=ws)
|
yield WebSocket(request=setup, websocket=ws)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from .abstract import ReverseDriver as ReverseDriver
|
|||||||
from .abstract import ReverseMixin as ReverseMixin
|
from .abstract import ReverseMixin as ReverseMixin
|
||||||
from .abstract import WebSocketClientMixin as WebSocketClientMixin
|
from .abstract import WebSocketClientMixin as WebSocketClientMixin
|
||||||
from .combine import combine_driver as combine_driver
|
from .combine import combine_driver as combine_driver
|
||||||
|
from .model import DEFAULT_TIMEOUT as DEFAULT_TIMEOUT
|
||||||
from .model import URL as URL
|
from .model import URL as URL
|
||||||
from .model import ContentTypes as ContentTypes
|
from .model import ContentTypes as ContentTypes
|
||||||
from .model import Cookies as Cookies
|
from .model import Cookies as Cookies
|
||||||
|
|||||||
@@ -19,7 +19,13 @@ from nonebot.typing import (
|
|||||||
T_BotDisconnectionHook,
|
T_BotDisconnectionHook,
|
||||||
T_DependencyCache,
|
T_DependencyCache,
|
||||||
)
|
)
|
||||||
from nonebot.utils import escape_tag, flatten_exception_group, run_coro_with_catch
|
from nonebot.utils import (
|
||||||
|
UNSET,
|
||||||
|
UnsetType,
|
||||||
|
escape_tag,
|
||||||
|
flatten_exception_group,
|
||||||
|
run_coro_with_catch,
|
||||||
|
)
|
||||||
|
|
||||||
from ._lifespan import LIFESPAN_FUNC, Lifespan
|
from ._lifespan import LIFESPAN_FUNC, Lifespan
|
||||||
from .model import (
|
from .model import (
|
||||||
@@ -246,7 +252,7 @@ class HTTPClientSession(abc.ABC):
|
|||||||
headers: HeaderTypes = None,
|
headers: HeaderTypes = None,
|
||||||
cookies: CookieTypes = None,
|
cookies: CookieTypes = None,
|
||||||
version: str | HTTPVersion = HTTPVersion.H11,
|
version: str | HTTPVersion = HTTPVersion.H11,
|
||||||
timeout: TimeoutTypes = None,
|
timeout: TimeoutTypes | UnsetType = UNSET,
|
||||||
proxy: str | None = None,
|
proxy: str | None = None,
|
||||||
):
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -9,14 +9,20 @@ import urllib.request
|
|||||||
from multidict import CIMultiDict
|
from multidict import CIMultiDict
|
||||||
from yarl import URL as URL
|
from yarl import URL as URL
|
||||||
|
|
||||||
|
from nonebot.utils import UNSET, UnsetType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Timeout:
|
class Timeout:
|
||||||
"""Request 超时配置。"""
|
"""Request 超时配置。"""
|
||||||
|
|
||||||
total: float | None = None
|
total: float | None | UnsetType = UNSET
|
||||||
connect: float | None = None
|
connect: float | None | UnsetType = UNSET
|
||||||
read: float | None = None
|
read: float | None | UnsetType = UNSET
|
||||||
|
close: float | None | UnsetType = UNSET
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_TIMEOUT = Timeout(total=None, connect=5.0, read=30.0, close=10.0)
|
||||||
|
|
||||||
|
|
||||||
RawURL: TypeAlias = tuple[bytes, bytes, int | None, bytes]
|
RawURL: TypeAlias = tuple[bytes, bytes, int | None, bytes]
|
||||||
@@ -68,7 +74,7 @@ class Request:
|
|||||||
json: Any = None,
|
json: Any = None,
|
||||||
files: FilesTypes = None,
|
files: FilesTypes = None,
|
||||||
version: str | HTTPVersion = HTTPVersion.H11,
|
version: str | HTTPVersion = HTTPVersion.H11,
|
||||||
timeout: TimeoutTypes = None,
|
timeout: TimeoutTypes | UnsetType = UNSET,
|
||||||
proxy: str | None = None,
|
proxy: str | None = None,
|
||||||
):
|
):
|
||||||
# method
|
# method
|
||||||
@@ -80,7 +86,7 @@ class Request:
|
|||||||
# http version
|
# http version
|
||||||
self.version: HTTPVersion = HTTPVersion(version)
|
self.version: HTTPVersion = HTTPVersion(version)
|
||||||
# timeout
|
# timeout
|
||||||
self.timeout: TimeoutTypes = timeout
|
self.timeout: TimeoutTypes | UnsetType = timeout
|
||||||
# proxy
|
# proxy
|
||||||
self.proxy: str | None = proxy
|
self.proxy: str | None = proxy
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from collections.abc import (
|
|||||||
import contextlib
|
import contextlib
|
||||||
from contextlib import AbstractContextManager, asynccontextmanager
|
from contextlib import AbstractContextManager, asynccontextmanager
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
from enum import Enum
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
@@ -27,8 +28,12 @@ from pathlib import Path
|
|||||||
import re
|
import re
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
Final,
|
||||||
Generic,
|
Generic,
|
||||||
|
Literal,
|
||||||
|
TypeAlias,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
final,
|
||||||
get_args,
|
get_args,
|
||||||
get_origin,
|
get_origin,
|
||||||
overload,
|
overload,
|
||||||
@@ -49,6 +54,8 @@ from nonebot.typing import (
|
|||||||
type_has_args,
|
type_has_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .compat import custom_validation
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@@ -57,6 +64,54 @@ V = TypeVar("V")
|
|||||||
E = TypeVar("E", bound=BaseException)
|
E = TypeVar("E", bound=BaseException)
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
@custom_validation
|
||||||
|
class Unset(Enum):
|
||||||
|
_UNSET = "<UNSET>"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "<UNSET>"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def __bool__(self) -> Literal[False]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __copy__(self):
|
||||||
|
return self._UNSET
|
||||||
|
|
||||||
|
def __deepcopy__(self, memo: dict[int, Any]):
|
||||||
|
return self._UNSET
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_validators__(cls):
|
||||||
|
yield cls._validate
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate(cls, value: Any):
|
||||||
|
if value is not cls._UNSET:
|
||||||
|
raise ValueError(f"{value!r} is not UNSET")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
UnsetType: TypeAlias = Literal[Unset._UNSET]
|
||||||
|
|
||||||
|
UNSET: Final[UnsetType] = Unset._UNSET
|
||||||
|
|
||||||
|
|
||||||
|
def exclude_unset(data: Any) -> Any:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return data.__class__(
|
||||||
|
(k, exclude_unset(v)) for k, v in data.items() if v is not UNSET
|
||||||
|
)
|
||||||
|
elif isinstance(data, list):
|
||||||
|
return data.__class__(exclude_unset(i) for i in data if i is not UNSET)
|
||||||
|
elif data is UNSET:
|
||||||
|
return None
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def escape_tag(s: str) -> str:
|
def escape_tag(s: str) -> str:
|
||||||
"""用于记录带颜色日志时转义 `<tag>` 类型特殊标签
|
"""用于记录带颜色日志时转义 `<tag>` 类型特殊标签
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from nonebot.drivers.aiohttp import Session as AiohttpSession
|
|||||||
from nonebot.drivers.aiohttp import WebSocket as AiohttpWebSocket
|
from nonebot.drivers.aiohttp import WebSocket as AiohttpWebSocket
|
||||||
from nonebot.exception import WebSocketClosed
|
from nonebot.exception import WebSocketClosed
|
||||||
from nonebot.params import Depends
|
from nonebot.params import Depends
|
||||||
|
from nonebot.utils import UNSET
|
||||||
from utils import FakeAdapter
|
from utils import FakeAdapter
|
||||||
|
|
||||||
|
|
||||||
@@ -706,6 +707,177 @@ async def test_aiohttp_websocket_close_frame(msg_type: str) -> None:
|
|||||||
await ws.receive()
|
await ws.receive()
|
||||||
|
|
||||||
|
|
||||||
|
def test_timeout_unset_vs_none():
|
||||||
|
# default: all fields are UNSET
|
||||||
|
t = Timeout()
|
||||||
|
assert t.total is UNSET
|
||||||
|
assert t.connect is UNSET
|
||||||
|
assert t.read is UNSET
|
||||||
|
assert t.close is UNSET
|
||||||
|
|
||||||
|
# explicitly set to None
|
||||||
|
t = Timeout(close=None)
|
||||||
|
assert t.close is None
|
||||||
|
assert t.close is not UNSET
|
||||||
|
|
||||||
|
# explicitly set to a value
|
||||||
|
t = Timeout(total=5.0, close=None)
|
||||||
|
assert t.total == 5.0
|
||||||
|
assert t.close is None
|
||||||
|
assert t.connect is UNSET
|
||||||
|
assert t.read is UNSET
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"driver",
|
||||||
|
[
|
||||||
|
pytest.param("nonebot.drivers.httpx:Driver", id="httpx"),
|
||||||
|
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
async def test_http_client_timeout(driver: Driver, server_url: URL):
|
||||||
|
"""HTTP requests work with fully unset, partial, and None timeout fields."""
|
||||||
|
assert isinstance(driver, HTTPClientMixin)
|
||||||
|
|
||||||
|
# timeout not set, default timeout should apply
|
||||||
|
request = Request("POST", server_url, content="test")
|
||||||
|
response = await driver.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
async for resp in driver.stream_request(request, chunk_size=1024):
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# timeout is float or none
|
||||||
|
request = Request("POST", server_url, content="test", timeout=10.0)
|
||||||
|
response = await driver.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
async for resp in driver.stream_request(request, chunk_size=1024):
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# all fields unset, default timeout should apply
|
||||||
|
request = Request("POST", server_url, content="test", timeout=Timeout())
|
||||||
|
response = await driver.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
async for resp in driver.stream_request(request, chunk_size=1024):
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# only total set
|
||||||
|
request = Request("POST", server_url, content="test", timeout=Timeout(total=10.0))
|
||||||
|
response = await driver.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
async for resp in driver.stream_request(request, chunk_size=1024):
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# explicit None (no timeout)
|
||||||
|
request = Request(
|
||||||
|
"POST",
|
||||||
|
server_url,
|
||||||
|
content="test",
|
||||||
|
timeout=Timeout(total=None, connect=None, read=None),
|
||||||
|
)
|
||||||
|
response = await driver.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
async for resp in driver.stream_request(request, chunk_size=1024):
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# session with timeout not set
|
||||||
|
session = driver.get_session()
|
||||||
|
async with session:
|
||||||
|
request = Request("POST", server_url, content="test")
|
||||||
|
response = await session.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# session with float or none timeout
|
||||||
|
session = driver.get_session(timeout=10.0)
|
||||||
|
async with session:
|
||||||
|
request = Request("POST", server_url, content="test")
|
||||||
|
response = await session.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# session with fully unset timeout
|
||||||
|
session = driver.get_session(timeout=Timeout())
|
||||||
|
async with session:
|
||||||
|
request = Request("POST", server_url, content="test")
|
||||||
|
response = await session.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# session with timeout
|
||||||
|
session = driver.get_session(timeout=Timeout(total=10.0, connect=5.0, read=5.0))
|
||||||
|
async with session:
|
||||||
|
request = Request("POST", server_url, content="test")
|
||||||
|
response = await session.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# session with timeout override
|
||||||
|
session = driver.get_session(timeout=Timeout(total=10.0))
|
||||||
|
async with session:
|
||||||
|
request = Request(
|
||||||
|
"POST", server_url, content="test", timeout=Timeout(total=20.0)
|
||||||
|
)
|
||||||
|
response = await session.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# session with timeout float override
|
||||||
|
session = driver.get_session(timeout=Timeout(total=10.0))
|
||||||
|
async with session:
|
||||||
|
request = Request("POST", server_url, content="test", timeout=20.0)
|
||||||
|
response = await session.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"driver",
|
||||||
|
[
|
||||||
|
pytest.param("nonebot.drivers.websockets:Driver", id="websockets"),
|
||||||
|
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
async def test_websocket_client_timeout(driver: Driver, server_url: URL):
|
||||||
|
"""WebSocket connections work with fully unset, partial, and None timeout fields."""
|
||||||
|
assert isinstance(driver, WebSocketClientMixin)
|
||||||
|
|
||||||
|
ws_url = server_url.with_scheme("ws")
|
||||||
|
|
||||||
|
# timeout not set, default timeout should apply
|
||||||
|
request = Request("GET", ws_url)
|
||||||
|
async with driver.websocket(request) as ws:
|
||||||
|
await ws.send("quit")
|
||||||
|
with pytest.raises(WebSocketClosed):
|
||||||
|
await ws.receive()
|
||||||
|
|
||||||
|
await anyio.sleep(1)
|
||||||
|
|
||||||
|
# timeout is float or none
|
||||||
|
request = Request("GET", ws_url, timeout=10.0)
|
||||||
|
async with driver.websocket(request) as ws:
|
||||||
|
await ws.send("quit")
|
||||||
|
with pytest.raises(WebSocketClosed):
|
||||||
|
await ws.receive()
|
||||||
|
|
||||||
|
await anyio.sleep(1)
|
||||||
|
|
||||||
|
# all fields unset, default timeout should apply
|
||||||
|
request = Request("GET", ws_url, timeout=Timeout())
|
||||||
|
async with driver.websocket(request) as ws:
|
||||||
|
await ws.send("quit")
|
||||||
|
with pytest.raises(WebSocketClosed):
|
||||||
|
await ws.receive()
|
||||||
|
|
||||||
|
await anyio.sleep(1)
|
||||||
|
|
||||||
|
# close explicitly set to None (no close timeout)
|
||||||
|
request = Request("GET", ws_url, timeout=Timeout(close=None))
|
||||||
|
async with driver.websocket(request) as ws:
|
||||||
|
await ws.send("quit")
|
||||||
|
with pytest.raises(WebSocketClosed):
|
||||||
|
await ws.receive()
|
||||||
|
|
||||||
|
await anyio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("driver", "driver_type"),
|
("driver", "driver_type"),
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -1,9 +1,19 @@
|
|||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
|
import pickle
|
||||||
from typing import ClassVar, Dict, List, Literal, TypeVar, Union # noqa: UP035
|
from typing import ClassVar, Dict, List, Literal, TypeVar, Union # noqa: UP035
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nonebot.compat import type_validate_python
|
||||||
from nonebot.utils import (
|
from nonebot.utils import (
|
||||||
|
UNSET,
|
||||||
DataclassEncoder,
|
DataclassEncoder,
|
||||||
|
Unset,
|
||||||
|
UnsetType,
|
||||||
escape_tag,
|
escape_tag,
|
||||||
|
exclude_unset,
|
||||||
generic_check_issubclass,
|
generic_check_issubclass,
|
||||||
is_async_gen_callable,
|
is_async_gen_callable,
|
||||||
is_coroutine_callable,
|
is_coroutine_callable,
|
||||||
@@ -12,6 +22,29 @@ from nonebot.utils import (
|
|||||||
from utils import FakeMessage, FakeMessageSegment
|
from utils import FakeMessage, FakeMessageSegment
|
||||||
|
|
||||||
|
|
||||||
|
def test_unset():
|
||||||
|
assert isinstance(UNSET, Unset)
|
||||||
|
assert bool(UNSET) is False
|
||||||
|
assert copy.copy(UNSET) is UNSET
|
||||||
|
assert copy.deepcopy(UNSET) is UNSET
|
||||||
|
assert pickle.loads(pickle.dumps(UNSET)) is UNSET
|
||||||
|
assert type_validate_python(UnsetType, UNSET) is UNSET
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
type_validate_python(UnsetType, 123)
|
||||||
|
|
||||||
|
|
||||||
|
def test_exclude_unset():
|
||||||
|
assert exclude_unset({"a": 1, "b": UNSET, "c": None, "d": {"x": UNSET}}) == {
|
||||||
|
"a": 1,
|
||||||
|
"c": None,
|
||||||
|
"d": {},
|
||||||
|
}
|
||||||
|
assert exclude_unset([1, UNSET, None, {"x": UNSET}]) == [1, None, {}]
|
||||||
|
assert exclude_unset(UNSET) is None
|
||||||
|
assert exclude_unset(123) == 123
|
||||||
|
|
||||||
|
|
||||||
def test_loguru_escape_tag():
|
def test_loguru_escape_tag():
|
||||||
assert escape_tag("<red>red</red>") == r"\<red>red\</red>"
|
assert escape_tag("<red>red</red>") == r"\<red>red\</red>"
|
||||||
assert escape_tag("<fg #fff>white</fg #fff>") == r"\<fg #fff>white\</fg #fff>"
|
assert escape_tag("<fg #fff>white</fg #fff>") == r"\<fg #fff>white\</fg #fff>"
|
||||||
|
|||||||
Reference in New Issue
Block a user