mirror of
https://github.com/nonebot/nonebot2.git
synced 2026-04-17 14:22:25 +00:00
Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
173 lines
4.7 KiB
Python
173 lines
4.7 KiB
Python
"""[websockets](https://websockets.readthedocs.io/) 驱动适配
|
|
|
|
```bash
|
|
nb driver install websockets
|
|
# 或者
|
|
pip install nonebot2[websockets]
|
|
```
|
|
|
|
:::tip 提示
|
|
本驱动仅支持客户端 WebSocket 连接
|
|
:::
|
|
|
|
FrontMatter:
|
|
mdx:
|
|
format: md
|
|
sidebar_position: 4
|
|
description: nonebot.drivers.websockets 模块
|
|
"""
|
|
|
|
from collections.abc import AsyncGenerator, Callable
|
|
from contextlib import asynccontextmanager
|
|
from functools import wraps
|
|
import logging
|
|
from types import CoroutineType
|
|
from typing import TYPE_CHECKING, Any, TypeVar
|
|
from typing_extensions import ParamSpec, override
|
|
|
|
from nonebot.drivers import (
|
|
DEFAULT_TIMEOUT,
|
|
Request,
|
|
Timeout,
|
|
WebSocketClientMixin,
|
|
combine_driver,
|
|
)
|
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
|
from nonebot.drivers.none import Driver as NoneDriver
|
|
from nonebot.exception import WebSocketClosed
|
|
from nonebot.log import LoguruHandler
|
|
from nonebot.utils import UNSET, exclude_unset
|
|
|
|
try:
|
|
from websockets import ClientConnection, ConnectionClosed, connect
|
|
except ModuleNotFoundError as e: # pragma: no cover
|
|
raise ImportError(
|
|
"Please install websockets first to use this driver. "
|
|
"Install with pip: `pip install nonebot2[websockets]`"
|
|
) from e
|
|
|
|
T = TypeVar("T")
|
|
P = ParamSpec("P")
|
|
|
|
logger = logging.Logger("websockets.client", "INFO")
|
|
logger.addHandler(LoguruHandler())
|
|
|
|
|
|
def catch_closed(
|
|
func: Callable[P, "CoroutineType[Any, Any, T]"],
|
|
) -> Callable[P, "CoroutineType[Any, Any, T]"]:
|
|
@wraps(func)
|
|
async def decorator(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
except ConnectionClosed as e:
|
|
raise WebSocketClosed(e.code, e.reason)
|
|
|
|
return decorator
|
|
|
|
|
|
class Mixin(WebSocketClientMixin):
|
|
"""Websockets Mixin"""
|
|
|
|
@property
|
|
@override
|
|
def type(self) -> str:
|
|
return "websockets"
|
|
|
|
@override
|
|
@asynccontextmanager
|
|
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
|
|
timeout_kwargs: dict[str, float | None] = {}
|
|
if isinstance(setup.timeout, Timeout):
|
|
open_timeout = (
|
|
setup.timeout.connect or setup.timeout.read or setup.timeout.total
|
|
)
|
|
timeout_kwargs = exclude_unset(
|
|
{"open_timeout": open_timeout, "close_timeout": setup.timeout.close}
|
|
)
|
|
elif setup.timeout is not UNSET:
|
|
timeout_kwargs = {
|
|
"open_timeout": setup.timeout,
|
|
"close_timeout": setup.timeout,
|
|
}
|
|
|
|
if not timeout_kwargs:
|
|
open_timeout = (
|
|
DEFAULT_TIMEOUT.connect or DEFAULT_TIMEOUT.read or DEFAULT_TIMEOUT.total
|
|
)
|
|
timeout_kwargs = exclude_unset(
|
|
{
|
|
"open_timeout": open_timeout,
|
|
"close_timeout": DEFAULT_TIMEOUT.close,
|
|
}
|
|
)
|
|
|
|
connection = connect(
|
|
str(setup.url),
|
|
additional_headers={**setup.headers, **setup.cookies.as_header(setup)},
|
|
proxy=setup.proxy if setup.proxy is not None else True,
|
|
**timeout_kwargs, # type: ignore
|
|
)
|
|
async with connection as ws:
|
|
yield WebSocket(request=setup, websocket=ws)
|
|
|
|
|
|
class WebSocket(BaseWebSocket):
|
|
"""Websockets WebSocket Wrapper"""
|
|
|
|
@override
|
|
def __init__(self, *, request: Request, websocket: ClientConnection):
|
|
super().__init__(request=request)
|
|
self.websocket = websocket
|
|
|
|
@property
|
|
@override
|
|
def closed(self) -> bool:
|
|
return self.websocket.close_code is not None
|
|
|
|
@override
|
|
async def accept(self):
|
|
raise NotImplementedError
|
|
|
|
@override
|
|
async def close(self, code: int = 1000, reason: str = ""):
|
|
await self.websocket.close(code, reason)
|
|
|
|
@override
|
|
@catch_closed
|
|
async def receive(self) -> str | bytes:
|
|
return await self.websocket.recv()
|
|
|
|
@override
|
|
@catch_closed
|
|
async def receive_text(self) -> str:
|
|
msg = await self.websocket.recv()
|
|
if isinstance(msg, bytes):
|
|
raise TypeError("WebSocket received unexpected frame type: bytes")
|
|
return msg
|
|
|
|
@override
|
|
@catch_closed
|
|
async def receive_bytes(self) -> bytes:
|
|
msg = await self.websocket.recv()
|
|
if isinstance(msg, str):
|
|
raise TypeError("WebSocket received unexpected frame type: str")
|
|
return msg
|
|
|
|
@override
|
|
async def send_text(self, data: str) -> None:
|
|
await self.websocket.send(data)
|
|
|
|
@override
|
|
async def send_bytes(self, data: bytes) -> None:
|
|
await self.websocket.send(data)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
class Driver(Mixin, NoneDriver): ...
|
|
|
|
else:
|
|
Driver = combine_driver(NoneDriver, Mixin)
|
|
"""Websockets Driver"""
|