mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-26 12:36:40 +00:00 
			
		
		
		
	✨ Feature: 重构驱动器 lifespan 方法 (#1860)
This commit is contained in:
		
							
								
								
									
										45
									
								
								nonebot/drivers/_lifespan.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								nonebot/drivers/_lifespan.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | |||||||
|  | from typing import Any, List, Union, Callable, Awaitable, cast | ||||||
|  |  | ||||||
|  | from nonebot.utils import run_sync, is_coroutine_callable | ||||||
|  |  | ||||||
|  | SYNC_LIFESPAN_FUNC = Callable[[], Any] | ||||||
|  | ASYNC_LIFESPAN_FUNC = Callable[[], Awaitable[Any]] | ||||||
|  | LIFESPAN_FUNC = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Lifespan: | ||||||
|  |     def __init__(self) -> None: | ||||||
|  |         self._startup_funcs: List[LIFESPAN_FUNC] = [] | ||||||
|  |         self._shutdown_funcs: List[LIFESPAN_FUNC] = [] | ||||||
|  |  | ||||||
|  |     def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: | ||||||
|  |         self._startup_funcs.append(func) | ||||||
|  |         return func | ||||||
|  |  | ||||||
|  |     def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: | ||||||
|  |         self._shutdown_funcs.append(func) | ||||||
|  |         return func | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     async def _run_lifespan_func( | ||||||
|  |         funcs: List[LIFESPAN_FUNC], | ||||||
|  |     ) -> None: | ||||||
|  |         for func in funcs: | ||||||
|  |             if is_coroutine_callable(func): | ||||||
|  |                 await cast(ASYNC_LIFESPAN_FUNC, func)() | ||||||
|  |             else: | ||||||
|  |                 await run_sync(cast(SYNC_LIFESPAN_FUNC, func))() | ||||||
|  |  | ||||||
|  |     async def startup(self) -> None: | ||||||
|  |         if self._startup_funcs: | ||||||
|  |             await self._run_lifespan_func(self._startup_funcs) | ||||||
|  |  | ||||||
|  |     async def shutdown(self) -> None: | ||||||
|  |         if self._shutdown_funcs: | ||||||
|  |             await self._run_lifespan_func(self._shutdown_funcs) | ||||||
|  |  | ||||||
|  |     async def __aenter__(self) -> None: | ||||||
|  |         await self.startup() | ||||||
|  |  | ||||||
|  |     async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: | ||||||
|  |         await self.shutdown() | ||||||
| @@ -27,7 +27,7 @@ from nonebot.drivers import HTTPVersion, ForwardMixin, ForwardDriver, combine_dr | |||||||
|  |  | ||||||
| try: | try: | ||||||
|     import aiohttp |     import aiohttp | ||||||
| except ImportError as e:  # pragma: no cover | except ModuleNotFoundError as e:  # pragma: no cover | ||||||
|     raise ImportError( |     raise ImportError( | ||||||
|         "Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`" |         "Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`" | ||||||
|     ) from e |     ) from e | ||||||
|   | |||||||
| @@ -19,7 +19,7 @@ FrontMatter: | |||||||
| import logging | import logging | ||||||
| import contextlib | import contextlib | ||||||
| from functools import wraps | from functools import wraps | ||||||
| from typing import Any, Dict, List, Tuple, Union, Callable, Optional | from typing import Any, Dict, List, Tuple, Union, Optional | ||||||
|  |  | ||||||
| from pydantic import BaseSettings | from pydantic import BaseSettings | ||||||
|  |  | ||||||
| @@ -32,12 +32,14 @@ from nonebot.drivers import Request as BaseRequest | |||||||
| from nonebot.drivers import WebSocket as BaseWebSocket | from nonebot.drivers import WebSocket as BaseWebSocket | ||||||
| from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup | from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup | ||||||
|  |  | ||||||
|  | from ._lifespan import LIFESPAN_FUNC, Lifespan | ||||||
|  |  | ||||||
| try: | try: | ||||||
|     import uvicorn |     import uvicorn | ||||||
|     from fastapi.responses import Response |     from fastapi.responses import Response | ||||||
|     from fastapi import FastAPI, Request, UploadFile, status |     from fastapi import FastAPI, Request, UploadFile, status | ||||||
|     from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect |     from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect | ||||||
| except ImportError as e:  # pragma: no cover | except ModuleNotFoundError as e:  # pragma: no cover | ||||||
|     raise ImportError( |     raise ImportError( | ||||||
|         "Please install FastAPI by using `pip install nonebot2[fastapi]`" |         "Please install FastAPI by using `pip install nonebot2[fastapi]`" | ||||||
|     ) from e |     ) from e | ||||||
| @@ -92,7 +94,10 @@ class Driver(ReverseDriver): | |||||||
|  |  | ||||||
|         self.fastapi_config: Config = Config(**config.dict()) |         self.fastapi_config: Config = Config(**config.dict()) | ||||||
|  |  | ||||||
|  |         self._lifespan = Lifespan() | ||||||
|  |  | ||||||
|         self._server_app = FastAPI( |         self._server_app = FastAPI( | ||||||
|  |             lifespan=self._lifespan_manager, | ||||||
|             openapi_url=self.fastapi_config.fastapi_openapi_url, |             openapi_url=self.fastapi_config.fastapi_openapi_url, | ||||||
|             docs_url=self.fastapi_config.fastapi_docs_url, |             docs_url=self.fastapi_config.fastapi_docs_url, | ||||||
|             redoc_url=self.fastapi_config.fastapi_redoc_url, |             redoc_url=self.fastapi_config.fastapi_redoc_url, | ||||||
| @@ -148,14 +153,20 @@ class Driver(ReverseDriver): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     @overrides(ReverseDriver) |     @overrides(ReverseDriver) | ||||||
|     def on_startup(self, func: Callable) -> Callable: |     def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: | ||||||
|         """参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_""" |         return self._lifespan.on_startup(func) | ||||||
|         return self.server_app.on_event("startup")(func) |  | ||||||
|  |  | ||||||
|     @overrides(ReverseDriver) |     @overrides(ReverseDriver) | ||||||
|     def on_shutdown(self, func: Callable) -> Callable: |     def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: | ||||||
|         """参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#shutdown-event>`_""" |         return self._lifespan.on_shutdown(func) | ||||||
|         return self.server_app.on_event("shutdown")(func) |  | ||||||
|  |     @contextlib.asynccontextmanager | ||||||
|  |     async def _lifespan_manager(self, app: FastAPI): | ||||||
|  |         await self._lifespan.startup() | ||||||
|  |         try: | ||||||
|  |             yield | ||||||
|  |         finally: | ||||||
|  |             await self._lifespan.shutdown() | ||||||
|  |  | ||||||
|     @overrides(ReverseDriver) |     @overrides(ReverseDriver) | ||||||
|     def run( |     def run( | ||||||
|   | |||||||
| @@ -31,7 +31,7 @@ from nonebot.drivers import ( | |||||||
|  |  | ||||||
| try: | try: | ||||||
|     import httpx |     import httpx | ||||||
| except ImportError as e:  # pragma: no cover | except ModuleNotFoundError as e:  # pragma: no cover | ||||||
|     raise ImportError( |     raise ImportError( | ||||||
|         "Please install httpx by using `pip install nonebot2[httpx]`" |         "Please install httpx by using `pip install nonebot2[httpx]`" | ||||||
|     ) from e |     ) from e | ||||||
|   | |||||||
| @@ -13,7 +13,6 @@ FrontMatter: | |||||||
| import signal | import signal | ||||||
| import asyncio | import asyncio | ||||||
| import threading | import threading | ||||||
| from typing import Set, Union, Callable, Awaitable, cast |  | ||||||
|  |  | ||||||
| from nonebot.log import logger | from nonebot.log import logger | ||||||
| from nonebot.consts import WINDOWS | from nonebot.consts import WINDOWS | ||||||
| @@ -22,7 +21,8 @@ from nonebot.config import Env, Config | |||||||
| from nonebot.drivers import Driver as BaseDriver | from nonebot.drivers import Driver as BaseDriver | ||||||
| from nonebot.utils import run_sync, is_coroutine_callable | from nonebot.utils import run_sync, is_coroutine_callable | ||||||
|  |  | ||||||
| HOOK_FUNC = Union[Callable[[], None], Callable[[], Awaitable[None]]] | from ._lifespan import LIFESPAN_FUNC, Lifespan | ||||||
|  |  | ||||||
| HANDLED_SIGNALS = ( | HANDLED_SIGNALS = ( | ||||||
|     signal.SIGINT,  # Unix signal 2. Sent by Ctrl+C. |     signal.SIGINT,  # Unix signal 2. Sent by Ctrl+C. | ||||||
|     signal.SIGTERM,  # Unix signal 15. Sent by `kill <pid>`. |     signal.SIGTERM,  # Unix signal 15. Sent by `kill <pid>`. | ||||||
| @@ -36,8 +36,9 @@ class Driver(BaseDriver): | |||||||
|  |  | ||||||
|     def __init__(self, env: Env, config: Config): |     def __init__(self, env: Env, config: Config): | ||||||
|         super().__init__(env, config) |         super().__init__(env, config) | ||||||
|         self.startup_funcs: Set[HOOK_FUNC] = set() |  | ||||||
|         self.shutdown_funcs: Set[HOOK_FUNC] = set() |         self._lifespan = Lifespan() | ||||||
|  |  | ||||||
|         self.should_exit: asyncio.Event = asyncio.Event() |         self.should_exit: asyncio.Event = asyncio.Event() | ||||||
|         self.force_exit: bool = False |         self.force_exit: bool = False | ||||||
|  |  | ||||||
| @@ -54,20 +55,18 @@ class Driver(BaseDriver): | |||||||
|         return logger |         return logger | ||||||
|  |  | ||||||
|     @overrides(BaseDriver) |     @overrides(BaseDriver) | ||||||
|     def on_startup(self, func: HOOK_FUNC) -> HOOK_FUNC: |     def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: | ||||||
|         """ |         """ | ||||||
|         注册一个启动时执行的函数 |         注册一个启动时执行的函数 | ||||||
|         """ |         """ | ||||||
|         self.startup_funcs.add(func) |         return self._lifespan.on_startup(func) | ||||||
|         return func |  | ||||||
|  |  | ||||||
|     @overrides(BaseDriver) |     @overrides(BaseDriver) | ||||||
|     def on_shutdown(self, func: HOOK_FUNC) -> HOOK_FUNC: |     def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: | ||||||
|         """ |         """ | ||||||
|         注册一个停止时执行的函数 |         注册一个停止时执行的函数 | ||||||
|         """ |         """ | ||||||
|         self.shutdown_funcs.add(func) |         return self._lifespan.on_shutdown(func) | ||||||
|         return func |  | ||||||
|  |  | ||||||
|     @overrides(BaseDriver) |     @overrides(BaseDriver) | ||||||
|     def run(self, *args, **kwargs): |     def run(self, *args, **kwargs): | ||||||
| @@ -85,21 +84,13 @@ class Driver(BaseDriver): | |||||||
|         await self._shutdown() |         await self._shutdown() | ||||||
|  |  | ||||||
|     async def _startup(self): |     async def _startup(self): | ||||||
|         # run startup |         try: | ||||||
|         cors = [ |             await self._lifespan.startup() | ||||||
|             cast(Callable[..., Awaitable[None]], startup)() |         except Exception as e: | ||||||
|             if is_coroutine_callable(startup) |             logger.opt(colors=True, exception=e).error( | ||||||
|             else run_sync(startup)() |                 "<r><bg #f8bbd0>Error when running startup function. " | ||||||
|             for startup in self.startup_funcs |                 "Ignored!</bg #f8bbd0></r>" | ||||||
|         ] |             ) | ||||||
|         if cors: |  | ||||||
|             try: |  | ||||||
|                 await asyncio.gather(*cors) |  | ||||||
|             except Exception as e: |  | ||||||
|                 logger.opt(colors=True, exception=e).error( |  | ||||||
|                     "<r><bg #f8bbd0>Error when running startup function. " |  | ||||||
|                     "Ignored!</bg #f8bbd0></r>" |  | ||||||
|                 ) |  | ||||||
|  |  | ||||||
|         logger.info("Application startup completed.") |         logger.info("Application startup completed.") | ||||||
|  |  | ||||||
| @@ -110,21 +101,14 @@ class Driver(BaseDriver): | |||||||
|         logger.info("Shutting down") |         logger.info("Shutting down") | ||||||
|  |  | ||||||
|         logger.info("Waiting for application shutdown.") |         logger.info("Waiting for application shutdown.") | ||||||
|         # run shutdown |  | ||||||
|         cors = [ |         try: | ||||||
|             cast(Callable[..., Awaitable[None]], shutdown)() |             await self._lifespan.shutdown() | ||||||
|             if is_coroutine_callable(shutdown) |         except Exception as e: | ||||||
|             else run_sync(shutdown)() |             logger.opt(colors=True, exception=e).error( | ||||||
|             for shutdown in self.shutdown_funcs |                 "<r><bg #f8bbd0>Error when running shutdown function. " | ||||||
|         ] |                 "Ignored!</bg #f8bbd0></r>" | ||||||
|         if cors: |             ) | ||||||
|             try: |  | ||||||
|                 await asyncio.gather(*cors) |  | ||||||
|             except Exception as e: |  | ||||||
|                 logger.opt(colors=True, exception=e).error( |  | ||||||
|                     "<r><bg #f8bbd0>Error when running shutdown function. " |  | ||||||
|                     "Ignored!</bg #f8bbd0></r>" |  | ||||||
|                 ) |  | ||||||
|  |  | ||||||
|         for task in asyncio.all_tasks(): |         for task in asyncio.all_tasks(): | ||||||
|             if task is not asyncio.current_task() and not task.done(): |             if task is not asyncio.current_task() and not task.done(): | ||||||
|   | |||||||
| @@ -37,7 +37,7 @@ try: | |||||||
|     from quart import Quart, Request, Response |     from quart import Quart, Request, Response | ||||||
|     from quart.datastructures import FileStorage |     from quart.datastructures import FileStorage | ||||||
|     from quart import Websocket as QuartWebSocket |     from quart import Websocket as QuartWebSocket | ||||||
| except ImportError as e:  # pragma: no cover | except ModuleNotFoundError as e:  # pragma: no cover | ||||||
|     raise ImportError( |     raise ImportError( | ||||||
|         "Please install Quart by using `pip install nonebot2[quart]`" |         "Please install Quart by using `pip install nonebot2[quart]`" | ||||||
|     ) from e |     ) from e | ||||||
|   | |||||||
| @@ -30,7 +30,7 @@ from nonebot.drivers import ForwardMixin, ForwardDriver, combine_driver | |||||||
| try: | try: | ||||||
|     from websockets.exceptions import ConnectionClosed |     from websockets.exceptions import ConnectionClosed | ||||||
|     from websockets.legacy.client import Connect, WebSocketClientProtocol |     from websockets.legacy.client import Connect, WebSocketClientProtocol | ||||||
| except ImportError as e:  # pragma: no cover | except ModuleNotFoundError as e:  # pragma: no cover | ||||||
|     raise ImportError( |     raise ImportError( | ||||||
|         "Please install websockets by using `pip install nonebot2[websockets]`" |         "Please install websockets by using `pip install nonebot2[websockets]`" | ||||||
|     ) from e |     ) from e | ||||||
|   | |||||||
| @@ -12,6 +12,7 @@ from nonebot.params import Depends | |||||||
| from nonebot import _resolve_combine_expr | from nonebot import _resolve_combine_expr | ||||||
| from nonebot.dependencies import Dependent | from nonebot.dependencies import Dependent | ||||||
| from nonebot.exception import WebSocketClosed | from nonebot.exception import WebSocketClosed | ||||||
|  | from nonebot.drivers._lifespan import Lifespan | ||||||
| from nonebot.drivers import ( | from nonebot.drivers import ( | ||||||
|     URL, |     URL, | ||||||
|     Driver, |     Driver, | ||||||
| @@ -36,6 +37,39 @@ def load_driver(request: pytest.FixtureRequest) -> Driver: | |||||||
|     return DriverClass(Env(environment=global_driver.env), global_driver.config) |     return DriverClass(Env(environment=global_driver.env), global_driver.config) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_lifespan(): | ||||||
|  |     lifespan = Lifespan() | ||||||
|  |  | ||||||
|  |     start_log = [] | ||||||
|  |     shutdown_log = [] | ||||||
|  |  | ||||||
|  |     @lifespan.on_startup | ||||||
|  |     async def _startup1(): | ||||||
|  |         assert start_log == [] | ||||||
|  |         start_log.append(1) | ||||||
|  |  | ||||||
|  |     @lifespan.on_startup | ||||||
|  |     async def _startup2(): | ||||||
|  |         assert start_log == [1] | ||||||
|  |         start_log.append(2) | ||||||
|  |  | ||||||
|  |     @lifespan.on_shutdown | ||||||
|  |     async def _shutdown1(): | ||||||
|  |         assert shutdown_log == [] | ||||||
|  |         shutdown_log.append(1) | ||||||
|  |  | ||||||
|  |     @lifespan.on_shutdown | ||||||
|  |     async def _shutdown2(): | ||||||
|  |         assert shutdown_log == [1] | ||||||
|  |         shutdown_log.append(2) | ||||||
|  |  | ||||||
|  |     async with lifespan: | ||||||
|  |         assert start_log == [1, 2] | ||||||
|  |  | ||||||
|  |     assert shutdown_log == [1, 2] | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|     "driver", |     "driver", | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user