Implement argument filters

This commit is contained in:
Richard Chien
2019-01-25 00:14:30 +08:00
parent 6b6daf7235
commit f8ecc7bba1
5 changed files with 284 additions and 47 deletions

View File

@ -0,0 +1,8 @@
from typing import Callable, Any, Awaitable, Union
ArgFilter_T = Callable[[Any], Union[Any, Awaitable[Any]]]
class ValidateError(ValueError):
def __init__(self, message=None):
self.message = message

View File

@ -0,0 +1,40 @@
from typing import Optional, List
def _simple_chinese_to_bool(text: str) -> Optional[bool]:
"""
Convert a chinese text to boolean.
Examples:
是的 -> True
好的呀 -> True
不要 -> False
不用了 -> False
你好呀 -> None
"""
text = text.strip().lower().replace(' ', '') \
.rstrip(',.!?~,。!?~了的呢吧呀啊呗啦')
if text in {'', '', '', '', '', '', '',
'ok', 'okay', 'yeah', 'yep',
'当真', '当然', '必须', '可以', '肯定', '没错', '确定', '确认'}:
return True
if text in {'', '不要', '不用', '不是', '', '不好', '不对', '不行', '',
'no', 'nono', 'nonono', 'nope', '不ok', '不可以', '不能',
'不可以'}:
return False
return None
def _split_nonempty_lines(text: str) -> List[str]:
return list(filter(lambda x: x, text.splitlines()))
def _split_nonempty_stripped_lines(text: str) -> List[str]:
return list(filter(lambda x: x,
map(lambda x: x.strip(), text.splitlines())))
simple_chinese_to_bool = _simple_chinese_to_bool
split_nonempty_lines = _split_nonempty_lines
split_nonempty_stripped_lines = _split_nonempty_stripped_lines

View File

@ -0,0 +1,29 @@
import re
from typing import List
from nonebot.message import Message
from nonebot.typing import Message_T
def _extract_text(arg: Message_T) -> str:
"""Extract all plain text segments from a message-like object."""
arg_as_msg = Message(arg)
return arg_as_msg.extract_plain_text()
def _extract_image_urls(arg: Message_T) -> List[str]:
"""Extract all image urls from a message-like object."""
arg_as_msg = Message(arg)
return [s.data['url'] for s in arg_as_msg
if s.type == 'image' and 'url' in s.data]
def _extract_numbers(arg: Message_T) -> List[float]:
"""Extract all numbers (integers and floats) from a message-like object."""
s = str(arg)
return list(map(float, re.findall(r'[+-]?(\d*\.?\d+|\d+\.?\d*)', s)))
extract_text = _extract_text
extract_image_urls = _extract_image_urls
extract_numbers = _extract_numbers

View File

@ -0,0 +1,101 @@
import re
from typing import Callable, Any
from nonebot.command.argfilter import ValidateError
class BaseValidator:
def __init__(self, message=None):
self.message = message
def raise_failure(self):
raise ValidateError(self.message)
class not_empty(BaseValidator):
"""
Validate any object to ensure it's not empty (is None or has no elements).
"""
def __call__(self, value):
if value is None:
self.raise_failure()
if hasattr(value, '__len__') and value.__len__() == 0:
self.raise_failure()
return value
class fit_size(BaseValidator):
"""
Validate any sized object to ensure the size/length
is in a given range [min_length, max_length].
"""
def __init__(self, min_length: int = 0, max_length: int = None,
message=None):
super().__init__(message)
self.min_length = min_length
self.max_length = max_length
def __call__(self, value):
length = len(value) if value is not None else 0
if length < self.min_length or \
(self.max_length is not None and length > self.max_length):
self.raise_failure()
return value
class match_regex(BaseValidator):
"""
Validate any string object to ensure it matches a given pattern.
"""
def __init__(self, pattern: str, message=None, *, flags=0,
fullmatch: bool = False):
super().__init__(message)
self.pattern = re.compile(pattern, flags)
self.fullmatch = fullmatch
def __call__(self, value):
if self.fullmatch:
if not re.fullmatch(self.pattern, value):
self.raise_failure()
else:
if not re.match(self.pattern, value):
self.raise_failure()
return value
class ensure_true(BaseValidator):
"""
Validate any object to ensure the result of applying
a boolean function to it is True.
"""
def __init__(self, bool_func: Callable[[Any], bool], message=None):
super().__init__(message)
self.bool_func = bool_func
def __call__(self, value):
if self.bool_func(value) is not True:
self.raise_failure()
return value
class between_inclusive(BaseValidator):
"""
Validate any comparable object to ensure it's between
`start` and `end` inclusively.
"""
def __init__(self, start=None, end=None, message=None):
super().__init__(message)
self.start = start
self.end = end
def __call__(self, value):
if self.start is not None and value < self.start:
self.raise_failure()
if self.end is not None and self.end < value:
self.raise_failure()
return value