♻️ reorganize internal tree

This commit is contained in:
yanyongyu
2022-02-06 17:08:11 +08:00
parent 65dc9a908b
commit 118519e15d
16 changed files with 52 additions and 25 deletions

View File

@ -0,0 +1,25 @@
from .model import URL as URL
from .model import RawURL as RawURL
from .driver import Driver as Driver
from .model import Cookies as Cookies
from .model import Request as Request
from .model import FileType as FileType
from .model import Response as Response
from .model import DataTypes as DataTypes
from .model import FileTypes as FileTypes
from .model import WebSocket as WebSocket
from .model import FilesTypes as FilesTypes
from .model import QueryTypes as QueryTypes
from .model import CookieTypes as CookieTypes
from .model import FileContent as FileContent
from .model import HTTPVersion as HTTPVersion
from .model import HeaderTypes as HeaderTypes
from .model import SimpleQuery as SimpleQuery
from .model import ContentTypes as ContentTypes
from .driver import ForwardMixin as ForwardMixin
from .model import QueryVariable as QueryVariable
from .driver import ForwardDriver as ForwardDriver
from .driver import ReverseDriver as ReverseDriver
from .driver import combine_driver as combine_driver
from .model import HTTPServerSetup as HTTPServerSetup
from .model import WebSocketServerSetup as WebSocketServerSetup

View File

