diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index b77e374e..d386f572 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -323,7 +323,11 @@ class WebSocket(BaseWebSocket): async def _receive(self) -> aiohttp.WSMessage: msg = await self.websocket.receive() - if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): + if msg.type in ( + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + aiohttp.WSMsgType.CLOSED, + ): raise WebSocketClosed(self.websocket.close_code or 1006) return msg diff --git a/tests/test_driver.py b/tests/test_driver.py index cf364daa..48f78690 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -2,6 +2,7 @@ from http.cookies import SimpleCookie import json from typing import Any, Optional +from aiohttp import ClientSession, ClientWebSocketResponse, WSMessage, WSMsgType import anyio from nonebug import App import pytest @@ -21,6 +22,7 @@ from nonebot.drivers import ( WebSocketClientMixin, WebSocketServerSetup, ) +from nonebot.drivers.aiohttp import WebSocket as AiohttpWebSocket from nonebot.exception import WebSocketClosed from nonebot.params import Depends from utils import FakeAdapter @@ -627,6 +629,42 @@ async def test_websocket_client(driver: Driver, server_url: URL): await anyio.sleep(1) +@pytest.mark.anyio +@pytest.mark.parametrize( + ("msg_type"), + [ + pytest.param("CLOSE", id="aiohttp-close"), + pytest.param("CLOSING", id="aiohttp-closing"), + pytest.param("CLOSED", id="aiohttp-closed"), + ], +) +async def test_aiohttp_websocket_close_frame(msg_type: str) -> None: + class DummyWS(ClientWebSocketResponse): + def __init__(self) -> None: + pass + + @property + def close_code(self) -> None: + return None + + @property + def closed(self) -> bool: + return True + + async def receive(self, timeout: Optional[float] = None) -> WSMessage: # noqa: ASYNC109 + return WSMessage(type=WSMsgType[msg_type], data=None, extra=None) + + async with ClientSession() as session: + ws = AiohttpWebSocket( + request=Request("GET", "ws://example.com"), + session=session, + websocket=DummyWS(), + ) + + with pytest.raises(WebSocketClosed, match=r"code=1006"): + await ws.receive() + + @pytest.mark.parametrize( ("driver", "driver_type"), [