🎨 change typing for formatter

This commit is contained in:
yanyongyu 2021-08-27 14:46:15 +08:00
parent f0bc47ec5e
commit 58d10abd32
2 changed files with 37 additions and 46 deletions

View File

@ -8,19 +8,19 @@
import abc import abc
import asyncio import asyncio
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass, field
from functools import partial from functools import partial
from typing import (Any, Dict, Generic, Iterable, List, Mapping, Optional, Set, from typing_extensions import Protocol
Tuple, Type, TypeVar, Union) from dataclasses import asdict, dataclass, field
from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping,
Generic, Optional, Iterable)
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Protocol
from nonebot.config import Config
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook from nonebot.config import Config
from nonebot.utils import DataclassEncoder from nonebot.utils import DataclassEncoder
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse
from ._formatter import MessageFormatter from ._formatter import MessageFormatter
@ -332,7 +332,9 @@ class Message(List[TMS], abc.ABC):
self.extend(self._construct(message)) self.extend(self._construct(message))
@classmethod @classmethod
def template(cls: Type[TM], format_string: str) -> MessageFormatter[TM]: def template(
cls: Type[TM],
format_string: str) -> MessageFormatter[TM, TMS]: # type: ignore
return MessageFormatter(cls, format_string) return MessageFormatter(cls, format_string)
@classmethod @classmethod

View File

@ -1,59 +1,49 @@
import functools import functools
import operator import operator
from string import Formatter from string import Formatter
from typing import (Any, Generic, List, Mapping, Protocol, Sequence, Set, Tuple, from typing import (Any, Set, List, Type, Tuple, Union, TypeVar, Mapping,
Type, TypeVar, Union, TYPE_CHECKING) Generic, Sequence, TYPE_CHECKING)
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.adapters import Message from . import Message, MessageSegment
TM = TypeVar("TM", bound="Message")
TMS = TypeVar("TMS", bound="MessageSegment")
TAddable = Union[str, TM, TMS]
class AddAble(Protocol): class MessageFormatter(Formatter, Generic[TM, TMS]):
def __add__(self, __s: Any) -> "AddAble": def __init__(self, factory: Type[TM], template: str) -> None:
...
def __str__(self) -> str:
...
AddAble_T = TypeVar("AddAble_T", bound=AddAble)
MessageResult_T = TypeVar("MessageResult_T", bound="Message", covariant=True)
class MessageFormatter(Formatter, Generic[MessageResult_T]):
def __init__(self, factory: Type[MessageResult_T], template: str) -> None:
super().__init__()
self.template = template self.template = template
self.factory = factory self.factory = factory
def format(self, *args: AddAble, **kwargs: AddAble) -> MessageResult_T: def format(self, *args: TAddable[TM, TMS], **kwargs: TAddable[TM,
msg: AddAble = super().format(self.template, *args, **kwargs) TMS]) -> TM:
return msg if isinstance(msg, self.factory) else self.factory( msg = self.vformat(self.template, args, kwargs)
msg) # type: ignore return msg if isinstance(msg, self.factory) else self.factory(msg)
def vformat(self, format_string: str, args: Sequence[AddAble], def vformat(self, format_string: str, args: Sequence[TAddable[TM, TMS]],
kwargs: Mapping[str, AddAble]): kwargs: Mapping[str, TAddable[TM, TMS]]) -> TM:
result, arg_index, used_args = self._vformat(format_string, args, used_args = set()
kwargs, set(), 2) result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
self.check_unused_args(list(used_args), args, kwargs) self.check_unused_args(list(used_args), args, kwargs)
return result return result
def _vformat( def _vformat(
self, self,
format_string: str, format_string: str,
args: Sequence[Any], args: Sequence[TAddable[TM, TMS]],
kwargs: Mapping[str, Any], kwargs: Mapping[str, TAddable[TM, TMS]],
used_args: Set[Union[int, str]], used_args: Set[Union[int, str]],
recursion_depth: int, recursion_depth: int,
auto_arg_index: int = 0, auto_arg_index: int = 0,
) -> Tuple[AddAble, int, Set[Union[int, str]]]: ) -> Tuple[TM, int]:
if recursion_depth < 0: if recursion_depth < 0:
raise ValueError("Max string recursion exceeded") raise ValueError("Max string recursion exceeded")
results: List[AddAble] = [] results: List[TAddable[TM, TMS]] = []
for (literal_text, field_name, format_spec, for (literal_text, field_name, format_spec,
conversion) in self.parse(format_string): conversion) in self.parse(format_string):
@ -95,24 +85,23 @@ class MessageFormatter(Formatter, Generic[MessageResult_T]):
obj = self.convert_field(obj, conversion) if conversion else obj obj = self.convert_field(obj, conversion) if conversion else obj
# expand the format spec, if needed # expand the format spec, if needed
format_control, auto_arg_index, formatted_args = self._vformat( format_control, auto_arg_index = self._vformat(
format_spec, format_spec,
args, args,
kwargs, kwargs,
used_args.copy(), used_args,
recursion_depth - 1, recursion_depth - 1,
auto_arg_index, auto_arg_index,
) )
used_args |= formatted_args
# format the object and append to the result # format the object and append to the result
formatted_text = self.format_field(obj, str(format_control)) formatted_text = self.format_field(obj, str(format_control))
results.append(formatted_text) results.append(formatted_text)
return functools.reduce(operator.add, results or return self.factory(functools.reduce(operator.add, results or
[""]), auto_arg_index, used_args [""])), auto_arg_index
def format_field(self, value: AddAble_T, def format_field(self, value: TAddable[TM, TMS],
format_spec: str) -> Union[AddAble_T, str]: format_spec: str) -> TAddable[TM, TMS]:
return super().format_field(value, return super().format_field(value,
format_spec) if format_spec else value format_spec) if format_spec else value