From f11970132c2764e62bad68ddacd7e2a29b5633d1 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Mon, 20 Jun 2022 15:52:12 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20Fix:=20=E4=BF=AE=E5=A4=8D=20Messag?= =?UTF-8?q?eSegment=20=E5=9C=A8=E6=9C=89=E9=A2=9D=E5=A4=96=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E6=97=B6=E6=8A=A5=E9=94=99=20(#1055)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/internal/adapter/message.py | 6 +++++- pyproject.toml | 4 ++++ tests/test_adapters/test_message.py | 22 +++++++++------------- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/nonebot/internal/adapter/message.py b/nonebot/internal/adapter/message.py index 0926077e..d1fb981e 100644 --- a/nonebot/internal/adapter/message.py +++ b/nonebot/internal/adapter/message.py @@ -66,7 +66,11 @@ class MessageSegment(abc.ABC, Generic[TM]): return value if not isinstance(value, dict): raise ValueError(f"Expected dict for MessageSegment, got {type(value)}") - return cls(**value) + if "type" not in value: + raise ValueError( + f"Expected dict with 'type' for MessageSegment, got {value}" + ) + return cls(type=value["type"], data=value.get("data", {})) def get(self, key: str, default: Any = None): return asdict(self).get(key, default) diff --git a/pyproject.toml b/pyproject.toml index d42a4ba8..b1fc34ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,10 @@ all = ["quart", "aiohttp", "httpx", "websockets"] [tool.pytest.ini_options] asyncio_mode = "auto" addopts = "--cov=nonebot --cov-report=term-missing" +filterwarnings = [ + "error", + "ignore::DeprecationWarning", +] [tool.black] line-length = 88 diff --git a/tests/test_adapters/test_message.py b/tests/test_adapters/test_message.py index 800ecd90..5e730c6f 100644 --- a/tests/test_adapters/test_message.py +++ b/tests/test_adapters/test_message.py @@ -1,3 +1,4 @@ +import pytest from pydantic import ValidationError, parse_obj_as from utils import make_fake_message @@ -29,14 +30,15 @@ def test_segment_validate(): MessageSegment = Message.get_segment_class() assert parse_obj_as( - MessageSegment, {"type": "text", "data": {"text": "text"}} + MessageSegment, + {"type": "text", "data": {"text": "text"}, "extra": "should be ignored"}, ) == MessageSegment.text("text") - try: + with pytest.raises(ValidationError): parse_obj_as(MessageSegment, "some str") - assert False - except ValidationError: - assert True + + with pytest.raises(ValidationError): + parse_obj_as(MessageSegment, {"data": {}}) def test_segment(): @@ -129,11 +131,8 @@ def test_message_validate(): assert parse_obj_as(Message, Message([])) == Message([]) - try: + with pytest.raises(ValidationError): parse_obj_as(Message, Message_([])) - assert False - except ValidationError: - assert True assert parse_obj_as(Message, "text") == Message([MessageSegment.text("text")]) @@ -146,8 +145,5 @@ def test_message_validate(): [MessageSegment.text("text"), {"type": "text", "data": {"text": "text"}}], ) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) - try: + with pytest.raises(ValidationError): parse_obj_as(Message, object()) - assert False - except ValidationError: - assert True