Feature: 细化内置驱动器请求参数中的超时控制颗粒度 (#3571)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
This commit is contained in:
Ailitonia
2025-07-15 10:29:18 +08:00
committed by GitHub
parent bc0682af8f
commit 4ec9bfb7d7
8 changed files with 115 additions and 19 deletions

View File

@ -24,6 +24,7 @@ from nonebot.internal.driver import Request as Request
from nonebot.internal.driver import Response as Response
from nonebot.internal.driver import ReverseDriver as ReverseDriver
from nonebot.internal.driver import ReverseMixin as ReverseMixin
from nonebot.internal.driver import Timeout as Timeout
from nonebot.internal.driver import WebSocket as WebSocket
from nonebot.internal.driver import WebSocketClientMixin as WebSocketClientMixin
from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup
@ -34,6 +35,7 @@ __autodoc__ = {
"Cookies": True,
"Request": True,
"Response": True,
"Timeout": True,
"WebSocket": True,
"HTTPVersion": True,
"Driver": True,

View File

@ -37,7 +37,14 @@ from nonebot.drivers import (
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.exception import WebSocketClosed
from nonebot.internal.driver import Cookies, CookieTypes, HeaderTypes, QueryTypes
from nonebot.internal.driver import (
Cookies,
CookieTypes,
HeaderTypes,
QueryTypes,
Timeout,
TimeoutTypes,
)
try:
import aiohttp
@ -56,7 +63,7 @@ class Session(HTTPClientSession):
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
):
self._client: Optional[aiohttp.ClientSession] = None
@ -78,7 +85,15 @@ class Session(HTTPClientSession):
else:
raise RuntimeError(f"Unsupported HTTP version: {version}")
self._timeout = timeout
if isinstance(timeout, Timeout):
self._timeout = aiohttp.ClientTimeout(
total=timeout.total,
connect=timeout.connect,
sock_read=timeout.read,
)
else:
self._timeout = aiohttp.ClientTimeout(timeout)
self._proxy = proxy
@property
@ -106,7 +121,14 @@ class Session(HTTPClientSession):
if cookie.value is not None
)
timeout = aiohttp.ClientTimeout(setup.timeout)
if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientTimeout(
total=setup.timeout.total,
connect=setup.timeout.connect,
sock_read=setup.timeout.read,
)
else:
timeout = aiohttp.ClientTimeout(setup.timeout)
async with await self.client.request(
setup.method,
@ -149,7 +171,14 @@ class Session(HTTPClientSession):
if cookie.value is not None
)
timeout = aiohttp.ClientTimeout(setup.timeout)
if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientTimeout(
total=setup.timeout.total,
connect=setup.timeout.connect,
sock_read=setup.timeout.read,
)
else:
timeout = aiohttp.ClientTimeout(setup.timeout)
async with self.client.request(
setup.method,
@ -226,7 +255,13 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
else:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
timeout = aiohttp.ClientWSTimeout(ws_close=setup.timeout or 10.0) # type: ignore
if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientWSTimeout(
ws_receive=setup.timeout.read, # type: ignore
ws_close=setup.timeout.total, # type: ignore
)
else:
timeout = aiohttp.ClientWSTimeout(ws_close=setup.timeout or 10.0) # type: ignore
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
async with session.ws_connect(
@ -245,7 +280,7 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
) -> Session:
return Session(

View File

@ -33,7 +33,14 @@ from nonebot.drivers import (
combine_driver,
)
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.internal.driver import Cookies, CookieTypes, HeaderTypes, QueryTypes
from nonebot.internal.driver import (
Cookies,
CookieTypes,
HeaderTypes,
QueryTypes,
Timeout,
TimeoutTypes,
)
try:
import httpx
@ -52,7 +59,7 @@ class Session(HTTPClientSession):
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
):
self._client: Optional[httpx.AsyncClient] = None
@ -65,7 +72,16 @@ class Session(HTTPClientSession):
)
self._cookies = Cookies(cookies)
self._version = HTTPVersion(version)
self._timeout = timeout
if isinstance(timeout, Timeout):
self._timeout = httpx.Timeout(
timeout=timeout.total,
connect=timeout.connect,
read=timeout.read,
)
else:
self._timeout = httpx.Timeout(timeout)
self._proxy = proxy
@property
@ -76,6 +92,15 @@ class Session(HTTPClientSession):
@override
async def request(self, setup: Request) -> Response:
if isinstance(setup.timeout, Timeout):
timeout = httpx.Timeout(
timeout=setup.timeout.total,
connect=setup.timeout.connect,
read=setup.timeout.read,
)
else:
timeout = httpx.Timeout(setup.timeout)
response = await self.client.request(
setup.method,
str(setup.url),
@ -87,7 +112,7 @@ class Session(HTTPClientSession):
params=setup.url.raw_query_string,
headers=tuple(setup.headers.items()),
cookies=setup.cookies.jar,
timeout=setup.timeout,
timeout=timeout,
)
return Response(
response.status_code,
@ -103,6 +128,15 @@ class Session(HTTPClientSession):
*,
chunk_size: int = 1024,
) -> AsyncGenerator[Response, None]:
if isinstance(setup.timeout, Timeout):
timeout = httpx.Timeout(
timeout=setup.timeout.total,
connect=setup.timeout.connect,
read=setup.timeout.read,
)
else:
timeout = httpx.Timeout(setup.timeout)
async with self.client.stream(
setup.method,
str(setup.url),
@ -114,7 +148,7 @@ class Session(HTTPClientSession):
params=setup.url.raw_query_string,
headers=tuple(setup.headers.items()),
cookies=setup.cookies.jar,
timeout=setup.timeout,
timeout=timeout,
) as response:
response_headers = response.headers.multi_items()
async for chunk in response.aiter_bytes(chunk_size=chunk_size):
@ -183,7 +217,7 @@ class Mixin(HTTPClientMixin):
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
) -> Session:
return Session(

View File

@ -25,7 +25,7 @@ from types import CoroutineType
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
from typing_extensions import ParamSpec, override
from nonebot.drivers import Request, WebSocketClientMixin, combine_driver
from nonebot.drivers import Request, Timeout, WebSocketClientMixin, combine_driver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.exception import WebSocketClosed
@ -73,10 +73,16 @@ class Mixin(WebSocketClientMixin):
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
if setup.proxy is not None:
logger.warning("proxy is not supported by websockets driver")
if isinstance(setup.timeout, Timeout):
timeout = setup.timeout.total or setup.timeout.connect or setup.timeout.read
else:
timeout = setup.timeout
connection = Connect(
str(setup.url),
extra_headers={**setup.headers, **setup.cookies.as_header(setup)},
open_timeout=setup.timeout,
open_timeout=timeout,
)
async with connection as ws:
yield WebSocket(request=setup, websocket=ws)

View File

@ -27,5 +27,7 @@ from .model import RawURL as RawURL
from .model import Request as Request
from .model import Response as Response
from .model import SimpleQuery as SimpleQuery
from .model import Timeout as Timeout
from .model import TimeoutTypes as TimeoutTypes
from .model import WebSocket as WebSocket
from .model import WebSocketServerSetup as WebSocketServerSetup

View File

@ -30,6 +30,7 @@ from .model import (
QueryTypes,
Request,
Response,
TimeoutTypes,
WebSocket,
WebSocketServerSetup,
)
@ -245,7 +246,7 @@ class HTTPClientSession(abc.ABC):
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
):
raise NotImplementedError
@ -315,7 +316,7 @@ class HTTPClientMixin(ForwardMixin):
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
) -> HTTPClientSession:
"""获取一个 HTTP 会话"""

