mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-08-01 02:30:16 +00:00
✨ Feature: 细化 driver 职责类型 (#2296)
This commit is contained in:
@ -1,7 +1,19 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from typing_extensions import TypeAlias
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Set,
|
||||
Dict,
|
||||
Type,
|
||||
Union,
|
||||
TypeVar,
|
||||
Callable,
|
||||
AsyncGenerator,
|
||||
overload,
|
||||
)
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Env, Config
|
||||
@ -21,11 +33,15 @@ if TYPE_CHECKING:
|
||||
from nonebot.internal.adapter import Bot, Adapter
|
||||
|
||||
|
||||
D = TypeVar("D", bound="Driver")
|
||||
|
||||
BOT_HOOK_PARAMS = [DependParam, BotParam, DefaultParam]
|
||||
|
||||
|
||||
class Driver(abc.ABC):
|
||||
"""Driver 基类。
|
||||
"""驱动器基类。
|
||||
|
||||
驱动器控制框架的启动和停止,适配器的注册,以及机器人生命周期管理。
|
||||
|
||||
参数:
|
||||
env: 包含环境信息的 Env 对象
|
||||
@ -45,6 +61,7 @@ class Driver(abc.ABC):
|
||||
self.config: Config = config
|
||||
"""全局配置对象"""
|
||||
self._bots: Dict[str, "Bot"] = {}
|
||||
self._bot_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
@ -94,6 +111,8 @@ class Driver(abc.ABC):
|
||||
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
|
||||
)
|
||||
|
||||
self.on_shutdown(self._cleanup)
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_startup(self, func: Callable) -> Callable:
|
||||
"""注册一个在驱动器启动时执行的函数"""
|
||||
@ -156,7 +175,9 @@ class Driver(abc.ABC):
|
||||
"</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
asyncio.create_task(_run_hook(bot))
|
||||
task = asyncio.create_task(_run_hook(bot))
|
||||
task.add_done_callback(self._bot_tasks.discard)
|
||||
self._bot_tasks.add(task)
|
||||
|
||||
def _bot_disconnect(self, bot: "Bot") -> None:
|
||||
"""在连接断开后,调用该函数来注销 bot 对象"""
|
||||
@ -183,23 +204,49 @@ class Driver(abc.ABC):
|
||||
"</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
asyncio.create_task(_run_hook(bot))
|
||||
task = asyncio.create_task(_run_hook(bot))
|
||||
task.add_done_callback(self._bot_tasks.discard)
|
||||
self._bot_tasks.add(task)
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""清理驱动器资源"""
|
||||
if self._bot_tasks:
|
||||
logger.opt(colors=True).debug(
|
||||
"<y>Waiting for running bot connection hooks...</y>"
|
||||
)
|
||||
await asyncio.gather(*self._bot_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
class ForwardMixin(abc.ABC):
|
||||
"""客户端混入基类。"""
|
||||
class Mixin(abc.ABC):
|
||||
"""可与其他驱动器共用的混入基类。"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def type(self) -> str:
|
||||
"""客户端驱动类型名称"""
|
||||
"""混入驱动类型名称"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ForwardMixin(Mixin):
|
||||
"""客户端混入基类。"""
|
||||
|
||||
|
||||
class ReverseMixin(Mixin):
|
||||
"""服务端混入基类。"""
|
||||
|
||||
|
||||
class HTTPClientMixin(ForwardMixin):
|
||||
"""HTTP 客户端混入基类。"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def request(self, setup: Request) -> Response:
|
||||
"""发送一个 HTTP 请求"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WebSocketClientMixin(ForwardMixin):
|
||||
"""WebSocket 客户端混入基类。"""
|
||||
|
||||
@abc.abstractmethod
|
||||
@asynccontextmanager
|
||||
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
|
||||
@ -208,12 +255,11 @@ class ForwardMixin(abc.ABC):
|
||||
yield # used for static type checking's generator detection
|
||||
|
||||
|
||||
class ForwardDriver(Driver, ForwardMixin):
|
||||
"""客户端基类。将客户端框架封装,以满足适配器使用。"""
|
||||
class ASGIMixin(ReverseMixin):
|
||||
"""ASGI 服务端基类。
|
||||
|
||||
|
||||
class ReverseDriver(Driver):
|
||||
"""服务端基类。将后端框架封装,以满足适配器使用。"""
|
||||
将后端框架封装,以满足适配器使用。
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@ -238,18 +284,49 @@ class ReverseDriver(Driver):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
|
||||
ForwardDriver: TypeAlias = ForwardMixin
|
||||
"""支持客户端请求的驱动器。
|
||||
|
||||
**Deprecated**,请使用 {ref}`nonebot.drivers.ForwardMixin` 或其子类代替。
|
||||
"""
|
||||
|
||||
ReverseDriver: TypeAlias = ReverseMixin
|
||||
"""支持服务端请求的驱动器。
|
||||
|
||||
**Deprecated**,请使用 {ref}`nonebot.drivers.ReverseMixin` 或其子类代替。
|
||||
"""
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class CombinedDriver(Driver, Mixin):
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def combine_driver(driver: Type[D]) -> Type[D]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def combine_driver(driver: Type[D], *mixins: Type[Mixin]) -> Type["CombinedDriver"]:
|
||||
...
|
||||
|
||||
|
||||
def combine_driver(
|
||||
driver: Type[D], *mixins: Type[Mixin]
|
||||
) -> Union[Type[D], Type["CombinedDriver"]]:
|
||||
"""将一个驱动器和多个混入类合并。"""
|
||||
# check first
|
||||
assert issubclass(driver, Driver), "`driver` must be subclass of Driver"
|
||||
assert all(
|
||||
issubclass(m, ForwardMixin) for m in mixins
|
||||
), "`mixins` must be subclass of ForwardMixin"
|
||||
issubclass(m, Mixin) for m in mixins
|
||||
), "`mixins` must be subclass of Mixin"
|
||||
|
||||
if not mixins:
|
||||
return driver
|
||||
|
||||
def type_(self: ForwardDriver) -> str:
|
||||
def type_(self: "CombinedDriver") -> str:
|
||||
return (
|
||||
driver.type.__get__(self)
|
||||
+ "+"
|
||||
@ -257,5 +334,5 @@ def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Dr
|
||||
)
|
||||
|
||||
return type(
|
||||
"CombinedDriver", (*mixins, driver, ForwardDriver), {"type": property(type_)}
|
||||
"CombinedDriver", (*mixins, driver), {"type": property(type_)}
|
||||
) # type: ignore
|
||||
|
Reference in New Issue
Block a user