♻️ rewrite adapter abc class

This commit is contained in:
yanyongyu
2021-12-06 22:19:05 +08:00
parent 180aaadda9
commit d80c02ae46
7 changed files with 172 additions and 437 deletions

View File

@ -26,7 +26,7 @@ from nonebot.config import Env, Config
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
if TYPE_CHECKING:
from nonebot.adapters import Bot
from nonebot.adapters import Bot, Adapter
class Driver(abc.ABC):
@ -34,9 +34,9 @@ class Driver(abc.ABC):
Driver 基类。
"""
_adapters: Dict[str, Type["Bot"]] = {}
_adapters: Dict[str, "Adapter"] = {}
"""
:类型: ``Dict[str, Type[Bot]]``
:类型: ``Dict[str, Adapter]``
:说明: 已注册的适配器列表
"""
_bot_connection_hook: Set[T_BotConnectionHook] = set()
@ -85,7 +85,7 @@ class Driver(abc.ABC):
"""
return self._clients
def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs):
def register_adapter(self, adapter: Type["Adapter"], **kwargs):
"""
:说明:
@ -97,13 +97,13 @@ class Driver(abc.ABC):
* ``adapter: Type[Bot]``: 适配器 Class
* ``**kwargs``: 其他传递给适配器的参数
"""
name = adapter.get_name()
if name in self._adapters:
logger.opt(colors=True).debug(
f'Adapter "<y>{escape_tag(name)}</y>" already exists'
)
return
self._adapters[name] = adapter
adapter.register(self, self.config, **kwargs)
self._adapters[name] = adapter(self, **kwargs)
logger.opt(colors=True).debug(
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
)
@ -213,34 +213,11 @@ class ForwardDriver(Driver):
"""
@abc.abstractmethod
def setup_http_polling(
self,
setup: Union["HTTPPollingSetup", Callable[[], Awaitable["HTTPPollingSetup"]]],
) -> None:
"""
:说明:
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
:参数:
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
"""
async def request(self, setup: "HTTPRequest") -> Any:
raise NotImplementedError
@abc.abstractmethod
def setup_websocket(
self, setup: Union["WebSocketSetup", Callable[[], Awaitable["WebSocketSetup"]]]
) -> None:
"""
:说明:
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
:参数:
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
"""
async def websocket(self, setup: "HTTPConnection") -> Any:
raise NotImplementedError
@ -261,7 +238,16 @@ class ReverseDriver(Driver):
"""驱动 ASGI 对象"""
raise NotImplementedError
@abc.abstractmethod
def setup_http_server(self, setup: "HTTPServerSetup") -> None:
raise NotImplementedError
@abc.abstractmethod
def setup_websocket_server(self, setup: "WebSocketServerSetup") -> None:
raise NotImplementedError
# TODO: repack dataclass
@dataclass
class HTTPConnection(abc.ABC):
http_version: str
@ -401,36 +387,13 @@ class WebSocket(HTTPConnection, abc.ABC):
@dataclass
class HTTPPollingSetup:
adapter: str
"""协议适配器名称"""
self_id: str
"""机器人 ID"""
url: str
"""URL"""
class HTTPServerSetup:
path: str
method: str
"""HTTP method"""
body: bytes
"""HTTP body"""
headers: Dict[str, str]
"""HTTP headers"""
http_version: str
"""HTTP version"""
poll_interval: float
"""HTTP 轮询间隔"""
handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]]
@dataclass
class WebSocketSetup:
adapter: str
"""协议适配器名称"""
self_id: str
"""机器人 ID"""
url: str
"""URL"""
headers: Dict[str, str] = field(default_factory=dict)
"""HTTP headers"""
reconnect: bool = True
"""WebSocket 是否重连"""
reconnect_interval: float = 3.0
"""WebSocket 重连间隔"""
class WebSocketServerSetup:
path: str
handle_func: Callable[[WebSocket], Awaitable[Any]]