View File

@ -42,6 +42,7 @@ FileTypes: TypeAlias = Union[
FileType,
]
FilesTypes: TypeAlias = Union[dict[str, FileTypes], list[tuple[str, FileTypes]], None]
TimeoutTypes: TypeAlias = Union[float, "Timeout", None]
class HTTPVersion(Enum):
@ -50,6 +51,15 @@ class HTTPVersion(Enum):
H2 = "2"
@dataclass
class Timeout:
"""Request 超时配置。"""
total: Optional[float] = None
connect: Optional[float] = None
read: Optional[float] = None
class Request:
def __init__(
self,
@ -64,7 +74,7 @@ class Request:
json: Any = None,
files: FilesTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
):
# method
@ -76,7 +86,7 @@ class Request:
# http version
self.version: HTTPVersion = HTTPVersion(version)
# timeout
self.timeout: Optional[float] = timeout
self.timeout: TimeoutTypes = timeout
# proxy
self.proxy: Optional[str] = proxy

View File

@ -16,6 +16,7 @@ from nonebot.drivers import (
HTTPServerSetup,
Request,
Response,
Timeout,
WebSocket,
WebSocketClientMixin,
WebSocketServerSetup,
@ -235,6 +236,7 @@ async def test_http_client(driver: Driver, server_url: URL):
headers={"X-Test": "test"},
cookies={"session": "test"},
content="test",
timeout=Timeout(total=4, connect=2, read=2),
)
response = await driver.request(request)
assert server_url.host is not None
@ -250,6 +252,7 @@ async def test_http_client(driver: Driver, server_url: URL):
headers={"X-Test": "test"},
cookies={"session": "test"},
content="test",
timeout=Timeout(total=4, connect=2, read=2),
)
assert request.url == request_raw_url.url, (
"request.url should be equal to request_raw_url.url"
@ -312,6 +315,7 @@ async def test_http_client(driver: Driver, server_url: URL):
headers={"X-Test": "stream"},
cookies={"session": "stream"},
content="stream_test" * 1024,
timeout=Timeout(total=4, connect=2, read=2),
)
chunks = []
async for resp in driver.stream_request(request, chunk_size=4):
@ -414,6 +418,7 @@ async def test_http_client_session(driver: Driver, server_url: URL):
headers={"X-Test": "test"},
cookies={"cookie": "test"},
content="test",
timeout=Timeout(total=4, connect=2, read=2),
)
response = await session.request(request)
assert response.status_code == 200
@ -499,6 +504,7 @@ async def test_http_client_session(driver: Driver, server_url: URL):
headers={"X-Test": "stream"},
cookies={"cookie": "stream"},
content="stream_test" * 1024,
timeout=Timeout(total=4, connect=2, read=2),
)
chunks = []
async for resp in session.stream_request(request, chunk_size=4):