🎨 update typing support

This commit is contained in:
yanyongyu
2020-12-06 02:30:19 +08:00
parent 9ab7176eaf
commit 629eed08b6
26 changed files with 247 additions and 205 deletions

View File

@ -12,13 +12,17 @@
import re
import asyncio
from itertools import product
from typing import Any, Dict, Union, Tuple, Optional, Callable, NoReturn, Awaitable, TYPE_CHECKING
from pygtrie import CharTrie
from nonebot import get_driver
from nonebot.log import logger
from nonebot.utils import run_sync
from nonebot.typing import Bot, Any, Dict, Event, Union, Tuple, NoReturn, Optional, Callable, Awaitable, RuleChecker
from nonebot.typing import State, RuleChecker
if TYPE_CHECKING:
from nonebot.adapters import BaseBot as Bot, BaseEvent as Event
class Rule:
@ -39,12 +43,12 @@ class Rule:
__slots__ = ("checkers",)
def __init__(
self, *checkers: Callable[[Bot, Event, dict],
self, *checkers: Callable[["Bot", "Event", State],
Awaitable[bool]]) -> None:
"""
:参数:
* ``*checkers: Callable[[Bot, Event, dict], Awaitable[bool]]``: **异步** RuleChecker
* ``*checkers: Callable[[Bot, Event, State], Awaitable[bool]]``: **异步** RuleChecker
"""
self.checkers = set(checkers)
@ -55,10 +59,10 @@ class Rule:
:类型:
* ``Set[Callable[[Bot, Event, dict], Awaitable[bool]]]``
* ``Set[Callable[[Bot, Event, State], Awaitable[bool]]]``
"""
async def __call__(self, bot: Bot, event: Event, state: dict) -> bool:
async def __call__(self, bot: "Bot", event: "Event", state: State) -> bool:
"""
:说明:
@ -68,7 +72,7 @@ class Rule:
* ``bot: Bot``: Bot 对象
* ``event: Event``: Event 对象
* ``state: dict``: 当前 State
* ``state: State``: 当前 State
:返回:
@ -113,8 +117,8 @@ class TrieRule:
cls.suffix[suffix[::-1]] = value
@classmethod
def get_value(cls, bot: Bot, event: Event,
state: dict) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def get_value(cls, bot: "Bot", event: "Event",
state: State) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if event.type != "message":
state["_prefix"] = {"raw_command": None, "command": None}
state["_suffix"] = {"raw_command": None, "command": None}
@ -176,7 +180,7 @@ def startswith(msg: str) -> Rule:
* ``msg: str``: 消息开头字符串
"""
async def _startswith(bot: Bot, event: Event, state: dict) -> bool:
async def _startswith(bot: "Bot", event: "Event", state: State) -> bool:
return event.plain_text.startswith(msg)
return Rule(_startswith)
@ -193,7 +197,7 @@ def endswith(msg: str) -> Rule:
* ``msg: str``: 消息结尾字符串
"""
async def _endswith(bot: Bot, event: Event, state: dict) -> bool:
async def _endswith(bot: "Bot", event: "Event", state: State) -> bool:
return event.plain_text.endswith(msg)
return Rule(_endswith)
@ -210,7 +214,7 @@ def keyword(*keywords: str) -> Rule:
* ``*keywords: str``: 关键词
"""
async def _keyword(bot: Bot, event: Event, state: dict) -> bool:
async def _keyword(bot: "Bot", event: "Event", state: State) -> bool:
return bool(event.plain_text and
any(keyword in event.plain_text for keyword in keywords))
@ -256,7 +260,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _command(bot: Bot, event: Event, state: dict) -> bool:
async def _command(bot: "Bot", event: "Event", state: State) -> bool:
return state["_prefix"]["command"] in commands
return Rule(_command)
@ -282,7 +286,7 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
pattern = re.compile(regex, flags)
async def _regex(bot: Bot, event: Event, state: dict) -> bool:
async def _regex(bot: "Bot", event: "Event", state: State) -> bool:
matched = pattern.search(str(event.message))
if matched:
state["_matched"] = matched.group()
@ -305,7 +309,7 @@ def to_me() -> Rule:
* 无
"""
async def _to_me(bot: Bot, event: Event, state: dict) -> bool:
async def _to_me(bot: "Bot", event: "Event", state: State) -> bool:
return bool(event.to_me)
return Rule(_to_me)