⚗️ support tuple prefix in startswith

This commit is contained in:
yanyongyu 2021-04-04 12:28:10 +08:00
parent 0d467d9275
commit d1a438a287
3 changed files with 29 additions and 18 deletions

View File

@ -288,7 +288,7 @@ def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = None,
return matcher return matcher
def on_startswith(msg: str, def on_startswith(msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = None, rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = None,
ignorecase: bool = False, ignorecase: bool = False,
**kwargs) -> Type[Matcher]: **kwargs) -> Type[Matcher]:
@ -299,7 +299,7 @@ def on_startswith(msg: str,
:参数: :参数:
* ``msg: str``: 指定消息开头内容 * ``msg: Union[str, Tuple[str, ...]]``: 指定消息开头内容
* ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则
* ``ignorecase: bool``: 是否忽略大小写 * ``ignorecase: bool``: 是否忽略大小写
* ``permission: Optional[Permission]``: 事件响应权限 * ``permission: Optional[Permission]``: 事件响应权限
@ -317,7 +317,7 @@ def on_startswith(msg: str,
return on_message(startswith(msg, ignorecase) & rule, **kwargs) return on_message(startswith(msg, ignorecase) & rule, **kwargs)
def on_endswith(msg: str, def on_endswith(msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = None, rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = None,
ignorecase: bool = False, ignorecase: bool = False,
**kwargs) -> Type[Matcher]: **kwargs) -> Type[Matcher]:
@ -328,7 +328,7 @@ def on_endswith(msg: str,
:参数: :参数:
* ``msg: str``: 指定消息结尾内容 * ``msg: Union[str, Tuple[str, ...]]``: 指定消息结尾内容
* ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则
* ``ignorecase: bool``: 是否忽略大小写 * ``ignorecase: bool``: 是否忽略大小写
* ``permission: Optional[Permission]``: 事件响应权限 * ``permission: Optional[Permission]``: 事件响应权限
@ -728,7 +728,8 @@ class MatcherGroup:
self.matchers.append(matcher) self.matchers.append(matcher)
return matcher return matcher
def on_startswith(self, msg: str, **kwargs) -> Type[Matcher]: def on_startswith(self, msg: Union[str, Tuple[str, ...]],
**kwargs) -> Type[Matcher]:
""" """
:说明: :说明:
@ -736,7 +737,7 @@ class MatcherGroup:
:参数: :参数:
* ``msg: str``: 指定消息开头内容 * ``msg: Union[str, Tuple[str, ...]]``: 指定消息开头内容
* ``ignorecase: bool``: 是否忽略大小写 * ``ignorecase: bool``: 是否忽略大小写
* ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则
* ``permission: Optional[Permission]``: 事件响应权限 * ``permission: Optional[Permission]``: 事件响应权限
@ -758,7 +759,8 @@ class MatcherGroup:
self.matchers.append(matcher) self.matchers.append(matcher)
return matcher return matcher
def on_endswith(self, msg: str, **kwargs) -> Type[Matcher]: def on_endswith(self, msg: Union[str, Tuple[str, ...]],
**kwargs) -> Type[Matcher]:
""" """
:说明: :说明:
@ -766,7 +768,7 @@ class MatcherGroup:
:参数: :参数:
* ``msg: str``: 指定消息结尾内容 * ``msg: Union[str, Tuple[str, ...]]``: 指定消息结尾内容
* ``ignorecase: bool``: 是否忽略大小写 * ``ignorecase: bool``: 是否忽略大小写
* ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则
* ``permission: Optional[Permission]``: 事件响应权限 * ``permission: Optional[Permission]``: 事件响应权限

View File

@ -90,7 +90,7 @@ def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = ...,
def on_startswith( def on_startswith(
msg: str, msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ..., rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ...,
ignorecase: bool = ..., ignorecase: bool = ...,
*, *,
@ -104,7 +104,7 @@ def on_startswith(
... ...
def on_endswith(msg: str, def on_endswith(msg: Union[str, Tuple[str, ...]],
rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ..., rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ...,
ignorecase: bool = ..., ignorecase: bool = ...,
*, *,
@ -300,7 +300,7 @@ class MatcherGroup:
def on_startswith( def on_startswith(
self, self,
msg: str, msg: Union[str, Tuple[str, ...]],
*, *,
ignorecase: bool = ..., ignorecase: bool = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,
@ -315,7 +315,7 @@ class MatcherGroup:
def on_endswith( def on_endswith(
self, self,
msg: str, msg: Union[str, Tuple[str, ...]],
*, *,
ignorecase: bool = ..., ignorecase: bool = ...,
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,

View File

@ -175,7 +175,8 @@ class TrieRule:
}) })
def startswith(msg: str, ignorecase: bool = False) -> Rule: def startswith(msg: Union[str, Tuple[str, ...]],
ignorecase: bool = False) -> Rule:
""" """
:说明: :说明:
@ -185,9 +186,12 @@ def startswith(msg: str, ignorecase: bool = False) -> Rule:
* ``msg: str``: 消息开头字符串 * ``msg: str``: 消息开头字符串
""" """
if isinstance(msg, str):
msg = (msg,)
pattern = re.compile(f"^{re.escape(msg)}", pattern = re.compile(
re.IGNORECASE if ignorecase else 0) f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})",
re.IGNORECASE if ignorecase else 0)
async def _startswith(bot: "Bot", event: "Event", state: T_State) -> bool: async def _startswith(bot: "Bot", event: "Event", state: T_State) -> bool:
if event.get_type() != "message": if event.get_type() != "message":
@ -198,7 +202,8 @@ def startswith(msg: str, ignorecase: bool = False) -> Rule:
return Rule(_startswith) return Rule(_startswith)
def endswith(msg: str, ignorecase: bool = False) -> Rule: def endswith(msg: Union[str, Tuple[str, ...]],
ignorecase: bool = False) -> Rule:
""" """
:说明: :说明:
@ -208,8 +213,12 @@ def endswith(msg: str, ignorecase: bool = False) -> Rule:
* ``msg: str``: 消息结尾字符串 * ``msg: str``: 消息结尾字符串
""" """
pattern = re.compile(f"{re.escape(msg)}$", if isinstance(msg, str):
re.IGNORECASE if ignorecase else 0) msg = (msg,)
pattern = re.compile(
f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$",
re.IGNORECASE if ignorecase else 0)
async def _endswith(bot: "Bot", event: "Event", state: T_State) -> bool: async def _endswith(bot: "Bot", event: "Event", state: T_State) -> bool:
if event.get_type() != "message": if event.get_type() != "message":