mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-26 12:36:40 +00:00 
			
		
		
		
	✨ Feature: 重构驱动器 lifespan 方法 (#1860)
This commit is contained in:
		
							
								
								
									
										45
									
								
								nonebot/drivers/_lifespan.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								nonebot/drivers/_lifespan.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | ||||
| from typing import Any, List, Union, Callable, Awaitable, cast | ||||
|  | ||||
| from nonebot.utils import run_sync, is_coroutine_callable | ||||
|  | ||||
| SYNC_LIFESPAN_FUNC = Callable[[], Any] | ||||
| ASYNC_LIFESPAN_FUNC = Callable[[], Awaitable[Any]] | ||||
| LIFESPAN_FUNC = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC] | ||||
|  | ||||
|  | ||||
| class Lifespan: | ||||
|     def __init__(self) -> None: | ||||
|         self._startup_funcs: List[LIFESPAN_FUNC] = [] | ||||
|         self._shutdown_funcs: List[LIFESPAN_FUNC] = [] | ||||
|  | ||||
|     def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: | ||||
|         self._startup_funcs.append(func) | ||||
|         return func | ||||
|  | ||||
|     def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: | ||||
|         self._shutdown_funcs.append(func) | ||||
|         return func | ||||
|  | ||||
|     @staticmethod | ||||
|     async def _run_lifespan_func( | ||||
|         funcs: List[LIFESPAN_FUNC], | ||||
|     ) -> None: | ||||
|         for func in funcs: | ||||
|             if is_coroutine_callable(func): | ||||
|                 await cast(ASYNC_LIFESPAN_FUNC, func)() | ||||
|             else: | ||||
|                 await run_sync(cast(SYNC_LIFESPAN_FUNC, func))() | ||||
|  | ||||
|     async def startup(self) -> None: | ||||
|         if self._startup_funcs: | ||||
|             await self._run_lifespan_func(self._startup_funcs) | ||||
|  | ||||
|     async def shutdown(self) -> None: | ||||
|         if self._shutdown_funcs: | ||||
|             await self._run_lifespan_func(self._shutdown_funcs) | ||||
|  | ||||
|     async def __aenter__(self) -> None: | ||||
|         await self.startup() | ||||
|  | ||||
|     async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: | ||||
|         await self.shutdown() | ||||
| @@ -27,7 +27,7 @@ from nonebot.drivers import HTTPVersion, ForwardMixin, ForwardDriver, combine_dr | ||||
|  | ||||
| try: | ||||
|     import aiohttp | ||||
| except ImportError as e:  # pragma: no cover | ||||
| except ModuleNotFoundError as e:  # pragma: no cover | ||||
|     raise ImportError( | ||||
|         "Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`" | ||||
|     ) from e | ||||
|   | ||||
| @@ -19,7 +19,7 @@ FrontMatter: | ||||
| import logging | ||||
| import contextlib | ||||
| from functools import wraps | ||||
| from typing import Any, Dict, List, Tuple, Union, Callable, Optional | ||||
| from typing import Any, Dict, List, Tuple, Union, Optional | ||||
|  | ||||
| from pydantic import BaseSettings | ||||
|  | ||||
| @@ -32,12 +32,14 @@ from nonebot.drivers import Request as BaseRequest | ||||
| from nonebot.drivers import WebSocket as BaseWebSocket | ||||
| from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup | ||||
|  | ||||
| from ._lifespan import LIFESPAN_FUNC, Lifespan | ||||
|  | ||||
| try: | ||||
|     import uvicorn | ||||
|     from fastapi.responses import Response | ||||
|     from fastapi import FastAPI, Request, UploadFile, status | ||||
|     from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect | ||||
| except ImportError as e:  # pragma: no cover | ||||
| except ModuleNotFoundError as e:  # pragma: no cover | ||||
|     raise ImportError( | ||||
|         "Please install FastAPI by using `pip install nonebot2[fastapi]`" | ||||
|     ) from e | ||||
| @@ -92,7 +94,10 @@ class Driver(ReverseDriver): | ||||
|  | ||||
|         self.fastapi_config: Config = Config(**config.dict()) | ||||
|  | ||||
|         self._lifespan = Lifespan() | ||||
|  | ||||
|         self._server_app = FastAPI( | ||||
|             lifespan=self._lifespan_manager, | ||||
|             openapi_url=self.fastapi_config.fastapi_openapi_url, | ||||
|             docs_url=self.fastapi_config.fastapi_docs_url, | ||||
|             redoc_url=self.fastapi_config.fastapi_redoc_url, | ||||
| @@ -148,14 +153,20 @@ class Driver(ReverseDriver): | ||||
|         ) | ||||
|  | ||||
|     @overrides(ReverseDriver) | ||||
|     def on_startup(self, func: Callable) -> Callable: | ||||
|         """参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_""" | ||||
|         return self.server_app.on_event("startup")(func) | ||||
|     def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: | ||||
|         return self._lifespan.on_startup(func) | ||||
|  | ||||
|     @overrides(ReverseDriver) | ||||
|     def on_shutdown(self, func: Callable) -> Callable: | ||||
|         """参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#shutdown-event>`_""" | ||||
|         return self.server_app.on_event("shutdown")(func) | ||||
|     def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: | ||||
|         return self._lifespan.on_shutdown(func) | ||||
|  | ||||
|     @contextlib.asynccontextmanager | ||||
|     async def _lifespan_manager(self, app: FastAPI): | ||||
|         await self._lifespan.startup() | ||||
|         try: | ||||
|             yield | ||||
|         finally: | ||||
|             await self._lifespan.shutdown() | ||||
|  | ||||
|     @overrides(ReverseDriver) | ||||
|     def run( | ||||
|   | ||||
| @@ -31,7 +31,7 @@ from nonebot.drivers import ( | ||||
|  | ||||
| try: | ||||
|     import httpx | ||||
| except ImportError as e:  # pragma: no cover | ||||
| except ModuleNotFoundError as e:  # pragma: no cover | ||||
|     raise ImportError( | ||||
|         "Please install httpx by using `pip install nonebot2[httpx]`" | ||||
|     ) from e | ||||
|   | ||||
| @@ -13,7 +13,6 @@ FrontMatter: | ||||
| import signal | ||||
| import asyncio | ||||
| import threading | ||||
| from typing import Set, Union, Callable, Awaitable, cast | ||||
|  | ||||
| from nonebot.log import logger | ||||
| from nonebot.consts import WINDOWS | ||||
| @@ -22,7 +21,8 @@ from nonebot.config import Env, Config | ||||
| from nonebot.drivers import Driver as BaseDriver | ||||
| from nonebot.utils import run_sync, is_coroutine_callable | ||||
|  | ||||
| HOOK_FUNC = Union[Callable[[], None], Callable[[], Awaitable[None]]] | ||||
| from ._lifespan import LIFESPAN_FUNC, Lifespan | ||||
|  | ||||
| HANDLED_SIGNALS = ( | ||||
|     signal.SIGINT,  # Unix signal 2. Sent by Ctrl+C. | ||||
|     signal.SIGTERM,  # Unix signal 15. Sent by `kill <pid>`. | ||||
| @@ -36,8 +36,9 @@ class Driver(BaseDriver): | ||||
|  | ||||
|     def __init__(self, env: Env, config: Config): | ||||
|         super().__init__(env, config) | ||||
|         self.startup_funcs: Set[HOOK_FUNC] = set() | ||||
|         self.shutdown_funcs: Set[HOOK_FUNC] = set() | ||||
|  | ||||
|         self._lifespan = Lifespan() | ||||
|  | ||||
|         self.should_exit: asyncio.Event = asyncio.Event() | ||||
|         self.force_exit: bool = False | ||||
|  | ||||
| @@ -54,20 +55,18 @@ class Driver(BaseDriver): | ||||
|         return logger | ||||
|  | ||||
|     @overrides(BaseDriver) | ||||
|     def on_startup(self, func: HOOK_FUNC) -> HOOK_FUNC: | ||||
|     def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: | ||||
|         """ | ||||
|         注册一个启动时执行的函数 | ||||
|         """ | ||||
|         self.startup_funcs.add(func) | ||||
|         return func | ||||
|         return self._lifespan.on_startup(func) | ||||
|  | ||||
|     @overrides(BaseDriver) | ||||
|     def on_shutdown(self, func: HOOK_FUNC) -> HOOK_FUNC: | ||||
|     def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: | ||||
|         """ | ||||
|         注册一个停止时执行的函数 | ||||
|         """ | ||||
|         self.shutdown_funcs.add(func) | ||||
|         return func | ||||
|         return self._lifespan.on_shutdown(func) | ||||
|  | ||||
|     @overrides(BaseDriver) | ||||
|     def run(self, *args, **kwargs): | ||||
| @@ -85,16 +84,8 @@ class Driver(BaseDriver): | ||||
|         await self._shutdown() | ||||
|  | ||||
|     async def _startup(self): | ||||
|         # run startup | ||||
|         cors = [ | ||||
|             cast(Callable[..., Awaitable[None]], startup)() | ||||
|             if is_coroutine_callable(startup) | ||||
|             else run_sync(startup)() | ||||
|             for startup in self.startup_funcs | ||||
|         ] | ||||
|         if cors: | ||||
|         try: | ||||
|                 await asyncio.gather(*cors) | ||||
|             await self._lifespan.startup() | ||||
|         except Exception as e: | ||||
|             logger.opt(colors=True, exception=e).error( | ||||
|                 "<r><bg #f8bbd0>Error when running startup function. " | ||||
| @@ -110,16 +101,9 @@ class Driver(BaseDriver): | ||||
|         logger.info("Shutting down") | ||||
|  | ||||
|         logger.info("Waiting for application shutdown.") | ||||
|         # run shutdown | ||||
|         cors = [ | ||||
|             cast(Callable[..., Awaitable[None]], shutdown)() | ||||
|             if is_coroutine_callable(shutdown) | ||||
|             else run_sync(shutdown)() | ||||
|             for shutdown in self.shutdown_funcs | ||||
|         ] | ||||
|         if cors: | ||||
|  | ||||
|         try: | ||||
|                 await asyncio.gather(*cors) | ||||
|             await self._lifespan.shutdown() | ||||
|         except Exception as e: | ||||
|             logger.opt(colors=True, exception=e).error( | ||||
|                 "<r><bg #f8bbd0>Error when running shutdown function. " | ||||
|   | ||||
| @@ -37,7 +37,7 @@ try: | ||||
|     from quart import Quart, Request, Response | ||||
|     from quart.datastructures import FileStorage | ||||
|     from quart import Websocket as QuartWebSocket | ||||
| except ImportError as e:  # pragma: no cover | ||||
| except ModuleNotFoundError as e:  # pragma: no cover | ||||
|     raise ImportError( | ||||
|         "Please install Quart by using `pip install nonebot2[quart]`" | ||||
|     ) from e | ||||
|   | ||||
| @@ -30,7 +30,7 @@ from nonebot.drivers import ForwardMixin, ForwardDriver, combine_driver | ||||
| try: | ||||
|     from websockets.exceptions import ConnectionClosed | ||||
|     from websockets.legacy.client import Connect, WebSocketClientProtocol | ||||
| except ImportError as e:  # pragma: no cover | ||||
| except ModuleNotFoundError as e:  # pragma: no cover | ||||
|     raise ImportError( | ||||
|         "Please install websockets by using `pip install nonebot2[websockets]`" | ||||
|     ) from e | ||||
|   | ||||
| @@ -12,6 +12,7 @@ from nonebot.params import Depends | ||||
| from nonebot import _resolve_combine_expr | ||||
| from nonebot.dependencies import Dependent | ||||
| from nonebot.exception import WebSocketClosed | ||||
| from nonebot.drivers._lifespan import Lifespan | ||||
| from nonebot.drivers import ( | ||||
|     URL, | ||||
|     Driver, | ||||
| @@ -36,6 +37,39 @@ def load_driver(request: pytest.FixtureRequest) -> Driver: | ||||
|     return DriverClass(Env(environment=global_driver.env), global_driver.config) | ||||
|  | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| async def test_lifespan(): | ||||
|     lifespan = Lifespan() | ||||
|  | ||||
|     start_log = [] | ||||
|     shutdown_log = [] | ||||
|  | ||||
|     @lifespan.on_startup | ||||
|     async def _startup1(): | ||||
|         assert start_log == [] | ||||
|         start_log.append(1) | ||||
|  | ||||
|     @lifespan.on_startup | ||||
|     async def _startup2(): | ||||
|         assert start_log == [1] | ||||
|         start_log.append(2) | ||||
|  | ||||
|     @lifespan.on_shutdown | ||||
|     async def _shutdown1(): | ||||
|         assert shutdown_log == [] | ||||
|         shutdown_log.append(1) | ||||
|  | ||||
|     @lifespan.on_shutdown | ||||
|     async def _shutdown2(): | ||||
|         assert shutdown_log == [1] | ||||
|         shutdown_log.append(2) | ||||
|  | ||||
|     async with lifespan: | ||||
|         assert start_log == [1, 2] | ||||
|  | ||||
|     assert shutdown_log == [1, 2] | ||||
|  | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| @pytest.mark.parametrize( | ||||
|     "driver", | ||||
|   | ||||
		Reference in New Issue
	
	Block a user