Feat: 添加 CommandStart 依赖注入参数 (#915)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: yanyongyu <42488585+yanyongyu@users.noreply.github.com>
This commit is contained in:
MeetWq
2022-04-20 14:43:29 +08:00
committed by GitHub
parent f989710cd6
commit 533e99418c
6 changed files with 74 additions and 11 deletions

View File

@ -14,7 +14,7 @@ from itertools import product
from argparse import Namespace
from typing_extensions import TypedDict
from argparse import ArgumentParser as ArgParser
from typing import Any, List, Tuple, Union, Optional, Sequence
from typing import Any, List, Tuple, Union, Optional, Sequence, NamedTuple
from pygtrie import CharTrie
@ -41,6 +41,7 @@ from nonebot.consts import (
CMD_ARG_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
CMD_START_KEY,
REGEX_MATCHED,
)
@ -50,15 +51,20 @@ CMD_RESULT = TypedDict(
"command": Optional[Tuple[str, ...]],
"raw_command": Optional[str],
"command_arg": Optional[Message[MessageSegment]],
"command_start": Optional[str],
},
)
TRIE_VALUE = NamedTuple(
"TRIE_VALUE", [("command_start", str), ("command", Tuple[str, ...])]
)
class TrieRule:
prefix: CharTrie = CharTrie()
@classmethod
def add_prefix(cls, prefix: str, value: Any):
def add_prefix(cls, prefix: str, value: TRIE_VALUE) -> None:
if prefix in cls.prefix:
logger.warning(f'Duplicated prefix rule "{prefix}"')
return
@ -66,7 +72,9 @@ class TrieRule:
@classmethod
def get_value(cls, bot: Bot, event: Event, state: T_State) -> CMD_RESULT:
prefix = CMD_RESULT(command=None, raw_command=None, command_arg=None)
prefix = CMD_RESULT(
command=None, raw_command=None, command_arg=None, command_start=None
)
state[PREFIX_KEY] = prefix
if event.get_type() != "message":
return prefix
@ -76,9 +84,11 @@ class TrieRule:
if message_seg.is_text():
segment_text = str(message_seg).lstrip()
pf = cls.prefix.longest_prefix(segment_text)
prefix[RAW_CMD_KEY] = pf.key
prefix[CMD_KEY] = pf.value
if pf.key:
if pf:
value: TRIE_VALUE = pf.value
prefix[RAW_CMD_KEY] = pf.key
prefix[CMD_START_KEY] = value.command_start
prefix[CMD_KEY] = value.command
msg = message.copy()
msg.pop(0)
new_message = msg.__class__(segment_text[len(pf.key) :].lstrip())
@ -292,10 +302,12 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
if len(command) == 1:
for start in command_start:
TrieRule.add_prefix(f"{start}{command[0]}", command)
TrieRule.add_prefix(f"{start}{command[0]}", TRIE_VALUE(start, command))
else:
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
TrieRule.add_prefix(
f"{start}{sep.join(command)}", TRIE_VALUE(start, command)
)
return Rule(CommandRule(commands))
@ -416,10 +428,12 @@ def shell_command(
if len(command) == 1:
for start in command_start:
TrieRule.add_prefix(f"{start}{command[0]}", command)
TrieRule.add_prefix(f"{start}{command[0]}", TRIE_VALUE(start, command))
else:
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
TrieRule.add_prefix(
f"{start}{sep.join(command)}", TRIE_VALUE(start, command)
)
return Rule(ShellCommandRule(commands, parser))