🐛 Fix: aiohttp 驱动未处理 WSMsgType.CLOSED 类型 (#3862)

This commit is contained in:
呵呵です
2026-02-14 19:29:11 +08:00
committed by GitHub
parent 7fbab3de79
commit 346eddda06
2 changed files with 43 additions and 1 deletions

View File

@@ -323,7 +323,11 @@ class WebSocket(BaseWebSocket):
async def _receive(self) -> aiohttp.WSMessage:
msg = await self.websocket.receive()
if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
if msg.type in (
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
aiohttp.WSMsgType.CLOSED,
):
raise WebSocketClosed(self.websocket.close_code or 1006)
return msg

View File

@@ -2,6 +2,7 @@ from http.cookies import SimpleCookie
import json
from typing import Any, Optional
from aiohttp import ClientSession, ClientWebSocketResponse, WSMessage, WSMsgType
import anyio
from nonebug import App
import pytest
@@ -21,6 +22,7 @@ from nonebot.drivers import (
WebSocketClientMixin,
WebSocketServerSetup,
)
from nonebot.drivers.aiohttp import WebSocket as AiohttpWebSocket
from nonebot.exception import WebSocketClosed
from nonebot.params import Depends
from utils import FakeAdapter
@@ -627,6 +629,42 @@ async def test_websocket_client(driver: Driver, server_url: URL):
await anyio.sleep(1)
@pytest.mark.anyio
@pytest.mark.parametrize(
("msg_type"),
[
pytest.param("CLOSE", id="aiohttp-close"),
pytest.param("CLOSING", id="aiohttp-closing"),
pytest.param("CLOSED", id="aiohttp-closed"),
],
)
async def test_aiohttp_websocket_close_frame(msg_type: str) -> None:
class DummyWS(ClientWebSocketResponse):
def __init__(self) -> None:
pass
@property
def close_code(self) -> None:
return None
@property
def closed(self) -> bool:
return True
async def receive(self, timeout: Optional[float] = None) -> WSMessage: # noqa: ASYNC109
return WSMessage(type=WSMsgType[msg_type], data=None, extra=None)
async with ClientSession() as session:
ws = AiohttpWebSocket(
request=Request("GET", "ws://example.com"),
session=session,
websocket=DummyWS(),
)
with pytest.raises(WebSocketClosed, match=r"code=1006"):
await ws.receive()
@pytest.mark.parametrize(
("driver", "driver_type"),
[