mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-22 10:36:41 +00:00 
			
		
		
		
	🐛 Fix: websockets 驱动器连接关闭 code 获取错误 (#2537)
This commit is contained in:
		| @@ -50,10 +50,7 @@ def catch_closed(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: | |||||||
|         try: |         try: | ||||||
|             return await func(*args, **kwargs) |             return await func(*args, **kwargs) | ||||||
|         except ConnectionClosed as e: |         except ConnectionClosed as e: | ||||||
|             if e.rcvd_then_sent: |             raise WebSocketClosed(e.code, e.reason) | ||||||
|                 raise WebSocketClosed(e.rcvd.code, e.rcvd.reason)  # type: ignore |  | ||||||
|             else: |  | ||||||
|                 raise WebSocketClosed(e.sent.code, e.sent.reason)  # type: ignore |  | ||||||
|  |  | ||||||
|     return decorator |     return decorator | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										6
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										6
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							| @@ -889,7 +889,7 @@ files = [ | |||||||
| name = "h11" | name = "h11" | ||||||
| version = "0.14.0" | version = "0.14.0" | ||||||
| description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" | description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" | ||||||
| optional = true | optional = false | ||||||
| python-versions = ">=3.7" | python-versions = ">=3.7" | ||||||
| files = [ | files = [ | ||||||
|     {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, |     {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, | ||||||
| @@ -2267,7 +2267,7 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] | |||||||
| name = "wsproto" | name = "wsproto" | ||||||
| version = "1.2.0" | version = "1.2.0" | ||||||
| description = "WebSockets state-machine based protocol implementation" | description = "WebSockets state-machine based protocol implementation" | ||||||
| optional = true | optional = false | ||||||
| python-versions = ">=3.7.0" | python-versions = ">=3.7.0" | ||||||
| files = [ | files = [ | ||||||
|     {file = "wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736"}, |     {file = "wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736"}, | ||||||
| @@ -2406,4 +2406,4 @@ websockets = ["websockets"] | |||||||
| [metadata] | [metadata] | ||||||
| lock-version = "2.0" | lock-version = "2.0" | ||||||
| python-versions = "^3.8" | python-versions = "^3.8" | ||||||
| content-hash = "e7bd1c1b070f1a46d94022047f2b76dbf90751f49086a099139f2ade4ad07a65" | content-hash = "ec064b0d1c22da40c55132f706fbf3802b8a5f8dcf647c2302ee0a2d248e3340" | ||||||
|   | |||||||
| @@ -52,6 +52,7 @@ ruff = ">=0.0.272,<1.0.0" | |||||||
|  |  | ||||||
| [tool.poetry.group.test.dependencies] | [tool.poetry.group.test.dependencies] | ||||||
| nonebug = "^0.3.0" | nonebug = "^0.3.0" | ||||||
|  | wsproto = "^1.2.0" | ||||||
| pytest-cov = "^4.0.0" | pytest-cov = "^4.0.0" | ||||||
| pytest-xdist = "^3.0.2" | pytest-xdist = "^3.0.2" | ||||||
| pytest-asyncio = "^0.23.2" | pytest-asyncio = "^0.23.2" | ||||||
|   | |||||||
| @@ -1,9 +1,15 @@ | |||||||
| import json | import json | ||||||
| import base64 | import base64 | ||||||
|  | import socket | ||||||
| from typing import Dict, List, Union, TypeVar | from typing import Dict, List, Union, TypeVar | ||||||
|  |  | ||||||
|  | from wsproto.events import Ping | ||||||
| from werkzeug import Request, Response | from werkzeug import Request, Response | ||||||
| from werkzeug.datastructures import MultiDict | from werkzeug.datastructures import MultiDict | ||||||
|  | from wsproto.frame_protocol import CloseReason | ||||||
|  | from wsproto.events import Request as WSRequest | ||||||
|  | from wsproto import WSConnection, ConnectionType | ||||||
|  | from wsproto.events import TextMessage, BytesMessage, CloseConnection, AcceptConnection | ||||||
|  |  | ||||||
| K = TypeVar("K") | K = TypeVar("K") | ||||||
| V = TypeVar("V") | V = TypeVar("V") | ||||||
| @@ -29,8 +35,7 @@ def flattern(d: "MultiDict[K, V]") -> Dict[K, Union[V, List[V]]]: | |||||||
|     return {k: v[0] if len(v) == 1 else v for k, v in d.to_dict(flat=False).items()} |     return {k: v[0] if len(v) == 1 else v for k, v in d.to_dict(flat=False).items()} | ||||||
|  |  | ||||||
|  |  | ||||||
| @Request.application | def http_echo(request: Request) -> Response: | ||||||
| def request_handler(request: Request) -> Response: |  | ||||||
|     try: |     try: | ||||||
|         _json = json.loads(request.data.decode("utf-8")) |         _json = json.loads(request.data.decode("utf-8")) | ||||||
|     except (ValueError, TypeError): |     except (ValueError, TypeError): | ||||||
| @@ -67,3 +72,65 @@ def request_handler(request: Request) -> Response: | |||||||
|         status=200, |         status=200, | ||||||
|         content_type="application/json", |         content_type="application/json", | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def websocket_echo(request: Request) -> Response: | ||||||
|  |     stream = request.environ["werkzeug.socket"] | ||||||
|  |  | ||||||
|  |     ws = WSConnection(ConnectionType.SERVER) | ||||||
|  |  | ||||||
|  |     in_data = b"GET %s HTTP/1.1\r\n" % request.path.encode("utf-8") | ||||||
|  |     for header, value in request.headers.items(): | ||||||
|  |         in_data += f"{header}: {value}\r\n".encode() | ||||||
|  |     in_data += b"\r\n" | ||||||
|  |  | ||||||
|  |     ws.receive_data(in_data) | ||||||
|  |  | ||||||
|  |     running: bool = True | ||||||
|  |     while True: | ||||||
|  |         out_data = b"" | ||||||
|  |  | ||||||
|  |         for event in ws.events(): | ||||||
|  |             if isinstance(event, WSRequest): | ||||||
|  |                 out_data += ws.send(AcceptConnection()) | ||||||
|  |             elif isinstance(event, CloseConnection): | ||||||
|  |                 out_data += ws.send(event.response()) | ||||||
|  |                 running = False | ||||||
|  |             elif isinstance(event, Ping): | ||||||
|  |                 out_data += ws.send(event.response()) | ||||||
|  |             elif isinstance(event, TextMessage): | ||||||
|  |                 if event.data == "quit": | ||||||
|  |                     out_data += ws.send( | ||||||
|  |                         CloseConnection(CloseReason.NORMAL_CLOSURE, "bye") | ||||||
|  |                     ) | ||||||
|  |                     running = False | ||||||
|  |                 else: | ||||||
|  |                     out_data += ws.send(TextMessage(data=event.data)) | ||||||
|  |             elif isinstance(event, BytesMessage): | ||||||
|  |                 if event.data == b"quit": | ||||||
|  |                     out_data += ws.send( | ||||||
|  |                         CloseConnection(CloseReason.NORMAL_CLOSURE, "bye") | ||||||
|  |                     ) | ||||||
|  |                     running = False | ||||||
|  |                 else: | ||||||
|  |                     out_data += ws.send(BytesMessage(data=event.data)) | ||||||
|  |  | ||||||
|  |         if out_data: | ||||||
|  |             stream.send(out_data) | ||||||
|  |  | ||||||
|  |         if not running: | ||||||
|  |             break | ||||||
|  |  | ||||||
|  |         in_data = stream.recv(4096) | ||||||
|  |         ws.receive_data(in_data) | ||||||
|  |  | ||||||
|  |     stream.shutdown(socket.SHUT_RDWR) | ||||||
|  |     return Response("", status=204) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @Request.application | ||||||
|  | def request_handler(request: Request) -> Response: | ||||||
|  |     if request.headers.get("Connection") == "Upgrade": | ||||||
|  |         return websocket_echo(request) | ||||||
|  |     else: | ||||||
|  |         return http_echo(request) | ||||||
|   | |||||||
| @@ -131,7 +131,7 @@ async def test_websocket_server(app: App, driver: Driver): | |||||||
|         assert data == b"ping" |         assert data == b"ping" | ||||||
|         await ws.send(b"pong") |         await ws.send(b"pong") | ||||||
|  |  | ||||||
|         with pytest.raises(WebSocketClosed): |         with pytest.raises(WebSocketClosed, match=r"code=1000"): | ||||||
|             await ws.receive() |             await ws.receive() | ||||||
|  |  | ||||||
|     ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws) |     ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws) | ||||||
| @@ -152,7 +152,7 @@ async def test_websocket_server(app: App, driver: Driver): | |||||||
|             await ws.send_bytes(b"ping") |             await ws.send_bytes(b"ping") | ||||||
|             assert await ws.receive_bytes() == b"pong" |             assert await ws.receive_bytes() == b"pong" | ||||||
|  |  | ||||||
|             await ws.close() |             await ws.close(code=1000) | ||||||
|  |  | ||||||
|     await asyncio.sleep(1) |     await asyncio.sleep(1) | ||||||
|  |  | ||||||
| @@ -315,9 +315,29 @@ async def test_http_client(driver: Driver, server_url: URL): | |||||||
|     ], |     ], | ||||||
|     indirect=True, |     indirect=True, | ||||||
| ) | ) | ||||||
| async def test_websocket_client(driver: Driver): | async def test_websocket_client(driver: Driver, server_url: URL): | ||||||
|     assert isinstance(driver, WebSocketClientMixin) |     assert isinstance(driver, WebSocketClientMixin) | ||||||
|  |  | ||||||
|  |     request = Request("GET", server_url.with_scheme("ws")) | ||||||
|  |     async with driver.websocket(request) as ws: | ||||||
|  |         await ws.send("test") | ||||||
|  |         assert await ws.receive() == "test" | ||||||
|  |  | ||||||
|  |         await ws.send(b"test") | ||||||
|  |         assert await ws.receive() == b"test" | ||||||
|  |  | ||||||
|  |         await ws.send_text("test") | ||||||
|  |         assert await ws.receive_text() == "test" | ||||||
|  |  | ||||||
|  |         await ws.send_bytes(b"test") | ||||||
|  |         assert await ws.receive_bytes() == b"test" | ||||||
|  |  | ||||||
|  |         await ws.send("quit") | ||||||
|  |         with pytest.raises(WebSocketClosed, match=r"code=1000"): | ||||||
|  |             await ws.receive() | ||||||
|  |  | ||||||
|  |     await asyncio.sleep(1) | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user