🎨 change typing for formatter

This commit is contained in:
yanyongyu
2021-08-27 14:46:15 +08:00
parent f0bc47ec5e
commit 58d10abd32
2 changed files with 37 additions and 46 deletions

View File

@ -1,59 +1,49 @@
import functools
import operator
from string import Formatter
from typing import (Any, Generic, List, Mapping, Protocol, Sequence, Set, Tuple,
Type, TypeVar, Union, TYPE_CHECKING)
from typing import (Any, Set, List, Type, Tuple, Union, TypeVar, Mapping,
Generic, Sequence, TYPE_CHECKING)
if TYPE_CHECKING:
from nonebot.adapters import Message
from . import Message, MessageSegment
TM = TypeVar("TM", bound="Message")
TMS = TypeVar("TMS", bound="MessageSegment")
TAddable = Union[str, TM, TMS]
class AddAble(Protocol):
class MessageFormatter(Formatter, Generic[TM, TMS]):
def __add__(self, __s: Any) -> "AddAble":
...
def __str__(self) -> str:
...
AddAble_T = TypeVar("AddAble_T", bound=AddAble)
MessageResult_T = TypeVar("MessageResult_T", bound="Message", covariant=True)
class MessageFormatter(Formatter, Generic[MessageResult_T]):
def __init__(self, factory: Type[MessageResult_T], template: str) -> None:
super().__init__()
def __init__(self, factory: Type[TM], template: str) -> None:
self.template = template
self.factory = factory
def format(self, *args: AddAble, **kwargs: AddAble) -> MessageResult_T:
msg: AddAble = super().format(self.template, *args, **kwargs)
return msg if isinstance(msg, self.factory) else self.factory(
msg) # type: ignore
def format(self, *args: TAddable[TM, TMS], **kwargs: TAddable[TM,
TMS]) -> TM:
msg = self.vformat(self.template, args, kwargs)
return msg if isinstance(msg, self.factory) else self.factory(msg)
def vformat(self, format_string: str, args: Sequence[AddAble],
kwargs: Mapping[str, AddAble]):
result, arg_index, used_args = self._vformat(format_string, args,
kwargs, set(), 2)
def vformat(self, format_string: str, args: Sequence[TAddable[TM, TMS]],
kwargs: Mapping[str, TAddable[TM, TMS]]) -> TM:
used_args = set()
result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
self.check_unused_args(list(used_args), args, kwargs)
return result
def _vformat(
self,
format_string: str,
args: Sequence[Any],
kwargs: Mapping[str, Any],
args: Sequence[TAddable[TM, TMS]],
kwargs: Mapping[str, TAddable[TM, TMS]],
used_args: Set[Union[int, str]],
recursion_depth: int,
auto_arg_index: int = 0,
) -> Tuple[AddAble, int, Set[Union[int, str]]]:
) -> Tuple[TM, int]:
if recursion_depth < 0:
raise ValueError("Max string recursion exceeded")
results: List[AddAble] = []
results: List[TAddable[TM, TMS]] = []
for (literal_text, field_name, format_spec,
conversion) in self.parse(format_string):
@ -95,24 +85,23 @@ class MessageFormatter(Formatter, Generic[MessageResult_T]):
obj = self.convert_field(obj, conversion) if conversion else obj
# expand the format spec, if needed
format_control, auto_arg_index, formatted_args = self._vformat(
format_control, auto_arg_index = self._vformat(
format_spec,
args,
kwargs,
used_args.copy(),
used_args,
recursion_depth - 1,
auto_arg_index,
)
used_args |= formatted_args
# format the object and append to the result
formatted_text = self.format_field(obj, str(format_control))
results.append(formatted_text)
return functools.reduce(operator.add, results or
[""]), auto_arg_index, used_args
return self.factory(functools.reduce(operator.add, results or
[""])), auto_arg_index
def format_field(self, value: AddAble_T,
format_spec: str) -> Union[AddAble_T, str]:
def format_field(self, value: TAddable[TM, TMS],
format_spec: str) -> TAddable[TM, TMS]:
return super().format_field(value,
format_spec) if format_spec else value