⚗️ change rule to use handler

This commit is contained in:
yanyongyu
2021-11-19 18:18:53 +08:00
parent ee619a33a9
commit 471d306e13
8 changed files with 182 additions and 148 deletions

View File

@ -14,16 +14,18 @@ import shlex
import asyncio
from itertools import product
from argparse import Namespace
from contextlib import AsyncExitStack
from typing_extensions import TypedDict
from argparse import ArgumentParser as ArgParser
from typing import (Any, Tuple, Union, Callable, NoReturn, Optional, Sequence,
Awaitable)
from typing import (Any, Dict, Tuple, Union, Callable, NoReturn, Optional,
Sequence, Awaitable)
from pygtrie import CharTrie
from nonebot import get_driver
from nonebot.log import logger
from nonebot.utils import run_sync
from nonebot.handler import Handler
from nonebot import params, get_driver
from nonebot.exception import ParserExit
from nonebot.typing import T_State, T_RuleChecker
from nonebot.adapters import Bot, Event, MessageSegment
@ -62,16 +64,22 @@ class Rule:
"""
__slots__ = ("checkers",)
def __init__(
self, *checkers: Callable[[Bot, Event, T_State],
Awaitable[bool]]) -> None:
HANDLER_PARAM_TYPES = [
params.BotParam, params.EventParam, params.StateParam
]
def __init__(self, *checkers: T_RuleChecker) -> None:
"""
:参数:
* ``*checkers: Callable[[Bot, Event, T_State], Awaitable[bool]]``: **异步** RuleChecker
* ``*checkers: T_RuleChecker``: RuleChecker
"""
self.checkers = set(checkers)
self.checkers = set(
Handler(checker,
allow_types=self.HANDLER_PARAM_TYPES,
dependency_overrides_provider=get_driver())
for checker in checkers)
"""
:说明:
@ -79,10 +87,17 @@ class Rule:
:类型:
* ``Set[Callable[[Bot, Event, T_State], Awaitable[bool]]]``
* ``Set[Handler]``
"""
async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool:
async def __call__(
self,
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> bool:
"""
:说明:
@ -99,19 +114,21 @@ class Rule:
- ``bool``
"""
results = await asyncio.gather(
*map(lambda c: c(bot, event, state), self.checkers))
checker(bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache)
for checker in self.checkers)
return all(results)
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
checkers = self.checkers.copy()
if other is None:
return self
elif isinstance(other, Rule):
checkers |= other.checkers
elif asyncio.iscoroutinefunction(other):
checkers.add(other) # type: ignore
checkers = [*self.checkers, *other.checkers]
else:
checkers.add(run_sync(other))
checkers = [*self.checkers, other]
return Rule(*checkers)
def __or__(self, other) -> NoReturn:
@ -226,7 +243,7 @@ def keyword(*keywords: str) -> Rule:
* ``*keywords: str``: 关键词
"""
async def _keyword(bot: Bot, event: Event, state: T_State) -> bool:
async def _keyword(event: Event) -> bool:
if event.get_type() != "message":
return False
text = event.get_plaintext()
@ -274,7 +291,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _command(bot: Bot, event: Event, state: T_State) -> bool:
async def _command(state: T_State) -> bool:
return state[PREFIX_KEY][CMD_KEY] in commands
return Rule(_command)
@ -294,7 +311,7 @@ class ArgumentParser(ArgParser):
old_message += message
setattr(self, "message", old_message)
def exit(self, status=0, message=None):
def exit(self, status: int = 0, message: Optional[str] = None):
raise ParserExit(status=status,
message=message or getattr(self, "message", None))
@ -360,7 +377,7 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]],
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _shell_command(bot: Bot, event: Event, state: T_State) -> bool:
async def _shell_command(event: Event, state: T_State) -> bool:
if state[PREFIX_KEY][CMD_KEY] in commands:
message = str(event.get_message())
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]
@ -400,7 +417,7 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
pattern = re.compile(regex, flags)
async def _regex(bot: Bot, event: Event, state: T_State) -> bool:
async def _regex(event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
matched = pattern.search(str(event.get_message()))
@ -415,6 +432,10 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
return Rule(_regex)
async def _to_me(event: Event) -> bool:
return event.is_tome()
def to_me() -> Rule:
"""
:说明:
@ -426,7 +447,4 @@ def to_me() -> Rule:
* 无
"""
async def _to_me(bot: Bot, event: Event, state: T_State) -> bool:
return event.is_tome()
return Rule(_to_me)