mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-31 15:06:42 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			354 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			354 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import abc
 | |
| from collections.abc import AsyncGenerator
 | |
| from contextlib import AsyncExitStack, asynccontextmanager
 | |
| from types import TracebackType
 | |
| from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
 | |
| from typing_extensions import Self, TypeAlias
 | |
| 
 | |
| from anyio import CancelScope, create_task_group
 | |
| from anyio.abc import TaskGroup
 | |
| from exceptiongroup import BaseExceptionGroup, catch
 | |
| 
 | |
| from nonebot.config import Config, Env
 | |
| from nonebot.dependencies import Dependent
 | |
| from nonebot.exception import SkippedException
 | |
| from nonebot.internal.params import BotParam, DefaultParam, DependParam
 | |
| from nonebot.log import logger
 | |
| from nonebot.typing import (
 | |
|     T_BotConnectionHook,
 | |
|     T_BotDisconnectionHook,
 | |
|     T_DependencyCache,
 | |
| )
 | |
| from nonebot.utils import escape_tag, flatten_exception_group, run_coro_with_catch
 | |
| 
 | |
| from ._lifespan import LIFESPAN_FUNC, Lifespan
 | |
| from .model import (
 | |
|     CookieTypes,
 | |
|     HeaderTypes,
 | |
|     HTTPServerSetup,
 | |
|     HTTPVersion,
 | |
|     QueryTypes,
 | |
|     Request,
 | |
|     Response,
 | |
|     WebSocket,
 | |
|     WebSocketServerSetup,
 | |
| )
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from nonebot.internal.adapter import Adapter, Bot
 | |
| 
 | |
| 
 | |
| BOT_HOOK_PARAMS = [DependParam, BotParam, DefaultParam]
 | |
| 
 | |
| 
 | |
| class Driver(abc.ABC):
 | |
|     """驱动器基类。
 | |
| 
 | |
|     驱动器控制框架的启动和停止,适配器的注册,以及机器人生命周期管理。
 | |
| 
 | |
|     参数:
 | |
|         env: 包含环境信息的 Env 对象
 | |
|         config: 包含配置信息的 Config 对象
 | |
|     """
 | |
| 
 | |
|     _adapters: ClassVar[dict[str, "Adapter"]] = {}
 | |
|     """已注册的适配器列表"""
 | |
|     _bot_connection_hook: ClassVar[set[Dependent[Any]]] = set()
 | |
|     """Bot 连接建立时执行的函数"""
 | |
|     _bot_disconnection_hook: ClassVar[set[Dependent[Any]]] = set()
 | |
|     """Bot 连接断开时执行的函数"""
 | |
| 
 | |
|     def __init__(self, env: Env, config: Config):
 | |
|         self.env: str = env.environment
 | |
|         """环境名称"""
 | |
|         self.config: Config = config
 | |
|         """全局配置对象"""
 | |
|         self._bots: dict[str, "Bot"] = {}
 | |
|         self._lifespan = Lifespan()
 | |
| 
 | |
|     def __repr__(self) -> str:
 | |
|         return (
 | |
|             f"Driver(type={self.type!r}, "
 | |
|             f"adapters={len(self._adapters)}, bots={len(self._bots)})"
 | |
|         )
 | |
| 
 | |
|     @property
 | |
|     def bots(self) -> dict[str, "Bot"]:
 | |
|         """获取当前所有已连接的 Bot"""
 | |
|         return self._bots
 | |
| 
 | |
|     @property
 | |
|     def task_group(self) -> TaskGroup:
 | |
|         return self._lifespan.task_group
 | |
| 
 | |
|     def register_adapter(self, adapter: type["Adapter"], **kwargs) -> None:
 | |
|         """注册一个协议适配器
 | |
| 
 | |
|         参数:
 | |
|             adapter: 适配器类
 | |
|             kwargs: 其他传递给适配器的参数
 | |
|         """
 | |
|         name = adapter.get_name()
 | |
|         if name in self._adapters:
 | |
|             logger.opt(colors=True).debug(
 | |
|                 f'Adapter "<y>{escape_tag(name)}</y>" already exists'
 | |
|             )
 | |
|             return
 | |
|         self._adapters[name] = adapter(self, **kwargs)
 | |
|         logger.opt(colors=True).debug(
 | |
|             f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
 | |
|         )
 | |
| 
 | |
|     @property
 | |
|     @abc.abstractmethod
 | |
|     def type(self) -> str:
 | |
|         """驱动类型名称"""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @property
 | |
|     @abc.abstractmethod
 | |
|     def logger(self):
 | |
|         """驱动专属 logger 日志记录器"""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     def run(self, *args, **kwargs):
 | |