@ -0,0 +1,234 @@
import abc
import asyncio
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
from nonebot.log import logger
from nonebot.utils import escape_tag
from nonebot.config import Env, Config
from nonebot.dependencies import Dependent
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
from nonebot.internal.params import BotParam, DependParam, DefaultParam
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
if TYPE_CHECKING:
from nonebot.internal.adapter import Bot, Adapter
BOT_HOOK_PARAMS = [DependParam, BotParam, DefaultParam]
class Driver(abc.ABC):
"""Driver 基类。
参数:
env: 包含环境信息的 Env 对象
config: 包含配置信息的 Config 对象
"""
_adapters: Dict[str, "Adapter"] = {}
"""已注册的适配器列表"""
_bot_connection_hook: Set[Dependent[Any]] = set()
"""Bot 连接建立时执行的函数"""
_bot_disconnection_hook: Set[Dependent[Any]] = set()
"""Bot 连接断开时执行的函数"""
def __init__(self, env: Env, config: Config):
self.env: str = env.environment
"""环境名称"""
self.config: Config = config
"""全局配置对象"""
self._clients: Dict[str, "Bot"] = {}
@property
def bots(self) -> Dict[str, "Bot"]:
"""获取当前所有已连接的 Bot"""
return self._clients
def register_adapter(self, adapter: Type["Adapter"], **kwargs) -> None:
"""注册一个协议适配器
参数:
adapter: 适配器类
kwargs: 其他传递给适配器的参数
"""
name = adapter.get_name()
if name in self._adapters:
logger.opt(colors=True).debug(
f'Adapter "<y>{escape_tag(name)}</y>" already exists'
)
return
self._adapters[name] = adapter(self, **kwargs)
logger.opt(colors=True).debug(
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
)
@property
@abc.abstractmethod
def type(self) -> str:
"""驱动类型名称"""
raise NotImplementedError
@property
@abc.abstractmethod
def logger(self):
"""驱动专属 logger 日志记录器"""
raise NotImplementedError
@abc.abstractmethod
def run(self, *args, **kwargs):
"""
启动驱动框架
"""
logger.opt(colors=True).debug(
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
)
@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
"""注册一个在驱动器启动时执行的函数"""
raise NotImplementedError
@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
"""注册一个在驱动器停止时执行的函数"""
raise NotImplementedError
@classmethod
def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook:
"""装饰一个函数使他在 bot 连接成功时执行。
钩子函数参数:
- bot: 当前连接上的 Bot 对象
"""
cls._bot_connection_hook.add(
Dependent[Any].parse(call=func, allow_types=BOT_HOOK_PARAMS)
)
return func
@classmethod
def on_bot_disconnect(cls, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook:
"""装饰一个函数使他在 bot 连接断开时执行。
钩子函数参数:
- bot: 当前连接上的 Bot 对象
"""
cls._bot_disconnection_hook.add(
Dependent[Any].parse(call=func, allow_types=BOT_HOOK_PARAMS)
)
return func
def _bot_connect(self, bot: "Bot") -> None:
"""在连接成功后,调用该函数来注册 bot 对象"""
if bot.self_id in self._clients:
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
self._clients[bot.self_id] = bot
async def _run_hook(bot: "Bot") -> None:
coros = list(map(lambda x: x(bot=bot), self._bot_connection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketConnection hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot))
def _bot_disconnect(self, bot: "Bot") -> None:
"""在连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._clients:
del self._clients[bot.self_id]
async def _run_hook(bot: "Bot") -> None:
coros = list(map(lambda x: x(bot=bot), self._bot_disconnection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketDisConnection hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot))
class ForwardMixin(abc.ABC):
"""客户端混入基类。"""
@property
@abc.abstractmethod
def type(self) -> str:
"""客户端驱动类型名称"""
raise NotImplementedError
@abc.abstractmethod
async def request(self, setup: Request) -> Response:
"""发送一个 HTTP 请求"""
raise NotImplementedError
@abc.abstractmethod
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
"""发起一个 WebSocket 连接"""
raise NotImplementedError
yield # used for static type checking's generator detection
class ForwardDriver(Driver, ForwardMixin):
"""客户端基类。将客户端框架封装,以满足适配器使用。"""
class ReverseDriver(Driver):
"""服务端基类。将后端框架封装,以满足适配器使用。"""
@property
@abc.abstractmethod
def server_app(self) -> Any:
"""驱动 APP 对象"""
raise NotImplementedError
@property
@abc.abstractmethod
def asgi(self) -> Any:
"""驱动 ASGI 对象"""
raise NotImplementedError
@abc.abstractmethod
def setup_http_server(self, setup: "HTTPServerSetup") -> None:
"""设置一个 HTTP 服务器路由配置"""
raise NotImplementedError
@abc.abstractmethod
def setup_websocket_server(self, setup: "WebSocketServerSetup") -> None:
"""设置一个 WebSocket 服务器路由配置"""
raise NotImplementedError
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
"""将一个驱动器和多个混入类合并。"""
# check first
assert issubclass(driver, Driver), "`driver` must be subclass of Driver"
assert all(
map(lambda m: issubclass(m, ForwardMixin), mixins)
), "`mixins` must be subclass of ForwardMixin"
if not mixins:
return driver
class CombinedDriver(*mixins, driver, ForwardDriver): # type: ignore
@property
def type(self) -> str:
return (
driver.type.__get__(self)
+ "+"
+ "+".join(map(lambda x: x.type.__get__(self), mixins))
)
return CombinedDriver

View File

@ -0,0 +1,338 @@
import abc
from enum import Enum
from dataclasses import dataclass
from http.cookiejar import Cookie, CookieJar
from typing import (
IO,
Any,
Dict,
List,
Tuple,
Union,
Mapping,
Callable,
Iterator,
Optional,
Awaitable,
MutableMapping,
)
from yarl import URL as URL
from multidict import CIMultiDict
RawURL = Tuple[bytes, bytes, Optional[int], bytes]
SimpleQuery = Union[str, int, float]
QueryVariable = Union[SimpleQuery, List[SimpleQuery]]
QueryTypes = Union[
None, str, Mapping[str, QueryVariable], List[Tuple[str, QueryVariable]]
]
HeaderTypes = Union[
None,
CIMultiDict[str],
Dict[str, str],
List[Tuple[str, str]],
]
CookieTypes = Union[None, "Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]]
ContentTypes = Union[str, bytes, None]
DataTypes = Union[dict, None]
FileContent = Union[IO[bytes], bytes]
FileType = Tuple[Optional[str], FileContent, Optional[str]]
FileTypes = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
Tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
FileType,
]
FilesTypes = Union[Dict[str, FileTypes], List[Tuple[str, FileTypes]], None]
class HTTPVersion(Enum):
H10 = "1.0"
H11 = "1.1"
H2 = "2"
class Request:
def __init__(
self,
method: Union[str, bytes],
url: Union["URL", str, RawURL],
*,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
content: ContentTypes = None,
data: DataTypes = None,
json: Any = None,
files: FilesTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
):
# method
self.method: str = (
method.decode("ascii").upper()
if isinstance(method, bytes)
else method.upper()
)
# http version
self.version: HTTPVersion = HTTPVersion(version)
# timeout
self.timeout: Optional[float] = timeout
# proxy
self.proxy: Optional[str] = proxy
# url
if isinstance(url, tuple):
scheme, host, port, path = url
url = URL.build(
scheme=scheme.decode("ascii"),
host=host.decode("ascii"),
port=port,
path=path.decode("ascii"),
)
else:
url = URL(url)
if params is not None:
url = url.update_query(params)
self.url: URL = url
# headers
self.headers: CIMultiDict[str]
if headers is not None:
self.headers = CIMultiDict(headers)
else:
self.headers = CIMultiDict()
# cookies
self.cookies = Cookies(cookies)
# body
self.content: ContentTypes = content
self.data: DataTypes = data
self.json: Any = json
self.files: Optional[List[Tuple[str, FileType]]] = None
if files:
self.files = []
files_ = files.items() if isinstance(files, dict) else files
for name, file_info in files_:
if not isinstance(file_info, tuple):
self.files.append((name, (None, file_info, None)))
elif len(file_info) == 2:
self.files.append((name, (file_info[0], file_info[1], None)))
else:
self.files.append((name, file_info)) # type: ignore
def __repr__(self) -> str:
class_name = self.__class__.__name__
url = str(self.url)
return f"<{class_name}({self.method!r}, {url!r})>"
class Response:
def __init__(
self,
status_code: int,
*,
headers: HeaderTypes = None,
content: ContentTypes = None,
request: Optional[Request] = None,
):
# status code
self.status_code: int = status_code
# headers
self.headers: CIMultiDict[str]
if headers is not None:
self.headers = CIMultiDict(headers)
else:
self.headers = CIMultiDict()
# body
self.content: ContentTypes = content
# request
self.request: Optional[Request] = request
class WebSocket(abc.ABC):
def __init__(self, *, request: Request):
# request
self.request: Request = request
@property
@abc.abstractmethod
def closed(self) -> bool:
"""
连接是否已经关闭
"""
raise NotImplementedError
@abc.abstractmethod
async def accept(self) -> None:
"""接受 WebSocket 连接请求"""
raise NotImplementedError
@abc.abstractmethod
async def close(self, code: int = 1000, reason: str = "") -> None:
"""关闭 WebSocket 连接请求"""
raise NotImplementedError
@abc.abstractmethod
async def receive(self) -> str:
"""接收一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def receive_bytes(self) -> bytes:
"""接收一条 WebSocket binary 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, data: str) -> None:
"""发送一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send_bytes(self, data: bytes) -> None:
"""发送一条 WebSocket binary 信息"""
raise NotImplementedError
class Cookies(MutableMapping):
def __init__(self, cookies: CookieTypes = None) -> None:
self.jar: CookieJar = cookies if isinstance(cookies, CookieJar) else CookieJar()
if cookies is not None and not isinstance(cookies, CookieJar):
if isinstance(cookies, dict):
for key, value in cookies.items():
self.set(key, value)
elif isinstance(cookies, list):
for key, value in cookies:
self.set(key, value)
elif isinstance(cookies, Cookies):
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
else:
raise TypeError(f"Cookies must be dict or list, not {type(cookies)}")
def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None:
cookie = Cookie(
version=0,
name=name,
value=value,
port=None,
port_specified=False,
domain=domain,
domain_specified=bool(domain),
domain_initial_dot=domain.startswith("."),
path=path,
path_specified=bool(path),
secure=False,
expires=None,
discard=True,
comment=None,
comment_url=None,
rest={},
rfc2109=False,
)
self.jar.set_cookie(cookie)
def get(
self,
name: str,
default: Optional[str] = None,
domain: str = None,
path: str = None,
) -> Optional[str]:
value: Optional[str] = None
for cookie in self.jar:
if (
cookie.name == name
and (domain is None or cookie.domain == domain)
and (path is None or cookie.path == path)
):
if value is not None:
message = f"Multiple cookies exist with name={name}"
raise ValueError(message)
value = cookie.value
return default if value is None else value
def delete(
self, name: str, domain: Optional[str] = None, path: Optional[str] = None
) -> None:
if domain is not None and path is not None:
return self.jar.clear(domain, path, name)
remove = [
cookie
for cookie in self.jar
if cookie.name == name
and (domain is None or cookie.domain == domain)
and (path is None or cookie.path == path)
]
for cookie in remove:
self.jar.clear(cookie.domain, cookie.path, cookie.name)
def clear(self, domain: Optional[str] = None, path: Optional[str] = None) -> None:
self.jar.clear(domain, path)
def update(self, cookies: CookieTypes = None) -> None:
cookies = Cookies(cookies)
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
def __setitem__(self, name: str, value: str) -> None:
return self.set(name, value)
def __getitem__(self, name: str) -> str:
value = self.get(name)
if value is None:
raise KeyError(name)
return value
def __delitem__(self, name: str) -> None:
return self.delete(name)
def __len__(self) -> int:
return len(self.jar)
def __iter__(self) -> Iterator[Cookie]:
return (cookie for cookie in self.jar)
def __repr__(self) -> str:
cookies_repr = ", ".join(
[
f"<Cookie {cookie.name}={cookie.value} for {cookie.domain} />"
for cookie in self.jar
]
)
return f"<Cookies [{cookies_repr}]>"
@dataclass
class HTTPServerSetup:
"""HTTP 服务器路由配置。"""
path: URL # path should not be absolute, check it by URL.is_absolute() == False
method: str
name: str
handle_func: Callable[[Request], Awaitable[Response]]
@dataclass
class WebSocketServerSetup:
"""WebSocket 服务器路由配置。"""
path: URL # path should not be absolute, check it by URL.is_absolute() == False
name: str
handle_func: Callable[[WebSocket], Awaitable[Any]]