mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-31 06:56:39 +00:00 
			
		
		
		
	♻️ use class rule and permission
This commit is contained in:
		| @@ -69,22 +69,22 @@ def get_sub_dependant( | ||||
|     allow_types: Optional[List[Type[Param]]] = None, | ||||
| ) -> Dependent: | ||||
|     sub_dependant = get_dependent( | ||||
|         func=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types | ||||
|         call=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types | ||||
|     ) | ||||
|     return sub_dependant | ||||
|  | ||||
|  | ||||
| def get_dependent( | ||||
|     *, | ||||
|     func: T_Handler, | ||||
|     call: T_Handler, | ||||
|     name: Optional[str] = None, | ||||
|     use_cache: bool = True, | ||||
|     allow_types: Optional[List[Type[Param]]] = None, | ||||
| ) -> Dependent: | ||||
|     signature = get_typed_signature(func) | ||||
|     signature = get_typed_signature(call) | ||||
|     params = signature.parameters | ||||
|     dependent = Dependent( | ||||
|         func=func, name=name, allow_types=allow_types, use_cache=use_cache | ||||
|         call=call, name=name, allow_types=allow_types, use_cache=use_cache | ||||
|     ) | ||||
|     for param_name, param in params.items(): | ||||
|         if isinstance(param.default, DependsWrapper): | ||||
| @@ -108,7 +108,7 @@ def get_dependent( | ||||
|                     break | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     f"Unknown parameter {param_name} for function {func} with type {param.annotation}" | ||||
|                     f"Unknown parameter {param_name} for function {call} with type {param.annotation}" | ||||
|                 ) | ||||
|  | ||||
|         annotation: Any = Any | ||||
| @@ -153,7 +153,7 @@ async def solve_dependencies( | ||||
|         if errs_: | ||||
|             logger.debug( | ||||
|                 f"{field_info} " | ||||
|                 f"type {type(value)} not match depends {_dependent.func} " | ||||
|                 f"type {type(value)} not match depends {_dependent.call} " | ||||
|                 f"annotation {field._type_display()}, ignored" | ||||
|             ) | ||||
|             raise SkippedException(field, value) | ||||
| @@ -163,9 +163,9 @@ async def solve_dependencies( | ||||
|     # solve sub dependencies | ||||
|     sub_dependent: Dependent | ||||
|     for sub_dependent in chain(_sub_dependents or tuple(), _dependent.dependencies): | ||||
|         sub_dependent.func = cast(Callable[..., Any], sub_dependent.func) | ||||
|         sub_dependent.call = cast(Callable[..., Any], sub_dependent.call) | ||||
|         sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key) | ||||
|         func = sub_dependent.func | ||||
|         call = sub_dependent.call | ||||
|  | ||||
|         # solve sub dependency with current cache | ||||
|         solved_result = await solve_dependencies( | ||||
| @@ -179,19 +179,19 @@ async def solve_dependencies( | ||||
|         async with cache_lock: | ||||
|             if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache: | ||||
|                 solved = dependency_cache[sub_dependent.cache_key] | ||||
|             elif is_gen_callable(func) or is_async_gen_callable(func): | ||||
|             elif is_gen_callable(call) or is_async_gen_callable(call): | ||||
|                 assert isinstance( | ||||
|                     _stack, AsyncExitStack | ||||
|                 ), "Generator dependency should be called in context" | ||||
|                 if is_gen_callable(func): | ||||
|                     cm = run_sync_ctx_manager(contextmanager(func)(**sub_values)) | ||||
|                 if is_gen_callable(call): | ||||
|                     cm = run_sync_ctx_manager(contextmanager(call)(**sub_values)) | ||||
|                 else: | ||||
|                     cm = asynccontextmanager(func)(**sub_values) | ||||
|                     cm = asynccontextmanager(call)(**sub_values) | ||||
|                 solved = await _stack.enter_async_context(cm) | ||||
|             elif is_coroutine_callable(func): | ||||
|                 solved = await func(**sub_values) | ||||
|             elif is_coroutine_callable(call): | ||||
|                 solved = await call(**sub_values) | ||||
|             else: | ||||
|                 solved = await run_sync(func)(**sub_values) | ||||
|                 solved = await run_sync(call)(**sub_values) | ||||
|  | ||||
|             # parameter dependency | ||||
|             if sub_dependent.name is not None: | ||||
|   | ||||
| @@ -36,17 +36,17 @@ class Dependent: | ||||
|     def __init__( | ||||
|         self, | ||||
|         *, | ||||
|         func: Optional[T_Handler] = None, | ||||
|         call: Optional[T_Handler] = None, | ||||
|         name: Optional[str] = None, | ||||
|         params: Optional[List[ModelField]] = None, | ||||
|         allow_types: Optional[List[Type[Param]]] = None, | ||||
|         dependencies: Optional[List["Dependent"]] = None, | ||||
|         use_cache: bool = True, | ||||
|     ) -> None: | ||||
|         self.func = func | ||||
|         self.call = call | ||||
|         self.name = name | ||||
|         self.params = params or [] | ||||
|         self.allow_types = allow_types or [] | ||||
|         self.dependencies = dependencies or [] | ||||
|         self.use_cache = use_cache | ||||
|         self.cache_key = self.func | ||||
|         self.cache_key = self.call | ||||
|   | ||||
| @@ -7,9 +7,9 @@ from pydantic.typing import ForwardRef, evaluate_forwardref | ||||
| from nonebot.typing import T_Handler | ||||
|  | ||||
|  | ||||
| def get_typed_signature(func: T_Handler) -> inspect.Signature: | ||||
|     signature = inspect.signature(func) | ||||
|     globalns = getattr(func, "__globals__", {}) | ||||
| def get_typed_signature(call: T_Handler) -> inspect.Signature: | ||||
|     signature = inspect.signature(call) | ||||
|     globalns = getattr(call, "__globals__", {}) | ||||
|     typed_params = [ | ||||
|         inspect.Parameter( | ||||
|             name=param.name, | ||||
|   | ||||
| @@ -25,7 +25,7 @@ class Handler: | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         func: Callable[..., Any], | ||||
|         call: Callable[..., Any], | ||||
|         *, | ||||
|         name: Optional[str] = None, | ||||
|         dependencies: Optional[List[DependsWrapper]] = None, | ||||
| @@ -38,17 +38,17 @@ class Handler: | ||||
|  | ||||
|         :参数: | ||||
|  | ||||
|           * ``func: Callable[..., Any]``: 事件处理函数。 | ||||
|           * ``call: Callable[..., Any]``: 事件处理函数。 | ||||
|           * ``name: Optional[str]``: 事件处理器名称。默认为函数名。 | ||||
|           * ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入。 | ||||
|           * ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型。 | ||||
|         """ | ||||
|         self.func = func | ||||
|         self.call = call | ||||
|         """ | ||||
|         :类型: ``Callable[..., Any]`` | ||||
|         :说明: 事件处理函数 | ||||
|         """ | ||||
|         self.name = get_name(func) if name is None else name | ||||
|         self.name = get_name(call) if name is None else name | ||||
|         """ | ||||
|         :类型: ``str`` | ||||
|         :说明: 事件处理函数名 | ||||
| @@ -68,7 +68,7 @@ class Handler: | ||||
|         if dependencies: | ||||
|             for depends in dependencies: | ||||
|                 self.cache_dependent(depends) | ||||
|         self.dependent = get_dependent(func=func, allow_types=self.allow_types) | ||||
|         self.dependent = get_dependent(call=call, allow_types=self.allow_types) | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         return f"<Handler {self.name}({', '.join(map(str, self.dependent.params))})>" | ||||
| @@ -94,10 +94,10 @@ class Handler: | ||||
|             **params, | ||||
|         ) | ||||
|  | ||||
|         if asyncio.iscoroutinefunction(self.func): | ||||
|             return await self.func(**values) | ||||
|         if asyncio.iscoroutinefunction(self.call): | ||||
|             return await self.call(**values) | ||||
|         else: | ||||
|             return await run_sync(self.func)(**values) | ||||
|             return await run_sync(self.call)(**values) | ||||
|  | ||||
|     def cache_dependent(self, dependency: DependsWrapper): | ||||
|         if not dependency.dependency: | ||||
|   | ||||
| @@ -442,7 +442,7 @@ class Matcher(metaclass=MatcherMeta): | ||||
|  | ||||
|         def _decorator(func: T_Handler) -> T_Handler: | ||||
|  | ||||
|             if cls.handlers and cls.handlers[-1].func is func: | ||||
|             if cls.handlers and cls.handlers[-1].call is func: | ||||
|                 func_handler = cls.handlers[-1] | ||||
|                 for depend in reversed(_dependencies): | ||||
|                     func_handler.prepend_dependency(depend) | ||||
| @@ -513,7 +513,7 @@ class Matcher(metaclass=MatcherMeta): | ||||
|  | ||||
|         def _decorator(func: T_Handler) -> T_Handler: | ||||
|  | ||||
|             if cls.handlers and cls.handlers[-1].func is func: | ||||
|             if cls.handlers and cls.handlers[-1].call is func: | ||||
|                 func_handler = cls.handlers[-1] | ||||
|                 for depend in reversed(_dependencies): | ||||
|                     func_handler.prepend_dependency(depend) | ||||
|   | ||||
| @@ -11,7 +11,17 @@ r""" | ||||
|  | ||||
| import asyncio | ||||
| from contextlib import AsyncExitStack | ||||
| from typing import Any, Dict, List, Type, Union, Callable, NoReturn, Optional | ||||
| from typing import ( | ||||
|     Any, | ||||
|     Dict, | ||||
|     List, | ||||
|     Type, | ||||
|     Tuple, | ||||
|     Union, | ||||
|     Callable, | ||||
|     NoReturn, | ||||
|     Optional, | ||||
| ) | ||||
|  | ||||
| from nonebot import params | ||||
| from nonebot.handler import Handler | ||||
| @@ -119,41 +129,59 @@ class Permission: | ||||
|             return Permission(*self.checkers, other) | ||||
|  | ||||
|  | ||||
| async def _message(event: Event) -> bool: | ||||
|     return event.get_type() == "message" | ||||
| class Message: | ||||
|     async def __call__(self, event: Event) -> bool: | ||||
|         return event.get_type() == "message" | ||||
|  | ||||
|  | ||||
| async def _notice(event: Event) -> bool: | ||||
|     return event.get_type() == "notice" | ||||
| class Notice: | ||||
|     async def __call__(self, event: Event) -> bool: | ||||
|         return event.get_type() == "notice" | ||||
|  | ||||
|  | ||||
| async def _request(event: Event) -> bool: | ||||
|     return event.get_type() == "request" | ||||
| class Request: | ||||
|     async def __call__(self, event: Event) -> bool: | ||||
|         return event.get_type() == "request" | ||||
|  | ||||
|  | ||||
| async def _metaevent(event: Event) -> bool: | ||||
|     return event.get_type() == "meta_event" | ||||
| class MetaEvent: | ||||
|     async def __call__(self, event: Event) -> bool: | ||||
|         return event.get_type() == "meta_event" | ||||
|  | ||||
|  | ||||
| MESSAGE = Permission(_message) | ||||
| MESSAGE = Permission(Message()) | ||||
| """ | ||||
| - **说明**: 匹配任意 ``message`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 message type 的 Matcher。 | ||||
| """ | ||||
| NOTICE = Permission(_notice) | ||||
| NOTICE = Permission(Notice()) | ||||
| """ | ||||
| - **说明**: 匹配任意 ``notice`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 notice type 的 Matcher。 | ||||
| """ | ||||
| REQUEST = Permission(_request) | ||||
| REQUEST = Permission(Request()) | ||||
| """ | ||||
| - **说明**: 匹配任意 ``request`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 request type 的 Matcher。 | ||||
| """ | ||||
| METAEVENT = Permission(_metaevent) | ||||
| METAEVENT = Permission(MetaEvent()) | ||||
| """ | ||||
| - **说明**: 匹配任意 ``meta_event`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 meta_event type 的 Matcher。 | ||||
| """ | ||||
|  | ||||
|  | ||||
| def USER(*user: str, perm: Optional[Permission] = None): | ||||
| class User: | ||||
|     def __init__( | ||||
|         self, users: Tuple[str, ...], perm: Optional[Permission] = None | ||||
|     ) -> None: | ||||
|         self.users = users | ||||
|         self.perm = perm | ||||
|  | ||||
|     async def __call__(self, bot: Bot, event: Event) -> bool: | ||||
|         return bool( | ||||
|             event.get_session_id() in self.users | ||||
|             and (self.perm is None or await self.perm(bot, event)) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def USER(*users: str, perm: Optional[Permission] = None): | ||||
|     """ | ||||
|     :说明: | ||||
|  | ||||
| @@ -165,21 +193,18 @@ def USER(*user: str, perm: Optional[Permission] = None): | ||||
|       * ``perm: Optional[Permission]``: 需要同时满足的权限 | ||||
|     """ | ||||
|  | ||||
|     async def _user(bot: Bot, event: Event) -> bool: | ||||
|         return bool( | ||||
|             event.get_session_id() in user and (perm is None or await perm(bot, event)) | ||||
|     return Permission(User(users, perm)) | ||||
|  | ||||
|  | ||||
| class SuperUser: | ||||
|     async def __call__(self, bot: Bot, event: Event) -> bool: | ||||
|         return ( | ||||
|             event.get_type() == "message" | ||||
|             and event.get_user_id() in bot.config.superusers | ||||
|         ) | ||||
|  | ||||
|     return Permission(_user) | ||||
|  | ||||
|  | ||||
| async def _superuser(bot: Bot, event: Event) -> bool: | ||||
|     return ( | ||||
|         event.get_type() == "message" and event.get_user_id() in bot.config.superusers | ||||
|     ) | ||||
|  | ||||
|  | ||||
| SUPERUSER = Permission(_superuser) | ||||
| SUPERUSER = Permission(SuperUser()) | ||||
| """ | ||||
| - **说明**: 匹配任意超级用户消息类型事件 | ||||
| """ | ||||
|   | ||||
							
								
								
									
										168
									
								
								nonebot/rule.py
									
									
									
									
									
								
							
							
						
						
									
										168
									
								
								nonebot/rule.py
									
									
									
									
									
								
							| @@ -203,6 +203,24 @@ class TrieRule: | ||||
|         return prefix, suffix | ||||
|  | ||||
|  | ||||
| class Startswith: | ||||
|     def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False): | ||||
|         self.msg = msg | ||||
|         self.ignorecase = ignorecase | ||||
|  | ||||
|     async def __call__(self, event: Event) -> Any: | ||||
|         if event.get_type() != "message": | ||||
|             return False | ||||
|         text = event.get_plaintext() | ||||
|         return bool( | ||||
|             re.match( | ||||
|                 f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})", | ||||
|                 text, | ||||
|                 re.IGNORECASE if self.ignorecase else 0, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule: | ||||
|     """ | ||||
|     :说明: | ||||
| @@ -216,18 +234,25 @@ def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Ru | ||||
|     if isinstance(msg, str): | ||||
|         msg = (msg,) | ||||
|  | ||||
|     pattern = re.compile( | ||||
|         f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})", | ||||
|         re.IGNORECASE if ignorecase else 0, | ||||
|     ) | ||||
|     return Rule(Startswith(msg, ignorecase)) | ||||
|  | ||||
|     async def _startswith(bot: Bot, event: Event, state: T_State) -> bool: | ||||
|  | ||||
| class Endswith: | ||||
|     def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False): | ||||
|         self.msg = msg | ||||
|         self.ignorecase = ignorecase | ||||
|  | ||||
|     async def __call__(self, event: Event) -> Any: | ||||
|         if event.get_type() != "message": | ||||
|             return False | ||||
|         text = event.get_plaintext() | ||||
|         return bool(pattern.match(text)) | ||||
|  | ||||
|     return Rule(_startswith) | ||||
|         return bool( | ||||
|             re.search( | ||||
|                 f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$", | ||||
|                 text, | ||||
|                 re.IGNORECASE if self.ignorecase else 0, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule: | ||||
| @@ -243,18 +268,18 @@ def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule | ||||
|     if isinstance(msg, str): | ||||
|         msg = (msg,) | ||||
|  | ||||
|     pattern = re.compile( | ||||
|         f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$", | ||||
|         re.IGNORECASE if ignorecase else 0, | ||||
|     ) | ||||
|     return Rule(Endswith(msg, ignorecase)) | ||||
|  | ||||
|     async def _endswith(bot: Bot, event: Event, state: T_State) -> bool: | ||||
|  | ||||
| class Keywords: | ||||
|     def __init__(self, *keywords: str): | ||||
|         self.keywords = keywords | ||||
|  | ||||
|     async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool: | ||||
|         if event.get_type() != "message": | ||||
|             return False | ||||
|         text = event.get_plaintext() | ||||
|         return bool(pattern.search(text)) | ||||
|  | ||||
|     return Rule(_endswith) | ||||
|         return bool(text and any(keyword in text for keyword in self.keywords)) | ||||
|  | ||||
|  | ||||
| def keyword(*keywords: str) -> Rule: | ||||
| @@ -268,13 +293,18 @@ def keyword(*keywords: str) -> Rule: | ||||
|       * ``*keywords: str``: 关键词 | ||||
|     """ | ||||
|  | ||||
|     async def _keyword(event: Event) -> bool: | ||||
|         if event.get_type() != "message": | ||||
|             return False | ||||
|         text = event.get_plaintext() | ||||
|         return bool(text and any(keyword in text for keyword in keywords)) | ||||
|     return Rule(Keywords(*keywords)) | ||||
|  | ||||
|     return Rule(_keyword) | ||||
|  | ||||
| class Command: | ||||
|     def __init__(self, cmds: List[Tuple[str, ...]]): | ||||
|         self.cmds = cmds | ||||
|  | ||||
|     async def __call__(self, state: T_State) -> bool: | ||||
|         return state[PREFIX_KEY][CMD_KEY] in self.cmds | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f"<Command {self.cmds}>" | ||||
|  | ||||
|  | ||||
| def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: | ||||
| @@ -304,10 +334,12 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: | ||||
|     config = get_driver().config | ||||
|     command_start = config.command_start | ||||
|     command_sep = config.command_sep | ||||
|     commands = list(cmds) | ||||
|     for index, command in enumerate(commands): | ||||
|     commands: List[Tuple[str, ...]] = [] | ||||
|     for command in cmds: | ||||
|         if isinstance(command, str): | ||||
|             commands[index] = command = (command,) | ||||
|             command = (command,) | ||||
|  | ||||
|         commands.append(command) | ||||
|  | ||||
|         if len(command) == 1: | ||||
|             for start in command_start: | ||||
| @@ -316,10 +348,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: | ||||
|             for start, sep in product(command_start, command_sep): | ||||
|                 TrieRule.add_prefix(f"{start}{sep.join(command)}", command) | ||||
|  | ||||
|     async def _command(state: T_State) -> bool: | ||||
|         return state[PREFIX_KEY][CMD_KEY] in commands | ||||
|  | ||||
|     return Rule(_command) | ||||
|     return Rule(Command(commands)) | ||||
|  | ||||
|  | ||||
| class ArgumentParser(ArgParser): | ||||
| @@ -350,6 +379,27 @@ class ArgumentParser(ArgParser): | ||||
|         return super().parse_args(args=args, namespace=namespace)  # type: ignore | ||||
|  | ||||
|  | ||||
| class ShellCommand: | ||||
|     def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]): | ||||
|         self.cmds = cmds | ||||
|         self.parser = parser | ||||
|  | ||||
|     async def __call__(self, event: Event, state: T_State) -> bool: | ||||
|         if state[PREFIX_KEY][CMD_KEY] in self.cmds: | ||||
|             message = str(event.get_message()) | ||||
|             strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip() | ||||
|             state[SHELL_ARGV] = shlex.split(strip_message) | ||||
|             if self.parser: | ||||
|                 try: | ||||
|                     args = self.parser.parse_args(state[SHELL_ARGV]) | ||||
|                     state[SHELL_ARGS] = args | ||||
|                 except ParserExit as e: | ||||
|                     state[SHELL_ARGS] = e | ||||
|             return True | ||||
|         else: | ||||
|             return False | ||||
|  | ||||
|  | ||||
| def shell_command( | ||||
|     *cmds: Union[str, Tuple[str, ...]], parser: Optional[ArgumentParser] = None | ||||
| ) -> Rule: | ||||
| @@ -392,10 +442,12 @@ def shell_command( | ||||
|     config = get_driver().config | ||||
|     command_start = config.command_start | ||||
|     command_sep = config.command_sep | ||||
|     commands = list(cmds) | ||||
|     for index, command in enumerate(commands): | ||||
|     commands: List[Tuple[str, ...]] = [] | ||||
|     for command in cmds: | ||||
|         if isinstance(command, str): | ||||
|             commands[index] = command = (command,) | ||||
|             command = (command,) | ||||
|  | ||||
|         commands.append(command) | ||||
|  | ||||
|         if len(command) == 1: | ||||
|             for start in command_start: | ||||
| @@ -404,23 +456,26 @@ def shell_command( | ||||
|             for start, sep in product(command_start, command_sep): | ||||
|                 TrieRule.add_prefix(f"{start}{sep.join(command)}", command) | ||||
|  | ||||
|     async def _shell_command(event: Event, state: T_State) -> bool: | ||||
|         if state[PREFIX_KEY][CMD_KEY] in commands: | ||||
|             message = str(event.get_message()) | ||||
|             strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip() | ||||
|             state[SHELL_ARGV] = shlex.split(strip_message) | ||||
|             if parser: | ||||
|                 try: | ||||
|                     args = parser.parse_args(state[SHELL_ARGV]) | ||||
|                     state[SHELL_ARGS] = args | ||||
|                 except ParserExit as e: | ||||
|                     state[SHELL_ARGS] = e | ||||
|     return Rule(ShellCommand(commands, parser)) | ||||
|  | ||||
|  | ||||
| class Regex: | ||||
|     def __init__(self, regex: str, flags: int = 0): | ||||
|         self.regex = regex | ||||
|         self.flags = flags | ||||
|  | ||||
|     async def __call__(self, event: Event, state: T_State) -> bool: | ||||
|         if event.get_type() != "message": | ||||
|             return False | ||||
|         matched = re.search(self.regex, str(event.get_message()), self.flags) | ||||
|         if matched: | ||||
|             state[REGEX_MATCHED] = matched.group() | ||||
|             state[REGEX_GROUP] = matched.groups() | ||||
|             state[REGEX_DICT] = matched.groupdict() | ||||
|             return True | ||||
|         else: | ||||
|             return False | ||||
|  | ||||
|     return Rule(_shell_command) | ||||
|  | ||||
|  | ||||
| def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule: | ||||
|     r""" | ||||
| @@ -441,25 +496,12 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule: | ||||
|     \:\:\: | ||||
|     """ | ||||
|  | ||||
|     pattern = re.compile(regex, flags) | ||||
|  | ||||
|     async def _regex(event: Event, state: T_State) -> bool: | ||||
|         if event.get_type() != "message": | ||||
|             return False | ||||
|         matched = pattern.search(str(event.get_message())) | ||||
|         if matched: | ||||
|             state[REGEX_MATCHED] = matched.group() | ||||
|             state[REGEX_GROUP] = matched.groups() | ||||
|             state[REGEX_DICT] = matched.groupdict() | ||||
|             return True | ||||
|         else: | ||||
|             return False | ||||
|  | ||||
|     return Rule(_regex) | ||||
|     return Rule(Regex(regex, flags)) | ||||
|  | ||||
|  | ||||
| async def _to_me(event: Event) -> bool: | ||||
|     return event.is_tome() | ||||
| class ToMe: | ||||
|     async def __call__(self, event: Event) -> bool: | ||||
|         return event.is_tome() | ||||
|  | ||||
|  | ||||
| def to_me() -> Rule: | ||||
| @@ -473,4 +515,4 @@ def to_me() -> Rule: | ||||
|       * 无 | ||||
|     """ | ||||
|  | ||||
|     return Rule(_to_me) | ||||
|     return Rule(ToMe()) | ||||
|   | ||||
| @@ -66,30 +66,30 @@ def generic_check_issubclass( | ||||
|         raise | ||||
|  | ||||
|  | ||||
| def is_coroutine_callable(func: Callable[..., Any]) -> bool: | ||||
|     if inspect.isroutine(func): | ||||
|         return inspect.iscoroutinefunction(func) | ||||
|     if inspect.isclass(func): | ||||
| def is_coroutine_callable(call: Callable[..., Any]) -> bool: | ||||
|     if inspect.isroutine(call): | ||||
|         return inspect.iscoroutinefunction(call) | ||||
|     if inspect.isclass(call): | ||||
|         return False | ||||
|     func_ = getattr(func, "__call__", None) | ||||
|     func_ = getattr(call, "__call__", None) | ||||
|     return inspect.iscoroutinefunction(func_) | ||||
|  | ||||
|  | ||||
| def is_gen_callable(func: Callable[..., Any]) -> bool: | ||||
|     if inspect.isgeneratorfunction(func): | ||||
| def is_gen_callable(call: Callable[..., Any]) -> bool: | ||||
|     if inspect.isgeneratorfunction(call): | ||||
|         return True | ||||
|     func_ = getattr(func, "__call__", None) | ||||
|     func_ = getattr(call, "__call__", None) | ||||
|     return inspect.isgeneratorfunction(func_) | ||||
|  | ||||
|  | ||||
| def is_async_gen_callable(func: Callable[..., Any]) -> bool: | ||||
|     if inspect.isasyncgenfunction(func): | ||||
| def is_async_gen_callable(call: Callable[..., Any]) -> bool: | ||||
|     if inspect.isasyncgenfunction(call): | ||||
|         return True | ||||
|     func_ = getattr(func, "__call__", None) | ||||
|     func_ = getattr(call, "__call__", None) | ||||
|     return inspect.isasyncgenfunction(func_) | ||||
|  | ||||
|  | ||||
| def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: | ||||
| def run_sync(call: Callable[P, R]) -> Callable[P, Awaitable[R]]: | ||||
|     """ | ||||
|     :说明: | ||||
|  | ||||
| @@ -97,17 +97,17 @@ def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: | ||||
|  | ||||
|     :参数: | ||||
|  | ||||
|       * ``func: Callable[P, R]``: 被装饰的同步函数 | ||||
|       * ``call: Callable[P, R]``: 被装饰的同步函数 | ||||
|  | ||||
|     :返回: | ||||
|  | ||||
|       - ``Callable[P, Awaitable[R]]`` | ||||
|     """ | ||||
|  | ||||
|     @wraps(func) | ||||
|     @wraps(call) | ||||
|     async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: | ||||
|         loop = asyncio.get_running_loop() | ||||
|         pfunc = partial(func, *args, **kwargs) | ||||
|         pfunc = partial(call, *args, **kwargs) | ||||
|         result = await loop.run_in_executor(None, pfunc) | ||||
|         return result | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user