mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-06-14 08:08:16 +00:00
🎨 change typing for formatter
This commit is contained in:
parent
f0bc47ec5e
commit
58d10abd32
@ -8,19 +8,19 @@
|
|||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import asdict, dataclass, field
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (Any, Dict, Generic, Iterable, List, Mapping, Optional, Set,
|
from typing_extensions import Protocol
|
||||||
Tuple, Type, TypeVar, Union)
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping,
|
||||||
|
Generic, Optional, Iterable)
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import Protocol
|
|
||||||
|
|
||||||
from nonebot.config import Config
|
|
||||||
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
|
from nonebot.config import Config
|
||||||
from nonebot.utils import DataclassEncoder
|
from nonebot.utils import DataclassEncoder
|
||||||
|
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
|
||||||
|
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse
|
||||||
|
|
||||||
from ._formatter import MessageFormatter
|
from ._formatter import MessageFormatter
|
||||||
|
|
||||||
@ -332,7 +332,9 @@ class Message(List[TMS], abc.ABC):
|
|||||||
self.extend(self._construct(message))
|
self.extend(self._construct(message))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def template(cls: Type[TM], format_string: str) -> MessageFormatter[TM]:
|
def template(
|
||||||
|
cls: Type[TM],
|
||||||
|
format_string: str) -> MessageFormatter[TM, TMS]: # type: ignore
|
||||||
return MessageFormatter(cls, format_string)
|
return MessageFormatter(cls, format_string)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,59 +1,49 @@
|
|||||||
import functools
|
import functools
|
||||||
import operator
|
import operator
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import (Any, Generic, List, Mapping, Protocol, Sequence, Set, Tuple,
|
from typing import (Any, Set, List, Type, Tuple, Union, TypeVar, Mapping,
|
||||||
Type, TypeVar, Union, TYPE_CHECKING)
|
Generic, Sequence, TYPE_CHECKING)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nonebot.adapters import Message
|
from . import Message, MessageSegment
|
||||||
|
|
||||||
|
TM = TypeVar("TM", bound="Message")
|
||||||
|
TMS = TypeVar("TMS", bound="MessageSegment")
|
||||||
|
TAddable = Union[str, TM, TMS]
|
||||||
|
|
||||||
|
|
||||||
class AddAble(Protocol):
|
class MessageFormatter(Formatter, Generic[TM, TMS]):
|
||||||
|
|
||||||
def __add__(self, __s: Any) -> "AddAble":
|
def __init__(self, factory: Type[TM], template: str) -> None:
|
||||||
...
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
AddAble_T = TypeVar("AddAble_T", bound=AddAble)
|
|
||||||
MessageResult_T = TypeVar("MessageResult_T", bound="Message", covariant=True)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageFormatter(Formatter, Generic[MessageResult_T]):
|
|
||||||
|
|
||||||
def __init__(self, factory: Type[MessageResult_T], template: str) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.template = template
|
self.template = template
|
||||||
self.factory = factory
|
self.factory = factory
|
||||||
|
|
||||||
def format(self, *args: AddAble, **kwargs: AddAble) -> MessageResult_T:
|
def format(self, *args: TAddable[TM, TMS], **kwargs: TAddable[TM,
|
||||||
msg: AddAble = super().format(self.template, *args, **kwargs)
|
TMS]) -> TM:
|
||||||
return msg if isinstance(msg, self.factory) else self.factory(
|
msg = self.vformat(self.template, args, kwargs)
|
||||||
msg) # type: ignore
|
return msg if isinstance(msg, self.factory) else self.factory(msg)
|
||||||
|
|
||||||
def vformat(self, format_string: str, args: Sequence[AddAble],
|
def vformat(self, format_string: str, args: Sequence[TAddable[TM, TMS]],
|
||||||
kwargs: Mapping[str, AddAble]):
|
kwargs: Mapping[str, TAddable[TM, TMS]]) -> TM:
|
||||||
result, arg_index, used_args = self._vformat(format_string, args,
|
used_args = set()
|
||||||
kwargs, set(), 2)
|
result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
|
||||||
self.check_unused_args(list(used_args), args, kwargs)
|
self.check_unused_args(list(used_args), args, kwargs)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _vformat(
|
def _vformat(
|
||||||
self,
|
self,
|
||||||
format_string: str,
|
format_string: str,
|
||||||
args: Sequence[Any],
|
args: Sequence[TAddable[TM, TMS]],
|
||||||
kwargs: Mapping[str, Any],
|
kwargs: Mapping[str, TAddable[TM, TMS]],
|
||||||
used_args: Set[Union[int, str]],
|
used_args: Set[Union[int, str]],
|
||||||
recursion_depth: int,
|
recursion_depth: int,
|
||||||
auto_arg_index: int = 0,
|
auto_arg_index: int = 0,
|
||||||
) -> Tuple[AddAble, int, Set[Union[int, str]]]:
|
) -> Tuple[TM, int]:
|
||||||
|
|
||||||
if recursion_depth < 0:
|
if recursion_depth < 0:
|
||||||
raise ValueError("Max string recursion exceeded")
|
raise ValueError("Max string recursion exceeded")
|
||||||
|
|
||||||
results: List[AddAble] = []
|
results: List[TAddable[TM, TMS]] = []
|
||||||
|
|
||||||
for (literal_text, field_name, format_spec,
|
for (literal_text, field_name, format_spec,
|
||||||
conversion) in self.parse(format_string):
|
conversion) in self.parse(format_string):
|
||||||
@ -95,24 +85,23 @@ class MessageFormatter(Formatter, Generic[MessageResult_T]):
|
|||||||
obj = self.convert_field(obj, conversion) if conversion else obj
|
obj = self.convert_field(obj, conversion) if conversion else obj
|
||||||
|
|
||||||
# expand the format spec, if needed
|
# expand the format spec, if needed
|
||||||
format_control, auto_arg_index, formatted_args = self._vformat(
|
format_control, auto_arg_index = self._vformat(
|
||||||
format_spec,
|
format_spec,
|
||||||
args,
|
args,
|
||||||
kwargs,
|
kwargs,
|
||||||
used_args.copy(),
|
used_args,
|
||||||
recursion_depth - 1,
|
recursion_depth - 1,
|
||||||
auto_arg_index,
|
auto_arg_index,
|
||||||
)
|
)
|
||||||
used_args |= formatted_args
|
|
||||||
|
|
||||||
# format the object and append to the result
|
# format the object and append to the result
|
||||||
formatted_text = self.format_field(obj, str(format_control))
|
formatted_text = self.format_field(obj, str(format_control))
|
||||||
results.append(formatted_text)
|
results.append(formatted_text)
|
||||||
|
|
||||||
return functools.reduce(operator.add, results or
|
return self.factory(functools.reduce(operator.add, results or
|
||||||
[""]), auto_arg_index, used_args
|
[""])), auto_arg_index
|
||||||
|
|
||||||
def format_field(self, value: AddAble_T,
|
def format_field(self, value: TAddable[TM, TMS],
|
||||||
format_spec: str) -> Union[AddAble_T, str]:
|
format_spec: str) -> TAddable[TM, TMS]:
|
||||||
return super().format_field(value,
|
return super().format_field(value,
|
||||||
format_spec) if format_spec else value
|
format_spec) if format_spec else value
|
||||||
|
Loading…
x
Reference in New Issue
Block a user