🐛 Fix MessageTemplate improper behavior when no format spec (#947)

* 🧪 Add a test to figure out bug in #938

* ♻️ 🐛 Refactor rich message template formatting, fix #938
This commit is contained in:
Mix
2022-04-30 09:59:23 +08:00
committed by GitHub
parent f028575f2f
commit 95331bbb22
2 changed files with 47 additions and 27 deletions

View File

@ -49,7 +49,9 @@ class MessageTemplate(Formatter, Generic[TF]):
) -> None:
...
def __init__(self, template, factory=str) -> None:
def __init__( # type:ignore
self, template, factory=str
) -> None: # TODO: fix type hint here
self.template: TF = template
self.factory: Type[TF] = factory
self.format_specs: Dict[str, FormatSpecFunc] = {}
@ -72,25 +74,37 @@ class MessageTemplate(Formatter, Generic[TF]):
return self._format([], mapping)
def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TF:
msg = self.factory()
full_message = self.factory()
used_args, arg_index = set(), 0
if isinstance(self.template, str):
msg += self.vformat(self.template, args, kwargs)
msg, arg_index = self._vformat(
self.template, args, kwargs, used_args, arg_index
)
full_message += msg
elif isinstance(self.template, self.factory):
template = cast("Message[MessageSegment]", self.template)
for seg in template:
msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg
if not seg.is_text():
full_message += seg
else:
msg, arg_index = self._vformat(
str(seg), args, kwargs, used_args, arg_index
)
full_message += msg
else:
raise TypeError("template must be a string or instance of Message!")
return msg # type:ignore
self.check_unused_args(list(used_args), args, kwargs)
return cast(TF, full_message)
def vformat(
self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]
self,
format_string: str,
args: Sequence[Any],
kwargs: Mapping[str, Any],
) -> TF:
used_args = set()
result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
self.check_unused_args(list(used_args), args, kwargs)
return result
raise NotImplementedError("`vformat` has merged into `_format`")
def _vformat(
self,
@ -98,12 +112,8 @@ class MessageTemplate(Formatter, Generic[TF]):
args: Sequence[Any],
kwargs: Mapping[str, Any],
used_args: Set[Union[int, str]],
recursion_depth: int,
auto_arg_index: int = 0,
) -> Tuple[TF, int]:
if recursion_depth < 0:
raise ValueError("Max string recursion exceeded")
results: List[Any] = [self.factory()]
for (literal_text, field_name, format_spec, conversion) in self.parse(
@ -143,23 +153,13 @@ class MessageTemplate(Formatter, Generic[TF]):
obj, arg_used = self.get_field(field_name, args, kwargs)
used_args.add(arg_used)
assert format_spec is not None
# do any conversion on the resulting object
obj = self.convert_field(obj, conversion) if conversion else obj
# expand the format spec, if needed
format_control, auto_arg_index = self._vformat(
format_spec,
args,
kwargs,
used_args,
recursion_depth - 1,
auto_arg_index,
)
# format the object and append to the result
formatted_text = self.format_field(obj, str(format_control))
formatted_text = (
self.format_field(obj, format_spec) if format_spec else obj
)
results.append(formatted_text)
return functools.reduce(self._add, results), auto_arg_index