mirror of
				https://github.com/nonebot/nonebot2.git
				synced 2025-10-26 20:46:39 +00:00 
			
		
		
		
	make matcher running concurrently and add to me checking
This commit is contained in:
		| @@ -51,7 +51,16 @@ class BaseEvent(abc.ABC): | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         # TODO: pretty print | ||||
|         return f"<Event: >" | ||||
|         return f"<Event: {self.type}/{self.detail_type} {self.raw_message}>" | ||||
|  | ||||
|     @property | ||||
|     def raw_event(self) -> dict: | ||||
|         return self._raw_event | ||||
|  | ||||
|     @property | ||||
|     @abc.abstractmethod | ||||
|     def self_id(self) -> str: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @property | ||||
|     @abc.abstractmethod | ||||
| @@ -93,6 +102,16 @@ class BaseEvent(abc.ABC): | ||||
|     def user_id(self, value) -> None: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @property | ||||
|     @abc.abstractmethod | ||||
|     def to_me(self) -> Optional[bool]: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @to_me.setter | ||||
|     @abc.abstractmethod | ||||
|     def to_me(self, value) -> None: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @property | ||||
|     @abc.abstractmethod | ||||
|     def message(self) -> Optional[Message]: | ||||
|   | ||||
| @@ -7,11 +7,12 @@ import asyncio | ||||
|  | ||||
| import httpx | ||||
|  | ||||
| from nonebot.log import logger | ||||
| from nonebot.config import Config | ||||
| from nonebot.message import handle_event | ||||
| from nonebot.typing import overrides, Driver, WebSocket, NoReturn | ||||
| from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional | ||||
| from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable | ||||
| from nonebot.typing import overrides, Driver, WebSocket, NoReturn | ||||
| from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment | ||||
|  | ||||
|  | ||||
| @@ -41,6 +42,67 @@ def _b2s(b: bool) -> str: | ||||
|     return str(b).lower() | ||||
|  | ||||
|  | ||||
| def _check_at_me(bot: "Bot", event: "Event"): | ||||
|     if event.type != "message": | ||||
|         return | ||||
|  | ||||
|     if event.detail_type == "private": | ||||
|         event.to_me = True | ||||
|     else: | ||||
|         event.to_me = False | ||||
|         at_me_seg = MessageSegment.at(event.self_id) | ||||
|  | ||||
|         # check the first segment | ||||
|         first_msg_seg = event.message[0] | ||||
|         if first_msg_seg == at_me_seg: | ||||
|             event.to_me = True | ||||
|             del event.message[0] | ||||
|  | ||||
|         if not event.to_me: | ||||
|             # check the last segment | ||||
|             i = -1 | ||||
|             last_msg_seg = event.message[i] | ||||
|             if last_msg_seg.type == "text" and \ | ||||
|                     not last_msg_seg.data["text"].strip() and \ | ||||
|                     len(event.message) >= 2: | ||||
|                 i -= 1 | ||||
|                 last_msg_seg = event.message[i] | ||||
|  | ||||
|             if last_msg_seg == at_me_seg: | ||||
|                 event.to_me = True | ||||
|                 del event.message[i:] | ||||
|  | ||||
|         if not event.message: | ||||
|             event.message.append(MessageSegment.text("")) | ||||
|  | ||||
|  | ||||
| def _check_nickname(bot: "Bot", event: "Event"): | ||||
|     if event.type != "message": | ||||
|         return | ||||
|  | ||||
|     first_msg_seg = event.message[0] | ||||
|     if first_msg_seg.type != "text": | ||||
|         return | ||||
|  | ||||
|     first_text = first_msg_seg.data["text"] | ||||
|  | ||||
|     if bot.config.NICKNAME: | ||||
|         # check if the user is calling me with my nickname | ||||
|         if isinstance(bot.config.NICKNAME, str) or \ | ||||
|                 not isinstance(bot.config.NICKNAME, Iterable): | ||||
|             nicknames = (bot.config.NICKNAME,) | ||||
|         else: | ||||
|             nicknames = filter(lambda n: n, bot.config.NICKNAME) | ||||
|         nickname_regex = "|".join(nicknames) | ||||
|         m = re.search(rf"^({nickname_regex})([\s,,]*|$)", first_text, | ||||
|                       re.IGNORECASE) | ||||
|         if m: | ||||
|             nickname = m.group(1) | ||||
|             logger.debug(f"User is calling me {nickname}") | ||||
|             event.to_me = True | ||||
|             first_msg_seg.data["text"] = first_text[m.end():] | ||||
|  | ||||
|  | ||||
| def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any: | ||||
|     if isinstance(result, dict): | ||||
|         if result.get("status") == "failed": | ||||
| @@ -108,6 +170,10 @@ class Bot(BaseBot): | ||||
|  | ||||
|         event = Event(message) | ||||
|  | ||||
|         # Check whether user is calling me | ||||
|         _check_at_me(self, event) | ||||
|         _check_nickname(self, event) | ||||
|  | ||||
|         await handle_event(self, event) | ||||
|  | ||||
|     @overrides(BaseBot) | ||||
| @@ -166,6 +232,11 @@ class Event(BaseEvent): | ||||
|  | ||||
|         super().__init__(raw_event) | ||||
|  | ||||
|     @property | ||||
|     @overrides(BaseEvent) | ||||
|     def self_id(self) -> str: | ||||
|         return str(self._raw_event["self_id"]) | ||||
|  | ||||
|     @property | ||||
|     @overrides(BaseEvent) | ||||
|     def type(self) -> str: | ||||
| @@ -206,6 +277,16 @@ class Event(BaseEvent): | ||||
|     def user_id(self, value) -> None: | ||||
|         self._raw_event["user_id"] = value | ||||
|  | ||||
|     @property | ||||
|     @overrides(BaseEvent) | ||||
|     def to_me(self) -> Optional[bool]: | ||||
|         return self._raw_event.get("to_me") | ||||
|  | ||||
|     @to_me.setter | ||||
|     @overrides(BaseEvent) | ||||
|     def to_me(self, value) -> None: | ||||
|         self._raw_event["to_me"] = value | ||||
|  | ||||
|     @property | ||||
|     @overrides(BaseEvent) | ||||
|     def message(self) -> Optional["Message"]: | ||||
| @@ -244,6 +325,18 @@ class Event(BaseEvent): | ||||
|  | ||||
| class MessageSegment(BaseMessageSegment): | ||||
|  | ||||
|     @overrides(BaseMessageSegment) | ||||
|     def __init__(self, type: str, data: Dict[str, str]) -> None: | ||||
|         if type == "at" and data.get("qq") == "all": | ||||
|             type = "at_all" | ||||
|             data.clear() | ||||
|         elif type == "shake": | ||||
|             type = "poke" | ||||
|             data = {"type": "Poke"} | ||||
|         elif type == "text": | ||||
|             data["text"] = unescape(data["text"]) | ||||
|         super().__init__(type=type, data=data) | ||||
|  | ||||
|     @overrides(BaseMessageSegment) | ||||
|     def __str__(self): | ||||
|         type_ = self.type | ||||
| @@ -271,7 +364,7 @@ class MessageSegment(BaseMessageSegment): | ||||
|         return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)}) | ||||
|  | ||||
|     @staticmethod | ||||
|     def at(user_id: int) -> "MessageSegment": | ||||
|     def at(user_id: Union[int, str]) -> "MessageSegment": | ||||
|         return MessageSegment("at", {"qq": str(user_id)}) | ||||
|  | ||||
|     @staticmethod | ||||
|   | ||||
| @@ -8,7 +8,13 @@ | ||||
| 这些异常并非所有需要用户处理,在 NoneBot 内部运行时被捕获,并进行对应操作。 | ||||
| """ | ||||
|  | ||||
| from nonebot.typing import Optional | ||||
| from nonebot.typing import List, Type, Optional | ||||
|  | ||||
|  | ||||
| class _ExceptionContainer(Exception): | ||||
|  | ||||
|     def __init__(self, exceptions: List[Type[Exception]]) -> None: | ||||
|         self.exceptions = exceptions | ||||
|  | ||||
|  | ||||
| class IgnoredException(Exception): | ||||
| @@ -37,12 +43,12 @@ class PausedException(Exception): | ||||
|     """ | ||||
|     :说明: | ||||
|  | ||||
|       指示 NoneBot 结束当前 Handler 并等待下一条消息后继续下一个 Handler。 | ||||
|       指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后继续下一个 ``Handler``。 | ||||
|       可用于用户输入新信息。 | ||||
|  | ||||
|     :用法: | ||||
|  | ||||
|       可以在 Handler 中通过 Matcher.pause() 抛出。 | ||||
|       可以在 ``Handler`` 中通过 ``Matcher.pause()`` 抛出。 | ||||
|     """ | ||||
|     pass | ||||
|  | ||||
| @@ -51,12 +57,12 @@ class RejectedException(Exception): | ||||
|     """ | ||||
|     :说明: | ||||
|  | ||||
|       指示 NoneBot 结束当前 Handler 并等待下一条消息后重新运行当前 Handler。 | ||||
|       指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后重新运行当前 ``Handler``。 | ||||
|       可用于用户重新输入。 | ||||
|  | ||||
|     :用法: | ||||
|  | ||||
|       可以在 Handler 中通过 Matcher.reject() 抛出。 | ||||
|       可以在 ``Handler`` 中通过 ``Matcher.reject()`` 抛出。 | ||||
|     """ | ||||
|     pass | ||||
|  | ||||
| @@ -65,12 +71,38 @@ class FinishedException(Exception): | ||||
|     """ | ||||
|     :说明: | ||||
|  | ||||
|       指示 NoneBot 结束当前 Handler 且后续 Handler 不再被运行。 | ||||
|       指示 NoneBot 结束当前 ``Handler`` 且后续 ``Handler`` 不再被运行。 | ||||
|       可用于结束用户会话。 | ||||
|  | ||||
|     :用法: | ||||
|  | ||||
|       可以在 Handler 中通过 Matcher.finish() 抛出。 | ||||
|       可以在 ``Handler`` 中通过 ``Matcher.finish()`` 抛出。 | ||||
|     """ | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class ExpiredException(Exception): | ||||
|     """ | ||||
|     :说明: | ||||
|  | ||||
|       指示 NoneBot 当前 ``Matcher`` 已失效。 | ||||
|  | ||||
|     :用法: | ||||
|  | ||||
|       当 ``Matcher`` 运行前检查时抛出。 | ||||
|     """ | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class StopPropagation(Exception): | ||||
|     """ | ||||
|     :说明: | ||||
|  | ||||
|       指示 NoneBot 终止事件向下层传播。 | ||||
|  | ||||
|     :用法: | ||||
|  | ||||
|       在 ``Matcher.block == True`` 时抛出。 | ||||
|     """ | ||||
|     pass | ||||
|  | ||||
|   | ||||
| @@ -26,6 +26,7 @@ class Matcher: | ||||
|     temp: bool = False | ||||
|     expire_time: Optional[datetime] = None | ||||
|     priority: int = 1 | ||||
|     block: bool = False | ||||
|  | ||||
|     _default_state: dict = {} | ||||
|  | ||||
| @@ -45,6 +46,7 @@ class Matcher: | ||||
|             handlers: list = [], | ||||
|             temp: bool = False, | ||||
|             priority: int = 1, | ||||
|             block: bool = False, | ||||
|             *, | ||||
|             default_state: dict = {}, | ||||
|             expire_time: Optional[datetime] = None) -> Type["Matcher"]: | ||||
| @@ -63,6 +65,7 @@ class Matcher: | ||||
|                 "temp": temp, | ||||
|                 "expire_time": expire_time, | ||||
|                 "priority": priority, | ||||
|                 "block": block, | ||||
|                 "_default_state": default_state | ||||
|             }) | ||||
|  | ||||
|   | ||||
| @@ -7,8 +7,10 @@ from datetime import datetime | ||||
| from nonebot.log import logger | ||||
| from nonebot.rule import TrieRule | ||||
| from nonebot.matcher import matchers | ||||
| from nonebot.exception import IgnoredException | ||||
| from nonebot.typing import Bot, Set, Event, PreProcessor | ||||
| from nonebot.typing import Set, Type, Union, NoReturn | ||||
| from nonebot.typing import Bot, Event, Matcher, PreProcessor | ||||
| from nonebot.exception import IgnoredException, ExpiredException | ||||
| from nonebot.exception import StopPropagation, _ExceptionContainer | ||||
|  | ||||
| _event_preprocessors: Set[PreProcessor] = set() | ||||
|  | ||||
| @@ -18,6 +20,38 @@ def event_preprocessor(func: PreProcessor) -> PreProcessor: | ||||
|     return func | ||||
|  | ||||
|  | ||||
| async def _run_matcher(Matcher: Type[Matcher], bot: Bot, event: Event, | ||||
|                        state: dict) -> Union[None, NoReturn]: | ||||
|     if datetime.now() > Matcher.expire_time: | ||||
|         raise _ExceptionContainer([ExpiredException]) | ||||
|  | ||||
|     try: | ||||
|         if not await Matcher.check_perm( | ||||
|                 bot, event) or not await Matcher.check_rule(bot, event, state): | ||||
|             return | ||||
|     except Exception as e: | ||||
|         logger.error(f"Rule check failed for matcher {Matcher}. Ignored.") | ||||
|         logger.exception(e) | ||||
|         return | ||||
|  | ||||
|     matcher = Matcher() | ||||
|     # TODO: BeforeMatcherRun | ||||
|     try: | ||||
|         logger.debug(f"Running matcher {matcher}") | ||||
|         await matcher.run(bot, event, state) | ||||
|     except Exception as e: | ||||
|         logger.error(f"Running matcher {matcher} failed.") | ||||
|         logger.exception(e) | ||||
|  | ||||
|     exceptions = [] | ||||
|     if Matcher.temp: | ||||
|         exceptions.append(ExpiredException) | ||||
|     if Matcher.block: | ||||
|         exceptions.append(StopPropagation) | ||||
|     if exceptions: | ||||
|         raise _ExceptionContainer(exceptions) | ||||
|  | ||||
|  | ||||
| async def handle_event(bot: Bot, event: Event): | ||||
|     coros = [] | ||||
|     state = {} | ||||
| @@ -33,37 +67,24 @@ async def handle_event(bot: Bot, event: Event): | ||||
|     # Trie Match | ||||
|     _, _ = TrieRule.get_value(bot, event, state) | ||||
|  | ||||
|     break_flag = False | ||||
|     for priority in sorted(matchers.keys()): | ||||
|         index = 0 | ||||
|         while index <= len(matchers[priority]): | ||||
|             Matcher = matchers[priority][index] | ||||
|         if break_flag: | ||||
|             break | ||||
|  | ||||
|             # Delete expired Matcher | ||||
|             if datetime.now() > Matcher.expire_time: | ||||
|                 del matchers[priority][index] | ||||
|                 continue | ||||
|         pending_tasks = [ | ||||
|             _run_matcher(matcher, bot, event, state.copy()) | ||||
|             for matcher in matchers[priority] | ||||
|         ] | ||||
|  | ||||
|             # Check rule | ||||
|             try: | ||||
|                 if not await Matcher.check_perm( | ||||
|                         bot, event) or not await Matcher.check_rule( | ||||
|                             bot, event, state): | ||||
|                     index += 1 | ||||
|                     continue | ||||
|             except Exception as e: | ||||
|                 logger.error( | ||||
|                     f"Rule check failed for matcher {Matcher}. Ignored.") | ||||
|                 logger.exception(e) | ||||
|                 continue | ||||
|         results = await asyncio.gather(*pending_tasks, return_exceptions=True) | ||||
|  | ||||
|             matcher = Matcher() | ||||
|             # TODO: BeforeMatcherRun | ||||
|             if Matcher.temp: | ||||
|                 del matchers[priority][index] | ||||
|  | ||||
|             try: | ||||
|                 await matcher.run(bot, event, state) | ||||
|             except Exception as e: | ||||
|                 logger.error(f"Running matcher {matcher} failed.") | ||||
|                 logger.exception(e) | ||||
|             return | ||||
|         i = 0 | ||||
|         for index, result in enumerate(results): | ||||
|             if isinstance(result, _ExceptionContainer): | ||||
|                 e_list = result.exceptions | ||||
|                 if StopPropagation in e_list: | ||||
|                     break_flag = True | ||||
|                 if ExpiredException in e_list: | ||||
|                     del matchers[priority][index - i] | ||||
|                     i += 1 | ||||
|   | ||||
| @@ -33,12 +33,14 @@ def on(rule: Union[Rule, RuleChecker] = Rule(), | ||||
|        handlers=[], | ||||
|        temp=False, | ||||
|        priority: int = 1, | ||||
|        block: bool = False, | ||||
|        state={}) -> Type[Matcher]: | ||||
|     matcher = Matcher.new("", | ||||
|                           Rule() & rule, | ||||
|                           permission, | ||||
|                           temp=temp, | ||||
|                           priority=priority, | ||||
|                           block=block, | ||||
|                           handlers=handlers, | ||||
|                           default_state=state) | ||||
|     _tmp_matchers.add(matcher) | ||||
| @@ -50,12 +52,14 @@ def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(), | ||||
|                  handlers=[], | ||||
|                  temp=False, | ||||
|                  priority: int = 1, | ||||
|                  block: bool = False, | ||||
|                  state={}) -> Type[Matcher]: | ||||
|     matcher = Matcher.new("meta_event", | ||||
|                           Rule() & rule, | ||||
|                           Permission(), | ||||
|                           temp=temp, | ||||
|                           priority=priority, | ||||
|                           block=block, | ||||
|                           handlers=handlers, | ||||
|                           default_state=state) | ||||
|     _tmp_matchers.add(matcher) | ||||
| @@ -68,12 +72,14 @@ def on_message(rule: Union[Rule, RuleChecker] = Rule(), | ||||
|                handlers=[], | ||||
|                temp=False, | ||||
|                priority: int = 1, | ||||
|                block: bool = True, | ||||
|                state={}) -> Type[Matcher]: | ||||
|     matcher = Matcher.new("message", | ||||
|                           Rule() & rule, | ||||
|                           permission, | ||||
|                           temp=temp, | ||||
|                           priority=priority, | ||||
|                           block=block, | ||||
|                           handlers=handlers, | ||||
|                           default_state=state) | ||||
|     _tmp_matchers.add(matcher) | ||||
| @@ -85,12 +91,14 @@ def on_notice(rule: Union[Rule, RuleChecker] = Rule(), | ||||
|               handlers=[], | ||||
|               temp=False, | ||||
|               priority: int = 1, | ||||
|               block: bool = False, | ||||
|               state={}) -> Type[Matcher]: | ||||
|     matcher = Matcher.new("notice", | ||||
|                           Rule() & rule, | ||||
|                           Permission(), | ||||
|                           temp=temp, | ||||
|                           priority=priority, | ||||
|                           block=block, | ||||
|                           handlers=handlers, | ||||
|                           default_state=state) | ||||
|     _tmp_matchers.add(matcher) | ||||
| @@ -102,12 +110,14 @@ def on_request(rule: Union[Rule, RuleChecker] = Rule(), | ||||
|                handlers=[], | ||||
|                temp=False, | ||||
|                priority: int = 1, | ||||
|                block: bool = False, | ||||
|                state={}) -> Type[Matcher]: | ||||
|     matcher = Matcher.new("request", | ||||
|                           Rule() & rule, | ||||
|                           Permission(), | ||||
|                           temp=temp, | ||||
|                           priority=priority, | ||||
|                           block=block, | ||||
|                           handlers=handlers, | ||||
|                           default_state=state) | ||||
|     _tmp_matchers.add(matcher) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user