mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-09-06 20:16:47 +00:00
♻️ reorganize class and add bot hook di
This commit is contained in:
@ -7,259 +7,32 @@ FrontMatter:
|
||||
description: nonebot.drivers 模块
|
||||
"""
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Set,
|
||||
Dict,
|
||||
Type,
|
||||
Callable,
|
||||
Awaitable,
|
||||
AsyncGenerator,
|
||||
)
|
||||
from nonebot.internal.model import URL as URL
|
||||
from nonebot.internal.driver import Driver as Driver
|
||||
from nonebot.internal.model import Cookies as Cookies
|
||||
from nonebot.internal.model import Request as Request
|
||||
from nonebot.internal.model import Response as Response
|
||||
from nonebot.internal.model import WebSocket as WebSocket
|
||||
from nonebot.internal.model import HTTPVersion as HTTPVersion
|
||||
from nonebot.internal.driver import ForwardMixin as ForwardMixin
|
||||
from nonebot.internal.driver import ForwardDriver as ForwardDriver
|
||||
from nonebot.internal.driver import ReverseDriver as ReverseDriver
|
||||
from nonebot.internal.driver import combine_driver as combine_driver
|
||||
from nonebot.internal.model import HTTPServerSetup as HTTPServerSetup
|
||||
from nonebot.internal.model import WebSocketServerSetup as WebSocketServerSetup
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
|
||||
|
||||
from ._model import URL as URL
|
||||
from ._model import Request as Request
|
||||
from ._model import Response as Response
|
||||
from ._model import WebSocket as WebSocket
|
||||
from ._model import HTTPVersion as HTTPVersion
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.adapters import Bot, Adapter
|
||||
|
||||
|
||||
class Driver(abc.ABC):
|
||||
"""Driver 基类。
|
||||
|
||||
参数:
|
||||
env: 包含环境信息的 Env 对象
|
||||
config: 包含配置信息的 Config 对象
|
||||
"""
|
||||
|
||||
_adapters: Dict[str, "Adapter"] = {}
|
||||
"""已注册的适配器列表"""
|
||||
_bot_connection_hook: Set[T_BotConnectionHook] = set()
|
||||
"""Bot 连接建立时执行的函数"""
|
||||
_bot_disconnection_hook: Set[T_BotDisconnectionHook] = set()
|
||||
"""Bot 连接断开时执行的函数"""
|
||||
|
||||
def __init__(self, env: Env, config: Config):
|
||||
self.env: str = env.environment
|
||||
"""环境名称"""
|
||||
self.config: Config = config
|
||||
"""全局配置对象"""
|
||||
self._clients: Dict[str, "Bot"] = {}
|
||||
|
||||
@property
|
||||
def bots(self) -> Dict[str, "Bot"]:
|
||||
"""获取当前所有已连接的 Bot"""
|
||||
return self._clients
|
||||
|
||||
def register_adapter(self, adapter: Type["Adapter"], **kwargs) -> None:
|
||||
"""注册一个协议适配器
|
||||
|
||||
参数:
|
||||
adapter: 适配器类
|
||||
kwargs: 其他传递给适配器的参数
|
||||
"""
|
||||
name = adapter.get_name()
|
||||
if name in self._adapters:
|
||||
logger.opt(colors=True).debug(
|
||||
f'Adapter "<y>{escape_tag(name)}</y>" already exists'
|
||||
)
|
||||
return
|
||||
self._adapters[name] = adapter(self, **kwargs)
|
||||
logger.opt(colors=True).debug(
|
||||
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def type(self) -> str:
|
||||
"""驱动类型名称"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def logger(self):
|
||||
"""驱动专属 logger 日志记录器"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, *args, **kwargs):
|
||||
"""
|
||||
启动驱动框架
|
||||
"""
|
||||
logger.opt(colors=True).debug(
|
||||
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_startup(self, func: Callable) -> Callable:
|
||||
"""注册一个在驱动器启动时执行的函数"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_shutdown(self, func: Callable) -> Callable:
|
||||
"""注册一个在驱动器停止时执行的函数"""
|
||||
raise NotImplementedError
|
||||
|
||||
def on_bot_connect(self, func: T_BotConnectionHook) -> T_BotConnectionHook:
|
||||
"""装饰一个函数使他在 bot 连接成功时执行。
|
||||
|
||||
钩子函数参数:
|
||||
|
||||
- bot: 当前连接上的 Bot 对象
|
||||
"""
|
||||
self._bot_connection_hook.add(func)
|
||||
return func
|
||||
|
||||
def on_bot_disconnect(self, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook:
|
||||
"""装饰一个函数使他在 bot 连接断开时执行。
|
||||
|
||||
钩子函数参数:
|
||||
|
||||
- bot: 当前连接上的 Bot 对象
|
||||
"""
|
||||
self._bot_disconnection_hook.add(func)
|
||||
return func
|
||||
|
||||
def _bot_connect(self, bot: "Bot") -> None:
|
||||
"""在连接成功后,调用该函数来注册 bot 对象"""
|
||||
if bot.self_id in self._clients:
|
||||
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
|
||||
self._clients[bot.self_id] = bot
|
||||
|
||||
async def _run_hook(bot: "Bot") -> None:
|
||||
coros = list(map(lambda x: x(bot), self._bot_connection_hook))
|
||||
if coros:
|
||||
try:
|
||||
await asyncio.gather(*coros)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running WebSocketConnection hook. "
|
||||
"Running cancelled!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
asyncio.create_task(_run_hook(bot))
|
||||
|
||||
def _bot_disconnect(self, bot: "Bot") -> None:
|
||||
"""在连接断开后,调用该函数来注销 bot 对象"""
|
||||
if bot.self_id in self._clients:
|
||||
del self._clients[bot.self_id]
|
||||
|
||||
async def _run_hook(bot: "Bot") -> None:
|
||||
coros = list(map(lambda x: x(bot), self._bot_disconnection_hook))
|
||||
if coros:
|
||||
try:
|
||||
await asyncio.gather(*coros)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running WebSocketDisConnection hook. "
|
||||
"Running cancelled!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
asyncio.create_task(_run_hook(bot))
|
||||
|
||||
|
||||
class ForwardMixin(abc.ABC):
|
||||
"""客户端混入基类。"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def type(self) -> str:
|
||||
"""客户端驱动类型名称"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def request(self, setup: Request) -> Response:
|
||||
"""发送一个 HTTP 请求"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
@asynccontextmanager
|
||||
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
|
||||
"""发起一个 WebSocket 连接"""
|
||||
raise NotImplementedError
|
||||
yield # used for static type checking's generator detection
|
||||
|
||||
|
||||
class ForwardDriver(Driver, ForwardMixin):
|
||||
"""客户端基类。将客户端框架封装,以满足适配器使用。"""
|
||||
|
||||
|
||||
class ReverseDriver(Driver):
|
||||
"""服务端基类。将后端框架封装,以满足适配器使用。"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def server_app(self) -> Any:
|
||||
"""驱动 APP 对象"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def asgi(self) -> Any:
|
||||
"""驱动 ASGI 对象"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup_http_server(self, setup: "HTTPServerSetup") -> None:
|
||||
"""设置一个 HTTP 服务器路由配置"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup_websocket_server(self, setup: "WebSocketServerSetup") -> None:
|
||||
"""设置一个 WebSocket 服务器路由配置"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
|
||||
"""将一个驱动器和多个混入类合并。"""
|
||||
# check first
|
||||
assert issubclass(driver, Driver), "`driver` must be subclass of Driver"
|
||||
assert all(
|
||||
map(lambda m: issubclass(m, ForwardMixin), mixins)
|
||||
), "`mixins` must be subclass of ForwardMixin"
|
||||
|
||||
if not mixins:
|
||||
return driver
|
||||
|
||||
class CombinedDriver(*mixins, driver, ForwardDriver): # type: ignore
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return (
|
||||
driver.type.__get__(self)
|
||||
+ "+"
|
||||
+ "+".join(map(lambda x: x.type.__get__(self), mixins))
|
||||
)
|
||||
|
||||
return CombinedDriver
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPServerSetup:
|
||||
"""HTTP 服务器路由配置。"""
|
||||
|
||||
path: URL # path should not be absolute, check it by URL.is_absolute() == False
|
||||
method: str
|
||||
name: str
|
||||
handle_func: Callable[[Request], Awaitable[Response]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketServerSetup:
|
||||
"""WebSocket 服务器路由配置。"""
|
||||
|
||||
path: URL # path should not be absolute, check it by URL.is_absolute() == False
|
||||
name: str
|
||||
handle_func: Callable[[WebSocket], Awaitable[Any]]
|
||||
__autodoc__ = {
|
||||
"URL": True,
|
||||
"Driver": True,
|
||||
"Cookies": True,
|
||||
"Request": True,
|
||||
"Response": True,
|
||||
"WebSocket": True,
|
||||
"HTTPVersion": True,
|
||||
"ForwardMixin": True,
|
||||
"ForwardDriver": True,
|
||||
"ReverseDriver": True,
|
||||
"combine_driver": True,
|
||||
"HTTPServerSetup": True,
|
||||
"WebSocketServerSetup": True,
|
||||
}
|
||||
|
Reference in New Issue
Block a user