diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index e8e5f2c3..5bd2e96f 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -3,6 +3,9 @@ import abc from functools import reduce +from dataclasses import dataclass + +# from pydantic.dataclasses import dataclass # dataclass with validation from nonebot.config import Config from nonebot.drivers import BaseWebSocket @@ -37,51 +40,65 @@ class BaseBot(abc.ABC): raise NotImplementedError -class BaseMessageSegment(dict): - - def __init__(self, - type_: Optional[str] = None, - data: Optional[Dict[str, str]] = None): - super().__init__() - if type_: - self.type = type_ - self.data = data - else: - raise ValueError('The "type" field cannot be empty') +@dataclass +class BaseMessageSegment(abc.ABC): + type: str + data: Dict[str, str] = {} + @abc.abstractmethod def __str__(self): raise NotImplementedError - def __getitem__(self, item): - if item not in ("type", "data"): - raise KeyError(f'Key "{item}" is not allowed') - return super().__getitem__(item) - - def __setitem__(self, key, value): - if key not in ("type", "data"): - raise KeyError(f'Key "{key}" is not allowed') - return super().__setitem__(key, value) - - # TODO: __eq__ __add__ - - @property - def type(self) -> str: - return self["type"] - - @type.setter - def type(self, value: str): - self["type"] = value - - @property - def data(self) -> Dict[str, str]: - return self["data"] - - @data.setter - def data(self, data: Optional[Dict[str, str]]): - self["data"] = data or {} + @abc.abstractmethod + def __add__(self, other): + raise NotImplementedError -class BaseMessage(list): +# class BaseMessageSegment(dict): + +# def __init__(self, +# type_: Optional[str] = None, +# data: Optional[Dict[str, str]] = None): +# super().__init__() +# if type_: +# self.type = type_ +# self.data = data +# else: +# raise ValueError('The "type" field cannot be empty') + +# def __str__(self): +# raise NotImplementedError + +# def __getitem__(self, item): +# if item not in ("type", "data"): +# raise KeyError(f'Key "{item}" is not allowed') +# return super().__getitem__(item) + +# def __setitem__(self, key, value): +# if key not in ("type", "data"): +# raise KeyError(f'Key "{key}" is not allowed') +# return super().__setitem__(key, value) + +# # TODO: __eq__ __add__ + +# @property +# def type(self) -> str: +# return self["type"] + +# @type.setter +# def type(self, value: str): +# self["type"] = value + +# @property +# def data(self) -> Dict[str, str]: +# return self["data"] + +# @data.setter +# def data(self, data: Optional[Dict[str, str]]): +# self["data"] = data or {} + + +class BaseMessage(list, abc.ABC): def __init__(self, message: Union[str, BaseMessageSegment, "BaseMessage"] = None, @@ -99,6 +116,7 @@ class BaseMessage(list): return ''.join((str(seg) for seg in self)) @staticmethod + @abc.abstractmethod def _construct(msg: str) -> Iterable[BaseMessageSegment]: raise NotImplementedError diff --git a/nonebot/adapters/cqhttp.py b/nonebot/adapters/cqhttp.py index 7e2fc827..89dafcb3 100644 --- a/nonebot/adapters/cqhttp.py +++ b/nonebot/adapters/cqhttp.py @@ -99,6 +99,7 @@ class Bot(BaseBot): class MessageSegment(BaseMessageSegment): + @overrides(BaseMessageSegment) def __str__(self): type_ = self.type data = self.data.copy() @@ -116,6 +117,10 @@ class MessageSegment(BaseMessageSegment): params = ",".join([f"{k}={escape(str(v))}" for k, v in data.items()]) return f"[CQ:{type_}{',' if params else ''}{params}]" + @overrides(BaseMessageSegment) + def __add__(self, other) -> "Message": + return Message(self) + other + @staticmethod def anonymous(ignore_failure: bool = False) -> "MessageSegment": return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)}) @@ -248,6 +253,7 @@ class MessageSegment(BaseMessageSegment): class Message(BaseMessage): @staticmethod + @overrides(BaseMessage) def _construct(msg: str) -> Iterable[MessageSegment]: def _iter_message() -> Iterable[Tuple[str, str]]: diff --git a/nonebot/typing.py b/nonebot/typing.py index cb798274..cfee1152 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from abc import ABC from types import ModuleType from typing import TYPE_CHECKING from typing import Any, Set, List, Dict, Type, Tuple, Mapping @@ -13,9 +12,9 @@ if TYPE_CHECKING: from nonebot.event import Event -def overrides(InterfaceClass: ABC): +def overrides(InterfaceClass: object): - def overrider(func): + def overrider(func: Callable) -> Callable: assert func.__name__ in dir( InterfaceClass), f"Error method: {func.__name__}" return func diff --git a/pyproject.toml b/pyproject.toml index 19f93c05..43973465 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ fastapi = "^0.58.1" uvicorn = "^0.11.5" pydantic = { extras = ["dotenv"], version = "^1.5.1" } apscheduler = { version = "^3.6.3", optional = true } -nonebot-test = { version = "^0.1.0", optional = true } +# nonebot-test = { version = "^0.1.0", optional = true } [tool.poetry.dev-dependencies] yapf = "^0.30.0"