From abcdbc4de9acd1679df5a5b6dc5726fcfa591f9d Mon Sep 17 00:00:00 2001 From: Mix Date: Sun, 7 Feb 2021 02:21:31 +0800 Subject: [PATCH 1/7] :boom: :bug: add support for non-plaintext start message --- nonebot/plugin.py | 21 ++++++++++++--------- nonebot/rule.py | 7 ++++--- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/nonebot/plugin.py b/nonebot/plugin.py index 3270fd12..7bb8716f 100644 --- a/nonebot/plugin.py +++ b/nonebot/plugin.py @@ -13,7 +13,7 @@ from types import ModuleType from dataclasses import dataclass from importlib._bootstrap import _load from contextvars import Context, ContextVar, copy_context -from typing import Any, Set, List, Dict, Type, Tuple, Union, Optional, TYPE_CHECKING +from typing import Any, Set, List, Dict, Type, Tuple, Union, Optional, TYPE_CHECKING, Iterable from nonebot.log import logger from nonebot.matcher import Matcher @@ -22,7 +22,7 @@ from nonebot.typing import T_State, T_StateFactory, T_Handler, T_RuleChecker from nonebot.rule import Rule, startswith, endswith, keyword, command, shell_command, ArgumentParser, regex if TYPE_CHECKING: - from nonebot.adapters import Bot, Event + from nonebot.adapters import Bot, Event, MessageSegment plugins: Dict[str, "Plugin"] = {} """ @@ -421,13 +421,16 @@ def on_command(cmd: Union[str, Tuple[str, ...]], """ async def _strip_cmd(bot: "Bot", event: "Event", state: T_State): - message = event.get_message() - segment = message.pop(0) - new_message = message.__class__( - str(segment) - [len(state["_prefix"]["raw_command"]):].lstrip()) # type: ignore - for new_segment in reversed(new_message): - message.insert(0, new_segment) + message: Iterable[MessageSegment] = event.get_message() + text_processed = False + for index, segment in enumerate(message): + segment: MessageSegment = message.pop(index) + if segment.is_text() and not text_processed: + segment, *_ = message.__class__( + str(segment)[len(state["_prefix"]["raw_command"]):].lstrip( + )) # type: ignore + text_processed = True + message.insert(index, segment) handlers = kwargs.pop("handlers", []) handlers.insert(0, _strip_cmd) diff --git a/nonebot/rule.py b/nonebot/rule.py index d9f75a24..002622a8 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -25,7 +25,7 @@ from nonebot.exception import ParserExit from nonebot.typing import T_State, T_RuleChecker if TYPE_CHECKING: - from nonebot.adapters import Bot, Event + from nonebot.adapters import Bot, Event, MessageSegment class Rule: @@ -137,8 +137,9 @@ class TrieRule: prefix = None suffix = None message = event.get_message() - message_seg = message[0] - if message_seg.is_text(): + message_seg: Optional[MessageSegment] = next( + filter(lambda x: x.is_text(), message), None) + if message_seg is not None: prefix = cls.prefix.longest_prefix(str(message_seg).lstrip()) message_seg_r = message[-1] if message_seg_r.is_text(): From bdd9f5ae30e9c3b825d024dc04c40c27193275e4 Mon Sep 17 00:00:00 2001 From: Mix Date: Sun, 7 Feb 2021 02:27:09 +0800 Subject: [PATCH 2/7] :bug: fix bad type hinting --- nonebot/plugin.py | 4 ++-- nonebot/rule.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nonebot/plugin.py b/nonebot/plugin.py index 7bb8716f..dd992f57 100644 --- a/nonebot/plugin.py +++ b/nonebot/plugin.py @@ -421,10 +421,10 @@ def on_command(cmd: Union[str, Tuple[str, ...]], """ async def _strip_cmd(bot: "Bot", event: "Event", state: T_State): - message: Iterable[MessageSegment] = event.get_message() + message = event.get_message() text_processed = False for index, segment in enumerate(message): - segment: MessageSegment = message.pop(index) + segment: "MessageSegment" = message.pop(index) if segment.is_text() and not text_processed: segment, *_ = message.__class__( str(segment)[len(state["_prefix"]["raw_command"]):].lstrip( diff --git a/nonebot/rule.py b/nonebot/rule.py index 002622a8..45146ae8 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -137,7 +137,7 @@ class TrieRule: prefix = None suffix = None message = event.get_message() - message_seg: Optional[MessageSegment] = next( + message_seg: Optional["MessageSegment"] = next( filter(lambda x: x.is_text(), message), None) if message_seg is not None: prefix = cls.prefix.longest_prefix(str(message_seg).lstrip()) From b59ff03abfca7204bb4bf99ea76c334310bb3ac1 Mon Sep 17 00:00:00 2001 From: Mix Date: Sun, 7 Feb 2021 10:14:19 +0800 Subject: [PATCH 3/7] :rewind: revert changes to change implement method This reverts commit bf7b2a8cbeafd55c2cf576545b63ff17d50b8866. --- nonebot/matcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 4c22be8f..0fda9f3d 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -418,7 +418,7 @@ class Matcher(metaclass=MatcherMeta): """ bot = current_bot.get() event = current_event.get() - return await bot.send(event=event, message=message, **kwargs) + await bot.send(event=event, message=message, **kwargs) @classmethod async def finish(cls, From 49010bf5b74f2dee5a08a4d728ac3acdb7a460b9 Mon Sep 17 00:00:00 2001 From: Mix Date: Sun, 7 Feb 2021 11:52:50 +0800 Subject: [PATCH 4/7] :alembic: trying to change mirai adapter message processing behavior --- nonebot/adapters/mirai/bot.py | 25 +++---- nonebot/adapters/mirai/event/__init__.py | 2 +- nonebot/adapters/mirai/event/message.py | 11 ++- nonebot/adapters/mirai/message.py | 8 +++ nonebot/adapters/mirai/utils.py | 88 ++++++++++++++---------- 5 files changed, 77 insertions(+), 57 deletions(-) diff --git a/nonebot/adapters/mirai/bot.py b/nonebot/adapters/mirai/bot.py index 0a1262c7..7bd40968 100644 --- a/nonebot/adapters/mirai/bot.py +++ b/nonebot/adapters/mirai/bot.py @@ -1,8 +1,7 @@ from datetime import datetime, timedelta -from functools import wraps from io import BytesIO from ipaddress import IPv4Address -from typing import (Any, Dict, List, NoReturn, Optional, Tuple, Union) +from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union import httpx @@ -10,15 +9,12 @@ from nonebot.adapters import Bot as BaseBot from nonebot.config import Config from nonebot.drivers import Driver, WebSocket from nonebot.exception import ApiNotAvailable, RequestDenied -from nonebot.log import logger -from nonebot.message import handle_event from nonebot.typing import overrides -from nonebot.utils import escape_tag from .config import Config as MiraiConfig from .event import Event, FriendMessage, GroupMessage, TempMessage from .message import MessageChain, MessageSegment -from .utils import catch_network_error, argument_validation, check_tome, Log +from .utils import Log, argument_validation, catch_network_error, process_event class SessionManager: @@ -212,20 +208,15 @@ class Bot(BaseBot): async def handle_message(self, message: dict): Log.debug(f'received message {message}') try: - await handle_event( + await process_event( bot=self, - event=await check_tome( - bot=self, - event=Event.new({ - **message, - 'self_id': self.self_id, - }), - ), + event=Event.new({ + **message, + 'self_id': self.self_id, + }), ) except Exception as e: - logger.opt(colors=True, exception=e).exception( - 'Failed to handle message ' - f'{escape_tag(str(message))}: ') + Log.error(f'Failed to handle message: {message}', e) @overrides(BaseBot) async def call_api(self, api: str, **data) -> NoReturn: diff --git a/nonebot/adapters/mirai/event/__init__.py b/nonebot/adapters/mirai/event/__init__.py index 1cf92096..91f4b127 100644 --- a/nonebot/adapters/mirai/event/__init__.py +++ b/nonebot/adapters/mirai/event/__init__.py @@ -13,7 +13,7 @@ from .request import * __all__ = [ 'Event', 'GroupChatInfo', 'GroupInfo', 'PrivateChatInfo', 'UserPermission', - 'MessageChain', 'MessageEvent', 'GroupMessage', 'FriendMessage', + 'MessageSource', 'MessageEvent', 'GroupMessage', 'FriendMessage', 'TempMessage', 'NoticeEvent', 'MuteEvent', 'BotMuteEvent', 'BotUnmuteEvent', 'MemberMuteEvent', 'MemberUnmuteEvent', 'BotJoinGroupEvent', 'BotLeaveEventActive', 'BotLeaveEventKick', 'MemberJoinEvent', diff --git a/nonebot/adapters/mirai/event/message.py b/nonebot/adapters/mirai/event/message.py index 26d534d4..5dda0857 100644 --- a/nonebot/adapters/mirai/event/message.py +++ b/nonebot/adapters/mirai/event/message.py @@ -1,6 +1,7 @@ -from typing import Any +from datetime import datetime +from typing import Any, Optional -from pydantic import Field +from pydantic import BaseModel, Field from nonebot.typing import overrides @@ -8,9 +9,15 @@ from ..message import MessageChain from .base import Event, GroupChatInfo, PrivateChatInfo +class MessageSource(BaseModel): + id: int + time: datetime + + class MessageEvent(Event): """消息事件基类""" message_chain: MessageChain = Field(alias='messageChain') + source: Optional[MessageSource] = None sender: Any @overrides(Event) diff --git a/nonebot/adapters/mirai/message.py b/nonebot/adapters/mirai/message.py index 26fb198c..2285ceda 100644 --- a/nonebot/adapters/mirai/message.py +++ b/nonebot/adapters/mirai/message.py @@ -306,5 +306,13 @@ class MessageChain(BaseMessage): *map(lambda segment: segment.as_dict(), self.copy()) # type: ignore ] + def extract_first(self, *type: MessageType) -> Optional[MessageSegment]: + if not len(self): + return None + first: MessageSegment = self[0] + if (not type) or (first.type in type): + return self.pop(0) + return None + def __repr__(self) -> str: return f'<{self.__class__.__name__} {[*self.copy()]}>' diff --git a/nonebot/adapters/mirai/utils.py b/nonebot/adapters/mirai/utils.py index db94dfed..bc7aa7dc 100644 --- a/nonebot/adapters/mirai/utils.py +++ b/nonebot/adapters/mirai/utils.py @@ -7,10 +7,11 @@ from pydantic import Extra, ValidationError, validate_arguments import nonebot.exception as exception from nonebot.log import logger +from nonebot.message import handle_event from nonebot.utils import escape_tag, logger_wrapper -from .event import Event, GroupMessage -from .message import MessageSegment, MessageType +from .event import Event, GroupMessage, MessageEvent, MessageSource +from .message import MessageType if TYPE_CHECKING: from .bot import Bot @@ -124,39 +125,52 @@ def argument_validation(function: _AnyCallable) -> _AnyCallable: return wrapper # type: ignore -async def check_tome(bot: "Bot", event: "Event") -> "Event": - if not isinstance(event, GroupMessage): - return event - - def _is_at(event: GroupMessage) -> bool: - for segment in event.message_chain: - segment: MessageSegment - if segment.type != MessageType.AT: - continue - if segment.data['target'] == event.self_id: - return True - return False - - def _is_nick(event: GroupMessage) -> bool: - text = event.get_plaintext() - if not text: - return False - nick_regex = '|'.join( - {i.strip() for i in bot.config.nickname if i.strip()}) - matched = re.search(rf"^({nick_regex})([\s,,]*|$)", text, re.IGNORECASE) - if matched is None: - return False - Log.info(f'User is calling me {matched.group(1)}') - return True - - def _is_reply(event: GroupMessage) -> bool: - for segment in event.message_chain: - segment: MessageSegment - if segment.type != MessageType.QUOTE: - continue - if segment.data['senderId'] == event.self_id: - return True - return False - - event.to_me = any([_is_at(event), _is_reply(event), _is_nick(event)]) +def process_source(bot: "Bot", event: MessageEvent) -> MessageEvent: + source = event.message_chain.extract_first(MessageType.SOURCE) + if source is not None: + event.source = MessageSource.parse_obj(source.data) return event + + +def process_at(bot: "Bot", event: GroupMessage) -> GroupMessage: + at = event.message_chain.extract_first(MessageType.AT) + if at is not None: + if at.data['target'] == event.self_id: + event.to_me = True + else: + event.message_chain.insert(0, at) + return event + + +def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage: + plain = event.message_chain.extract_first(MessageType.PLAIN) + if plain is not None: + text = str(plain) + nick_regex = '|'.join(filter(lambda x: x, bot.config.nickname)) + matched = re.search(rf"^({nick_regex})([\s,,]*|$)", text, re.IGNORECASE) + if matched is not None: + nickname = matched.group(1) + Log.info(f'User is calling me {nickname}') + plain.data['text'] = text[matched.end():] + event.message_chain.insert(0, plain) + return event + + +def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage: + reply = event.message_chain.extract_first(MessageType.QUOTE) + if reply is not None: + if reply.data['sender_id'] == event.self_id: + event.to_me = True + else: + event.message_chain.insert(0, reply) + return event + + +async def process_event(bot: "Bot", event: Event) -> None: + if isinstance(event, MessageEvent): + event = process_source(bot, event) + if isinstance(event, GroupMessage): + event = process_nick(bot, event) + event = process_reply(bot, event) + event = process_at(bot, event) + await handle_event(bot, event) \ No newline at end of file From 85aba9e36f92a3c8d6f2a1124445f04bf953456a Mon Sep 17 00:00:00 2001 From: Mix Date: Sun, 7 Feb 2021 12:17:21 +0800 Subject: [PATCH 5/7] :bug: fix bug founded during test in mirai adapter --- nonebot/adapters/mirai/bot.py | 5 ++--- nonebot/adapters/mirai/message.py | 8 +++++--- nonebot/adapters/mirai/utils.py | 3 ++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/nonebot/adapters/mirai/bot.py b/nonebot/adapters/mirai/bot.py index 7bd40968..900e3ec5 100644 --- a/nonebot/adapters/mirai/bot.py +++ b/nonebot/adapters/mirai/bot.py @@ -253,10 +253,9 @@ class Bot(BaseBot): * ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息 * ``at_sender: bool``: 是否 @ 事件主体 """ - if isinstance(message, MessageSegment): + print(event, message, at_sender) + if not isinstance(message, MessageChain): message = MessageChain(message) - elif isinstance(message, str): - message = MessageChain(MessageSegment.plain(message)) if isinstance(event, FriendMessage): return await self.send_friend_message(target=event.sender.id, message_chain=message) diff --git a/nonebot/adapters/mirai/message.py b/nonebot/adapters/mirai/message.py index 2285ceda..f6af1ab6 100644 --- a/nonebot/adapters/mirai/message.py +++ b/nonebot/adapters/mirai/message.py @@ -273,12 +273,14 @@ class MessageChain(BaseMessage): """ @overrides(BaseMessage) - def __init__(self, message: Union[List[Dict[str, Any]], - Iterable[MessageSegment], MessageSegment], - **kwargs): + def __init__(self, message: Union[List[Dict[str, + Any]], Iterable[MessageSegment], + MessageSegment, str], **kwargs): super().__init__(**kwargs) if isinstance(message, MessageSegment): self.append(message) + elif isinstance(message, str): + self.append(MessageSegment.plain(text=message)) elif isinstance(message, Iterable): self.extend(self._construct(message)) else: diff --git a/nonebot/adapters/mirai/utils.py b/nonebot/adapters/mirai/utils.py index bc7aa7dc..c9c9c143 100644 --- a/nonebot/adapters/mirai/utils.py +++ b/nonebot/adapters/mirai/utils.py @@ -149,6 +149,7 @@ def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage: nick_regex = '|'.join(filter(lambda x: x, bot.config.nickname)) matched = re.search(rf"^({nick_regex})([\s,,]*|$)", text, re.IGNORECASE) if matched is not None: + event.to_me = True nickname = matched.group(1) Log.info(f'User is calling me {nickname}') plain.data['text'] = text[matched.end():] @@ -159,7 +160,7 @@ def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage: def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage: reply = event.message_chain.extract_first(MessageType.QUOTE) if reply is not None: - if reply.data['sender_id'] == event.self_id: + if reply.data['senderId'] == event.self_id: event.to_me = True else: event.message_chain.insert(0, reply) From 24349953e3ff1cc9d3a1478d50b2564e3bd8c4c6 Mon Sep 17 00:00:00 2001 From: Mix Date: Sun, 7 Feb 2021 12:17:34 +0800 Subject: [PATCH 6/7] :white_check_mark: update test case --- tests/test_plugins/test_mirai.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/test_plugins/test_mirai.py b/tests/test_plugins/test_mirai.py index a5da93ae..c518290a 100644 --- a/tests/test_plugins/test_mirai.py +++ b/tests/test_plugins/test_mirai.py @@ -1,13 +1,20 @@ -from nonebot.plugin import on_message +from nonebot.plugin import on_keyword, on_command +from nonebot.rule import to_me from nonebot.adapters.mirai import Bot, MessageEvent -message_test = on_message() +message_test = on_keyword({'reply'}, rule=to_me()) @message_test.handle() async def _message(bot: Bot, event: MessageEvent): text = event.get_plaintext() - if not text: - return - reversed_text = ''.join(reversed(text)) - await bot.send(event, reversed_text, at_sender=True) + await bot.send(event, text, at_sender=True) + + +command_test = on_command('miecho') + + +@command_test.handle() +async def _echo(bot: Bot, event: MessageEvent): + text = event.get_plaintext() + await bot.send(event, text, at_sender=True) \ No newline at end of file From 382a9b6e125f6c126d96e1291e4767852b81e316 Mon Sep 17 00:00:00 2001 From: Mix Date: Sun, 7 Feb 2021 12:40:31 +0800 Subject: [PATCH 7/7] :loud_sound: improve message logging --- nonebot/adapters/mirai/bot.py | 1 - nonebot/adapters/mirai/message.py | 5 +++-- nonebot/adapters/mirai/utils.py | 16 ++++++++-------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/nonebot/adapters/mirai/bot.py b/nonebot/adapters/mirai/bot.py index 900e3ec5..4f5cb196 100644 --- a/nonebot/adapters/mirai/bot.py +++ b/nonebot/adapters/mirai/bot.py @@ -253,7 +253,6 @@ class Bot(BaseBot): * ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息 * ``at_sender: bool``: 是否 @ 事件主体 """ - print(event, message, at_sender) if not isinstance(message, MessageChain): message = MessageChain(message) if isinstance(event, FriendMessage): diff --git a/nonebot/adapters/mirai/message.py b/nonebot/adapters/mirai/message.py index f6af1ab6..d2a3ec39 100644 --- a/nonebot/adapters/mirai/message.py +++ b/nonebot/adapters/mirai/message.py @@ -44,8 +44,9 @@ class MessageSegment(BaseMessageSegment): @overrides(BaseMessageSegment) def __str__(self) -> str: - if self.is_text(): - return self.data.get('text', '') + return self.data['text'] if self.is_text() else repr(self) + + def __repr__(self) -> str: return '[mirai:%s]' % ','.join([ self.type.value, *map( diff --git a/nonebot/adapters/mirai/utils.py b/nonebot/adapters/mirai/utils.py index c9c9c143..385bd3c6 100644 --- a/nonebot/adapters/mirai/utils.py +++ b/nonebot/adapters/mirai/utils.py @@ -23,27 +23,26 @@ _AnyCallable = TypeVar("_AnyCallable", bound=Callable) class Log: @staticmethod - def _log(level: str, message: Any, exception: Optional[Exception] = None): + def log(level: str, message: str, exception: Optional[Exception] = None): logger = logger_wrapper('MIRAI') - logger(level=level, - message=escape_tag(str(message)), - exception=exception) + message = '' + escape_tag(message) + '' + logger(level=level.upper(), message=message, exception=exception) @classmethod def info(cls, message: Any): - cls._log('INFO', escape_tag(str(message))) + cls.log('INFO', str(message)) @classmethod def debug(cls, message: Any): - cls._log('DEBUG', escape_tag(str(message))) + cls.log('DEBUG', str(message)) @classmethod def warn(cls, message: Any): - cls._log('WARNING', escape_tag(str(message))) + cls.log('WARNING', str(message)) @classmethod def error(cls, message: Any, exception: Optional[Exception] = None): - cls._log('ERROR', escape_tag(str(message)), exception=exception) + cls.log('ERROR', str(message), exception=exception) class ActionFailed(exception.ActionFailed): @@ -169,6 +168,7 @@ def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage: async def process_event(bot: "Bot", event: Event) -> None: if isinstance(event, MessageEvent): + Log.debug(event.message_chain) event = process_source(bot, event) if isinstance(event, GroupMessage): event = process_nick(bot, event)