add websocket close exception

This commit is contained in:
yanyongyu
2021-12-26 14:20:09 +08:00
parent e64f399370
commit 8093c5d154
3 changed files with 58 additions and 2 deletions

View File

@ -1,15 +1,18 @@
import logging
from functools import wraps
from typing import AsyncGenerator
from contextlib import asynccontextmanager
from nonebot.typing import overrides
from nonebot.log import LoguruHandler
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ForwardMixin, combine_driver
try:
from websockets.exceptions import ConnectionClosed
from websockets.legacy.client import Connect, WebSocketClientProtocol
except ImportError:
raise ImportError(
@ -20,6 +23,20 @@ logger = logging.Logger("websockets.client", "INFO")
logger.addHandler(LoguruHandler())
def catch_closed(func):
@wraps(func)
async def decorator(*args, **kwargs):
try:
return await func(*args, **kwargs)
except ConnectionClosed as e:
if e.rcvd_then_sent:
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason)
else:
raise WebSocketClosed(e.sent.code, e.sent.reason)
return decorator
class Mixin(ForwardMixin):
@property
@overrides(ForwardMixin)
@ -62,6 +79,7 @@ class WebSocket(BaseWebSocket):
await self.websocket.close(code, reason)
@overrides(BaseWebSocket)
@catch_closed
async def receive(self) -> str:
msg = await self.websocket.recv()
if isinstance(msg, bytes):
@ -69,6 +87,7 @@ class WebSocket(BaseWebSocket):
return msg
@overrides(BaseWebSocket)
@catch_closed
async def receive_bytes(self) -> bytes:
msg = await self.websocket.recv()
if isinstance(msg, str):