mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-31 15:06:42 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			101 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			101 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from collections.abc import Awaitable, Iterable
 | |
| from types import TracebackType
 | |
| from typing import Any, Callable, Optional, Union, cast
 | |
| from typing_extensions import TypeAlias
 | |
| 
 | |
| import anyio
 | |
| from anyio.abc import TaskGroup
 | |
| from exceptiongroup import suppress
 | |
| 
 | |
| from nonebot.utils import is_coroutine_callable, run_sync
 | |
| 
 | |
| SYNC_LIFESPAN_FUNC: TypeAlias = Callable[[], Any]
 | |
| ASYNC_LIFESPAN_FUNC: TypeAlias = Callable[[], Awaitable[Any]]
 | |
| LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
 | |
| 
 | |
| 
 | |
| class Lifespan:
 | |
|     def __init__(self) -> None:
 | |
|         self._task_group: Optional[TaskGroup] = None
 | |
| 
 | |
|         self._startup_funcs: list[LIFESPAN_FUNC] = []
 | |
|         self._ready_funcs: list[LIFESPAN_FUNC] = []
 | |
|         self._shutdown_funcs: list[LIFESPAN_FUNC] = []
 | |
| 
 | |
|     @property
 | |
|     def task_group(self) -> TaskGroup:
 | |
|         if self._task_group is None:
 | |
|             raise RuntimeError("Lifespan not started")
 | |
|         return self._task_group
 | |
| 
 | |
|     @task_group.setter
 | |
|     def task_group(self, task_group: TaskGroup) -> None:
 | |
|         if self._task_group is not None:
 | |
|             raise RuntimeError("Lifespan already started")
 | |
|         self._task_group = task_group
 | |
| 
 | |
|     def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
 | |
|         self._startup_funcs.append(func)
 | |
|         return func
 | |
| 
 | |
|     def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
 | |
|         self._shutdown_funcs.append(func)
 | |
|         return func
 | |
| 
 | |
|     def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
 | |
|         self._ready_funcs.append(func)
 | |
|         return func
 | |
| 
 | |
|     @staticmethod
 | |
|     async def _run_lifespan_func(
 | |
|         funcs: Iterable[LIFESPAN_FUNC],
 | |
|     ) -> None:
 | |
|         for func in funcs:
 | |
|             if is_coroutine_callable(func):
 | |
|                 await cast(ASYNC_LIFESPAN_FUNC, func)()
 | |
|             else:
 | |
|                 await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
 | |
| 
 | |
|     async def startup(self) -> None:
 | |
|         # create background task group
 | |
|         self.task_group = anyio.create_task_group()
 | |
|         await self.task_group.__aenter__()
 | |
| 
 | |
|         # run startup funcs
 | |
|         if self._startup_funcs:
 | |
|             await self._run_lifespan_func(self._startup_funcs)
 | |
| 
 | |
|         # run ready funcs
 | |
|         if self._ready_funcs:
 | |
|             await self._run_lifespan_func(self._ready_funcs)
 | |
| 
 | |
|     async def shutdown(
 | |
|         self,
 | |
|         *,
 | |
|         exc_type: Optional[type[BaseException]] = None,
 | |
|         exc_val: Optional[BaseException] = None,
 | |
|         exc_tb: Optional[TracebackType] = None,
 | |
|     ) -> None:
 | |
|         if self._shutdown_funcs:
 | |
|             # reverse shutdown funcs to ensure stack order
 | |
|             await self._run_lifespan_func(reversed(self._shutdown_funcs))
 | |
| 
 | |
|         # shutdown background task group
 | |
|         self.task_group.cancel_scope.cancel()
 | |
| 
 | |
|         with suppress(Exception):
 | |
|             await self.task_group.__aexit__(exc_type, exc_val, exc_tb)
 | |
| 
 | |
|         self._task_group = None
 | |
| 
 | |
|     async def __aenter__(self) -> None:
 | |
|         await self.startup()
 | |
| 
 | |
|     async def __aexit__(
 | |
|         self,
 | |
|         exc_type: Optional[type[BaseException]],
 | |
|         exc_val: Optional[BaseException],
 | |
|         exc_tb: Optional[TracebackType],
 | |
|     ) -> None:
 | |
|         await self.shutdown(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)
 |