Feature: 支持 HTTP 客户端会话 (#2627)

This commit is contained in:
Ju4tCode
2024-04-05 21:11:05 +08:00
committed by GitHub
parent 53e2a86dd9
commit 485aa62755
7 changed files with 420 additions and 65 deletions

View File

@ -17,15 +17,19 @@ FrontMatter:
from typing_extensions import override
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, AsyncGenerator
from typing import TYPE_CHECKING, Union, Optional, AsyncGenerator
from multidict import CIMultiDict
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
from nonebot.drivers import URL, Request, Response
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes
from nonebot.drivers import (
HTTPVersion,
HTTPClientMixin,
HTTPClientSession,
WebSocketClientMixin,
combine_driver,
)
@ -39,6 +43,105 @@ except ModuleNotFoundError as e: # pragma: no cover
) from e
class Session(HTTPClientSession):
@override
def __init__(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
):
self._client: Optional[aiohttp.ClientSession] = None
self._params = URL.build(query=params).query if params is not None else None
self._headers = CIMultiDict(headers) if headers is not None else None
self._cookies = tuple(
(cookie.name, cookie.value)
for cookie in Cookies(cookies)
if cookie.value is not None
)
version = HTTPVersion(version)
if version == HTTPVersion.H10:
self._version = aiohttp.HttpVersion10
elif version == HTTPVersion.H11:
self._version = aiohttp.HttpVersion11
else:
raise RuntimeError(f"Unsupported HTTP version: {version}")
self._timeout = timeout
self._proxy = proxy
@property
def client(self) -> aiohttp.ClientSession:
if self._client is None:
raise RuntimeError("Session is not initialized")
return self._client
@override
async def request(self, setup: Request) -> Response:
if self._params:
params = self._params.copy()
params.update(setup.url.query)
url = setup.url.with_query(params)
else:
url = setup.url
data = setup.data
if setup.files:
data = aiohttp.FormData(data or {}, quote_fields=False)
for name, file in setup.files:
data.add_field(name, file[1], content_type=file[2], filename=file[0])
cookies = (
(cookie.name, cookie.value)
for cookie in setup.cookies
if cookie.value is not None
)
timeout = aiohttp.ClientTimeout(setup.timeout)
async with await self.client.request(
setup.method,
url,
data=setup.content or data,
json=setup.json,
cookies=cookies,
headers=setup.headers,
proxy=setup.proxy or self._proxy,
timeout=timeout,
) as response:
return Response(
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)
@override
async def setup(self) -> None:
self._client = aiohttp.ClientSession(
cookies=self._cookies,
headers=self._headers,
version=self._version,
timeout=self._timeout,
trust_env=True,
)
await self._client.__aenter__()
@override
async def close(self) -> None:
try:
if self._client is not None:
await self._client.close()
finally:
self._client = None
class Mixin(HTTPClientMixin, WebSocketClientMixin):
"""AIOHTTP Mixin"""
@ -49,42 +152,8 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
@override
async def request(self, setup: Request) -> Response:
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11:
version = aiohttp.HttpVersion11
else:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
timeout = aiohttp.ClientTimeout(setup.timeout)
data = setup.data
if setup.files:
data = aiohttp.FormData(data or {}, quote_fields=False)
for name, file in setup.files:
data.add_field(name, file[1], content_type=file[2], filename=file[0])
cookies = {
cookie.name: cookie.value for cookie in setup.cookies if cookie.value
}
async with aiohttp.ClientSession(
cookies=cookies, version=version, trust_env=True
) as session:
async with session.request(
setup.method,
setup.url,
data=setup.content or data,
json=setup.json,
headers=setup.headers,
timeout=timeout,
proxy=setup.proxy,
) as response:
return Response(
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)
async with self.get_session() as session:
return await session.request(setup)
@override
@asynccontextmanager
@ -106,6 +175,25 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
) as ws:
yield WebSocket(request=setup, session=session, websocket=ws)
@override
def get_session(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
) -> Session:
return Session(
params=params,
headers=headers,
cookies=cookies,
version=version,
timeout=timeout,
proxy=proxy,
)
class WebSocket(BaseWebSocket):
"""AIOHTTP Websocket Wrapper"""