Feature: 支持 HTTP 客户端会话 (#2627)

This commit is contained in:
Ju4tCode
2024-04-05 21:11:05 +08:00
committed by GitHub
parent 53e2a86dd9
commit 485aa62755
7 changed files with 420 additions and 65 deletions

View File

@ -23,6 +23,7 @@ from nonebot.internal.driver import ReverseDriver as ReverseDriver
from nonebot.internal.driver import combine_driver as combine_driver
from nonebot.internal.driver import HTTPClientMixin as HTTPClientMixin
from nonebot.internal.driver import HTTPServerSetup as HTTPServerSetup
from nonebot.internal.driver import HTTPClientSession as HTTPClientSession
from nonebot.internal.driver import WebSocketClientMixin as WebSocketClientMixin
from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup

View File

@ -17,15 +17,19 @@ FrontMatter:
from typing_extensions import override
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, AsyncGenerator
from typing import TYPE_CHECKING, Union, Optional, AsyncGenerator
from multidict import CIMultiDict
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
from nonebot.drivers import URL, Request, Response
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes
from nonebot.drivers import (
HTTPVersion,
HTTPClientMixin,
HTTPClientSession,
WebSocketClientMixin,
combine_driver,
)
@ -39,6 +43,105 @@ except ModuleNotFoundError as e: # pragma: no cover
) from e
class Session(HTTPClientSession):
@override
def __init__(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
):
self._client: Optional[aiohttp.ClientSession] = None
self._params = URL.build(query=params).query if params is not None else None
self._headers = CIMultiDict(headers) if headers is not None else None
self._cookies = tuple(
(cookie.name, cookie.value)
for cookie in Cookies(cookies)
if cookie.value is not None
)
version = HTTPVersion(version)
if version == HTTPVersion.H10:
self._version = aiohttp.HttpVersion10
elif version == HTTPVersion.H11:
self._version = aiohttp.HttpVersion11
else:
raise RuntimeError(f"Unsupported HTTP version: {version}")
self._timeout = timeout
self._proxy = proxy
@property
def client(self) -> aiohttp.ClientSession:
if self._client is None:
raise RuntimeError("Session is not initialized")
return self._client
@override
async def request(self, setup: Request) -> Response:
if self._params:
params = self._params.copy()
params.update(setup.url.query)
url = setup.url.with_query(params)
else:
url = setup.url
data = setup.data
if setup.files:
data = aiohttp.FormData(data or {}, quote_fields=False)
for name, file in setup.files:
data.add_field(name, file[1], content_type=file[2], filename=file[0])
cookies = (
(cookie.name, cookie.value)
for cookie in setup.cookies
if cookie.value is not None
)
timeout = aiohttp.ClientTimeout(setup.timeout)
async with await self.client.request(
setup.method,
url,
data=setup.content or data,
json=setup.json,
cookies=cookies,
headers=setup.headers,
proxy=setup.proxy or self._proxy,
timeout=timeout,
) as response:
return Response(
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)
@override
async def setup(self) -> None:
self._client = aiohttp.ClientSession(
cookies=self._cookies,
headers=self._headers,
version=self._version,
timeout=self._timeout,
trust_env=True,
)
await self._client.__aenter__()
@override
async def close(self) -> None:
try:
if self._client is not None:
await self._client.close()
finally:
self._client = None
class Mixin(HTTPClientMixin, WebSocketClientMixin):
"""AIOHTTP Mixin"""
@ -49,42 +152,8 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
@override
async def request(self, setup: Request) -> Response:
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11:
version = aiohttp.HttpVersion11
else:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
timeout = aiohttp.ClientTimeout(setup.timeout)
data = setup.data
if setup.files:
data = aiohttp.FormData(data or {}, quote_fields=False)
for name, file in setup.files:
data.add_field(name, file[1], content_type=file[2], filename=file[0])
cookies = {
cookie.name: cookie.value for cookie in setup.cookies if cookie.value
}
async with aiohttp.ClientSession(
cookies=cookies, version=version, trust_env=True
) as session:
async with session.request(
setup.method,
setup.url,
data=setup.content or data,
json=setup.json,
headers=setup.headers,
timeout=timeout,
proxy=setup.proxy,
) as response:
return Response(
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)
async with self.get_session() as session:
return await session.request(setup)
@override
@asynccontextmanager
@ -106,6 +175,25 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
) as ws:
yield WebSocket(request=setup, session=session, websocket=ws)
@override
def get_session(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
) -> Session:
return Session(
params=params,
headers=headers,
cookies=cookies,
version=version,
timeout=timeout,
proxy=proxy,
)
class WebSocket(BaseWebSocket):
"""AIOHTTP Websocket Wrapper"""

View File

@ -15,15 +15,20 @@ FrontMatter:
description: nonebot.drivers.httpx 模块
"""
from typing import TYPE_CHECKING
from typing_extensions import override
from typing import TYPE_CHECKING, Union, Optional
from multidict import CIMultiDict
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes
from nonebot.drivers import (
URL,
Request,
Response,
HTTPVersion,
HTTPClientMixin,
HTTPClientSession,
combine_driver,
)
@ -36,6 +41,77 @@ except ModuleNotFoundError as e: # pragma: no cover
) from e
class Session(HTTPClientSession):
@override
def __init__(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
):
self._client: Optional[httpx.AsyncClient] = None
self._params = (
tuple(URL.build(query=params).query.items()) if params is not None else None
)
self._headers = (
tuple(CIMultiDict(headers).items()) if headers is not None else None
)
self._cookies = Cookies(cookies)
self._version = HTTPVersion(version)
self._timeout = timeout
self._proxy = proxy
@property
def client(self) -> httpx.AsyncClient:
if self._client is None:
raise RuntimeError("Session is not initialized")
return self._client
@override
async def request(self, setup: Request) -> Response:
response = await self.client.request(
setup.method,
str(setup.url),
content=setup.content,
data=setup.data,
files=setup.files,
json=setup.json,
headers=tuple(setup.headers.items()),
cookies=setup.cookies.jar,
timeout=setup.timeout,
)
return Response(
response.status_code,
headers=response.headers.multi_items(),
content=response.content,
request=setup,
)
@override
async def setup(self) -> None:
self._client = httpx.AsyncClient(
params=self._params,
headers=self._headers,
cookies=self._cookies.jar,
http2=self._version == HTTPVersion.H2,
proxies=self._proxy,
follow_redirects=True,
)
await self._client.__aenter__()
@override
async def close(self) -> None:
try:
if self._client is not None:
await self._client.aclose()
finally:
self._client = None
class Mixin(HTTPClientMixin):
"""HTTPX Mixin"""
@ -46,28 +122,29 @@ class Mixin(HTTPClientMixin):
@override
async def request(self, setup: Request) -> Response:
async with httpx.AsyncClient(
cookies=setup.cookies.jar,
http2=setup.version == HTTPVersion.H2,
proxies=setup.proxy,
follow_redirects=True,
) as client:
response = await client.request(
setup.method,
str(setup.url),
content=setup.content,
data=setup.data,
json=setup.json,
files=setup.files,
headers=tuple(setup.headers.items()),
timeout=setup.timeout,
)
return Response(
response.status_code,
headers=response.headers.multi_items(),
content=response.content,
request=setup,
)
async with self.get_session(
version=setup.version, proxy=setup.proxy
) as session:
return await session.request(setup)
@override
def get_session(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
) -> Session:
return Session(
params=params,
headers=headers,
cookies=cookies,
version=version,
timeout=timeout,
proxy=proxy,
)
if TYPE_CHECKING: