diff --git a/nonebot/internal/adapter/template.py b/nonebot/internal/adapter/template.py index 22d2aebb..e04830a8 100644 --- a/nonebot/internal/adapter/template.py +++ b/nonebot/internal/adapter/template.py @@ -104,7 +104,7 @@ class MessageTemplate(Formatter, Generic[TF]): if recursion_depth < 0: raise ValueError("Max string recursion exceeded") - results: List[Any] = [] + results: List[Any] = [self.factory()] for (literal_text, field_name, format_spec, conversion) in self.parse( format_string @@ -162,10 +162,7 @@ class MessageTemplate(Formatter, Generic[TF]): formatted_text = self.format_field(obj, str(format_control)) results.append(formatted_text) - return ( - self.factory(functools.reduce(self._add, results or [""])), - auto_arg_index, - ) + return functools.reduce(self._add, results), auto_arg_index def format_field(self, value: Any, format_spec: str) -> Any: formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec) diff --git a/tests/test_adapters/test_template.py b/tests/test_adapters/test_template.py index 1814b04d..84856625 100644 --- a/tests/test_adapters/test_template.py +++ b/tests/test_adapters/test_template.py @@ -1,4 +1,4 @@ -from utils import make_fake_message +from utils import escape_text, make_fake_message def test_template_basis(): @@ -10,11 +10,8 @@ def test_template_basis(): def test_template_message(): - from nonebot.adapters import MessageTemplate - Message = make_fake_message() - - template = MessageTemplate("{a:custom}{b:text}{c:image}", Message) + template = Message.template("{a:custom}{b:text}{c:image}") @template.add_format_spec def custom(input: str) -> str: @@ -33,3 +30,12 @@ def test_template_message(): assert template.format_map(format_args) == formatted assert formatted.extract_plain_text() == "custom-custom!text" assert str(formatted) == "custom-custom!text[fake:image]" + + +def test_message_injection(): + Message = make_fake_message() + + template = Message.template("{name}Is Bad") + message = template.format(name="[fake:image]") + + assert message.extract_plain_text() == escape_text("[fake:image]Is Bad") diff --git a/tests/test_examples/test_weather.py b/tests/test_examples/test_weather.py index 08bab064..086e2c1c 100644 --- a/tests/test_examples/test_weather.py +++ b/tests/test_examples/test_weather.py @@ -29,7 +29,11 @@ async def test_weather(app: App): event = make_fake_event(_message=msg, _to_me=True)() ctx.receive_event(bot, event) - ctx.should_call_send(event, Message("你想查询的城市 南京 暂不支持,请重新输入!"), True) + ctx.should_call_send( + event, + Message.template("你想查询的城市 {} 暂不支持,请重新输入!").format("南京"), + True, + ) ctx.should_rejected() msg = Message("北京") @@ -53,7 +57,11 @@ async def test_weather(app: App): event = make_fake_event(_message=msg)() ctx.receive_event(bot, event) - ctx.should_call_send(event, Message("你想查询的城市 杭州 暂不支持,请重新输入!"), True) + ctx.should_call_send( + event, + Message.template("你想查询的城市 {} 暂不支持,请重新输入!").format("杭州"), + True, + ) ctx.should_rejected() msg = Message("北京") diff --git a/tests/utils.py b/tests/utils.py index ef54b69b..0cd94be4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,7 +6,14 @@ if TYPE_CHECKING: from nonebot.adapters import Event, Message -def make_fake_message() -> Type["Message"]: +def escape_text(s: str, *, escape_comma: bool = True) -> str: + s = s.replace("&", "&").replace("[", "[").replace("]", "]") + if escape_comma: + s = s.replace(",", ",") + return s + + +def make_fake_message(): from nonebot.adapters import Message, MessageSegment class FakeMessageSegment(MessageSegment): @@ -42,6 +49,10 @@ def make_fake_message() -> Type["Message"]: yield FakeMessageSegment(**seg) return + def __add__(self, other): + other = escape_text(other) if isinstance(other, str) else other + return super().__add__(other) + return FakeMessage