change websocket client to context manager

This commit is contained in:
yanyongyu 2021-12-26 13:42:13 +08:00
parent 00c2ee8490
commit 7b204d72e6
4 changed files with 37 additions and 12 deletions

View File

@ -8,7 +8,17 @@
import abc import abc
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable from contextlib import asynccontextmanager
from typing import (
TYPE_CHECKING,
Any,
Set,
Dict,
Type,
Callable,
Awaitable,
AsyncGenerator,
)
from ._model import URL as URL from ._model import URL as URL
from nonebot.log import logger from nonebot.log import logger
@ -215,8 +225,10 @@ class ForwardMixin(abc.ABC):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def websocket(self, setup: Request) -> WebSocket: @asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
raise NotImplementedError raise NotImplementedError
yield # used for static type checking's generator detection
class ForwardDriver(Driver, ForwardMixin): class ForwardDriver(Driver, ForwardMixin):

View File

@ -5,6 +5,9 @@ AIOHTTP 驱动适配
本驱动仅支持客户端连接 本驱动仅支持客户端连接
""" """
from typing import AsyncGenerator
from contextlib import asynccontextmanager
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.drivers import Request, Response from nonebot.drivers import Request, Response
from nonebot.drivers._block_driver import BlockDriver from nonebot.drivers._block_driver import BlockDriver
@ -59,7 +62,8 @@ class Mixin(ForwardMixin):
return res return res
@overrides(ForwardMixin) @overrides(ForwardMixin)
async def websocket(self, setup: Request) -> "WebSocket": @asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
if setup.version == HTTPVersion.H10: if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10 version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11: elif setup.version == HTTPVersion.H11:
@ -68,15 +72,15 @@ class Mixin(ForwardMixin):
raise RuntimeError(f"Unsupported HTTP version: {setup.version}") raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
session = aiohttp.ClientSession(version=version, trust_env=True) session = aiohttp.ClientSession(version=version, trust_env=True)
ws = await session.ws_connect( async with session.ws_connect(
setup.url, setup.url,
method=setup.method, method=setup.method,
timeout=setup.timeout or 10, timeout=setup.timeout or 10,
headers=setup.headers, headers=setup.headers,
proxy=setup.proxy, proxy=setup.proxy,
) ) as ws:
websocket = WebSocket(request=setup, session=session, websocket=ws) websocket = WebSocket(request=setup, session=session, websocket=ws)
return websocket yield websocket
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):

View File

@ -1,3 +1,6 @@
from typing import AsyncGenerator
from contextlib import asynccontextmanager
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.drivers._block_driver import BlockDriver from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import ( from nonebot.drivers import (
@ -48,8 +51,10 @@ class Mixin(ForwardMixin):
) )
@overrides(ForwardMixin) @overrides(ForwardMixin)
async def websocket(self, setup: Request) -> WebSocket: @asynccontextmanager
return await super(Mixin, self).websocket(setup) async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
async with super(Mixin, self).websocket(setup) as ws:
yield ws
Driver = combine_driver(BlockDriver, Mixin) Driver = combine_driver(BlockDriver, Mixin)

View File

@ -1,4 +1,6 @@
import logging import logging
from typing import AsyncGenerator
from contextlib import asynccontextmanager
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.log import LoguruHandler from nonebot.log import LoguruHandler
@ -29,13 +31,15 @@ class Mixin(ForwardMixin):
return await super(Mixin, self).request(setup) return await super(Mixin, self).request(setup)
@overrides(ForwardMixin) @overrides(ForwardMixin)
async def websocket(self, setup: Request) -> "WebSocket": @asynccontextmanager
ws = await Connect( async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
connection = Connect(
str(setup.url), str(setup.url),
extra_headers=setup.headers.items(), extra_headers=setup.headers.items(),
open_timeout=setup.timeout, open_timeout=setup.timeout,
) )
return WebSocket(request=setup, websocket=ws) async with connection as ws:
yield WebSocket(request=setup, websocket=ws)
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):