diff --git a/nonebot/command.py b/nonebot/command/__init__.py similarity index 82% rename from nonebot/command.py rename to nonebot/command/__init__.py index c949778c..3fda70af 100644 --- a/nonebot/command.py +++ b/nonebot/command/__init__.py @@ -3,16 +3,18 @@ import re import shlex from datetime import datetime from typing import ( - Tuple, Union, Callable, Iterable, Any, Optional, List, Dict + Tuple, Union, Callable, Iterable, Any, Optional, List, Dict, + Awaitable ) -from . import NoneBot, permission as perm -from .helpers import context_id, send, render_expression -from .log import logger -from .message import Message -from .session import BaseSession -from .typing import ( - Context_T, CommandName_T, CommandArgs_T, Message_T +from nonebot import NoneBot, permission as perm +from nonebot.command.argfilter import ArgFilter_T, ValidateError +from nonebot.helpers import context_id, send, render_expression +from nonebot.log import logger +from nonebot.message import Message +from nonebot.session import BaseSession +from nonebot.typing import ( + Context_T, CommandName_T, CommandArgs_T, Message_T, State_T ) # key: one segment of command name @@ -27,19 +29,25 @@ _aliases = {} # type: Dict[str, CommandName_T] # value: CommandSession object _sessions = {} # type: Dict[str, CommandSession] +CommandHandler_T = Callable[['CommandSession'], Any] + class Command: __slots__ = ('name', 'func', 'permission', 'only_to_me', 'privileged', 'args_parser_func') - def __init__(self, *, name: CommandName_T, func: Callable, - permission: int, only_to_me: bool, privileged: bool): + def __init__(self, *, + name: CommandName_T, + func: CommandHandler_T, + permission: int, + only_to_me: bool, + privileged: bool): self.name = name self.func = func self.permission = permission self.only_to_me = only_to_me self.privileged = privileged - self.args_parser_func = None + self.args_parser_func: Optional[CommandHandler_T] = None async def run(self, session, *, check_perm: bool = True, @@ -56,8 +64,28 @@ class Command: if self.func and has_perm: if dry: return True - if self.args_parser_func: - await self.args_parser_func(session) + + if session.current_arg_filters is not None and \ + session.current_key is not None: + # argument-level filters are given, use them + arg = session.current_arg + for f in session.current_arg_filters: + try: + res = f(arg) + if isinstance(res, Awaitable): + res = await res + arg = res + except ValidateError as e: + # validation failed + session.pause(e.message) + + # passed all filters + session.state[session.current_key] = arg + else: + # fallback to command-level args_parser_func + if self.args_parser_func: + await self.args_parser_func(session) + await self.func(session) return True return False @@ -77,14 +105,14 @@ class Command: class CommandFunc: __slots__ = ('cmd', 'func') - def __init__(self, cmd: Command, func: Callable): + def __init__(self, cmd: Command, func: CommandHandler_T): self.cmd = cmd self.func = func - def __call__(self, *args, **kwargs): - return self.func(*args, **kwargs) + def __call__(self, session: 'CommandSession') -> Any: + return self.func(session) - def args_parser(self, parser_func: Callable): + def args_parser(self, parser_func: CommandHandler_T) -> CommandHandler_T: """ Decorator to register a function as the arguments parser of the corresponding command. @@ -110,7 +138,7 @@ def on_command(name: Union[str, CommandName_T], *, :param shell_like: use shell-like syntax to split arguments """ - def deco(func: Callable) -> Callable: + def deco(func: CommandHandler_T) -> CommandHandler_T: if not isinstance(name, (str, tuple)): raise TypeError('the name of a command must be a str or tuple') if not name: @@ -153,7 +181,7 @@ class CommandGroup: privileged: Optional[bool] = None, shell_like: Optional[bool] = None): self.basename = (name,) if isinstance(name, str) else name - self.permission = permission + self.permission = permission # TODO: use .pyi self.only_to_me = only_to_me self.privileged = privileged self.shell_like = shell_like @@ -204,10 +232,10 @@ def _find_command(name: Union[str, CommandName_T]) -> Optional[Command]: return cmd if isinstance(cmd, Command) else None -class _FurtherInteractionNeeded(Exception): +class _PauseException(Exception): """ - Raised by session.pause() indicating that the command should - enter interactive mode to ask the user for some arguments. + Raised by session.pause() indicating that the command session + should be paused to ask the user for some arguments. """ pass @@ -244,23 +272,48 @@ class SwitchException(Exception): class CommandSession(BaseSession): - __slots__ = ('cmd', 'current_key', 'current_arg', + __slots__ = ('cmd', 'current_key', 'current_arg', 'current_arg_filters', 'current_arg_text', 'current_arg_images', - 'args', '_last_interaction', '_running') + '_state', '_last_interaction', '_running') def __init__(self, bot: NoneBot, ctx: Context_T, cmd: Command, *, current_arg: str = '', args: Optional[CommandArgs_T] = None): super().__init__(bot, ctx) self.cmd = cmd # Command object - self.current_key = None # current key that the command handler needs - self.current_arg = None # current argument (with potential CQ codes) - self.current_arg_text = None # current argument without any CQ codes - self.current_arg_images = None # image urls in current argument - self.refresh(ctx, current_arg=current_arg) - self.args = args or {} + + # unique key of the argument that is currently requesting (asking) + self.current_key: Optional[str] = None + + # initialize current argument + self.current_arg: str = '' # with potential CQ codes + self.current_arg_text: str = '' # without any CQ codes TODO: property + self.current_arg_images: List[str] = [] # image urls + self.refresh(ctx, current_arg=current_arg) # fill the above + + # initialize current argument filters + self.current_arg_filters: Optional[List[ArgFilter_T]] = None + + self._state: State_T = {} + if args: + self._state.update(args) self._last_interaction = None # last interaction time of this session self._running = False + @property + def state(self) -> State_T: + """ + State of the session. + + This contains all named arguments and + other session scope temporary values. + """ + return self._state + + @property + def args(self) -> CommandArgs_T: + """Deprecated. Use `session.state` instead.""" + return self.state + @property def running(self) -> bool: return self._running @@ -292,7 +345,7 @@ class CommandSession(BaseSession): Shell-like argument list, similar to sys.argv. Only available while shell_like is True in on_command decorator. """ - return self.get_optional('argv', []) + return self.state.get('argv', []) def refresh(self, ctx: Context_T, *, current_arg: str = '') -> None: """ @@ -308,38 +361,46 @@ class CommandSession(BaseSession): self.current_arg_images = [s.data['url'] for s in current_arg_as_msg if s.type == 'image' and 'url' in s.data] - def get(self, key: Any, *, - prompt: Optional[Message_T] = None, **kwargs) -> Any: + def get(self, key: str, *, + prompt: Optional[Message_T] = None, + arg_filters: Optional[List[ArgFilter_T]] = None, + **kwargs) -> Any: """ Get an argument with a given key. If the argument does not exist in the current session, - a FurtherInteractionNeeded exception will be raised, - and the caller of the command will know it should keep - the session for further interaction with the user. + a pause exception will be raised, and the caller of + the command will know it should keep the session for + further interaction with the user. :param key: argument key :param prompt: prompt to ask the user + :param arg_filters: argument filters for next user input :return: the argument value """ - value = self.get_optional(key) - if value is not None: - return value + if key in self.state: + return self.state[key] self.current_key = key + self.current_arg_filters = arg_filters + # TODO: self.current_send_kwargs # ask the user for more information self.pause(prompt, **kwargs) - def get_optional(self, key: Any, + def get_optional(self, key: str, default: Optional[Any] = None) -> Optional[Any]: - """Simply get a argument with given key.""" - return self.args.get(key, default) + """ + Simply get a argument with given key. + + Deprecated. Use `session.state.get()` instead. + """ + return self.state.get(key, default) def pause(self, message: Optional[Message_T] = None, **kwargs) -> None: """Pause the session for further interaction.""" if message: asyncio.ensure_future(self.send(message, **kwargs)) - raise _FurtherInteractionNeeded + raise _PauseException def finish(self, message: Optional[Message_T] = None, **kwargs) -> None: """Finish the session.""" @@ -564,9 +625,7 @@ async def _real_run_command(session: CommandSession, handled = future.result() except asyncio.TimeoutError: handled = True - except (_FurtherInteractionNeeded, - _FinishException, - SwitchException) as e: + except (_PauseException, _FinishException, SwitchException) as e: raise e except Exception as e: logger.error(f'An exception occurred while ' @@ -574,7 +633,7 @@ async def _real_run_command(session: CommandSession, logger.exception(e) handled = True raise _FinishException(handled) - except _FurtherInteractionNeeded: + except _PauseException: session.running = False if disable_interaction: # if the command needs further interaction, we view it as failed diff --git a/nonebot/command/argfilter/__init__.py b/nonebot/command/argfilter/__init__.py new file mode 100644 index 00000000..507f89ac --- /dev/null +++ b/nonebot/command/argfilter/__init__.py @@ -0,0 +1,8 @@ +from typing import Callable, Any, Awaitable, Union + +ArgFilter_T = Callable[[Any], Union[Any, Awaitable[Any]]] + + +class ValidateError(ValueError): + def __init__(self, message=None): + self.message = message diff --git a/nonebot/command/argfilter/converters.py b/nonebot/command/argfilter/converters.py new file mode 100644 index 00000000..2dff731f --- /dev/null +++ b/nonebot/command/argfilter/converters.py @@ -0,0 +1,40 @@ +from typing import Optional, List + + +def _simple_chinese_to_bool(text: str) -> Optional[bool]: + """ + Convert a chinese text to boolean. + + Examples: + + 是的 -> True + 好的呀 -> True + 不要 -> False + 不用了 -> False + 你好呀 -> None + """ + text = text.strip().lower().replace(' ', '') \ + .rstrip(',.!?~,。!?~了的呢吧呀啊呗啦') + if text in {'要', '用', '是', '好', '对', '嗯', '行', + 'ok', 'okay', 'yeah', 'yep', + '当真', '当然', '必须', '可以', '肯定', '没错', '确定', '确认'}: + return True + if text in {'不', '不要', '不用', '不是', '否', '不好', '不对', '不行', '别', + 'no', 'nono', 'nonono', 'nope', '不ok', '不可以', '不能', + '不可以'}: + return False + return None + + +def _split_nonempty_lines(text: str) -> List[str]: + return list(filter(lambda x: x, text.splitlines())) + + +def _split_nonempty_stripped_lines(text: str) -> List[str]: + return list(filter(lambda x: x, + map(lambda x: x.strip(), text.splitlines()))) + + +simple_chinese_to_bool = _simple_chinese_to_bool +split_nonempty_lines = _split_nonempty_lines +split_nonempty_stripped_lines = _split_nonempty_stripped_lines diff --git a/nonebot/command/argfilter/extractors.py b/nonebot/command/argfilter/extractors.py new file mode 100644 index 00000000..6c0e54c2 --- /dev/null +++ b/nonebot/command/argfilter/extractors.py @@ -0,0 +1,29 @@ +import re +from typing import List + +from nonebot.message import Message +from nonebot.typing import Message_T + + +def _extract_text(arg: Message_T) -> str: + """Extract all plain text segments from a message-like object.""" + arg_as_msg = Message(arg) + return arg_as_msg.extract_plain_text() + + +def _extract_image_urls(arg: Message_T) -> List[str]: + """Extract all image urls from a message-like object.""" + arg_as_msg = Message(arg) + return [s.data['url'] for s in arg_as_msg + if s.type == 'image' and 'url' in s.data] + + +def _extract_numbers(arg: Message_T) -> List[float]: + """Extract all numbers (integers and floats) from a message-like object.""" + s = str(arg) + return list(map(float, re.findall(r'[+-]?(\d*\.?\d+|\d+\.?\d*)', s))) + + +extract_text = _extract_text +extract_image_urls = _extract_image_urls +extract_numbers = _extract_numbers diff --git a/nonebot/command/argfilter/validators.py b/nonebot/command/argfilter/validators.py new file mode 100644 index 00000000..373760af --- /dev/null +++ b/nonebot/command/argfilter/validators.py @@ -0,0 +1,101 @@ +import re +from typing import Callable, Any + +from nonebot.command.argfilter import ValidateError + + +class BaseValidator: + def __init__(self, message=None): + self.message = message + + def raise_failure(self): + raise ValidateError(self.message) + + +class not_empty(BaseValidator): + """ + Validate any object to ensure it's not empty (is None or has no elements). + """ + + def __call__(self, value): + if value is None: + self.raise_failure() + if hasattr(value, '__len__') and value.__len__() == 0: + self.raise_failure() + return value + + +class fit_size(BaseValidator): + """ + Validate any sized object to ensure the size/length + is in a given range [min_length, max_length]. + """ + + def __init__(self, min_length: int = 0, max_length: int = None, + message=None): + super().__init__(message) + self.min_length = min_length + self.max_length = max_length + + def __call__(self, value): + length = len(value) if value is not None else 0 + if length < self.min_length or \ + (self.max_length is not None and length > self.max_length): + self.raise_failure() + return value + + +class match_regex(BaseValidator): + """ + Validate any string object to ensure it matches a given pattern. + """ + + def __init__(self, pattern: str, message=None, *, flags=0, + fullmatch: bool = False): + super().__init__(message) + self.pattern = re.compile(pattern, flags) + self.fullmatch = fullmatch + + def __call__(self, value): + if self.fullmatch: + if not re.fullmatch(self.pattern, value): + self.raise_failure() + else: + if not re.match(self.pattern, value): + self.raise_failure() + return value + + +class ensure_true(BaseValidator): + """ + Validate any object to ensure the result of applying + a boolean function to it is True. + """ + + def __init__(self, bool_func: Callable[[Any], bool], message=None): + super().__init__(message) + self.bool_func = bool_func + + def __call__(self, value): + if self.bool_func(value) is not True: + self.raise_failure() + return value + + +class between_inclusive(BaseValidator): + """ + Validate any comparable object to ensure it's between + `start` and `end` inclusively. + """ + + def __init__(self, start=None, end=None, message=None): + super().__init__(message) + self.start = start + self.end = end + + def __call__(self, value): + if self.start is not None and value < self.start: + self.raise_failure() + if self.end is not None and self.end < value: + self.raise_failure() + return value