mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-06-09 05:45:51 +00:00
♿ change websocket client to context manager
This commit is contained in:
parent
00c2ee8490
commit
7b204d72e6
@ -8,7 +8,17 @@
|
|||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Set,
|
||||||
|
Dict,
|
||||||
|
Type,
|
||||||
|
Callable,
|
||||||
|
Awaitable,
|
||||||
|
AsyncGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
from ._model import URL as URL
|
from ._model import URL as URL
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
@ -215,8 +225,10 @@ class ForwardMixin(abc.ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def websocket(self, setup: Request) -> WebSocket:
|
@asynccontextmanager
|
||||||
|
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
yield # used for static type checking's generator detection
|
||||||
|
|
||||||
|
|
||||||
class ForwardDriver(Driver, ForwardMixin):
|
class ForwardDriver(Driver, ForwardMixin):
|
||||||
|
@ -5,6 +5,9 @@ AIOHTTP 驱动适配
|
|||||||
本驱动仅支持客户端连接
|
本驱动仅支持客户端连接
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.drivers import Request, Response
|
from nonebot.drivers import Request, Response
|
||||||
from nonebot.drivers._block_driver import BlockDriver
|
from nonebot.drivers._block_driver import BlockDriver
|
||||||
@ -59,7 +62,8 @@ class Mixin(ForwardMixin):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
@overrides(ForwardMixin)
|
@overrides(ForwardMixin)
|
||||||
async def websocket(self, setup: Request) -> "WebSocket":
|
@asynccontextmanager
|
||||||
|
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
|
||||||
if setup.version == HTTPVersion.H10:
|
if setup.version == HTTPVersion.H10:
|
||||||
version = aiohttp.HttpVersion10
|
version = aiohttp.HttpVersion10
|
||||||
elif setup.version == HTTPVersion.H11:
|
elif setup.version == HTTPVersion.H11:
|
||||||
@ -68,15 +72,15 @@ class Mixin(ForwardMixin):
|
|||||||
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
||||||
|
|
||||||
session = aiohttp.ClientSession(version=version, trust_env=True)
|
session = aiohttp.ClientSession(version=version, trust_env=True)
|
||||||
ws = await session.ws_connect(
|
async with session.ws_connect(
|
||||||
setup.url,
|
setup.url,
|
||||||
method=setup.method,
|
method=setup.method,
|
||||||
timeout=setup.timeout or 10,
|
timeout=setup.timeout or 10,
|
||||||
headers=setup.headers,
|
headers=setup.headers,
|
||||||
proxy=setup.proxy,
|
proxy=setup.proxy,
|
||||||
)
|
) as ws:
|
||||||
websocket = WebSocket(request=setup, session=session, websocket=ws)
|
websocket = WebSocket(request=setup, session=session, websocket=ws)
|
||||||
return websocket
|
yield websocket
|
||||||
|
|
||||||
|
|
||||||
class WebSocket(BaseWebSocket):
|
class WebSocket(BaseWebSocket):
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
from typing import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.drivers._block_driver import BlockDriver
|
from nonebot.drivers._block_driver import BlockDriver
|
||||||
from nonebot.drivers import (
|
from nonebot.drivers import (
|
||||||
@ -48,8 +51,10 @@ class Mixin(ForwardMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@overrides(ForwardMixin)
|
@overrides(ForwardMixin)
|
||||||
async def websocket(self, setup: Request) -> WebSocket:
|
@asynccontextmanager
|
||||||
return await super(Mixin, self).websocket(setup)
|
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
|
||||||
|
async with super(Mixin, self).websocket(setup) as ws:
|
||||||
|
yield ws
|
||||||
|
|
||||||
|
|
||||||
Driver = combine_driver(BlockDriver, Mixin)
|
Driver = combine_driver(BlockDriver, Mixin)
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.log import LoguruHandler
|
from nonebot.log import LoguruHandler
|
||||||
@ -29,13 +31,15 @@ class Mixin(ForwardMixin):
|
|||||||
return await super(Mixin, self).request(setup)
|
return await super(Mixin, self).request(setup)
|
||||||
|
|
||||||
@overrides(ForwardMixin)
|
@overrides(ForwardMixin)
|
||||||
async def websocket(self, setup: Request) -> "WebSocket":
|
@asynccontextmanager
|
||||||
ws = await Connect(
|
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
|
||||||
|
connection = Connect(
|
||||||
str(setup.url),
|
str(setup.url),
|
||||||
extra_headers=setup.headers.items(),
|
extra_headers=setup.headers.items(),
|
||||||
open_timeout=setup.timeout,
|
open_timeout=setup.timeout,
|
||||||
)
|
)
|
||||||
return WebSocket(request=setup, websocket=ws)
|
async with connection as ws:
|
||||||
|
yield WebSocket(request=setup, websocket=ws)
|
||||||
|
|
||||||
|
|
||||||
class WebSocket(BaseWebSocket):
|
class WebSocket(BaseWebSocket):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user