mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-06-17 18:08:05 +00:00
Implement argument filters
This commit is contained in:
parent
6b6daf7235
commit
f8ecc7bba1
@ -3,16 +3,18 @@ import re
|
|||||||
import shlex
|
import shlex
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import (
|
||||||
Tuple, Union, Callable, Iterable, Any, Optional, List, Dict
|
Tuple, Union, Callable, Iterable, Any, Optional, List, Dict,
|
||||||
|
Awaitable
|
||||||
)
|
)
|
||||||
|
|
||||||
from . import NoneBot, permission as perm
|
from nonebot import NoneBot, permission as perm
|
||||||
from .helpers import context_id, send, render_expression
|
from nonebot.command.argfilter import ArgFilter_T, ValidateError
|
||||||
from .log import logger
|
from nonebot.helpers import context_id, send, render_expression
|
||||||
from .message import Message
|
from nonebot.log import logger
|
||||||
from .session import BaseSession
|
from nonebot.message import Message
|
||||||
from .typing import (
|
from nonebot.session import BaseSession
|
||||||
Context_T, CommandName_T, CommandArgs_T, Message_T
|
from nonebot.typing import (
|
||||||
|
Context_T, CommandName_T, CommandArgs_T, Message_T, State_T
|
||||||
)
|
)
|
||||||
|
|
||||||
# key: one segment of command name
|
# key: one segment of command name
|
||||||
@ -27,19 +29,25 @@ _aliases = {} # type: Dict[str, CommandName_T]
|
|||||||
# value: CommandSession object
|
# value: CommandSession object
|
||||||
_sessions = {} # type: Dict[str, CommandSession]
|
_sessions = {} # type: Dict[str, CommandSession]
|
||||||
|
|
||||||
|
CommandHandler_T = Callable[['CommandSession'], Any]
|
||||||
|
|
||||||
|
|
||||||
class Command:
|
class Command:
|
||||||
__slots__ = ('name', 'func', 'permission',
|
__slots__ = ('name', 'func', 'permission',
|
||||||
'only_to_me', 'privileged', 'args_parser_func')
|
'only_to_me', 'privileged', 'args_parser_func')
|
||||||
|
|
||||||
def __init__(self, *, name: CommandName_T, func: Callable,
|
def __init__(self, *,
|
||||||
permission: int, only_to_me: bool, privileged: bool):
|
name: CommandName_T,
|
||||||
|
func: CommandHandler_T,
|
||||||
|
permission: int,
|
||||||
|
only_to_me: bool,
|
||||||
|
privileged: bool):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.func = func
|
self.func = func
|
||||||
self.permission = permission
|
self.permission = permission
|
||||||
self.only_to_me = only_to_me
|
self.only_to_me = only_to_me
|
||||||
self.privileged = privileged
|
self.privileged = privileged
|
||||||
self.args_parser_func = None
|
self.args_parser_func: Optional[CommandHandler_T] = None
|
||||||
|
|
||||||
async def run(self, session, *,
|
async def run(self, session, *,
|
||||||
check_perm: bool = True,
|
check_perm: bool = True,
|
||||||
@ -56,8 +64,28 @@ class Command:
|
|||||||
if self.func and has_perm:
|
if self.func and has_perm:
|
||||||
if dry:
|
if dry:
|
||||||
return True
|
return True
|
||||||
if self.args_parser_func:
|
|
||||||
await self.args_parser_func(session)
|
if session.current_arg_filters is not None and \
|
||||||
|
session.current_key is not None:
|
||||||
|
# argument-level filters are given, use them
|
||||||
|
arg = session.current_arg
|
||||||
|
for f in session.current_arg_filters:
|
||||||
|
try:
|
||||||
|
res = f(arg)
|
||||||
|
if isinstance(res, Awaitable):
|
||||||
|
res = await res
|
||||||
|
arg = res
|
||||||
|
except ValidateError as e:
|
||||||
|
# validation failed
|
||||||
|
session.pause(e.message)
|
||||||
|
|
||||||
|
# passed all filters
|
||||||
|
session.state[session.current_key] = arg
|
||||||
|
else:
|
||||||
|
# fallback to command-level args_parser_func
|
||||||
|
if self.args_parser_func:
|
||||||
|
await self.args_parser_func(session)
|
||||||
|
|
||||||
await self.func(session)
|
await self.func(session)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
@ -77,14 +105,14 @@ class Command:
|
|||||||
class CommandFunc:
|
class CommandFunc:
|
||||||
__slots__ = ('cmd', 'func')
|
__slots__ = ('cmd', 'func')
|
||||||
|
|
||||||
def __init__(self, cmd: Command, func: Callable):
|
def __init__(self, cmd: Command, func: CommandHandler_T):
|
||||||
self.cmd = cmd
|
self.cmd = cmd
|
||||||
self.func = func
|
self.func = func
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, session: 'CommandSession') -> Any:
|
||||||
return self.func(*args, **kwargs)
|
return self.func(session)
|
||||||
|
|
||||||
def args_parser(self, parser_func: Callable):
|
def args_parser(self, parser_func: CommandHandler_T) -> CommandHandler_T:
|
||||||
"""
|
"""
|
||||||
Decorator to register a function as the arguments parser of
|
Decorator to register a function as the arguments parser of
|
||||||
the corresponding command.
|
the corresponding command.
|
||||||
@ -110,7 +138,7 @@ def on_command(name: Union[str, CommandName_T], *,
|
|||||||
:param shell_like: use shell-like syntax to split arguments
|
:param shell_like: use shell-like syntax to split arguments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def deco(func: Callable) -> Callable:
|
def deco(func: CommandHandler_T) -> CommandHandler_T:
|
||||||
if not isinstance(name, (str, tuple)):
|
if not isinstance(name, (str, tuple)):
|
||||||
raise TypeError('the name of a command must be a str or tuple')
|
raise TypeError('the name of a command must be a str or tuple')
|
||||||
if not name:
|
if not name:
|
||||||
@ -153,7 +181,7 @@ class CommandGroup:
|
|||||||
privileged: Optional[bool] = None,
|
privileged: Optional[bool] = None,
|
||||||
shell_like: Optional[bool] = None):
|
shell_like: Optional[bool] = None):
|
||||||
self.basename = (name,) if isinstance(name, str) else name
|
self.basename = (name,) if isinstance(name, str) else name
|
||||||
self.permission = permission
|
self.permission = permission # TODO: use .pyi
|
||||||
self.only_to_me = only_to_me
|
self.only_to_me = only_to_me
|
||||||
self.privileged = privileged
|
self.privileged = privileged
|
||||||
self.shell_like = shell_like
|
self.shell_like = shell_like
|
||||||
@ -204,10 +232,10 @@ def _find_command(name: Union[str, CommandName_T]) -> Optional[Command]:
|
|||||||
return cmd if isinstance(cmd, Command) else None
|
return cmd if isinstance(cmd, Command) else None
|
||||||
|
|
||||||
|
|
||||||
class _FurtherInteractionNeeded(Exception):
|
class _PauseException(Exception):
|
||||||
"""
|
"""
|
||||||
Raised by session.pause() indicating that the command should
|
Raised by session.pause() indicating that the command session
|
||||||
enter interactive mode to ask the user for some arguments.
|
should be paused to ask the user for some arguments.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -244,23 +272,48 @@ class SwitchException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class CommandSession(BaseSession):
|
class CommandSession(BaseSession):
|
||||||
__slots__ = ('cmd', 'current_key', 'current_arg',
|
__slots__ = ('cmd', 'current_key', 'current_arg', 'current_arg_filters',
|
||||||
'current_arg_text', 'current_arg_images',
|
'current_arg_text', 'current_arg_images',
|
||||||
'args', '_last_interaction', '_running')
|
'_state', '_last_interaction', '_running')
|
||||||
|
|
||||||
def __init__(self, bot: NoneBot, ctx: Context_T, cmd: Command, *,
|
def __init__(self, bot: NoneBot, ctx: Context_T, cmd: Command, *,
|
||||||
current_arg: str = '', args: Optional[CommandArgs_T] = None):
|
current_arg: str = '', args: Optional[CommandArgs_T] = None):
|
||||||
super().__init__(bot, ctx)
|
super().__init__(bot, ctx)
|
||||||
self.cmd = cmd # Command object
|
self.cmd = cmd # Command object
|
||||||
self.current_key = None # current key that the command handler needs
|
|
||||||
self.current_arg = None # current argument (with potential CQ codes)
|
# unique key of the argument that is currently requesting (asking)
|
||||||
self.current_arg_text = None # current argument without any CQ codes
|
self.current_key: Optional[str] = None
|
||||||
self.current_arg_images = None # image urls in current argument
|
|
||||||
self.refresh(ctx, current_arg=current_arg)
|
# initialize current argument
|
||||||
self.args = args or {}
|
self.current_arg: str = '' # with potential CQ codes
|
||||||
|
self.current_arg_text: str = '' # without any CQ codes TODO: property
|
||||||
|
self.current_arg_images: List[str] = [] # image urls
|
||||||
|
self.refresh(ctx, current_arg=current_arg) # fill the above
|
||||||
|
|
||||||
|
# initialize current argument filters
|
||||||
|
self.current_arg_filters: Optional[List[ArgFilter_T]] = None
|
||||||
|
|
||||||
|
self._state: State_T = {}
|
||||||
|
if args:
|
||||||
|
self._state.update(args)
|
||||||
self._last_interaction = None # last interaction time of this session
|
self._last_interaction = None # last interaction time of this session
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> State_T:
|
||||||
|
"""
|
||||||
|
State of the session.
|
||||||
|
|
||||||
|
This contains all named arguments and
|
||||||
|
other session scope temporary values.
|
||||||
|
"""
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args(self) -> CommandArgs_T:
|
||||||
|
"""Deprecated. Use `session.state` instead."""
|
||||||
|
return self.state
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def running(self) -> bool:
|
def running(self) -> bool:
|
||||||
return self._running
|
return self._running
|
||||||
@ -292,7 +345,7 @@ class CommandSession(BaseSession):
|
|||||||
Shell-like argument list, similar to sys.argv.
|
Shell-like argument list, similar to sys.argv.
|
||||||
Only available while shell_like is True in on_command decorator.
|
Only available while shell_like is True in on_command decorator.
|
||||||
"""
|
"""
|
||||||
return self.get_optional('argv', [])
|
return self.state.get('argv', [])
|
||||||
|
|
||||||
def refresh(self, ctx: Context_T, *, current_arg: str = '') -> None:
|
def refresh(self, ctx: Context_T, *, current_arg: str = '') -> None:
|
||||||
"""
|
"""
|
||||||
@ -308,38 +361,46 @@ class CommandSession(BaseSession):
|
|||||||
self.current_arg_images = [s.data['url'] for s in current_arg_as_msg
|
self.current_arg_images = [s.data['url'] for s in current_arg_as_msg
|
||||||
if s.type == 'image' and 'url' in s.data]
|
if s.type == 'image' and 'url' in s.data]
|
||||||
|
|
||||||
def get(self, key: Any, *,
|
def get(self, key: str, *,
|
||||||
prompt: Optional[Message_T] = None, **kwargs) -> Any:
|
prompt: Optional[Message_T] = None,
|
||||||
|
arg_filters: Optional[List[ArgFilter_T]] = None,
|
||||||
|
**kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
Get an argument with a given key.
|
Get an argument with a given key.
|
||||||
|
|
||||||
If the argument does not exist in the current session,
|
If the argument does not exist in the current session,
|
||||||
a FurtherInteractionNeeded exception will be raised,
|
a pause exception will be raised, and the caller of
|
||||||
and the caller of the command will know it should keep
|
the command will know it should keep the session for
|
||||||
the session for further interaction with the user.
|
further interaction with the user.
|
||||||
|
|
||||||
:param key: argument key
|
:param key: argument key
|
||||||
:param prompt: prompt to ask the user
|
:param prompt: prompt to ask the user
|
||||||
|
:param arg_filters: argument filters for next user input
|
||||||
:return: the argument value
|
:return: the argument value
|
||||||
"""
|
"""
|
||||||
value = self.get_optional(key)
|
if key in self.state:
|
||||||
if value is not None:
|
return self.state[key]
|
||||||
return value
|
|
||||||
|
|
||||||
self.current_key = key
|
self.current_key = key
|
||||||
|
self.current_arg_filters = arg_filters
|
||||||
|
# TODO: self.current_send_kwargs
|
||||||
# ask the user for more information
|
# ask the user for more information
|
||||||
self.pause(prompt, **kwargs)
|
self.pause(prompt, **kwargs)
|
||||||
|
|
||||||
def get_optional(self, key: Any,
|
def get_optional(self, key: str,
|
||||||
default: Optional[Any] = None) -> Optional[Any]:
|
default: Optional[Any] = None) -> Optional[Any]:
|
||||||
"""Simply get a argument with given key."""
|
"""
|
||||||
return self.args.get(key, default)
|
Simply get a argument with given key.
|
||||||
|
|
||||||
|
Deprecated. Use `session.state.get()` instead.
|
||||||
|
"""
|
||||||
|
return self.state.get(key, default)
|
||||||
|
|
||||||
def pause(self, message: Optional[Message_T] = None, **kwargs) -> None:
|
def pause(self, message: Optional[Message_T] = None, **kwargs) -> None:
|
||||||
"""Pause the session for further interaction."""
|
"""Pause the session for further interaction."""
|
||||||
if message:
|
if message:
|
||||||
asyncio.ensure_future(self.send(message, **kwargs))
|
asyncio.ensure_future(self.send(message, **kwargs))
|
||||||
raise _FurtherInteractionNeeded
|
raise _PauseException
|
||||||
|
|
||||||
def finish(self, message: Optional[Message_T] = None, **kwargs) -> None:
|
def finish(self, message: Optional[Message_T] = None, **kwargs) -> None:
|
||||||
"""Finish the session."""
|
"""Finish the session."""
|
||||||
@ -564,9 +625,7 @@ async def _real_run_command(session: CommandSession,
|
|||||||
handled = future.result()
|
handled = future.result()
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
handled = True
|
handled = True
|
||||||
except (_FurtherInteractionNeeded,
|
except (_PauseException, _FinishException, SwitchException) as e:
|
||||||
_FinishException,
|
|
||||||
SwitchException) as e:
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'An exception occurred while '
|
logger.error(f'An exception occurred while '
|
||||||
@ -574,7 +633,7 @@ async def _real_run_command(session: CommandSession,
|
|||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
handled = True
|
handled = True
|
||||||
raise _FinishException(handled)
|
raise _FinishException(handled)
|
||||||
except _FurtherInteractionNeeded:
|
except _PauseException:
|
||||||
session.running = False
|
session.running = False
|
||||||
if disable_interaction:
|
if disable_interaction:
|
||||||
# if the command needs further interaction, we view it as failed
|
# if the command needs further interaction, we view it as failed
|
8
nonebot/command/argfilter/__init__.py
Normal file
8
nonebot/command/argfilter/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from typing import Callable, Any, Awaitable, Union
|
||||||
|
|
||||||
|
ArgFilter_T = Callable[[Any], Union[Any, Awaitable[Any]]]
|
||||||
|
|
||||||
|
|
||||||
|
class ValidateError(ValueError):
|
||||||
|
def __init__(self, message=None):
|
||||||
|
self.message = message
|
40
nonebot/command/argfilter/converters.py
Normal file
40
nonebot/command/argfilter/converters.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
|
||||||
|
def _simple_chinese_to_bool(text: str) -> Optional[bool]:
|
||||||
|
"""
|
||||||
|
Convert a chinese text to boolean.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
是的 -> True
|
||||||
|
好的呀 -> True
|
||||||
|
不要 -> False
|
||||||
|
不用了 -> False
|
||||||
|
你好呀 -> None
|
||||||
|
"""
|
||||||
|
text = text.strip().lower().replace(' ', '') \
|
||||||
|
.rstrip(',.!?~,。!?~了的呢吧呀啊呗啦')
|
||||||
|
if text in {'要', '用', '是', '好', '对', '嗯', '行',
|
||||||
|
'ok', 'okay', 'yeah', 'yep',
|
||||||
|
'当真', '当然', '必须', '可以', '肯定', '没错', '确定', '确认'}:
|
||||||
|
return True
|
||||||
|
if text in {'不', '不要', '不用', '不是', '否', '不好', '不对', '不行', '别',
|
||||||
|
'no', 'nono', 'nonono', 'nope', '不ok', '不可以', '不能',
|
||||||
|
'不可以'}:
|
||||||
|
return False
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _split_nonempty_lines(text: str) -> List[str]:
|
||||||
|
return list(filter(lambda x: x, text.splitlines()))
|
||||||
|
|
||||||
|
|
||||||
|
def _split_nonempty_stripped_lines(text: str) -> List[str]:
|
||||||
|
return list(filter(lambda x: x,
|
||||||
|
map(lambda x: x.strip(), text.splitlines())))
|
||||||
|
|
||||||
|
|
||||||
|
simple_chinese_to_bool = _simple_chinese_to_bool
|
||||||
|
split_nonempty_lines = _split_nonempty_lines
|
||||||
|
split_nonempty_stripped_lines = _split_nonempty_stripped_lines
|
29
nonebot/command/argfilter/extractors.py
Normal file
29
nonebot/command/argfilter/extractors.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from nonebot.message import Message
|
||||||
|
from nonebot.typing import Message_T
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text(arg: Message_T) -> str:
|
||||||
|
"""Extract all plain text segments from a message-like object."""
|
||||||
|
arg_as_msg = Message(arg)
|
||||||
|
return arg_as_msg.extract_plain_text()
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_image_urls(arg: Message_T) -> List[str]:
|
||||||
|
"""Extract all image urls from a message-like object."""
|
||||||
|
arg_as_msg = Message(arg)
|
||||||
|
return [s.data['url'] for s in arg_as_msg
|
||||||
|
if s.type == 'image' and 'url' in s.data]
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_numbers(arg: Message_T) -> List[float]:
|
||||||
|
"""Extract all numbers (integers and floats) from a message-like object."""
|
||||||
|
s = str(arg)
|
||||||
|
return list(map(float, re.findall(r'[+-]?(\d*\.?\d+|\d+\.?\d*)', s)))
|
||||||
|
|
||||||
|
|
||||||
|
extract_text = _extract_text
|
||||||
|
extract_image_urls = _extract_image_urls
|
||||||
|
extract_numbers = _extract_numbers
|
101
nonebot/command/argfilter/validators.py
Normal file
101
nonebot/command/argfilter/validators.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import re
|
||||||
|
from typing import Callable, Any
|
||||||
|
|
||||||
|
from nonebot.command.argfilter import ValidateError
|
||||||
|
|
||||||
|
|
||||||
|
class BaseValidator:
|
||||||
|
def __init__(self, message=None):
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def raise_failure(self):
|
||||||
|
raise ValidateError(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class not_empty(BaseValidator):
|
||||||
|
"""
|
||||||
|
Validate any object to ensure it's not empty (is None or has no elements).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, value):
|
||||||
|
if value is None:
|
||||||
|
self.raise_failure()
|
||||||
|
if hasattr(value, '__len__') and value.__len__() == 0:
|
||||||
|
self.raise_failure()
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class fit_size(BaseValidator):
|
||||||
|
"""
|
||||||
|
Validate any sized object to ensure the size/length
|
||||||
|
is in a given range [min_length, max_length].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, min_length: int = 0, max_length: int = None,
|
||||||
|
message=None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.min_length = min_length
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def __call__(self, value):
|
||||||
|
length = len(value) if value is not None else 0
|
||||||
|
if length < self.min_length or \
|
||||||
|
(self.max_length is not None and length > self.max_length):
|
||||||
|
self.raise_failure()
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class match_regex(BaseValidator):
|
||||||
|
"""
|
||||||
|
Validate any string object to ensure it matches a given pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pattern: str, message=None, *, flags=0,
|
||||||
|
fullmatch: bool = False):
|
||||||
|
super().__init__(message)
|
||||||
|
self.pattern = re.compile(pattern, flags)
|
||||||
|
self.fullmatch = fullmatch
|
||||||
|
|
||||||
|
def __call__(self, value):
|
||||||
|
if self.fullmatch:
|
||||||
|
if not re.fullmatch(self.pattern, value):
|
||||||
|
self.raise_failure()
|
||||||
|
else:
|
||||||
|
if not re.match(self.pattern, value):
|
||||||
|
self.raise_failure()
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class ensure_true(BaseValidator):
|
||||||
|
"""
|
||||||
|
Validate any object to ensure the result of applying
|
||||||
|
a boolean function to it is True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bool_func: Callable[[Any], bool], message=None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.bool_func = bool_func
|
||||||
|
|
||||||
|
def __call__(self, value):
|
||||||
|
if self.bool_func(value) is not True:
|
||||||
|
self.raise_failure()
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class between_inclusive(BaseValidator):
|
||||||
|
"""
|
||||||
|
Validate any comparable object to ensure it's between
|
||||||
|
`start` and `end` inclusively.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, start=None, end=None, message=None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.start = start
|
||||||
|
self.end = end
|
||||||
|
|
||||||
|
def __call__(self, value):
|
||||||
|
if self.start is not None and value < self.start:
|
||||||
|
self.raise_failure()
|
||||||
|
if self.end is not None and self.end < value:
|
||||||
|
self.raise_failure()
|
||||||
|
return value
|
Loading…
x
Reference in New Issue
Block a user