diff --git a/nonebot/drivers/_lifespan.py b/nonebot/drivers/_lifespan.py new file mode 100644 index 00000000..194986fd --- /dev/null +++ b/nonebot/drivers/_lifespan.py @@ -0,0 +1,45 @@ +from typing import Any, List, Union, Callable, Awaitable, cast + +from nonebot.utils import run_sync, is_coroutine_callable + +SYNC_LIFESPAN_FUNC = Callable[[], Any] +ASYNC_LIFESPAN_FUNC = Callable[[], Awaitable[Any]] +LIFESPAN_FUNC = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC] + + +class Lifespan: + def __init__(self) -> None: + self._startup_funcs: List[LIFESPAN_FUNC] = [] + self._shutdown_funcs: List[LIFESPAN_FUNC] = [] + + def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: + self._startup_funcs.append(func) + return func + + def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: + self._shutdown_funcs.append(func) + return func + + @staticmethod + async def _run_lifespan_func( + funcs: List[LIFESPAN_FUNC], + ) -> None: + for func in funcs: + if is_coroutine_callable(func): + await cast(ASYNC_LIFESPAN_FUNC, func)() + else: + await run_sync(cast(SYNC_LIFESPAN_FUNC, func))() + + async def startup(self) -> None: + if self._startup_funcs: + await self._run_lifespan_func(self._startup_funcs) + + async def shutdown(self) -> None: + if self._shutdown_funcs: + await self._run_lifespan_func(self._shutdown_funcs) + + async def __aenter__(self) -> None: + await self.startup() + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.shutdown() diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index e4b6267d..c24bb972 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -27,7 +27,7 @@ from nonebot.drivers import HTTPVersion, ForwardMixin, ForwardDriver, combine_dr try: import aiohttp -except ImportError as e: # pragma: no cover +except ModuleNotFoundError as e: # pragma: no cover raise ImportError( "Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`" ) from e diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index aa76ff78..4f7ea2a6 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -19,7 +19,7 @@ FrontMatter: import logging import contextlib from functools import wraps -from typing import Any, Dict, List, Tuple, Union, Callable, Optional +from typing import Any, Dict, List, Tuple, Union, Optional from pydantic import BaseSettings @@ -32,12 +32,14 @@ from nonebot.drivers import Request as BaseRequest from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup +from ._lifespan import LIFESPAN_FUNC, Lifespan + try: import uvicorn from fastapi.responses import Response from fastapi import FastAPI, Request, UploadFile, status from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect -except ImportError as e: # pragma: no cover +except ModuleNotFoundError as e: # pragma: no cover raise ImportError( "Please install FastAPI by using `pip install nonebot2[fastapi]`" ) from e @@ -92,7 +94,10 @@ class Driver(ReverseDriver): self.fastapi_config: Config = Config(**config.dict()) + self._lifespan = Lifespan() + self._server_app = FastAPI( + lifespan=self._lifespan_manager, openapi_url=self.fastapi_config.fastapi_openapi_url, docs_url=self.fastapi_config.fastapi_docs_url, redoc_url=self.fastapi_config.fastapi_redoc_url, @@ -148,14 +153,20 @@ class Driver(ReverseDriver): ) @overrides(ReverseDriver) - def on_startup(self, func: Callable) -> Callable: - """参考文档: `Events `_""" - return self.server_app.on_event("startup")(func) + def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: + return self._lifespan.on_startup(func) @overrides(ReverseDriver) - def on_shutdown(self, func: Callable) -> Callable: - """参考文档: `Events `_""" - return self.server_app.on_event("shutdown")(func) + def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: + return self._lifespan.on_shutdown(func) + + @contextlib.asynccontextmanager + async def _lifespan_manager(self, app: FastAPI): + await self._lifespan.startup() + try: + yield + finally: + await self._lifespan.shutdown() @overrides(ReverseDriver) def run( diff --git a/nonebot/drivers/httpx.py b/nonebot/drivers/httpx.py index eb55c2be..a0eac938 100644 --- a/nonebot/drivers/httpx.py +++ b/nonebot/drivers/httpx.py @@ -31,7 +31,7 @@ from nonebot.drivers import ( try: import httpx -except ImportError as e: # pragma: no cover +except ModuleNotFoundError as e: # pragma: no cover raise ImportError( "Please install httpx by using `pip install nonebot2[httpx]`" ) from e diff --git a/nonebot/drivers/none.py b/nonebot/drivers/none.py index 5fb357f3..029c849e 100644 --- a/nonebot/drivers/none.py +++ b/nonebot/drivers/none.py @@ -13,7 +13,6 @@ FrontMatter: import signal import asyncio import threading -from typing import Set, Union, Callable, Awaitable, cast from nonebot.log import logger from nonebot.consts import WINDOWS @@ -22,7 +21,8 @@ from nonebot.config import Env, Config from nonebot.drivers import Driver as BaseDriver from nonebot.utils import run_sync, is_coroutine_callable -HOOK_FUNC = Union[Callable[[], None], Callable[[], Awaitable[None]]] +from ._lifespan import LIFESPAN_FUNC, Lifespan + HANDLED_SIGNALS = ( signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. signal.SIGTERM, # Unix signal 15. Sent by `kill `. @@ -36,8 +36,9 @@ class Driver(BaseDriver): def __init__(self, env: Env, config: Config): super().__init__(env, config) - self.startup_funcs: Set[HOOK_FUNC] = set() - self.shutdown_funcs: Set[HOOK_FUNC] = set() + + self._lifespan = Lifespan() + self.should_exit: asyncio.Event = asyncio.Event() self.force_exit: bool = False @@ -54,20 +55,18 @@ class Driver(BaseDriver): return logger @overrides(BaseDriver) - def on_startup(self, func: HOOK_FUNC) -> HOOK_FUNC: + def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: """ 注册一个启动时执行的函数 """ - self.startup_funcs.add(func) - return func + return self._lifespan.on_startup(func) @overrides(BaseDriver) - def on_shutdown(self, func: HOOK_FUNC) -> HOOK_FUNC: + def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: """ 注册一个停止时执行的函数 """ - self.shutdown_funcs.add(func) - return func + return self._lifespan.on_shutdown(func) @overrides(BaseDriver) def run(self, *args, **kwargs): @@ -85,21 +84,13 @@ class Driver(BaseDriver): await self._shutdown() async def _startup(self): - # run startup - cors = [ - cast(Callable[..., Awaitable[None]], startup)() - if is_coroutine_callable(startup) - else run_sync(startup)() - for startup in self.startup_funcs - ] - if cors: - try: - await asyncio.gather(*cors) - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running startup function. " - "Ignored!" - ) + try: + await self._lifespan.startup() + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running startup function. " + "Ignored!" + ) logger.info("Application startup completed.") @@ -110,21 +101,14 @@ class Driver(BaseDriver): logger.info("Shutting down") logger.info("Waiting for application shutdown.") - # run shutdown - cors = [ - cast(Callable[..., Awaitable[None]], shutdown)() - if is_coroutine_callable(shutdown) - else run_sync(shutdown)() - for shutdown in self.shutdown_funcs - ] - if cors: - try: - await asyncio.gather(*cors) - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running shutdown function. " - "Ignored!" - ) + + try: + await self._lifespan.shutdown() + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running shutdown function. " + "Ignored!" + ) for task in asyncio.all_tasks(): if task is not asyncio.current_task() and not task.done(): diff --git a/nonebot/drivers/quart.py b/nonebot/drivers/quart.py index f16134e3..4c5d69cd 100644 --- a/nonebot/drivers/quart.py +++ b/nonebot/drivers/quart.py @@ -37,7 +37,7 @@ try: from quart import Quart, Request, Response from quart.datastructures import FileStorage from quart import Websocket as QuartWebSocket -except ImportError as e: # pragma: no cover +except ModuleNotFoundError as e: # pragma: no cover raise ImportError( "Please install Quart by using `pip install nonebot2[quart]`" ) from e diff --git a/nonebot/drivers/websockets.py b/nonebot/drivers/websockets.py index 808ff791..644fa063 100644 --- a/nonebot/drivers/websockets.py +++ b/nonebot/drivers/websockets.py @@ -30,7 +30,7 @@ from nonebot.drivers import ForwardMixin, ForwardDriver, combine_driver try: from websockets.exceptions import ConnectionClosed from websockets.legacy.client import Connect, WebSocketClientProtocol -except ImportError as e: # pragma: no cover +except ModuleNotFoundError as e: # pragma: no cover raise ImportError( "Please install websockets by using `pip install nonebot2[websockets]`" ) from e diff --git a/tests/test_driver.py b/tests/test_driver.py index d0b00c88..624cf00e 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -12,6 +12,7 @@ from nonebot.params import Depends from nonebot import _resolve_combine_expr from nonebot.dependencies import Dependent from nonebot.exception import WebSocketClosed +from nonebot.drivers._lifespan import Lifespan from nonebot.drivers import ( URL, Driver, @@ -36,6 +37,39 @@ def load_driver(request: pytest.FixtureRequest) -> Driver: return DriverClass(Env(environment=global_driver.env), global_driver.config) +@pytest.mark.asyncio +async def test_lifespan(): + lifespan = Lifespan() + + start_log = [] + shutdown_log = [] + + @lifespan.on_startup + async def _startup1(): + assert start_log == [] + start_log.append(1) + + @lifespan.on_startup + async def _startup2(): + assert start_log == [1] + start_log.append(2) + + @lifespan.on_shutdown + async def _shutdown1(): + assert shutdown_log == [] + shutdown_log.append(1) + + @lifespan.on_shutdown + async def _shutdown2(): + assert shutdown_log == [1] + shutdown_log.append(2) + + async with lifespan: + assert start_log == [1, 2] + + assert shutdown_log == [1, 2] + + @pytest.mark.asyncio @pytest.mark.parametrize( "driver",