Feature: 重构驱动器 lifespan 方法 (#1860)

This commit is contained in:
Ju4tCode
2023-03-29 15:59:54 +08:00
committed by GitHub
parent 0d0bc656c8
commit a8a76393a5
8 changed files with 126 additions and 52 deletions

View File

@ -13,7 +13,6 @@ FrontMatter:
import signal
import asyncio
import threading
from typing import Set, Union, Callable, Awaitable, cast
from nonebot.log import logger
from nonebot.consts import WINDOWS
@ -22,7 +21,8 @@ from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver
from nonebot.utils import run_sync, is_coroutine_callable
HOOK_FUNC = Union[Callable[[], None], Callable[[], Awaitable[None]]]
from ._lifespan import LIFESPAN_FUNC, Lifespan
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
@ -36,8 +36,9 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self.startup_funcs: Set[HOOK_FUNC] = set()
self.shutdown_funcs: Set[HOOK_FUNC] = set()
self._lifespan = Lifespan()
self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False
@ -54,20 +55,18 @@ class Driver(BaseDriver):
return logger
@overrides(BaseDriver)
def on_startup(self, func: HOOK_FUNC) -> HOOK_FUNC:
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""
注册一个启动时执行的函数
"""
self.startup_funcs.add(func)
return func
return self._lifespan.on_startup(func)
@overrides(BaseDriver)
def on_shutdown(self, func: HOOK_FUNC) -> HOOK_FUNC:
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""
注册一个停止时执行的函数
"""
self.shutdown_funcs.add(func)
return func
return self._lifespan.on_shutdown(func)
@overrides(BaseDriver)
def run(self, *args, **kwargs):
@ -85,21 +84,13 @@ class Driver(BaseDriver):
await self._shutdown()
async def _startup(self):
# run startup
cors = [
cast(Callable[..., Awaitable[None]], startup)()
if is_coroutine_callable(startup)
else run_sync(startup)()
for startup in self.startup_funcs
]
if cors:
try:
await asyncio.gather(*cors)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running startup function. "
"Ignored!</bg #f8bbd0></r>"
)
try:
await self._lifespan.startup()
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running startup function. "
"Ignored!</bg #f8bbd0></r>"
)
logger.info("Application startup completed.")
@ -110,21 +101,14 @@ class Driver(BaseDriver):
logger.info("Shutting down")
logger.info("Waiting for application shutdown.")
# run shutdown
cors = [
cast(Callable[..., Awaitable[None]], shutdown)()
if is_coroutine_callable(shutdown)
else run_sync(shutdown)()
for shutdown in self.shutdown_funcs
]
if cors:
try:
await asyncio.gather(*cors)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>"
)
try:
await self._lifespan.shutdown()
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>"
)
for task in asyncio.all_tasks():
if task is not asyncio.current_task() and not task.done():