From 3041650b4b5b66d41a14a6381d47d8c0828d21a4 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 16 Jan 2022 13:17:05 +0800 Subject: [PATCH 1/5] :sparkles: add advanced message slice support --- nonebot/adapters/_message.py | 48 +++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index 52db50ee..e6bcc665 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -5,12 +5,14 @@ from typing import ( Any, Dict, List, + Tuple, Type, Union, Generic, Mapping, TypeVar, Iterable, + overload, ) from ._template import MessageTemplate @@ -87,7 +89,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[TM]): raise NotImplementedError -class Message(List[TMS], abc.ABC): +class Message(List[TMS], Generic[TMS], abc.ABC): """消息数组""" def __init__( @@ -180,6 +182,50 @@ class Message(List[TMS], abc.ABC): self.extend(self._construct(other)) return self + @overload + def __getitem__(self: TM, __args: str) -> TM: + ... + + @overload + def __getitem__(self, __args: Tuple[str, int]) -> TMS: + ... + + @overload + def __getitem__(self: TM, __args: Tuple[str, slice]) -> TM: + ... + + @overload + def __getitem__(self, __args: int) -> TMS: + ... + + @overload + def __getitem__(self: TM, __args: slice) -> TM: + ... + + def __getitem__( + self: TM, + __args: Union[ + str, + Tuple[str, int], + Tuple[str, slice], + int, + slice, + ], + ) -> Union[TMS, TM]: + arg1, arg2 = __args if isinstance(__args, tuple) else (__args, None) + if isinstance(arg1, int) and arg2 is None: + return super().__getitem__(arg1) + elif isinstance(arg1, slice) and arg2 is None: + return self.__class__(super().__getitem__(arg1)) + elif isinstance(arg1, str) and arg2 is None: + return self.__class__(seg for seg in self if seg.type == arg1) + elif isinstance(arg1, str) and isinstance(arg2, int): + return [seg for seg in self if seg.type == arg1][arg2] + elif isinstance(arg1, str) and isinstance(arg2, slice): + return self.__class__([seg for seg in self if seg.type == arg1][arg2]) + else: + raise ValueError("Invalid arguments to __getitem__") + def append(self: TM, obj: Union[str, TMS]) -> TM: """ 添加一个消息段到消息数组末尾 From 39822378a7680220a5b5673c7b1f576367397cca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 16 Jan 2022 05:26:55 +0000 Subject: [PATCH 2/5] :rotating_light: auto fix by pre-commit hooks --- nonebot/adapters/_message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index e6bcc665..9b1996d9 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -5,8 +5,8 @@ from typing import ( Any, Dict, List, - Tuple, Type, + Tuple, Union, Generic, Mapping, From 1221baaa942faf8163e0443a2f2541fa4504b135 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 16 Jan 2022 17:13:26 +0800 Subject: [PATCH 3/5] :sparkles: Implement `.get` and `.index` methods for `Message` --- nonebot/adapters/_message.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index 9b1996d9..d3f050a1 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -89,7 +89,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[TM]): raise NotImplementedError -class Message(List[TMS], Generic[TMS], abc.ABC): +class Message(List[TMS], abc.ABC): """消息数组""" def __init__( @@ -204,7 +204,7 @@ class Message(List[TMS], Generic[TMS], abc.ABC): def __getitem__( self: TM, - __args: Union[ + args: Union[ str, Tuple[str, int], Tuple[str, slice], @@ -212,7 +212,7 @@ class Message(List[TMS], Generic[TMS], abc.ABC): slice, ], ) -> Union[TMS, TM]: - arg1, arg2 = __args if isinstance(__args, tuple) else (__args, None) + arg1, arg2 = args if isinstance(args, tuple) else (args, None) if isinstance(arg1, int) and arg2 is None: return super().__getitem__(arg1) elif isinstance(arg1, slice) and arg2 is None: @@ -224,7 +224,22 @@ class Message(List[TMS], Generic[TMS], abc.ABC): elif isinstance(arg1, str) and isinstance(arg2, slice): return self.__class__([seg for seg in self if seg.type == arg1][arg2]) else: - raise ValueError("Invalid arguments to __getitem__") + raise ValueError("Incorrect arguments to slice") + + def index(self, value: Union[TMS, str], *args) -> int: + if isinstance(value, str): + first_segment = next((seg for seg in self if seg.type == value), None) # type: ignore + return super().index(first_segment, *args) # type: ignore + return super().index(value, *args) + + def get(self: TM, type_: str, count: int) -> TM: + iterator = (seg for seg in self if seg.type == type_) + return self.__class__( + filter( + lambda seg: seg is not None, + (next(iterator) for _ in range(count)), + ) + ) def append(self: TM, obj: Union[str, TMS]) -> TM: """ From 3b4c4d30812031b889b2370a7bc868af454dab14 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Mon, 17 Jan 2022 00:28:36 +0800 Subject: [PATCH 4/5] :sparkles: :zap: Implement `.count` and optimize `.get` performance for message slice --- nonebot/adapters/_message.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index d3f050a1..a1a44ba9 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -12,6 +12,7 @@ from typing import ( Mapping, TypeVar, Iterable, + Optional, overload, ) @@ -232,14 +233,20 @@ class Message(List[TMS], abc.ABC): return super().index(first_segment, *args) # type: ignore return super().index(value, *args) - def get(self: TM, type_: str, count: int) -> TM: - iterator = (seg for seg in self if seg.type == type_) - return self.__class__( - filter( - lambda seg: seg is not None, - (next(iterator) for _ in range(count)), - ) - ) + def get(self: TM, type_: str, count: Optional[int] = None) -> TM: + if count is None: + return self[type_] + + iterator, filtered = (seg for seg in self if seg.type == type_), [] + for _ in range(count): + seg = next(iterator, None) + if seg is None: + break + filtered.append(seg) + return self.__class__(filtered) + + def count(self, value: Union[TMS, str]) -> int: + return len(self[value]) if isinstance(value, str) else super().count(value) def append(self: TM, obj: Union[str, TMS]) -> TM: """ From b037be448537be9d430078289780cd1183fbc4c2 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Mon, 17 Jan 2022 00:29:09 +0800 Subject: [PATCH 5/5] :white_check_mark: add unit test for message slice --- tests/test_adapters/test_message.py | 54 ++++++++++++++++++++++++++++ tests/test_adapters/test_template.py | 17 --------- tests/utils.py | 11 ++++-- 3 files changed, 62 insertions(+), 20 deletions(-) create mode 100644 tests/test_adapters/test_message.py delete mode 100644 tests/test_adapters/test_template.py diff --git a/tests/test_adapters/test_message.py b/tests/test_adapters/test_message.py new file mode 100644 index 00000000..88d9e563 --- /dev/null +++ b/tests/test_adapters/test_message.py @@ -0,0 +1,54 @@ +from utils import make_fake_message + + +def test_message_template(): + from nonebot.adapters import MessageTemplate + + Message = make_fake_message() + + template = MessageTemplate("{a:custom}{b:text}{c:image}", Message) + + @template.add_format_spec + def custom(input: str) -> str: + return input + "-custom!" + + formatted = template.format(a="test", b="test", c="https://example.com/test") + assert formatted.extract_plain_text() == "test-custom!test" + assert str(formatted) == "test-custom!test[fake:image]" + + +def test_message_slice(): + + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + message = Message( + [ + MessageSegment.text("test"), + MessageSegment.image("test2"), + MessageSegment.image("test3"), + MessageSegment.text("test4"), + ] + ) + + assert message[0] == MessageSegment.text("test") + + assert message[0:2] == Message( + [MessageSegment.text("test"), MessageSegment.image("test2")] + ) + + assert message["image"] == Message( + [MessageSegment.image("test2"), MessageSegment.image("test3")] + ) + + assert message["image", 0] == MessageSegment.image("test2") + assert message["image", 0:2] == message["image"] + + assert message.index(message[0]) == 0 + assert message.index("image") == 1 + + assert message.get("image") == message["image"] + assert message.get("image", 114514) == message["image"] + assert message.get("image", 1) == Message([message["image", 0]]) + + assert message.count("image") == 2 diff --git a/tests/test_adapters/test_template.py b/tests/test_adapters/test_template.py deleted file mode 100644 index 3dbef541..00000000 --- a/tests/test_adapters/test_template.py +++ /dev/null @@ -1,17 +0,0 @@ -from utils import make_fake_message - - -def test_message_template(): - from nonebot.adapters import MessageTemplate - - Message = make_fake_message() - - template = MessageTemplate("{a:custom}{b:text}{c:image}", Message) - - @template.add_format_spec - def custom(input: str) -> str: - return input + "-custom!" - - formatted = template.format(a="test", b="test", c="https://example.com/test") - assert formatted.extract_plain_text() == "test-custom!test" - assert str(formatted) == "test-custom!test[fake:image]" diff --git a/tests/utils.py b/tests/utils.py index 09b2a987..711c7c05 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Type, Optional +from typing import TYPE_CHECKING, Type, Union, Mapping, Iterable, Optional from pydantic import create_model @@ -34,8 +34,13 @@ def make_fake_message() -> Type["Message"]: return FakeMessageSegment @staticmethod - def _construct(msg: str): - yield FakeMessageSegment.text(msg) + def _construct(msg: Union[str, Iterable[Mapping]]): + if isinstance(msg, str): + yield FakeMessageSegment.text(msg) + else: + for seg in msg: + yield FakeMessageSegment(**seg) + return return FakeMessage