Feature: 细化 driver 职责类型 (#2296)

This commit is contained in:
Ju4tCode
2023-08-26 11:03:24 +08:00
committed by GitHub
parent 807a86371d
commit 2e635370bb
20 changed files with 632 additions and 284 deletions

View File

@ -1,7 +1,19 @@
import abc
import asyncio
from typing_extensions import TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
from typing import (
TYPE_CHECKING,
Any,
Set,
Dict,
Type,
Union,
TypeVar,
Callable,
AsyncGenerator,
overload,
)
from nonebot.log import logger
from nonebot.config import Env, Config
@ -21,11 +33,15 @@ if TYPE_CHECKING:
from nonebot.internal.adapter import Bot, Adapter
D = TypeVar("D", bound="Driver")
BOT_HOOK_PARAMS = [DependParam, BotParam, DefaultParam]
class Driver(abc.ABC):
"""Driver 基类。
"""驱动器基类。
驱动器控制框架的启动和停止,适配器的注册,以及机器人生命周期管理。
参数:
env: 包含环境信息的 Env 对象
@ -45,6 +61,7 @@ class Driver(abc.ABC):
self.config: Config = config
"""全局配置对象"""
self._bots: Dict[str, "Bot"] = {}
self._bot_tasks: Set[asyncio.Task] = set()
def __repr__(self) -> str:
return (
@ -94,6 +111,8 @@ class Driver(abc.ABC):
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
)
self.on_shutdown(self._cleanup)
@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
"""注册一个在驱动器启动时执行的函数"""
@ -156,7 +175,9 @@ class Driver(abc.ABC):
"</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot))
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
def _bot_disconnect(self, bot: "Bot") -> None:
"""在连接断开后,调用该函数来注销 bot 对象"""
@ -183,23 +204,49 @@ class Driver(abc.ABC):
"</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot))
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
async def _cleanup(self) -> None:
"""清理驱动器资源"""
if self._bot_tasks:
logger.opt(colors=True).debug(
"<y>Waiting for running bot connection hooks...</y>"
)
await asyncio.gather(*self._bot_tasks, return_exceptions=True)
class ForwardMixin(abc.ABC):
"""客户端混入基类。"""
class Mixin(abc.ABC):
"""可与其他驱动器共用的混入基类。"""
@property
@abc.abstractmethod
def type(self) -> str:
"""客户端驱动类型名称"""
"""混入驱动类型名称"""
raise NotImplementedError
class ForwardMixin(Mixin):
"""客户端混入基类。"""
class ReverseMixin(Mixin):
"""服务端混入基类。"""
class HTTPClientMixin(ForwardMixin):
"""HTTP 客户端混入基类。"""
@abc.abstractmethod
async def request(self, setup: Request) -> Response:
"""发送一个 HTTP 请求"""
raise NotImplementedError
class WebSocketClientMixin(ForwardMixin):
"""WebSocket 客户端混入基类。"""
@abc.abstractmethod
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
@ -208,12 +255,11 @@ class ForwardMixin(abc.ABC):
yield # used for static type checking's generator detection
class ForwardDriver(Driver, ForwardMixin):
"""客户端基类。将客户端框架封装,以满足适配器使用。"""
class ASGIMixin(ReverseMixin):
"""ASGI 服务端基类。
class ReverseDriver(Driver):
"""服务端基类。将后端框架封装,以满足适配器使用。"""
将后端框架封装,以满足适配器使用。
"""
@property
@abc.abstractmethod
@ -238,18 +284,49 @@ class ReverseDriver(Driver):
raise NotImplementedError
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
ForwardDriver: TypeAlias = ForwardMixin
"""支持客户端请求的驱动器。
**Deprecated**,请使用 {ref}`nonebot.drivers.ForwardMixin` 或其子类代替。
"""
ReverseDriver: TypeAlias = ReverseMixin
"""支持服务端请求的驱动器。
**Deprecated**,请使用 {ref}`nonebot.drivers.ReverseMixin` 或其子类代替。
"""
if TYPE_CHECKING:
class CombinedDriver(Driver, Mixin):
...
@overload
def combine_driver(driver: Type[D]) -> Type[D]:
...
@overload
def combine_driver(driver: Type[D], *mixins: Type[Mixin]) -> Type["CombinedDriver"]:
...
def combine_driver(
driver: Type[D], *mixins: Type[Mixin]
) -> Union[Type[D], Type["CombinedDriver"]]:
"""将一个驱动器和多个混入类合并。"""
# check first
assert issubclass(driver, Driver), "`driver` must be subclass of Driver"
assert all(
issubclass(m, ForwardMixin) for m in mixins
), "`mixins` must be subclass of ForwardMixin"
issubclass(m, Mixin) for m in mixins
), "`mixins` must be subclass of Mixin"
if not mixins:
return driver
def type_(self: ForwardDriver) -> str:
def type_(self: "CombinedDriver") -> str:
return (
driver.type.__get__(self)
+ "+"
@ -257,5 +334,5 @@ def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Dr
)
return type(
"CombinedDriver", (*mixins, driver, ForwardDriver), {"type": property(type_)}
"CombinedDriver", (*mixins, driver), {"type": property(type_)}
) # type: ignore