mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-31 06:56:39 +00:00 
			
		
		
		
	✨ fastapi driver support forward connect
This commit is contained in:
		| @@ -212,6 +212,9 @@ class Driver(ForwardDriver): | ||||
|         BotClass = self._adapters[setup.adapter] | ||||
|         bot = BotClass(setup.self_id, request) | ||||
|         self._bot_connect(bot) | ||||
|         logger.opt(colors=True).info( | ||||
|             f"Start http polling for <y>{setup.adapter.upper()} " | ||||
|             f"Bot {setup.self_id}</y>") | ||||
|  | ||||
|         headers = request.headers | ||||
|         timeout = aiohttp.ClientTimeout(30) | ||||
| @@ -289,11 +292,13 @@ class Driver(ForwardDriver): | ||||
|                     ) | ||||
|                     try: | ||||
|                         async with session.ws_connect(url) as ws: | ||||
|                             logger.opt(colors=True).info( | ||||
|                                 f"WebSocket Connection to <y>{setup.adapter.upper()} " | ||||
|                                 f"Bot {setup.self_id}</y> succeeded!") | ||||
|                             request = WebSocket( | ||||
|                                 setup.http_version, url.scheme, url.path, | ||||
|                                 url.raw_query_string.encode("latin-1"), { | ||||
|                                     **setup.headers, "host": host | ||||
|                                 }, ws) | ||||
|                                 url.raw_query_string.encode("latin-1"), headers, | ||||
|                                 ws) | ||||
|  | ||||
|                             BotClass = self._adapters[setup.adapter] | ||||
|                             bot = BotClass(setup.self_id, request) | ||||
|   | ||||
| @@ -11,19 +11,46 @@ FastAPI 驱动适配 | ||||
| import asyncio | ||||
| import logging | ||||
| from dataclasses import dataclass | ||||
| from typing import List, Optional, Callable | ||||
| from typing import List, Dict, Union, Optional, Callable | ||||
|  | ||||
| import httpx | ||||
| import uvicorn | ||||
| from pydantic import BaseSettings | ||||
| from fastapi.responses import Response | ||||
| from websockets.exceptions import ConnectionClosed | ||||
| from fastapi import status, Request, FastAPI, HTTPException | ||||
| from websockets.legacy.client import Connect, WebSocketClientProtocol | ||||
| from starlette.websockets import (WebSocketState, WebSocketDisconnect, WebSocket | ||||
|                                   as FastAPIWebSocket) | ||||
|  | ||||
| from nonebot.log import logger | ||||
| from nonebot.adapters import Bot | ||||
| from nonebot.typing import overrides | ||||
| from nonebot.config import Env, Config as NoneBotConfig | ||||
| from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket | ||||
| from nonebot.drivers import ReverseDriver, ForwardDriver | ||||
| from nonebot.drivers import HTTPRequest, WebSocket as BaseWebSocket | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class HTTPPollingSetup: | ||||
|     adapter: str | ||||
|     self_id: str | ||||
|     url: str | ||||
|     method: str | ||||
|     body: bytes | ||||
|     headers: Dict[str, str] | ||||
|     http_version: str | ||||
|     poll_interval: float | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class WebSocketSetup: | ||||
|     adapter: str | ||||
|     self_id: str | ||||
|     url: str | ||||
|     headers: Dict[str, str] | ||||
|     http_version: str | ||||
|     reconnect_interval: float | ||||
|  | ||||
|  | ||||
| class Config(BaseSettings): | ||||
| @@ -75,7 +102,7 @@ class Config(BaseSettings): | ||||
|         extra = "ignore" | ||||
|  | ||||
|  | ||||
| class Driver(ReverseDriver): | ||||
| class Driver(ReverseDriver, ForwardDriver): | ||||
|     """ | ||||
|     FastAPI 驱动框架 | ||||
|  | ||||
| @@ -90,7 +117,11 @@ class Driver(ReverseDriver): | ||||
|     def __init__(self, env: Env, config: NoneBotConfig): | ||||
|         super().__init__(env, config) | ||||
|  | ||||
|         self.fastapi_config = Config(**config.dict()) | ||||
|         self.fastapi_config: Config = Config(**config.dict()) | ||||
|         self.http_pollings: List[HTTPPollingSetup] = [] | ||||
|         self.websockets: List[WebSocketSetup] = [] | ||||
|         self.shutdown: asyncio.Event = asyncio.Event() | ||||
|         self.connections: List[asyncio.Task] = [] | ||||
|  | ||||
|         self._server_app = FastAPI( | ||||
|             debug=config.debug, | ||||
| @@ -104,6 +135,9 @@ class Driver(ReverseDriver): | ||||
|         self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse) | ||||
|         self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse) | ||||
|  | ||||
|         self.on_startup(self._run_forward) | ||||
|         self.on_shutdown(self._shutdown_forward) | ||||
|  | ||||
|     @property | ||||
|     @overrides(ReverseDriver) | ||||
|     def type(self) -> str: | ||||
| @@ -138,6 +172,32 @@ class Driver(ReverseDriver): | ||||
|         """参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_""" | ||||
|         return self.server_app.on_event("shutdown")(func) | ||||
|  | ||||
|     @overrides(ForwardDriver) | ||||
|     def setup_http_polling(self, | ||||
|                            adapter: str, | ||||
|                            self_id: str, | ||||
|                            url: str, | ||||
|                            polling_interval: float = 3., | ||||
|                            method: str = "GET", | ||||
|                            body: bytes = b"", | ||||
|                            headers: Dict[str, str] = {}, | ||||
|                            http_version: str = "1.1") -> None: | ||||
|         self.http_pollings.append( | ||||
|             HTTPPollingSetup(adapter, self_id, url, method, body, headers, | ||||
|                              http_version, polling_interval)) | ||||
|  | ||||
|     @overrides(ForwardDriver) | ||||
|     def setup_websocket(self, | ||||
|                         adapter: str, | ||||
|                         self_id: str, | ||||
|                         url: str, | ||||
|                         reconnect_interval: float = 3., | ||||
|                         headers: Dict[str, str] = {}, | ||||
|                         http_version: str = "1.1") -> None: | ||||
|         self.websockets.append( | ||||
|             WebSocketSetup(adapter, self_id, url, headers, http_version, | ||||
|                            reconnect_interval)) | ||||
|  | ||||
|     @overrides(ReverseDriver) | ||||
|     def run(self, | ||||
|             host: Optional[str] = None, | ||||
| @@ -166,14 +226,27 @@ class Driver(ReverseDriver): | ||||
|                 }, | ||||
|             }, | ||||
|         } | ||||
|         uvicorn.run(app or self.server_app, | ||||
|                     host=host or str(self.config.host), | ||||
|                     port=port or self.config.port, | ||||
|                     reload=bool(app) and self.config.debug, | ||||
|                     reload_dirs=self.fastapi_config.fastapi_reload_dirs or None, | ||||
|                     debug=self.config.debug, | ||||
|                     log_config=LOGGING_CONFIG, | ||||
|                     **kwargs) | ||||
|         uvicorn.run( | ||||
|             app or self.server_app,  # type: ignore | ||||
|             host=host or str(self.config.host), | ||||
|             port=port or self.config.port, | ||||
|             reload=bool(app) and self.config.debug, | ||||
|             reload_dirs=self.fastapi_config.fastapi_reload_dirs or None, | ||||
|             debug=self.config.debug, | ||||
|             log_config=LOGGING_CONFIG, | ||||
|             **kwargs) | ||||
|  | ||||
|     def _run_forward(self): | ||||
|         for setup in self.http_pollings: | ||||
|             self.connections.append(asyncio.create_task(self._http_loop(setup))) | ||||
|         for setup in self.websockets: | ||||
|             self.connections.append(asyncio.create_task(self._ws_loop(setup))) | ||||
|  | ||||
|     def _shutdown_forward(self): | ||||
|         self.shutdown.set() | ||||
|         for task in self.connections: | ||||
|             if not task.done(): | ||||
|                 task.cancel() | ||||
|  | ||||
|     async def _handle_http(self, adapter: str, request: Request): | ||||
|         data = await request.body() | ||||
| @@ -263,37 +336,166 @@ class Driver(ReverseDriver): | ||||
|         finally: | ||||
|             self._bot_disconnect(bot) | ||||
|  | ||||
|     async def _http_loop(self, setup: HTTPPollingSetup): | ||||
|         url = httpx.URL(setup.url) | ||||
|         if not url.netloc: | ||||
|             logger.opt(colors=True).error( | ||||
|                 f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>") | ||||
|             return | ||||
|         request = HTTPRequest( | ||||
|             setup.http_version, url.scheme, url.path, url.query, { | ||||
|                 **setup.headers, "host": url.netloc.decode("ascii") | ||||
|             }, setup.method, setup.body) | ||||
|  | ||||
|         BotClass = self._adapters[setup.adapter] | ||||
|         bot = BotClass(setup.self_id, request) | ||||
|         self._bot_connect(bot) | ||||
|         logger.opt(colors=True).info( | ||||
|             f"Start http polling for <y>{setup.adapter.upper()} " | ||||
|             f"Bot {setup.self_id}</y>") | ||||
|  | ||||
|         headers = request.headers | ||||
|         http2: bool = False | ||||
|         if request.http_version == "2": | ||||
|             http2 = True | ||||
|  | ||||
|         try: | ||||
|             async with httpx.AsyncClient(headers=headers, | ||||
|                                          timeout=30., | ||||
|                                          http2=http2) as session: | ||||
|                 while not self.shutdown.is_set(): | ||||
|                     logger.debug( | ||||
|                         f"Bot {setup.self_id} from adapter {setup.adapter} request {url}" | ||||
|                     ) | ||||
|                     try: | ||||
|                         response = await session.request(request.method, | ||||
|                                                          url, | ||||
|                                                          content=request.body) | ||||
|                         response.raise_for_status() | ||||
|                         data = response.read() | ||||
|                         asyncio.create_task(bot.handle_message(data)) | ||||
|                     except httpx.HTTPError as e: | ||||
|                         logger.opt(colors=True, exception=e).error( | ||||
|                             f"<r><bg #f8bbd0>Error occurred while requesting {url}. " | ||||
|                             "Try to reconnect...</bg #f8bbd0></r>") | ||||
|  | ||||
|                     await asyncio.sleep(setup.poll_interval) | ||||
|  | ||||
|         except asyncio.CancelledError: | ||||
|             pass | ||||
|         except Exception as e: | ||||
|             logger.opt(colors=True, exception=e).error( | ||||
|                 "<r><bg #f8bbd0>Unexpected exception occurred " | ||||
|                 "while http polling</bg #f8bbd0></r>") | ||||
|         finally: | ||||
|             self._bot_disconnect(bot) | ||||
|  | ||||
|     async def _ws_loop(self, setup: WebSocketSetup): | ||||
|         url = httpx.URL(setup.url) | ||||
|         if not url.netloc: | ||||
|             logger.opt(colors=True).error( | ||||
|                 f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>") | ||||
|             return | ||||
|  | ||||
|         headers = {**setup.headers, "host": url.netloc.decode("ascii")} | ||||
|  | ||||
|         bot: Optional[Bot] = None | ||||
|         try: | ||||
|             while True: | ||||
|                 logger.debug( | ||||
|                     f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}" | ||||
|                 ) | ||||
|                 try: | ||||
|                     connection = Connect(setup.url) | ||||
|                     async with connection as ws: | ||||
|                         logger.opt(colors=True).info( | ||||
|                             f"WebSocket Connection to <y>{setup.adapter.upper()} " | ||||
|                             f"Bot {setup.self_id}</y> succeeded!") | ||||
|                         request = WebSocket(setup.http_version, url.scheme, | ||||
|                                             url.path, url.query, headers, ws) | ||||
|  | ||||
|                         BotClass = self._adapters[setup.adapter] | ||||
|                         bot = BotClass(setup.self_id, request) | ||||
|                         self._bot_connect(bot) | ||||
|                         while not self.shutdown.is_set(): | ||||
|                             # use try except instead of "request.closed" because of queued message | ||||
|                             try: | ||||
|                                 msg = await request.receive_bytes() | ||||
|                                 asyncio.create_task(bot.handle_message(msg)) | ||||
|                             except ConnectionClosed: | ||||
|                                 logger.opt(colors=True).error( | ||||
|                                     "<r><bg #f8bbd0>WebSocket connection closed by peer. " | ||||
|                                     "Try to reconnect...</bg #f8bbd0></r>") | ||||
|                 except Exception as e: | ||||
|                     logger.opt(colors=True, exception=e).error( | ||||
|                         f"<r><bg #f8bbd0>Error while connecting to {url}. " | ||||
|                         "Try to reconnect...</bg #f8bbd0></r>") | ||||
|                 finally: | ||||
|                     if bot: | ||||
|                         self._bot_disconnect(bot) | ||||
|                     bot = None | ||||
|                 await asyncio.sleep(setup.reconnect_interval) | ||||
|  | ||||
|         except asyncio.CancelledError: | ||||
|             pass | ||||
|         except Exception as e: | ||||
|             logger.opt(colors=True, exception=e).error( | ||||
|                 "<r><bg #f8bbd0>Unexpected exception occurred " | ||||
|                 "while websocket loop</bg #f8bbd0></r>") | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class WebSocket(BaseWebSocket): | ||||
|     websocket: FastAPIWebSocket = None  # type: ignore | ||||
|     websocket: Union[FastAPIWebSocket, | ||||
|                      WebSocketClientProtocol] = None  # type: ignore | ||||
|  | ||||
|     @property | ||||
|     @overrides(BaseWebSocket) | ||||
|     def closed(self): | ||||
|         return (self.websocket.client_state == WebSocketState.DISCONNECTED or | ||||
|     def closed(self) -> bool: | ||||
|         if isinstance(self.websocket, FastAPIWebSocket): | ||||
|             return ( | ||||
|                 self.websocket.client_state == WebSocketState.DISCONNECTED or | ||||
|                 self.websocket.application_state == WebSocketState.DISCONNECTED) | ||||
|         else: | ||||
|             return self.websocket.closed | ||||
|  | ||||
|     @overrides(BaseWebSocket) | ||||
|     async def accept(self): | ||||
|         await self.websocket.accept() | ||||
|         if isinstance(self.websocket, FastAPIWebSocket): | ||||
|             await self.websocket.accept() | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     @overrides(BaseWebSocket) | ||||
|     async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): | ||||
|         await self.websocket.close(code=code) | ||||
|         await self.websocket.close(code) | ||||
|  | ||||
|     @overrides(BaseWebSocket) | ||||
|     async def receive(self) -> str: | ||||
|         return await self.websocket.receive_text() | ||||
|         if isinstance(self.websocket, FastAPIWebSocket): | ||||
|             return await self.websocket.receive_text() | ||||
|         else: | ||||
|             msg = await self.websocket.recv() | ||||
|             return msg.decode("utf-8") if isinstance(msg, bytes) else msg | ||||
|  | ||||
|     @overrides(BaseWebSocket) | ||||
|     async def receive_bytes(self) -> bytes: | ||||
|         return await self.websocket.receive_bytes() | ||||
|         if isinstance(self.websocket, FastAPIWebSocket): | ||||
|             return await self.websocket.receive_bytes() | ||||
|         else: | ||||
|             msg = await self.websocket.recv() | ||||
|             return msg.encode("utf-8") if isinstance(msg, str) else msg | ||||
|  | ||||
|     @overrides(BaseWebSocket) | ||||
|     async def send(self, data: str) -> None: | ||||
|         await self.websocket.send({"type": "websocket.send", "text": data}) | ||||
|         if isinstance(self.websocket, FastAPIWebSocket): | ||||
|             await self.websocket.send({"type": "websocket.send", "text": data}) | ||||
|         else: | ||||
|             await self.websocket.send(data) | ||||
|  | ||||
|     @overrides(BaseWebSocket) | ||||
|     async def send_bytes(self, data: bytes) -> None: | ||||
|         await self.websocket.send({"type": "websocket.send", "bytes": data}) | ||||
|         if isinstance(self.websocket, FastAPIWebSocket): | ||||
|             await self.websocket.send({"type": "websocket.send", "bytes": data}) | ||||
|         else: | ||||
|             await self.websocket.send(data) | ||||
|   | ||||
| @@ -140,14 +140,15 @@ class Driver(ReverseDriver): | ||||
|                 }, | ||||
|             }, | ||||
|         } | ||||
|         uvicorn.run(app or self.server_app, | ||||
|                     host=host or str(self.config.host), | ||||
|                     port=port or self.config.port, | ||||
|                     reload=bool(app) and self.config.debug, | ||||
|                     reload_dirs=self.quart_config.quart_reload_dirs or None, | ||||
|                     debug=self.config.debug, | ||||
|                     log_config=LOGGING_CONFIG, | ||||
|                     **kwargs) | ||||
|         uvicorn.run( | ||||
|             app or self.server_app,  # type: ignore | ||||
|             host=host or str(self.config.host), | ||||
|             port=port or self.config.port, | ||||
|             reload=bool(app) and self.config.debug, | ||||
|             reload_dirs=self.quart_config.quart_reload_dirs or None, | ||||
|             debug=self.config.debug, | ||||
|             log_config=LOGGING_CONFIG, | ||||
|             **kwargs) | ||||
|  | ||||
|     async def _handle_http(self, adapter: str): | ||||
|         request: Request = _request | ||||
|   | ||||
		Reference in New Issue
	
	Block a user