Feature: 细化 driver 职责类型 (#2296)

This commit is contained in:
Ju4tCode
2023-08-26 11:03:24 +08:00
committed by GitHub
parent 807a86371d
commit 2e635370bb
20 changed files with 632 additions and 284 deletions

View File

@ -1,4 +1,5 @@
from .model import URL as URL
from .driver import Mixin as Mixin
from .model import RawURL as RawURL
from .driver import Driver as Driver
from .model import Cookies as Cookies
@ -8,6 +9,7 @@ from .model import Response as Response
from .model import DataTypes as DataTypes
from .model import FileTypes as FileTypes
from .model import WebSocket as WebSocket
from .driver import ASGIMixin as ASGIMixin
from .model import FilesTypes as FilesTypes
from .model import QueryTypes as QueryTypes
from .model import CookieTypes as CookieTypes
@ -17,9 +19,12 @@ from .model import HeaderTypes as HeaderTypes
from .model import SimpleQuery as SimpleQuery
from .model import ContentTypes as ContentTypes
from .driver import ForwardMixin as ForwardMixin
from .driver import ReverseMixin as ReverseMixin
from .model import QueryVariable as QueryVariable
from .driver import ForwardDriver as ForwardDriver
from .driver import ReverseDriver as ReverseDriver
from .driver import combine_driver as combine_driver
from .model import HTTPServerSetup as HTTPServerSetup
from .driver import HTTPClientMixin as HTTPClientMixin
from .model import WebSocketServerSetup as WebSocketServerSetup
from .driver import WebSocketClientMixin as WebSocketClientMixin

View File

@ -1,7 +1,19 @@
import abc
import asyncio
from typing_extensions import TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
from typing import (
TYPE_CHECKING,
Any,
Set,
Dict,
Type,
Union,
TypeVar,
Callable,
AsyncGenerator,
overload,
)
from nonebot.log import logger
from nonebot.config import Env, Config
@ -21,11 +33,15 @@ if TYPE_CHECKING:
from nonebot.internal.adapter import Bot, Adapter
D = TypeVar("D", bound="Driver")
BOT_HOOK_PARAMS = [DependParam, BotParam, DefaultParam]
class Driver(abc.ABC):
"""Driver 基类。
"""驱动器基类。
驱动器控制框架的启动和停止,适配器的注册,以及机器人生命周期管理。
参数:
env: 包含环境信息的 Env 对象
@ -45,6 +61,7 @@ class Driver(abc.ABC):
self.config: Config = config
"""全局配置对象"""
self._bots: Dict[str, "Bot"] = {}
self._bot_tasks: Set[asyncio.Task] = set()
def __repr__(self) -> str:
return (
@ -94,6 +111,8 @@ class Driver(abc.ABC):
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
)
self.on_shutdown(self._cleanup)
@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
"""注册一个在驱动器启动时执行的函数"""
@ -156,7 +175,9 @@ class Driver(abc.ABC):
"</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot))
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
def _bot_disconnect(self, bot: "Bot") -> None:
"""在连接断开后,调用该函数来注销 bot 对象"""
@ -183,23 +204,49 @@ class Driver(abc.ABC):
"</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot))
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
async def _cleanup(self) -> None:
"""清理驱动器资源"""
if self._bot_tasks:
logger.opt(colors=True).debug(
"<y>Waiting for running bot connection hooks...</y>"
)
await asyncio.gather(*self._bot_tasks, return_exceptions=True)
class ForwardMixin(abc.ABC):
"""客户端混入基类。"""
class Mixin(abc.ABC):
"""可与其他驱动器共用的混入基类。"""
@property
@abc.abstractmethod
def type(self) -> str:
"""客户端驱动类型名称"""
"""混入驱动类型名称"""
raise NotImplementedError
class ForwardMixin(Mixin):
"""客户端混入基类。"""
class ReverseMixin(Mixin):
"""服务端混入基类。"""
class HTTPClientMixin(ForwardMixin):
"""HTTP 客户端混入基类。"""
@abc.abstractmethod
async def request(self, setup: Request) -> Response:
"""发送一个 HTTP 请求"""
raise NotImplementedError
class WebSocketClientMixin(ForwardMixin):
"""WebSocket 客户端混入基类。"""
@abc.abstractmethod
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
@ -208,12 +255,11 @@ class ForwardMixin(abc.ABC):
yield # used for static type checking's generator detection
class ForwardDriver(Driver, ForwardMixin):
"""客户端基类。将客户端框架封装,以满足适配器使用。"""
class ASGIMixin(ReverseMixin):
"""ASGI 服务端基类。
class ReverseDriver(Driver):
"""服务端基类。将后端框架封装,以满足适配器使用。"""
将后端框架封装,以满足适配器使用。
"""
@property
@abc.abstractmethod
@ -238,18 +284,49 @@ class ReverseDriver(Driver):
raise NotImplementedError
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
ForwardDriver: TypeAlias = ForwardMixin
"""支持客户端请求的驱动器。
**Deprecated**,请使用 {ref}`nonebot.drivers.ForwardMixin` 或其子类代替。
"""
ReverseDriver: TypeAlias = ReverseMixin
"""支持服务端请求的驱动器。
**Deprecated**,请使用 {ref}`nonebot.drivers.ReverseMixin` 或其子类代替。
"""
if TYPE_CHECKING:
class CombinedDriver(Driver, Mixin):
...
@overload
def combine_driver(driver: Type[D]) -> Type[D]:
...
@overload
def combine_driver(driver: Type[D], *mixins: Type[Mixin]) -> Type["CombinedDriver"]:
...
def combine_driver(
driver: Type[D], *mixins: Type[Mixin]
) -> Union[Type[D], Type["CombinedDriver"]]:
"""将一个驱动器和多个混入类合并。"""
# check first
assert issubclass(driver, Driver), "`driver` must be subclass of Driver"
assert all(
issubclass(m, ForwardMixin) for m in mixins
), "`mixins` must be subclass of ForwardMixin"
issubclass(m, Mixin) for m in mixins
), "`mixins` must be subclass of Mixin"
if not mixins:
return driver
def type_(self: ForwardDriver) -> str:
def type_(self: "CombinedDriver") -> str:
return (
driver.type.__get__(self)
+ "+"
@ -257,5 +334,5 @@ def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Dr
)
return type(
"CombinedDriver", (*mixins, driver, ForwardDriver), {"type": property(type_)}
"CombinedDriver", (*mixins, driver), {"type": property(type_)}
) # type: ignore