mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-07-16 02:50:48 +00:00
✨ Feature: 命令匹配支持强制指定空白符 (#1748)
This commit is contained in:
@ -39,8 +39,8 @@ from nonebot.log import logger
|
||||
from nonebot.typing import T_State
|
||||
from nonebot.exception import ParserExit
|
||||
from nonebot.internal.rule import Rule as Rule
|
||||
from nonebot.params import Command, EventToMe, CommandArg
|
||||
from nonebot.adapters import Bot, Event, Message, MessageSegment
|
||||
from nonebot.params import Command, EventToMe, CommandArg, CommandWhitespace
|
||||
from nonebot.consts import (
|
||||
CMD_KEY,
|
||||
REGEX_STR,
|
||||
@ -57,6 +57,7 @@ from nonebot.consts import (
|
||||
FULLMATCH_KEY,
|
||||
REGEX_MATCHED,
|
||||
STARTSWITH_KEY,
|
||||
CMD_WHITESPACE_KEY,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -68,6 +69,7 @@ CMD_RESULT = TypedDict(
|
||||
"raw_command": Optional[str],
|
||||
"command_arg": Optional[Message[MessageSegment]],
|
||||
"command_start": Optional[str],
|
||||
"command_whitespace": Optional[str],
|
||||
},
|
||||
)
|
||||
|
||||
@ -91,7 +93,11 @@ class TrieRule:
|
||||
@classmethod
|
||||
def get_value(cls, bot: Bot, event: Event, state: T_State) -> CMD_RESULT:
|
||||
prefix = CMD_RESULT(
|
||||
command=None, raw_command=None, command_arg=None, command_start=None
|
||||
command=None,
|
||||
raw_command=None,
|
||||
command_arg=None,
|
||||
command_start=None,
|
||||
command_whitespace=None,
|
||||
)
|
||||
state[PREFIX_KEY] = prefix
|
||||
if event.get_type() != "message":
|
||||
@ -106,11 +112,25 @@ class TrieRule:
|
||||
prefix[RAW_CMD_KEY] = pf.key
|
||||
prefix[CMD_START_KEY] = value.command_start
|
||||
prefix[CMD_KEY] = value.command
|
||||
|
||||
msg = message.copy()
|
||||
msg.pop(0)
|
||||
new_message = msg.__class__(segment_text[len(pf.key) :].lstrip())
|
||||
for new_segment in reversed(new_message):
|
||||
msg.insert(0, new_segment)
|
||||
|
||||
# check whitespace
|
||||
arg_str = segment_text[len(pf.key) :]
|
||||
arg_str_stripped = arg_str.lstrip()
|
||||
has_arg = arg_str_stripped or msg
|
||||
if (
|
||||
has_arg
|
||||
and (stripped_len := len(arg_str) - len(arg_str_stripped)) > 0
|
||||
):
|
||||
prefix[CMD_WHITESPACE_KEY] = arg_str[:stripped_len]
|
||||
|
||||
# construct command arg
|
||||
if arg_str_stripped:
|
||||
new_message = msg.__class__(arg_str_stripped)
|
||||
for new_segment in reversed(new_message):
|
||||
msg.insert(0, new_segment)
|
||||
prefix[CMD_ARG_KEY] = msg
|
||||
|
||||
return prefix
|
||||
@ -339,12 +359,18 @@ class CommandRule:
|
||||
|
||||
参数:
|
||||
cmds: 指定命令元组列表
|
||||
force_whitespace: 是否强制命令后必须有指定空白符
|
||||
"""
|
||||
|
||||
__slots__ = ("cmds",)
|
||||
__slots__ = ("cmds", "force_whitespace")
|
||||
|
||||
def __init__(self, cmds: List[Tuple[str, ...]]):
|
||||
def __init__(
|
||||
self,
|
||||
cmds: List[Tuple[str, ...]],
|
||||
force_whitespace: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
self.cmds = tuple(cmds)
|
||||
self.force_whitespace = force_whitespace
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Command(cmds={self.cmds})"
|
||||
@ -357,11 +383,24 @@ class CommandRule:
|
||||
def __hash__(self) -> int:
|
||||
return hash((frozenset(self.cmds),))
|
||||
|
||||
async def __call__(self, cmd: Optional[Tuple[str, ...]] = Command()) -> bool:
|
||||
return cmd in self.cmds
|
||||
async def __call__(
|
||||
self,
|
||||
cmd: Optional[Tuple[str, ...]] = Command(),
|
||||
cmd_whitespace: Optional[str] = CommandWhitespace(),
|
||||
) -> bool:
|
||||
if cmd not in self.cmds:
|
||||
return False
|
||||
if self.force_whitespace is None:
|
||||
return True
|
||||
if isinstance(self.force_whitespace, str):
|
||||
return self.force_whitespace == cmd_whitespace
|
||||
return self.force_whitespace == (cmd_whitespace is not None)
|
||||
|
||||
|
||||
def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
|
||||
def command(
|
||||
*cmds: Union[str, Tuple[str, ...]],
|
||||
force_whitespace: Optional[Union[str, bool]] = None,
|
||||
) -> Rule:
|
||||
"""匹配消息命令。
|
||||
|
||||
根据配置里提供的 {ref}``command_start` <nonebot.config.Config.command_start>`,
|
||||
@ -373,6 +412,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
|
||||
|
||||
参数:
|
||||
cmds: 命令文本或命令元组
|
||||
force_whitespace: 是否强制命令后必须有指定空白符
|
||||
|
||||
用法:
|
||||
使用默认 `command_start`, `command_sep` 配置
|
||||
@ -404,7 +444,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
|
||||
f"{start}{sep.join(command)}", TRIE_VALUE(start, command)
|
||||
)
|
||||
|
||||
return Rule(CommandRule(commands))
|
||||
return Rule(CommandRule(commands, force_whitespace))
|
||||
|
||||
|
||||
class ArgumentParser(ArgParser):
|
||||
|
Reference in New Issue
Block a user