♻️ use class rule and permission

This commit is contained in:
yanyongyu
2021-12-06 10:10:51 +08:00
parent ca4d7397f8
commit 5b75b72720
8 changed files with 202 additions and 135 deletions

View File

@ -203,6 +203,24 @@ class TrieRule:
return prefix, suffix
class Startswith:
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
self.msg = msg
self.ignorecase = ignorecase
async def __call__(self, event: Event) -> Any:
if event.get_type() != "message":
return False
text = event.get_plaintext()
return bool(
re.match(
f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})",
text,
re.IGNORECASE if self.ignorecase else 0,
)
)
def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
"""
:说明:
@ -216,18 +234,25 @@ def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Ru
if isinstance(msg, str):
msg = (msg,)
pattern = re.compile(
f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})",
re.IGNORECASE if ignorecase else 0,
)
return Rule(Startswith(msg, ignorecase))
async def _startswith(bot: Bot, event: Event, state: T_State) -> bool:
class Endswith:
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
self.msg = msg
self.ignorecase = ignorecase
async def __call__(self, event: Event) -> Any:
if event.get_type() != "message":
return False
text = event.get_plaintext()
return bool(pattern.match(text))
return Rule(_startswith)
return bool(
re.search(
f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$",
text,
re.IGNORECASE if self.ignorecase else 0,
)
)
def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
@ -243,18 +268,18 @@ def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule
if isinstance(msg, str):
msg = (msg,)
pattern = re.compile(
f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$",
re.IGNORECASE if ignorecase else 0,
)
return Rule(Endswith(msg, ignorecase))
async def _endswith(bot: Bot, event: Event, state: T_State) -> bool:
class Keywords:
def __init__(self, *keywords: str):
self.keywords = keywords
async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
text = event.get_plaintext()
return bool(pattern.search(text))
return Rule(_endswith)
return bool(text and any(keyword in text for keyword in self.keywords))
def keyword(*keywords: str) -> Rule:
@ -268,13 +293,18 @@ def keyword(*keywords: str) -> Rule:
* ``*keywords: str``: 关键词
"""
async def _keyword(event: Event) -> bool:
if event.get_type() != "message":
return False
text = event.get_plaintext()
return bool(text and any(keyword in text for keyword in keywords))
return Rule(Keywords(*keywords))
return Rule(_keyword)
class Command:
def __init__(self, cmds: List[Tuple[str, ...]]):
self.cmds = cmds
async def __call__(self, state: T_State) -> bool:
return state[PREFIX_KEY][CMD_KEY] in self.cmds
def __repr__(self):
return f"<Command {self.cmds}>"
def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
@ -304,10 +334,12 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
config = get_driver().config
command_start = config.command_start
command_sep = config.command_sep
commands = list(cmds)
for index, command in enumerate(commands):
commands: List[Tuple[str, ...]] = []
for command in cmds:
if isinstance(command, str):
commands[index] = command = (command,)
command = (command,)
commands.append(command)
if len(command) == 1:
for start in command_start:
@ -316,10 +348,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _command(state: T_State) -> bool:
return state[PREFIX_KEY][CMD_KEY] in commands
return Rule(_command)
return Rule(Command(commands))
class ArgumentParser(ArgParser):
@ -350,6 +379,27 @@ class ArgumentParser(ArgParser):
return super().parse_args(args=args, namespace=namespace) # type: ignore
class ShellCommand:
def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]):
self.cmds = cmds
self.parser = parser
async def __call__(self, event: Event, state: T_State) -> bool:
if state[PREFIX_KEY][CMD_KEY] in self.cmds:
message = str(event.get_message())
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
state[SHELL_ARGV] = shlex.split(strip_message)
if self.parser:
try:
args = self.parser.parse_args(state[SHELL_ARGV])
state[SHELL_ARGS] = args
except ParserExit as e:
state[SHELL_ARGS] = e
return True
else:
return False
def shell_command(
*cmds: Union[str, Tuple[str, ...]], parser: Optional[ArgumentParser] = None
) -> Rule:
@ -392,10 +442,12 @@ def shell_command(
config = get_driver().config
command_start = config.command_start
command_sep = config.command_sep
commands = list(cmds)
for index, command in enumerate(commands):
commands: List[Tuple[str, ...]] = []
for command in cmds:
if isinstance(command, str):
commands[index] = command = (command,)
command = (command,)
commands.append(command)
if len(command) == 1:
for start in command_start:
@ -404,23 +456,26 @@ def shell_command(
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _shell_command(event: Event, state: T_State) -> bool:
if state[PREFIX_KEY][CMD_KEY] in commands:
message = str(event.get_message())
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
state[SHELL_ARGV] = shlex.split(strip_message)
if parser:
try:
args = parser.parse_args(state[SHELL_ARGV])
state[SHELL_ARGS] = args
except ParserExit as e:
state[SHELL_ARGS] = e
return Rule(ShellCommand(commands, parser))
class Regex:
def __init__(self, regex: str, flags: int = 0):
self.regex = regex
self.flags = flags
async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
matched = re.search(self.regex, str(event.get_message()), self.flags)
if matched:
state[REGEX_MATCHED] = matched.group()
state[REGEX_GROUP] = matched.groups()
state[REGEX_DICT] = matched.groupdict()
return True
else:
return False
return Rule(_shell_command)
def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
r"""
@ -441,25 +496,12 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
\:\:\:
"""
pattern = re.compile(regex, flags)
async def _regex(event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
matched = pattern.search(str(event.get_message()))
if matched:
state[REGEX_MATCHED] = matched.group()
state[REGEX_GROUP] = matched.groups()
state[REGEX_DICT] = matched.groupdict()
return True
else:
return False
return Rule(_regex)
return Rule(Regex(regex, flags))
async def _to_me(event: Event) -> bool:
return event.is_tome()
class ToMe:
async def __call__(self, event: Event) -> bool:
return event.is_tome()
def to_me() -> Rule:
@ -473,4 +515,4 @@ def to_me() -> Rule:
* 无
"""
return Rule(_to_me)
return Rule(ToMe())