|         """启动驱动框架"""
 | |
|         logger.opt(colors=True).success(
 | |
|             f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
 | |
|         )
 | |
| 
 | |
|     def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
 | |
|         """注册一个启动时执行的函数"""
 | |
|         return self._lifespan.on_startup(func)
 | |
| 
 | |
|     def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
 | |
|         """注册一个停止时执行的函数"""
 | |
|         return self._lifespan.on_shutdown(func)
 | |
| 
 | |
|     @classmethod
 | |
|     def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook:
 | |
|         """装饰一个函数使他在 bot 连接成功时执行。
 | |
| 
 | |
|         钩子函数参数:
 | |
| 
 | |
|         - bot: 当前连接上的 Bot 对象
 | |
|         """
 | |
|         cls._bot_connection_hook.add(
 | |
|             Dependent[Any].parse(call=func, allow_types=BOT_HOOK_PARAMS)
 | |
|         )
 | |
|         return func
 | |
| 
 | |
|     @classmethod
 | |
|     def on_bot_disconnect(cls, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook:
 | |
|         """装饰一个函数使他在 bot 连接断开时执行。
 | |
| 
 | |
|         钩子函数参数:
 | |
| 
 | |
|         - bot: 当前连接上的 Bot 对象
 | |
|         """
 | |
|         cls._bot_disconnection_hook.add(
 | |
|             Dependent[Any].parse(call=func, allow_types=BOT_HOOK_PARAMS)
 | |
|         )
 | |
|         return func
 | |
| 
 | |
|     def _bot_connect(self, bot: "Bot") -> None:
 | |
|         """在连接成功后,调用该函数来注册 bot 对象"""
 | |
|         if bot.self_id in self._bots:
 | |
|             raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
 | |
|         self._bots[bot.self_id] = bot
 | |
| 
 | |
|         if not self._bot_connection_hook:
 | |
|             return
 | |
| 
 | |
|         def handle_exception(exc_group: BaseExceptionGroup) -> None:
 | |
|             for exc in flatten_exception_group(exc_group):
 | |
|                 logger.opt(colors=True, exception=exc).error(
 | |
|                     "<r><bg #f8bbd0>"
 | |
|                     "Error when running WebSocketConnection hook:"
 | |
|                     "</bg #f8bbd0></r>"
 | |
|                 )
 | |
| 
 | |
|         async def _run_hook(bot: "Bot") -> None:
 | |
|             dependency_cache: T_DependencyCache = {}
 | |
|             with CancelScope(shield=True), catch({Exception: handle_exception}):
 | |
|                 async with AsyncExitStack() as stack, create_task_group() as tg:
 | |
|                     for hook in self._bot_connection_hook:
 | |
|                         tg.start_soon(
 | |
|                             run_coro_with_catch,
 | |
|                             hook(
 | |
|                                 bot=bot, stack=stack, dependency_cache=dependency_cache
 | |
|                             ),
 | |
|                             (SkippedException,),
 | |
|                         )
 | |
| 
 | |
|         self.task_group.start_soon(_run_hook, bot)
 | |
| 
 | |
|     def _bot_disconnect(self, bot: "Bot") -> None:
 | |
|         """在连接断开后,调用该函数来注销 bot 对象"""
 | |
|         if bot.self_id in self._bots:
 | |
|             del self._bots[bot.self_id]
 | |
| 
 | |
|         if not self._bot_disconnection_hook:
 | |
|             return
 | |
| 
 | |
|         def handle_exception(exc_group: BaseExceptionGroup) -> None:
 | |
|             for exc in flatten_exception_group(exc_group):
 | |
|                 logger.opt(colors=True, exception=exc).error(
 | |
|                     "<r><bg #f8bbd0>"
 | |
|                     "Error when running WebSocketDisConnection hook:"
 | |
|                     "</bg #f8bbd0></r>"
 | |
|                 )
 | |
| 
 | |
|         async def _run_hook(bot: "Bot") -> None:
 | |
|             dependency_cache: T_DependencyCache = {}
 | |
|             # shield cancellation to ensure bot disconnect hooks are always run
 | |
|             with CancelScope(shield=True), catch({Exception: handle_exception}):
 | |
|                 async with create_task_group() as tg, AsyncExitStack() as stack:
 | |
|                     for hook in self._bot_disconnection_hook:
 | |
|                         tg.start_soon(
 | |
|                             run_coro_with_catch,
 | |
|                             hook(
 | |
|                                 bot=bot, stack=stack, dependency_cache=dependency_cache
 | |
|                             ),
 | |
|                             (SkippedException,),
 | |
|                         )
 | |
| 
 | |
|         self.task_group.start_soon(_run_hook, bot)
 | |
| 
 | |
| 
 | |
| class Mixin(abc.ABC):
 | |
|     """可与其他驱动器共用的混入基类。"""
 | |
| 
 | |
|     @property
 | |
|     @abc.abstractmethod
 | |
|     def type(self) -> str:
 | |
|         """混入驱动类型名称"""
 | |
|         raise NotImplementedError
 | |
| 
 | |
| 
 | |
| class ForwardMixin(Mixin):
 | |
|     """客户端混入基类。"""
 | |
| 
 | |
| 
 | |
| class ReverseMixin(Mixin):
 | |
|     """服务端混入基类。"""
 | |
| 
 | |
| 
 | |
| class HTTPClientSession(abc.ABC):
 | |
|     """HTTP 客户端会话基类。"""
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     def __init__(
 | |
|         self,
 | |
|         params: QueryTypes = None,
 | |
|         headers: HeaderTypes = None,
 | |
|         cookies: CookieTypes = None,
 | |
|         version: Union[str, HTTPVersion] = HTTPVersion.H11,
 | |
|         timeout: Optional[float] = None,
 | |
|         proxy: Optional[str] = None,
 | |
|     ):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     async def request(self, setup: Request) -> Response:
 | |
|         """发送一个 HTTP 请求"""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     async def setup(self) -> None:
 | |
|         """初始化会话"""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     async def close(self) -> None:
 | |
|         """关闭会话"""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     async def __aenter__(self) -> Self:
 | |
|         await self.setup()
 | |
|         return self
 | |
| 
 | |
|     async def __aexit__(
 | |
|         self,
 | |
|         exc_type: Optional[type[BaseException]],
 | |
|         exc: Optional[BaseException],
 | |
|         tb: Optional[TracebackType],
 | |
|     ) -> None:
 | |
|         await self.close()
 | |
| 
 | |
| 
 | |
| class HTTPClientMixin(ForwardMixin):
 | |
|     """HTTP 客户端混入基类。"""
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     async def request(self, setup: Request) -> Response:
 | |
|         """发送一个 HTTP 请求"""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     def get_session(
 | |
|         self,
 | |
|         params: QueryTypes = None,
 | |
|         headers: HeaderTypes = None,
 | |
|         cookies: CookieTypes = None,
 | |
|         version: Union[str, HTTPVersion] = HTTPVersion.H11,
 | |
|         timeout: Optional[float] = None,
 | |
|         proxy: Optional[str] = None,
 | |
|     ) -> HTTPClientSession:
 | |
|         """获取一个 HTTP 会话"""
 | |
|         raise NotImplementedError
 | |
| 
 | |
| 
 | |
| class WebSocketClientMixin(ForwardMixin):
 | |
|     """WebSocket 客户端混入基类。"""
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     @asynccontextmanager
 | |
|     async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
 | |
|         """发起一个 WebSocket 连接"""
 | |
|         raise NotImplementedError
 | |
|         yield  # used for static type checking's generator detection
 | |
| 
 | |
| 
 | |
| class ASGIMixin(ReverseMixin):
 | |
|     """ASGI 服务端基类。
 | |
| 
 | |
|     将后端框架封装,以满足适配器使用。
 | |
|     """
 | |
| 
 | |
|     @property
 | |
|     @abc.abstractmethod
 | |
|     def server_app(self) -> Any:
 | |
|         """驱动 APP 对象"""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @property
 | |
|     @abc.abstractmethod
 | |
|     def asgi(self) -> Any:
 | |
|         """驱动 ASGI 对象"""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     def setup_http_server(self, setup: "HTTPServerSetup") -> None:
 | |
|         """设置一个 HTTP 服务器路由配置"""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abc.abstractmethod
 | |
|     def setup_websocket_server(self, setup: "WebSocketServerSetup") -> None:
 | |
|         """设置一个 WebSocket 服务器路由配置"""
 | |
|         raise NotImplementedError
 | |
| 
 | |
| 
 | |
| ForwardDriver: TypeAlias = ForwardMixin
 | |
| """支持客户端请求的驱动器。
 | |
| 
 | |
| **Deprecated**,请使用 {ref}`nonebot.drivers.ForwardMixin` 或其子类代替。
 | |
| """
 | |
| 
 | |
| ReverseDriver: TypeAlias = ReverseMixin
 | |
| """支持服务端请求的驱动器。
 | |
| 
 | |
| **Deprecated**,请使用 {ref}`nonebot.drivers.ReverseMixin` 或其子类代替。
 | |
| """
 |