Implement argument filters

This commit is contained in:
Richard Chien 2019-01-25 00:14:30 +08:00
parent 6b6daf7235
commit f8ecc7bba1
5 changed files with 284 additions and 47 deletions

View File

@ -3,16 +3,18 @@ import re
import shlex import shlex
from datetime import datetime from datetime import datetime
from typing import ( from typing import (
Tuple, Union, Callable, Iterable, Any, Optional, List, Dict Tuple, Union, Callable, Iterable, Any, Optional, List, Dict,
Awaitable
) )
from . import NoneBot, permission as perm from nonebot import NoneBot, permission as perm
from .helpers import context_id, send, render_expression from nonebot.command.argfilter import ArgFilter_T, ValidateError
from .log import logger from nonebot.helpers import context_id, send, render_expression
from .message import Message from nonebot.log import logger
from .session import BaseSession from nonebot.message import Message
from .typing import ( from nonebot.session import BaseSession
Context_T, CommandName_T, CommandArgs_T, Message_T from nonebot.typing import (
Context_T, CommandName_T, CommandArgs_T, Message_T, State_T
) )
# key: one segment of command name # key: one segment of command name
@ -27,19 +29,25 @@ _aliases = {} # type: Dict[str, CommandName_T]
# value: CommandSession object # value: CommandSession object
_sessions = {} # type: Dict[str, CommandSession] _sessions = {} # type: Dict[str, CommandSession]
CommandHandler_T = Callable[['CommandSession'], Any]
class Command: class Command:
__slots__ = ('name', 'func', 'permission', __slots__ = ('name', 'func', 'permission',
'only_to_me', 'privileged', 'args_parser_func') 'only_to_me', 'privileged', 'args_parser_func')
def __init__(self, *, name: CommandName_T, func: Callable, def __init__(self, *,
permission: int, only_to_me: bool, privileged: bool): name: CommandName_T,
func: CommandHandler_T,
permission: int,
only_to_me: bool,
privileged: bool):
self.name = name self.name = name
self.func = func self.func = func
self.permission = permission self.permission = permission
self.only_to_me = only_to_me self.only_to_me = only_to_me
self.privileged = privileged self.privileged = privileged
self.args_parser_func = None self.args_parser_func: Optional[CommandHandler_T] = None
async def run(self, session, *, async def run(self, session, *,
check_perm: bool = True, check_perm: bool = True,
@ -56,8 +64,28 @@ class Command:
if self.func and has_perm: if self.func and has_perm:
if dry: if dry:
return True return True
if self.args_parser_func:
await self.args_parser_func(session) if session.current_arg_filters is not None and \
session.current_key is not None:
# argument-level filters are given, use them
arg = session.current_arg
for f in session.current_arg_filters:
try:
res = f(arg)
if isinstance(res, Awaitable):
res = await res
arg = res
except ValidateError as e:
# validation failed
session.pause(e.message)
# passed all filters
session.state[session.current_key] = arg
else:
# fallback to command-level args_parser_func
if self.args_parser_func:
await self.args_parser_func(session)
await self.func(session) await self.func(session)
return True return True
return False return False
@ -77,14 +105,14 @@ class Command:
class CommandFunc: class CommandFunc:
__slots__ = ('cmd', 'func') __slots__ = ('cmd', 'func')
def __init__(self, cmd: Command, func: Callable): def __init__(self, cmd: Command, func: CommandHandler_T):
self.cmd = cmd self.cmd = cmd
self.func = func self.func = func
def __call__(self, *args, **kwargs): def __call__(self, session: 'CommandSession') -> Any:
return self.func(*args, **kwargs) return self.func(session)
def args_parser(self, parser_func: Callable): def args_parser(self, parser_func: CommandHandler_T) -> CommandHandler_T:
""" """
Decorator to register a function as the arguments parser of Decorator to register a function as the arguments parser of
the corresponding command. the corresponding command.
@ -110,7 +138,7 @@ def on_command(name: Union[str, CommandName_T], *,
:param shell_like: use shell-like syntax to split arguments :param shell_like: use shell-like syntax to split arguments
""" """
def deco(func: Callable) -> Callable: def deco(func: CommandHandler_T) -> CommandHandler_T:
if not isinstance(name, (str, tuple)): if not isinstance(name, (str, tuple)):
raise TypeError('the name of a command must be a str or tuple') raise TypeError('the name of a command must be a str or tuple')
if not name: if not name:
@ -153,7 +181,7 @@ class CommandGroup:
privileged: Optional[bool] = None, privileged: Optional[bool] = None,
shell_like: Optional[bool] = None): shell_like: Optional[bool] = None):
self.basename = (name,) if isinstance(name, str) else name self.basename = (name,) if isinstance(name, str) else name
self.permission = permission self.permission = permission # TODO: use .pyi
self.only_to_me = only_to_me self.only_to_me = only_to_me
self.privileged = privileged self.privileged = privileged
self.shell_like = shell_like self.shell_like = shell_like
@ -204,10 +232,10 @@ def _find_command(name: Union[str, CommandName_T]) -> Optional[Command]:
return cmd if isinstance(cmd, Command) else None return cmd if isinstance(cmd, Command) else None
class _FurtherInteractionNeeded(Exception): class _PauseException(Exception):
""" """
Raised by session.pause() indicating that the command should Raised by session.pause() indicating that the command session
enter interactive mode to ask the user for some arguments. should be paused to ask the user for some arguments.
""" """
pass pass
@ -244,23 +272,48 @@ class SwitchException(Exception):
class CommandSession(BaseSession): class CommandSession(BaseSession):
__slots__ = ('cmd', 'current_key', 'current_arg', __slots__ = ('cmd', 'current_key', 'current_arg', 'current_arg_filters',
'current_arg_text', 'current_arg_images', 'current_arg_text', 'current_arg_images',
'args', '_last_interaction', '_running') '_state', '_last_interaction', '_running')
def __init__(self, bot: NoneBot, ctx: Context_T, cmd: Command, *, def __init__(self, bot: NoneBot, ctx: Context_T, cmd: Command, *,
current_arg: str = '', args: Optional[CommandArgs_T] = None): current_arg: str = '', args: Optional[CommandArgs_T] = None):
super().__init__(bot, ctx) super().__init__(bot, ctx)
self.cmd = cmd # Command object self.cmd = cmd # Command object
self.current_key = None # current key that the command handler needs
self.current_arg = None # current argument (with potential CQ codes) # unique key of the argument that is currently requesting (asking)
self.current_arg_text = None # current argument without any CQ codes self.current_key: Optional[str] = None
self.current_arg_images = None # image urls in current argument
self.refresh(ctx, current_arg=current_arg) # initialize current argument
self.args = args or {} self.current_arg: str = '' # with potential CQ codes
self.current_arg_text: str = '' # without any CQ codes TODO: property
self.current_arg_images: List[str] = [] # image urls
self.refresh(ctx, current_arg=current_arg) # fill the above
# initialize current argument filters
self.current_arg_filters: Optional[List[ArgFilter_T]] = None
self._state: State_T = {}
if args:
self._state.update(args)
self._last_interaction = None # last interaction time of this session self._last_interaction = None # last interaction time of this session
self._running = False self._running = False
@property
def state(self) -> State_T:
"""
State of the session.
This contains all named arguments and
other session scope temporary values.
"""
return self._state
@property
def args(self) -> CommandArgs_T:
"""Deprecated. Use `session.state` instead."""
return self.state
@property @property
def running(self) -> bool: def running(self) -> bool:
return self._running return self._running
@ -292,7 +345,7 @@ class CommandSession(BaseSession):
Shell-like argument list, similar to sys.argv. Shell-like argument list, similar to sys.argv.
Only available while shell_like is True in on_command decorator. Only available while shell_like is True in on_command decorator.
""" """
return self.get_optional('argv', []) return self.state.get('argv', [])
def refresh(self, ctx: Context_T, *, current_arg: str = '') -> None: def refresh(self, ctx: Context_T, *, current_arg: str = '') -> None:
""" """
@ -308,38 +361,46 @@ class CommandSession(BaseSession):
self.current_arg_images = [s.data['url'] for s in current_arg_as_msg self.current_arg_images = [s.data['url'] for s in current_arg_as_msg
if s.type == 'image' and 'url' in s.data] if s.type == 'image' and 'url' in s.data]
def get(self, key: Any, *, def get(self, key: str, *,
prompt: Optional[Message_T] = None, **kwargs) -> Any: prompt: Optional[Message_T] = None,
arg_filters: Optional[List[ArgFilter_T]] = None,
**kwargs) -> Any:
""" """
Get an argument with a given key. Get an argument with a given key.
If the argument does not exist in the current session, If the argument does not exist in the current session,
a FurtherInteractionNeeded exception will be raised, a pause exception will be raised, and the caller of
and the caller of the command will know it should keep the command will know it should keep the session for
the session for further interaction with the user. further interaction with the user.
:param key: argument key :param key: argument key
:param prompt: prompt to ask the user :param prompt: prompt to ask the user
:param arg_filters: argument filters for next user input
:return: the argument value :return: the argument value
""" """
value = self.get_optional(key) if key in self.state:
if value is not None: return self.state[key]
return value
self.current_key = key self.current_key = key
self.current_arg_filters = arg_filters
# TODO: self.current_send_kwargs
# ask the user for more information # ask the user for more information
self.pause(prompt, **kwargs) self.pause(prompt, **kwargs)
def get_optional(self, key: Any, def get_optional(self, key: str,
default: Optional[Any] = None) -> Optional[Any]: default: Optional[Any] = None) -> Optional[Any]:
"""Simply get a argument with given key.""" """
return self.args.get(key, default) Simply get a argument with given key.
Deprecated. Use `session.state.get()` instead.
"""
return self.state.get(key, default)
def pause(self, message: Optional[Message_T] = None, **kwargs) -> None: def pause(self, message: Optional[Message_T] = None, **kwargs) -> None:
"""Pause the session for further interaction.""" """Pause the session for further interaction."""
if message: if message:
asyncio.ensure_future(self.send(message, **kwargs)) asyncio.ensure_future(self.send(message, **kwargs))
raise _FurtherInteractionNeeded raise _PauseException
def finish(self, message: Optional[Message_T] = None, **kwargs) -> None: def finish(self, message: Optional[Message_T] = None, **kwargs) -> None:
"""Finish the session.""" """Finish the session."""
@ -564,9 +625,7 @@ async def _real_run_command(session: CommandSession,
handled = future.result() handled = future.result()
except asyncio.TimeoutError: except asyncio.TimeoutError:
handled = True handled = True
except (_FurtherInteractionNeeded, except (_PauseException, _FinishException, SwitchException) as e:
_FinishException,
SwitchException) as e:
raise e raise e
except Exception as e: except Exception as e:
logger.error(f'An exception occurred while ' logger.error(f'An exception occurred while '
@ -574,7 +633,7 @@ async def _real_run_command(session: CommandSession,
logger.exception(e) logger.exception(e)
handled = True handled = True
raise _FinishException(handled) raise _FinishException(handled)
except _FurtherInteractionNeeded: except _PauseException:
session.running = False session.running = False
if disable_interaction: if disable_interaction:
# if the command needs further interaction, we view it as failed # if the command needs further interaction, we view it as failed

View File

@ -0,0 +1,8 @@
from typing import Callable, Any, Awaitable, Union
ArgFilter_T = Callable[[Any], Union[Any, Awaitable[Any]]]
class ValidateError(ValueError):
def __init__(self, message=None):
self.message = message

View File

@ -0,0 +1,40 @@
from typing import Optional, List
def _simple_chinese_to_bool(text: str) -> Optional[bool]:
"""
Convert a chinese text to boolean.
Examples:
是的 -> True
好的呀 -> True
不要 -> False
不用了 -> False
你好呀 -> None
"""
text = text.strip().lower().replace(' ', '') \
.rstrip(',.!?~,。!?~了的呢吧呀啊呗啦')
if text in {'', '', '', '', '', '', '',
'ok', 'okay', 'yeah', 'yep',
'当真', '当然', '必须', '可以', '肯定', '没错', '确定', '确认'}:
return True
if text in {'', '不要', '不用', '不是', '', '不好', '不对', '不行', '',
'no', 'nono', 'nonono', 'nope', '不ok', '不可以', '不能',
'不可以'}:
return False
return None
def _split_nonempty_lines(text: str) -> List[str]:
return list(filter(lambda x: x, text.splitlines()))
def _split_nonempty_stripped_lines(text: str) -> List[str]:
return list(filter(lambda x: x,
map(lambda x: x.strip(), text.splitlines())))
simple_chinese_to_bool = _simple_chinese_to_bool
split_nonempty_lines = _split_nonempty_lines
split_nonempty_stripped_lines = _split_nonempty_stripped_lines

View File

@ -0,0 +1,29 @@
import re
from typing import List
from nonebot.message import Message
from nonebot.typing import Message_T
def _extract_text(arg: Message_T) -> str:
"""Extract all plain text segments from a message-like object."""
arg_as_msg = Message(arg)
return arg_as_msg.extract_plain_text()
def _extract_image_urls(arg: Message_T) -> List[str]:
"""Extract all image urls from a message-like object."""
arg_as_msg = Message(arg)
return [s.data['url'] for s in arg_as_msg
if s.type == 'image' and 'url' in s.data]
def _extract_numbers(arg: Message_T) -> List[float]:
"""Extract all numbers (integers and floats) from a message-like object."""
s = str(arg)
return list(map(float, re.findall(r'[+-]?(\d*\.?\d+|\d+\.?\d*)', s)))
extract_text = _extract_text
extract_image_urls = _extract_image_urls
extract_numbers = _extract_numbers

View File

@ -0,0 +1,101 @@
import re
from typing import Callable, Any
from nonebot.command.argfilter import ValidateError
class BaseValidator:
def __init__(self, message=None):
self.message = message
def raise_failure(self):
raise ValidateError(self.message)
class not_empty(BaseValidator):
"""
Validate any object to ensure it's not empty (is None or has no elements).
"""
def __call__(self, value):
if value is None:
self.raise_failure()
if hasattr(value, '__len__') and value.__len__() == 0:
self.raise_failure()
return value
class fit_size(BaseValidator):
"""
Validate any sized object to ensure the size/length
is in a given range [min_length, max_length].
"""
def __init__(self, min_length: int = 0, max_length: int = None,
message=None):
super().__init__(message)
self.min_length = min_length
self.max_length = max_length
def __call__(self, value):
length = len(value) if value is not None else 0
if length < self.min_length or \
(self.max_length is not None and length > self.max_length):
self.raise_failure()
return value
class match_regex(BaseValidator):
"""
Validate any string object to ensure it matches a given pattern.
"""
def __init__(self, pattern: str, message=None, *, flags=0,
fullmatch: bool = False):
super().__init__(message)
self.pattern = re.compile(pattern, flags)
self.fullmatch = fullmatch
def __call__(self, value):
if self.fullmatch:
if not re.fullmatch(self.pattern, value):
self.raise_failure()
else:
if not re.match(self.pattern, value):
self.raise_failure()
return value
class ensure_true(BaseValidator):
"""
Validate any object to ensure the result of applying
a boolean function to it is True.
"""
def __init__(self, bool_func: Callable[[Any], bool], message=None):
super().__init__(message)
self.bool_func = bool_func
def __call__(self, value):
if self.bool_func(value) is not True:
self.raise_failure()
return value
class between_inclusive(BaseValidator):
"""
Validate any comparable object to ensure it's between
`start` and `end` inclusively.
"""
def __init__(self, start=None, end=None, message=None):
super().__init__(message)
self.start = start
self.end = end
def __call__(self, value):
if self.start is not None and value < self.start:
self.raise_failure()
if self.end is not None and self.end < value:
self.raise_failure()
return value