mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-06-07 12:55:56 +00:00
🐛 Fix MessageTemplate
improper behavior when no format spec (#947)
* 🧪 Add a test to figure out bug in #938 * ♻️ 🐛 Refactor rich message template formatting, fix #938
This commit is contained in:
parent
f028575f2f
commit
95331bbb22
@ -49,7 +49,9 @@ class MessageTemplate(Formatter, Generic[TF]):
|
|||||||
) -> None:
|
) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __init__(self, template, factory=str) -> None:
|
def __init__( # type:ignore
|
||||||
|
self, template, factory=str
|
||||||
|
) -> None: # TODO: fix type hint here
|
||||||
self.template: TF = template
|
self.template: TF = template
|
||||||
self.factory: Type[TF] = factory
|
self.factory: Type[TF] = factory
|
||||||
self.format_specs: Dict[str, FormatSpecFunc] = {}
|
self.format_specs: Dict[str, FormatSpecFunc] = {}
|
||||||
@ -72,25 +74,37 @@ class MessageTemplate(Formatter, Generic[TF]):
|
|||||||
return self._format([], mapping)
|
return self._format([], mapping)
|
||||||
|
|
||||||
def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TF:
|
def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TF:
|
||||||
msg = self.factory()
|
full_message = self.factory()
|
||||||
|
used_args, arg_index = set(), 0
|
||||||
|
|
||||||
if isinstance(self.template, str):
|
if isinstance(self.template, str):
|
||||||
msg += self.vformat(self.template, args, kwargs)
|
msg, arg_index = self._vformat(
|
||||||
|
self.template, args, kwargs, used_args, arg_index
|
||||||
|
)
|
||||||
|
full_message += msg
|
||||||
elif isinstance(self.template, self.factory):
|
elif isinstance(self.template, self.factory):
|
||||||
template = cast("Message[MessageSegment]", self.template)
|
template = cast("Message[MessageSegment]", self.template)
|
||||||
for seg in template:
|
for seg in template:
|
||||||
msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg
|
if not seg.is_text():
|
||||||
|
full_message += seg
|
||||||
|
else:
|
||||||
|
msg, arg_index = self._vformat(
|
||||||
|
str(seg), args, kwargs, used_args, arg_index
|
||||||
|
)
|
||||||
|
full_message += msg
|
||||||
else:
|
else:
|
||||||
raise TypeError("template must be a string or instance of Message!")
|
raise TypeError("template must be a string or instance of Message!")
|
||||||
|
|
||||||
return msg # type:ignore
|
self.check_unused_args(list(used_args), args, kwargs)
|
||||||
|
return cast(TF, full_message)
|
||||||
|
|
||||||
def vformat(
|
def vformat(
|
||||||
self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]
|
self,
|
||||||
|
format_string: str,
|
||||||
|
args: Sequence[Any],
|
||||||
|
kwargs: Mapping[str, Any],
|
||||||
) -> TF:
|
) -> TF:
|
||||||
used_args = set()
|
raise NotImplementedError("`vformat` has merged into `_format`")
|
||||||
result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
|
|
||||||
self.check_unused_args(list(used_args), args, kwargs)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _vformat(
|
def _vformat(
|
||||||
self,
|
self,
|
||||||
@ -98,12 +112,8 @@ class MessageTemplate(Formatter, Generic[TF]):
|
|||||||
args: Sequence[Any],
|
args: Sequence[Any],
|
||||||
kwargs: Mapping[str, Any],
|
kwargs: Mapping[str, Any],
|
||||||
used_args: Set[Union[int, str]],
|
used_args: Set[Union[int, str]],
|
||||||
recursion_depth: int,
|
|
||||||
auto_arg_index: int = 0,
|
auto_arg_index: int = 0,
|
||||||
) -> Tuple[TF, int]:
|
) -> Tuple[TF, int]:
|
||||||
if recursion_depth < 0:
|
|
||||||
raise ValueError("Max string recursion exceeded")
|
|
||||||
|
|
||||||
results: List[Any] = [self.factory()]
|
results: List[Any] = [self.factory()]
|
||||||
|
|
||||||
for (literal_text, field_name, format_spec, conversion) in self.parse(
|
for (literal_text, field_name, format_spec, conversion) in self.parse(
|
||||||
@ -143,23 +153,13 @@ class MessageTemplate(Formatter, Generic[TF]):
|
|||||||
obj, arg_used = self.get_field(field_name, args, kwargs)
|
obj, arg_used = self.get_field(field_name, args, kwargs)
|
||||||
used_args.add(arg_used)
|
used_args.add(arg_used)
|
||||||
|
|
||||||
assert format_spec is not None
|
|
||||||
|
|
||||||
# do any conversion on the resulting object
|
# do any conversion on the resulting object
|
||||||
obj = self.convert_field(obj, conversion) if conversion else obj
|
obj = self.convert_field(obj, conversion) if conversion else obj
|
||||||
|
|
||||||
# expand the format spec, if needed
|
|
||||||
format_control, auto_arg_index = self._vformat(
|
|
||||||
format_spec,
|
|
||||||
args,
|
|
||||||
kwargs,
|
|
||||||
used_args,
|
|
||||||
recursion_depth - 1,
|
|
||||||
auto_arg_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
# format the object and append to the result
|
# format the object and append to the result
|
||||||
formatted_text = self.format_field(obj, str(format_control))
|
formatted_text = (
|
||||||
|
self.format_field(obj, format_spec) if format_spec else obj
|
||||||
|
)
|
||||||
results.append(formatted_text)
|
results.append(formatted_text)
|
||||||
|
|
||||||
return functools.reduce(self._add, results), auto_arg_index
|
return functools.reduce(self._add, results), auto_arg_index
|
||||||
|
@ -32,6 +32,26 @@ def test_template_message():
|
|||||||
assert str(formatted) == "custom-custom!text[fake:image]"
|
assert str(formatted) == "custom-custom!text[fake:image]"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rich_template_message():
|
||||||
|
Message = make_fake_message()
|
||||||
|
MS = Message.get_segment_class()
|
||||||
|
|
||||||
|
pic1, pic2, pic3 = (
|
||||||
|
MS.image("file:///pic1.jpg"),
|
||||||
|
MS.image("file:///pic2.jpg"),
|
||||||
|
MS.image("file:///pic3.jpg"),
|
||||||
|
)
|
||||||
|
|
||||||
|
template = Message.template("{}{}" + pic2 + "{}")
|
||||||
|
|
||||||
|
result = template.format(pic1, "[fake:image]", pic3)
|
||||||
|
|
||||||
|
assert result["image"] == Message([pic1, pic2, pic3])
|
||||||
|
assert str(result) == (
|
||||||
|
"[fake:image]" + escape_text("[fake:image]") + "[fake:image]" + "[fake:image]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_message_injection():
|
def test_message_injection():
|
||||||
Message = make_fake_message()
|
Message = make_fake_message()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user