mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-11-04 00:46:43 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			185 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			185 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import functools
 | 
						|
from string import Formatter
 | 
						|
from typing import (
 | 
						|
    TYPE_CHECKING,
 | 
						|
    Any,
 | 
						|
    Set,
 | 
						|
    Dict,
 | 
						|
    List,
 | 
						|
    Type,
 | 
						|
    Tuple,
 | 
						|
    Union,
 | 
						|
    Generic,
 | 
						|
    Mapping,
 | 
						|
    TypeVar,
 | 
						|
    Callable,
 | 
						|
    Optional,
 | 
						|
    Sequence,
 | 
						|
    cast,
 | 
						|
    overload,
 | 
						|
)
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    from .message import Message, MessageSegment
 | 
						|
 | 
						|
TM = TypeVar("TM", bound="Message")
 | 
						|
TF = TypeVar("TF", str, "Message")
 | 
						|
 | 
						|
FormatSpecFunc = Callable[[Any], str]
 | 
						|
FormatSpecFunc_T = TypeVar("FormatSpecFunc_T", bound=FormatSpecFunc)
 | 
						|
 | 
						|
 | 
						|
class MessageTemplate(Formatter, Generic[TF]):
 | 
						|
    """消息模板格式化实现类。
 | 
						|
 | 
						|
    参数:
 | 
						|
        template: 模板
 | 
						|
        factory: 消息类型工厂,默认为 `str`
 | 
						|
    """
 | 
						|
 | 
						|
    @overload
 | 
						|
    def __init__(
 | 
						|
        self: "MessageTemplate[str]", template: str, factory: Type[str] = str
 | 
						|
    ) -> None:
 | 
						|
        ...
 | 
						|
 | 
						|
    @overload
 | 
						|
    def __init__(
 | 
						|
        self: "MessageTemplate[TM]", template: Union[str, TM], factory: Type[TM]
 | 
						|
    ) -> None:
 | 
						|
        ...
 | 
						|
 | 
						|
    def __init__(self, template, factory=str) -> None:
 | 
						|
        self.template: TF = template
 | 
						|
        self.factory: Type[TF] = factory
 | 
						|
        self.format_specs: Dict[str, FormatSpecFunc] = {}
 | 
						|
 | 
						|
    def add_format_spec(
 | 
						|
        self, spec: FormatSpecFunc_T, name: Optional[str] = None
 | 
						|
    ) -> FormatSpecFunc_T:
 | 
						|
        name = name or spec.__name__
 | 
						|
        if name in self.format_specs:
 | 
						|
            raise ValueError(f"Format spec {name} already exists!")
 | 
						|
        self.format_specs[name] = spec
 | 
						|
        return spec
 | 
						|
 | 
						|
    def format(self, *args, **kwargs):
 | 
						|
        """根据传入参数和模板生成消息对象"""
 | 
						|
        return self._format(args, kwargs)
 | 
						|
 | 
						|
    def format_map(self, mapping: Mapping[str, Any]) -> TF:
 | 
						|
        """根据传入字典和模板生成消息对象, 在传入字段名不是有效标识符时有用"""
 | 
						|
        return self._format([], mapping)
 | 
						|
 | 
						|
    def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TF:
 | 
						|
        msg = self.factory()
 | 
						|
        if isinstance(self.template, str):
 | 
						|
            msg += self.vformat(self.template, args, kwargs)
 | 
						|
        elif isinstance(self.template, self.factory):
 | 
						|
            template = cast("Message[MessageSegment]", self.template)
 | 
						|
            for seg in template:
 | 
						|
                msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg
 | 
						|
        else:
 | 
						|
            raise TypeError("template must be a string or instance of Message!")
 | 
						|
 | 
						|
        return msg  # type:ignore
 | 
						|
 | 
						|
    def vformat(
 | 
						|
        self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]
 | 
						|
    ) -> TF:
 | 
						|
        used_args = set()
 | 
						|
        result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
 | 
						|
        self.check_unused_args(list(used_args), args, kwargs)
 | 
						|
        return result
 | 
						|
 | 
						|
    def _vformat(
 | 
						|
        self,
 | 
						|
        format_string: str,
 | 
						|
        args: Sequence[Any],
 | 
						|
        kwargs: Mapping[str, Any],
 | 
						|
        used_args: Set[Union[int, str]],
 | 
						|
        recursion_depth: int,
 | 
						|
        auto_arg_index: int = 0,
 | 
						|
    ) -> Tuple[TF, int]:
 | 
						|
        if recursion_depth < 0:
 | 
						|
            raise ValueError("Max string recursion exceeded")
 | 
						|
 | 
						|
        results: List[Any] = [self.factory()]
 | 
						|
 | 
						|
        for (literal_text, field_name, format_spec, conversion) in self.parse(
 | 
						|
            format_string
 | 
						|
        ):
 | 
						|
 | 
						|
            # output the literal text
 | 
						|
            if literal_text:
 | 
						|
                results.append(literal_text)
 | 
						|
 | 
						|
            # if there's a field, output it
 | 
						|
            if field_name is not None:
 | 
						|
                # this is some markup, find the object and do
 | 
						|
                #  the formatting
 | 
						|
 | 
						|
                # handle arg indexing when empty field_names are given.
 | 
						|
                if field_name == "":
 | 
						|
                    if auto_arg_index is False:
 | 
						|
                        raise ValueError(
 | 
						|
                            "cannot switch from manual field specification to "
 | 
						|
                            "automatic field numbering"
 | 
						|
                        )
 | 
						|
                    field_name = str(auto_arg_index)
 | 
						|
                    auto_arg_index += 1
 | 
						|
                elif field_name.isdigit():
 | 
						|
                    if auto_arg_index:
 | 
						|
                        raise ValueError(
 | 
						|
                            "cannot switch from manual field specification to "
 | 
						|
                            "automatic field numbering"
 | 
						|
                        )
 | 
						|
                    # disable auto arg incrementing, if it gets
 | 
						|
                    # used later on, then an exception will be raised
 | 
						|
                    auto_arg_index = False
 | 
						|
 | 
						|
                # given the field_name, find the object it references
 | 
						|
                #  and the argument it came from
 | 
						|
                obj, arg_used = self.get_field(field_name, args, kwargs)
 | 
						|
                used_args.add(arg_used)
 | 
						|
 | 
						|
                assert format_spec is not None
 | 
						|
 | 
						|
                # do any conversion on the resulting object
 | 
						|
                obj = self.convert_field(obj, conversion) if conversion else obj
 | 
						|
 | 
						|
                # expand the format spec, if needed
 | 
						|
                format_control, auto_arg_index = self._vformat(
 | 
						|
                    format_spec,
 | 
						|
                    args,
 | 
						|
                    kwargs,
 | 
						|
                    used_args,
 | 
						|
                    recursion_depth - 1,
 | 
						|
                    auto_arg_index,
 | 
						|
                )
 | 
						|
 | 
						|
                # format the object and append to the result
 | 
						|
                formatted_text = self.format_field(obj, str(format_control))
 | 
						|
                results.append(formatted_text)
 | 
						|
 | 
						|
        return functools.reduce(self._add, results), auto_arg_index
 | 
						|
 | 
						|
    def format_field(self, value: Any, format_spec: str) -> Any:
 | 
						|
        formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec)
 | 
						|
        if formatter is None and not issubclass(self.factory, str):
 | 
						|
            segment_class: Type["MessageSegment"] = self.factory.get_segment_class()
 | 
						|
            method = getattr(segment_class, format_spec, None)
 | 
						|
            if callable(method) and not cast(str, method.__name__).startswith("_"):
 | 
						|
                formatter = getattr(segment_class, format_spec)
 | 
						|
        return (
 | 
						|
            super().format_field(value, format_spec)
 | 
						|
            if formatter is None
 | 
						|
            else formatter(value)
 | 
						|
        )
 | 
						|
 | 
						|
    def _add(self, a: Any, b: Any) -> Any:
 | 
						|
        try:
 | 
						|
            return a + b
 | 
						|
        except TypeError:
 | 
						|
            return a + str(b)
 |