diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index d651517a..7a2e06a7 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -125,6 +125,51 @@ class Session(HTTPClientSession): request=setup, ) + @override + async def stream_request( + self, + setup: Request, + *, + chunk_size: int = 1024, + ) -> AsyncGenerator[Response, None]: + if self._params: + url = setup.url.with_query({**self._params, **setup.url.query}) + else: + url = setup.url + + data = setup.data + if setup.files: + data = aiohttp.FormData(data or {}, quote_fields=False) + for name, file in setup.files: + data.add_field(name, file[1], content_type=file[2], filename=file[0]) + + cookies = ( + (cookie.name, cookie.value) + for cookie in setup.cookies + if cookie.value is not None + ) + + timeout = aiohttp.ClientTimeout(setup.timeout) + + async with self.client.request( + setup.method, + url, + data=setup.content or data, + json=setup.json, + cookies=cookies, + headers=setup.headers, + proxy=setup.proxy or self._proxy, + timeout=timeout, + ) as response: + response_headers = response.headers.copy() + async for chunk in response.content.iter_chunked(chunk_size): + yield Response( + response.status, + headers=response_headers, + content=chunk, + request=setup, + ) + @override async def setup(self) -> None: if self._client is not None: @@ -160,6 +205,17 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin): async with self.get_session() as session: return await session.request(setup) + @override + async def stream_request( + self, + setup: Request, + *, + chunk_size: int = 1024, + ) -> AsyncGenerator[Response, None]: + async with self.get_session() as session: + async for response in session.stream_request(setup, chunk_size=chunk_size): + yield response + @override @asynccontextmanager async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]: diff --git a/nonebot/drivers/httpx.py b/nonebot/drivers/httpx.py index c6c015af..bca949a1 100644 --- a/nonebot/drivers/httpx.py +++ b/nonebot/drivers/httpx.py @@ -17,6 +17,7 @@ FrontMatter: description: nonebot.drivers.httpx 模块 """ +from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Optional, Union from typing_extensions import override @@ -95,6 +96,35 @@ class Session(HTTPClientSession): request=setup, ) + @override + async def stream_request( + self, + setup: Request, + *, + chunk_size: int = 1024, + ) -> AsyncGenerator[Response, None]: + async with self.client.stream( + setup.method, + str(setup.url), + content=setup.content, + data=setup.data, + files=setup.files, + json=setup.json, + # ensure the params priority + params=setup.url.raw_query_string, + headers=tuple(setup.headers.items()), + cookies=setup.cookies.jar, + timeout=setup.timeout, + ) as response: + response_headers = response.headers.multi_items() + async for chunk in response.aiter_bytes(chunk_size=chunk_size): + yield Response( + response.status_code, + headers=response_headers, + content=chunk, + request=setup, + ) + @override async def setup(self) -> None: if self._client is not None: @@ -133,6 +163,19 @@ class Mixin(HTTPClientMixin): ) as session: return await session.request(setup) + @override + async def stream_request( + self, + setup: Request, + *, + chunk_size: int = 1024, + ) -> AsyncGenerator[Response, None]: + async with self.get_session( + version=setup.version, proxy=setup.proxy + ) as session: + async for response in session.stream_request(setup, chunk_size=chunk_size): + yield response + @override def get_session( self, diff --git a/nonebot/internal/driver/abstract.py b/nonebot/internal/driver/abstract.py index abb2161e..d35e0011 100644 --- a/nonebot/internal/driver/abstract.py +++ b/nonebot/internal/driver/abstract.py @@ -255,6 +255,17 @@ class HTTPClientSession(abc.ABC): """发送一个 HTTP 请求""" raise NotImplementedError + @abc.abstractmethod + async def stream_request( + self, + setup: Request, + *, + chunk_size: int = 1024, + ) -> AsyncGenerator[Response, None]: + """发送一个 HTTP 流式请求""" + raise NotImplementedError + yield # used for static type checking's generator detection + @abc.abstractmethod async def setup(self) -> None: """初始化会话""" @@ -286,6 +297,17 @@ class HTTPClientMixin(ForwardMixin): """发送一个 HTTP 请求""" raise NotImplementedError + @abc.abstractmethod + async def stream_request( + self, + setup: Request, + *, + chunk_size: int = 1024, + ) -> AsyncGenerator[Response, None]: + """发送一个 HTTP 流式请求""" + raise NotImplementedError + yield # used for static type checking's generator detection + @abc.abstractmethod def get_session( self, diff --git a/tests/test_driver.py b/tests/test_driver.py index 0fced0b7..3094ea62 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -304,6 +304,78 @@ async def test_http_client(driver: Driver, server_url: URL): "test3": "test", }, "file parsing error" + # post stream request with query, headers, cookies and content + request = Request( + "POST", + server_url, + params={"param": "stream"}, + headers={"X-Test": "stream"}, + cookies={"session": "stream"}, + content="stream_test" * 1024, + ) + chunks = [] + async for resp in driver.stream_request(request, chunk_size=4): + assert resp.status_code == 200 + assert resp.content + chunks.append(resp.content) + assert all(len(chunk) == 4 for chunk in chunks[:-1]) + data = json.loads(b"".join(chunks)) + assert data["method"] == "POST" + assert data["args"] == {"param": "stream"} + assert data["headers"].get("X-Test") == "stream" + assert data["headers"].get("Cookie") == "session=stream" + assert data["data"] == "stream_test" * 1024 + + # post stream request with data body + request = Request("POST", server_url, data={"form": "test"}) + chunks = [] + async for resp in driver.stream_request(request, chunk_size=4): + assert resp.status_code == 200 + assert resp.content + chunks.append(resp.content) + assert all(len(chunk) == 4 for chunk in chunks[:-1]) + data = json.loads(b"".join(chunks)) + assert data["method"] == "POST" + assert data["form"] == {"form": "test"} + + # post stream request with json body + request = Request("POST", server_url, json={"json": "test"}) + chunks = [] + async for resp in driver.stream_request(request, chunk_size=4): + assert resp.status_code == 200 + assert resp.content + chunks.append(resp.content) + assert all(len(chunk) == 4 for chunk in chunks[:-1]) + data = json.loads(b"".join(chunks)) + assert data["method"] == "POST" + assert data["json"] == {"json": "test"} + + # post stream request with files and form data + request = Request( + "POST", + server_url, + data={"form": "test"}, + files=[ + ("test1", b"test"), + ("test2", ("test.txt", b"test")), + ("test3", ("test.txt", b"test", "text/plain")), + ], + ) + chunks = [] + async for resp in driver.stream_request(request, chunk_size=4): + assert response.status_code == 200 + assert response.content + chunks.append(resp.content) + assert all(len(chunk) == 4 for chunk in chunks[:-1]) + data = json.loads(b"".join(chunks)) + assert data["method"] == "POST" + assert data["form"] == {"form": "test"} + assert data["files"] == { + "test1": "test", + "test2": "test", + "test3": "test", + }, "file parsing error" + await anyio.sleep(1) @@ -419,6 +491,100 @@ async def test_http_client_session(driver: Driver, server_url: URL): "test3": "test", }, "file parsing error" + # post stream request with query, headers, cookies and content + request = Request( + "POST", + server_url, + params={"param": "stream"}, + headers={"X-Test": "stream"}, + cookies={"cookie": "stream"}, + content="stream_test" * 1024, + ) + chunks = [] + async for resp in session.stream_request(request, chunk_size=4): + assert resp.status_code == 200 + assert resp.content + chunks.append(resp.content) + assert all(len(chunk) == 4 for chunk in chunks[:-1]) + data = json.loads(b"".join(chunks)) + assert data["method"] == "POST" + assert data["args"] == {"session": "test", "param": "stream"} + assert data["headers"].get("X-Session") == "test" + assert data["headers"].get("X-Test") == "stream" + assert { + key: cookie.value + for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items() + } == {"session": "test", "cookie": "stream"} + assert data["data"] == "stream_test" * 1024 + + # post stream request with data body + request = Request("POST", server_url, data={"form": "test"}) + chunks = [] + async for resp in session.stream_request(request, chunk_size=4): + assert resp.status_code == 200 + assert resp.content + chunks.append(resp.content) + assert all(len(chunk) == 4 for chunk in chunks[:-1]) + data = json.loads(b"".join(chunks)) + assert data["method"] == "POST" + assert data["args"] == {"session": "test"} + assert data["headers"].get("X-Session") == "test" + assert { + key: cookie.value + for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items() + } == {"session": "test"} + assert data["form"] == {"form": "test"} + + # post stream request with json body + request = Request("POST", server_url, json={"json": "test"}) + chunks = [] + async for resp in session.stream_request(request, chunk_size=4): + assert resp.status_code == 200 + assert resp.content + chunks.append(resp.content) + assert all(len(chunk) == 4 for chunk in chunks[:-1]) + data = json.loads(b"".join(chunks)) + assert data["method"] == "POST" + assert data["args"] == {"session": "test"} + assert data["headers"].get("X-Session") == "test" + assert { + key: cookie.value + for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items() + } == {"session": "test"} + assert data["json"] == {"json": "test"} + + # post stream request with files and form data + request = Request( + "POST", + server_url, + data={"form": "test"}, + files=[ + ("test1", b"test"), + ("test2", ("test.txt", b"test")), + ("test3", ("test.txt", b"test", "text/plain")), + ], + ) + chunks = [] + async for resp in session.stream_request(request, chunk_size=4): + assert resp.status_code == 200 + assert resp.content + chunks.append(resp.content) + assert all(len(chunk) == 4 for chunk in chunks[:-1]) + data = json.loads(b"".join(chunks)) + assert data["method"] == "POST" + assert data["args"] == {"session": "test"} + assert data["headers"].get("X-Session") == "test" + assert { + key: cookie.value + for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items() + } == {"session": "test"} + assert data["form"] == {"form": "test"} + assert data["files"] == { + "test1": "test", + "test2": "test", + "test3": "test", + }, "file parsing error" + await anyio.sleep(1)