🐛 Fix: 新增 Lifespan._on_ready() 供适配器使用 (#2483)

Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Bryan不可思议
2023-12-10 18:12:10 +08:00
committed by GitHub
parent 915274081d
commit 8f3f385cb6
7 changed files with 48 additions and 67 deletions

View File

@ -34,8 +34,6 @@ from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import HTTPServerSetup, WebSocketServerSetup
from ._lifespan import LIFESPAN_FUNC, Lifespan
try:
import uvicorn
from fastapi.responses import Response
@ -97,8 +95,6 @@ class Driver(BaseDriver, ASGIMixin):
self.fastapi_config: Config = Config(**config.dict())
self._lifespan = Lifespan()
self._server_app = FastAPI(
lifespan=self._lifespan_manager,
openapi_url=self.fastapi_config.fastapi_openapi_url,
@ -155,14 +151,6 @@ class Driver(BaseDriver, ASGIMixin):
name=setup.name,
)
@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_startup(func)
@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_shutdown(func)
@contextlib.asynccontextmanager
async def _lifespan_manager(self, app: FastAPI):
await self._lifespan.startup()

View File

@ -19,8 +19,6 @@ from nonebot.consts import WINDOWS
from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver
from ._lifespan import LIFESPAN_FUNC, Lifespan
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
@ -35,8 +33,6 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self._lifespan = Lifespan()
self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False
@ -52,16 +48,6 @@ class Driver(BaseDriver):
"""none driver 使用的 logger"""
return logger
@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)
@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)
@override
def run(self, *args, **kwargs):
"""启动 none driver"""

View File

@ -18,18 +18,7 @@ FrontMatter:
import asyncio
from functools import wraps
from typing_extensions import override
from typing import (
Any,
Dict,
List,
Tuple,
Union,
TypeVar,
Callable,
Optional,
Coroutine,
cast,
)
from typing import Any, Dict, List, Tuple, Union, Optional, cast
from pydantic import BaseSettings
@ -57,8 +46,6 @@ except ModuleNotFoundError as e: # pragma: no cover
"Install with pip: `pip install nonebot2[quart]`"
) from e
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
def catch_closed(func):
@wraps(func)
@ -102,6 +89,8 @@ class Driver(BaseDriver, ASGIMixin):
self._server_app = Quart(
self.__class__.__qualname__, **self.quart_config.quart_extra
)
self._server_app.before_serving(self._lifespan.startup)
self._server_app.after_serving(self._lifespan.shutdown)
@property
@override
@ -150,16 +139,6 @@ class Driver(BaseDriver, ASGIMixin):
view_func=_handle,
)
@override
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.before_serving(func) # type: ignore
@override
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.after_serving(func) # type: ignore
@override
def run(
self,

View File

@ -3,6 +3,7 @@ from contextlib import asynccontextmanager
from typing import Any, Dict, AsyncGenerator
from nonebot.config import Config
from nonebot.internal.driver._lifespan import LIFESPAN_FUNC
from nonebot.internal.driver import (
Driver,
Request,
@ -97,6 +98,9 @@ class Adapter(abc.ABC):
async with self.driver.websocket(setup) as ws:
yield ws
def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self.driver._lifespan.on_ready(func)
@abc.abstractmethod
async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any:
"""`Adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。

View File

@ -11,6 +11,7 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
class Lifespan:
def __init__(self) -> None:
self._startup_funcs: List[LIFESPAN_FUNC] = []
self._ready_funcs: List[LIFESPAN_FUNC] = []
self._shutdown_funcs: List[LIFESPAN_FUNC] = []
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
@ -21,6 +22,10 @@ class Lifespan:
self._shutdown_funcs.append(func)
return func
def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._ready_funcs.append(func)
return func
@staticmethod
async def _run_lifespan_func(
funcs: List[LIFESPAN_FUNC],
@ -35,6 +40,9 @@ class Lifespan:
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)
if self._ready_funcs:
await self._run_lifespan_func(self._ready_funcs)
async def shutdown(self) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)

View File

@ -2,7 +2,7 @@ import abc
import asyncio
from typing_extensions import TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
from typing import TYPE_CHECKING, Any, Set, Dict, Type, AsyncGenerator
from nonebot.log import logger
from nonebot.config import Env, Config
@ -16,6 +16,7 @@ from nonebot.typing import (
T_BotDisconnectionHook,
)
from ._lifespan import LIFESPAN_FUNC, Lifespan
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
if TYPE_CHECKING:
@ -49,6 +50,7 @@ class Driver(abc.ABC):
"""全局配置对象"""
self._bots: Dict[str, "Bot"] = {}
self._bot_tasks: Set[asyncio.Task] = set()
self._lifespan = Lifespan()
def __repr__(self) -> str:
return (
@ -100,15 +102,13 @@ class Driver(abc.ABC):
self.on_shutdown(self._cleanup)
@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
"""注册一个在驱动器启动时执行的函数"""
raise NotImplementedError
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)
@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
"""注册一个在驱动器停止时执行的函数"""
raise NotImplementedError
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)
@classmethod
def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook: