From 58d10abd325f7cdeef087f9fee9aea7d390dd746 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Fri, 27 Aug 2021 14:46:15 +0800 Subject: [PATCH] :art: change typing for formatter --- nonebot/adapters/_base.py | 18 +++++----- nonebot/adapters/_formatter.py | 65 ++++++++++++++-------------------- 2 files changed, 37 insertions(+), 46 deletions(-) diff --git a/nonebot/adapters/_base.py b/nonebot/adapters/_base.py index 123914e6..387cf316 100644 --- a/nonebot/adapters/_base.py +++ b/nonebot/adapters/_base.py @@ -8,19 +8,19 @@ import abc import asyncio from copy import deepcopy -from dataclasses import asdict, dataclass, field from functools import partial -from typing import (Any, Dict, Generic, Iterable, List, Mapping, Optional, Set, - Tuple, Type, TypeVar, Union) +from typing_extensions import Protocol +from dataclasses import asdict, dataclass, field +from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping, + Generic, Optional, Iterable) from pydantic import BaseModel -from typing_extensions import Protocol -from nonebot.config import Config -from nonebot.drivers import Driver, HTTPConnection, HTTPResponse from nonebot.log import logger -from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook +from nonebot.config import Config from nonebot.utils import DataclassEncoder +from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook +from nonebot.drivers import Driver, HTTPConnection, HTTPResponse from ._formatter import MessageFormatter @@ -332,7 +332,9 @@ class Message(List[TMS], abc.ABC): self.extend(self._construct(message)) @classmethod - def template(cls: Type[TM], format_string: str) -> MessageFormatter[TM]: + def template( + cls: Type[TM], + format_string: str) -> MessageFormatter[TM, TMS]: # type: ignore return MessageFormatter(cls, format_string) @classmethod diff --git a/nonebot/adapters/_formatter.py b/nonebot/adapters/_formatter.py index 1c11afa3..9efc004c 100644 --- a/nonebot/adapters/_formatter.py +++ b/nonebot/adapters/_formatter.py @@ -1,59 +1,49 @@ import functools import operator from string import Formatter -from typing import (Any, Generic, List, Mapping, Protocol, Sequence, Set, Tuple, - Type, TypeVar, Union, TYPE_CHECKING) +from typing import (Any, Set, List, Type, Tuple, Union, TypeVar, Mapping, + Generic, Sequence, TYPE_CHECKING) if TYPE_CHECKING: - from nonebot.adapters import Message + from . import Message, MessageSegment + +TM = TypeVar("TM", bound="Message") +TMS = TypeVar("TMS", bound="MessageSegment") +TAddable = Union[str, TM, TMS] -class AddAble(Protocol): +class MessageFormatter(Formatter, Generic[TM, TMS]): - def __add__(self, __s: Any) -> "AddAble": - ... - - def __str__(self) -> str: - ... - - -AddAble_T = TypeVar("AddAble_T", bound=AddAble) -MessageResult_T = TypeVar("MessageResult_T", bound="Message", covariant=True) - - -class MessageFormatter(Formatter, Generic[MessageResult_T]): - - def __init__(self, factory: Type[MessageResult_T], template: str) -> None: - super().__init__() + def __init__(self, factory: Type[TM], template: str) -> None: self.template = template self.factory = factory - def format(self, *args: AddAble, **kwargs: AddAble) -> MessageResult_T: - msg: AddAble = super().format(self.template, *args, **kwargs) - return msg if isinstance(msg, self.factory) else self.factory( - msg) # type: ignore + def format(self, *args: TAddable[TM, TMS], **kwargs: TAddable[TM, + TMS]) -> TM: + msg = self.vformat(self.template, args, kwargs) + return msg if isinstance(msg, self.factory) else self.factory(msg) - def vformat(self, format_string: str, args: Sequence[AddAble], - kwargs: Mapping[str, AddAble]): - result, arg_index, used_args = self._vformat(format_string, args, - kwargs, set(), 2) + def vformat(self, format_string: str, args: Sequence[TAddable[TM, TMS]], + kwargs: Mapping[str, TAddable[TM, TMS]]) -> TM: + used_args = set() + result, _ = self._vformat(format_string, args, kwargs, used_args, 2) self.check_unused_args(list(used_args), args, kwargs) return result def _vformat( self, format_string: str, - args: Sequence[Any], - kwargs: Mapping[str, Any], + args: Sequence[TAddable[TM, TMS]], + kwargs: Mapping[str, TAddable[TM, TMS]], used_args: Set[Union[int, str]], recursion_depth: int, auto_arg_index: int = 0, - ) -> Tuple[AddAble, int, Set[Union[int, str]]]: + ) -> Tuple[TM, int]: if recursion_depth < 0: raise ValueError("Max string recursion exceeded") - results: List[AddAble] = [] + results: List[TAddable[TM, TMS]] = [] for (literal_text, field_name, format_spec, conversion) in self.parse(format_string): @@ -95,24 +85,23 @@ class MessageFormatter(Formatter, Generic[MessageResult_T]): obj = self.convert_field(obj, conversion) if conversion else obj # expand the format spec, if needed - format_control, auto_arg_index, formatted_args = self._vformat( + format_control, auto_arg_index = self._vformat( format_spec, args, kwargs, - used_args.copy(), + used_args, recursion_depth - 1, auto_arg_index, ) - used_args |= formatted_args # format the object and append to the result formatted_text = self.format_field(obj, str(format_control)) results.append(formatted_text) - return functools.reduce(operator.add, results or - [""]), auto_arg_index, used_args + return self.factory(functools.reduce(operator.add, results or + [""])), auto_arg_index - def format_field(self, value: AddAble_T, - format_spec: str) -> Union[AddAble_T, str]: + def format_field(self, value: TAddable[TM, TMS], + format_spec: str) -> TAddable[TM, TMS]: return super().format_field(value, format_spec) if format_spec else value