diff --git a/nonebot/adapters/mirai/bot.py b/nonebot/adapters/mirai/bot.py index 0a1262c7..4f5cb196 100644 --- a/nonebot/adapters/mirai/bot.py +++ b/nonebot/adapters/mirai/bot.py @@ -1,8 +1,7 @@ from datetime import datetime, timedelta -from functools import wraps from io import BytesIO from ipaddress import IPv4Address -from typing import (Any, Dict, List, NoReturn, Optional, Tuple, Union) +from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union import httpx @@ -10,15 +9,12 @@ from nonebot.adapters import Bot as BaseBot from nonebot.config import Config from nonebot.drivers import Driver, WebSocket from nonebot.exception import ApiNotAvailable, RequestDenied -from nonebot.log import logger -from nonebot.message import handle_event from nonebot.typing import overrides -from nonebot.utils import escape_tag from .config import Config as MiraiConfig from .event import Event, FriendMessage, GroupMessage, TempMessage from .message import MessageChain, MessageSegment -from .utils import catch_network_error, argument_validation, check_tome, Log +from .utils import Log, argument_validation, catch_network_error, process_event class SessionManager: @@ -212,20 +208,15 @@ class Bot(BaseBot): async def handle_message(self, message: dict): Log.debug(f'received message {message}') try: - await handle_event( + await process_event( bot=self, - event=await check_tome( - bot=self, - event=Event.new({ - **message, - 'self_id': self.self_id, - }), - ), + event=Event.new({ + **message, + 'self_id': self.self_id, + }), ) except Exception as e: - logger.opt(colors=True, exception=e).exception( - 'Failed to handle message ' - f'{escape_tag(str(message))}: ') + Log.error(f'Failed to handle message: {message}', e) @overrides(BaseBot) async def call_api(self, api: str, **data) -> NoReturn: @@ -262,10 +253,8 @@ class Bot(BaseBot): * ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息 * ``at_sender: bool``: 是否 @ 事件主体 """ - if isinstance(message, MessageSegment): + if not isinstance(message, MessageChain): message = MessageChain(message) - elif isinstance(message, str): - message = MessageChain(MessageSegment.plain(message)) if isinstance(event, FriendMessage): return await self.send_friend_message(target=event.sender.id, message_chain=message) diff --git a/nonebot/adapters/mirai/event/__init__.py b/nonebot/adapters/mirai/event/__init__.py index 1cf92096..91f4b127 100644 --- a/nonebot/adapters/mirai/event/__init__.py +++ b/nonebot/adapters/mirai/event/__init__.py @@ -13,7 +13,7 @@ from .request import * __all__ = [ 'Event', 'GroupChatInfo', 'GroupInfo', 'PrivateChatInfo', 'UserPermission', - 'MessageChain', 'MessageEvent', 'GroupMessage', 'FriendMessage', + 'MessageSource', 'MessageEvent', 'GroupMessage', 'FriendMessage', 'TempMessage', 'NoticeEvent', 'MuteEvent', 'BotMuteEvent', 'BotUnmuteEvent', 'MemberMuteEvent', 'MemberUnmuteEvent', 'BotJoinGroupEvent', 'BotLeaveEventActive', 'BotLeaveEventKick', 'MemberJoinEvent', diff --git a/nonebot/adapters/mirai/event/message.py b/nonebot/adapters/mirai/event/message.py index 26d534d4..5dda0857 100644 --- a/nonebot/adapters/mirai/event/message.py +++ b/nonebot/adapters/mirai/event/message.py @@ -1,6 +1,7 @@ -from typing import Any +from datetime import datetime +from typing import Any, Optional -from pydantic import Field +from pydantic import BaseModel, Field from nonebot.typing import overrides @@ -8,9 +9,15 @@ from ..message import MessageChain from .base import Event, GroupChatInfo, PrivateChatInfo +class MessageSource(BaseModel): + id: int + time: datetime + + class MessageEvent(Event): """消息事件基类""" message_chain: MessageChain = Field(alias='messageChain') + source: Optional[MessageSource] = None sender: Any @overrides(Event) diff --git a/nonebot/adapters/mirai/message.py b/nonebot/adapters/mirai/message.py index 26fb198c..d2a3ec39 100644 --- a/nonebot/adapters/mirai/message.py +++ b/nonebot/adapters/mirai/message.py @@ -44,8 +44,9 @@ class MessageSegment(BaseMessageSegment): @overrides(BaseMessageSegment) def __str__(self) -> str: - if self.is_text(): - return self.data.get('text', '') + return self.data['text'] if self.is_text() else repr(self) + + def __repr__(self) -> str: return '[mirai:%s]' % ','.join([ self.type.value, *map( @@ -273,12 +274,14 @@ class MessageChain(BaseMessage): """ @overrides(BaseMessage) - def __init__(self, message: Union[List[Dict[str, Any]], - Iterable[MessageSegment], MessageSegment], - **kwargs): + def __init__(self, message: Union[List[Dict[str, + Any]], Iterable[MessageSegment], + MessageSegment, str], **kwargs): super().__init__(**kwargs) if isinstance(message, MessageSegment): self.append(message) + elif isinstance(message, str): + self.append(MessageSegment.plain(text=message)) elif isinstance(message, Iterable): self.extend(self._construct(message)) else: @@ -306,5 +309,13 @@ class MessageChain(BaseMessage): *map(lambda segment: segment.as_dict(), self.copy()) # type: ignore ] + def extract_first(self, *type: MessageType) -> Optional[MessageSegment]: + if not len(self): + return None + first: MessageSegment = self[0] + if (not type) or (first.type in type): + return self.pop(0) + return None + def __repr__(self) -> str: return f'<{self.__class__.__name__} {[*self.copy()]}>' diff --git a/nonebot/adapters/mirai/utils.py b/nonebot/adapters/mirai/utils.py index db94dfed..385bd3c6 100644 --- a/nonebot/adapters/mirai/utils.py +++ b/nonebot/adapters/mirai/utils.py @@ -7,10 +7,11 @@ from pydantic import Extra, ValidationError, validate_arguments import nonebot.exception as exception from nonebot.log import logger +from nonebot.message import handle_event from nonebot.utils import escape_tag, logger_wrapper -from .event import Event, GroupMessage -from .message import MessageSegment, MessageType +from .event import Event, GroupMessage, MessageEvent, MessageSource +from .message import MessageType if TYPE_CHECKING: from .bot import Bot @@ -22,27 +23,26 @@ _AnyCallable = TypeVar("_AnyCallable", bound=Callable) class Log: @staticmethod - def _log(level: str, message: Any, exception: Optional[Exception] = None): + def log(level: str, message: str, exception: Optional[Exception] = None): logger = logger_wrapper('MIRAI') - logger(level=level, - message=escape_tag(str(message)), - exception=exception) + message = '' + escape_tag(message) + '' + logger(level=level.upper(), message=message, exception=exception) @classmethod def info(cls, message: Any): - cls._log('INFO', escape_tag(str(message))) + cls.log('INFO', str(message)) @classmethod def debug(cls, message: Any): - cls._log('DEBUG', escape_tag(str(message))) + cls.log('DEBUG', str(message)) @classmethod def warn(cls, message: Any): - cls._log('WARNING', escape_tag(str(message))) + cls.log('WARNING', str(message)) @classmethod def error(cls, message: Any, exception: Optional[Exception] = None): - cls._log('ERROR', escape_tag(str(message)), exception=exception) + cls.log('ERROR', str(message), exception=exception) class ActionFailed(exception.ActionFailed): @@ -124,39 +124,54 @@ def argument_validation(function: _AnyCallable) -> _AnyCallable: return wrapper # type: ignore -async def check_tome(bot: "Bot", event: "Event") -> "Event": - if not isinstance(event, GroupMessage): - return event - - def _is_at(event: GroupMessage) -> bool: - for segment in event.message_chain: - segment: MessageSegment - if segment.type != MessageType.AT: - continue - if segment.data['target'] == event.self_id: - return True - return False - - def _is_nick(event: GroupMessage) -> bool: - text = event.get_plaintext() - if not text: - return False - nick_regex = '|'.join( - {i.strip() for i in bot.config.nickname if i.strip()}) - matched = re.search(rf"^({nick_regex})([\s,,]*|$)", text, re.IGNORECASE) - if matched is None: - return False - Log.info(f'User is calling me {matched.group(1)}') - return True - - def _is_reply(event: GroupMessage) -> bool: - for segment in event.message_chain: - segment: MessageSegment - if segment.type != MessageType.QUOTE: - continue - if segment.data['senderId'] == event.self_id: - return True - return False - - event.to_me = any([_is_at(event), _is_reply(event), _is_nick(event)]) +def process_source(bot: "Bot", event: MessageEvent) -> MessageEvent: + source = event.message_chain.extract_first(MessageType.SOURCE) + if source is not None: + event.source = MessageSource.parse_obj(source.data) return event + + +def process_at(bot: "Bot", event: GroupMessage) -> GroupMessage: + at = event.message_chain.extract_first(MessageType.AT) + if at is not None: + if at.data['target'] == event.self_id: + event.to_me = True + else: + event.message_chain.insert(0, at) + return event + + +def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage: + plain = event.message_chain.extract_first(MessageType.PLAIN) + if plain is not None: + text = str(plain) + nick_regex = '|'.join(filter(lambda x: x, bot.config.nickname)) + matched = re.search(rf"^({nick_regex})([\s,,]*|$)", text, re.IGNORECASE) + if matched is not None: + event.to_me = True + nickname = matched.group(1) + Log.info(f'User is calling me {nickname}') + plain.data['text'] = text[matched.end():] + event.message_chain.insert(0, plain) + return event + + +def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage: + reply = event.message_chain.extract_first(MessageType.QUOTE) + if reply is not None: + if reply.data['senderId'] == event.self_id: + event.to_me = True + else: + event.message_chain.insert(0, reply) + return event + + +async def process_event(bot: "Bot", event: Event) -> None: + if isinstance(event, MessageEvent): + Log.debug(event.message_chain) + event = process_source(bot, event) + if isinstance(event, GroupMessage): + event = process_nick(bot, event) + event = process_reply(bot, event) + event = process_at(bot, event) + await handle_event(bot, event) \ No newline at end of file diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 4c22be8f..0fda9f3d 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -418,7 +418,7 @@ class Matcher(metaclass=MatcherMeta): """ bot = current_bot.get() event = current_event.get() - return await bot.send(event=event, message=message, **kwargs) + await bot.send(event=event, message=message, **kwargs) @classmethod async def finish(cls, diff --git a/nonebot/plugin.py b/nonebot/plugin.py index 3270fd12..dd992f57 100644 --- a/nonebot/plugin.py +++ b/nonebot/plugin.py @@ -13,7 +13,7 @@ from types import ModuleType from dataclasses import dataclass from importlib._bootstrap import _load from contextvars import Context, ContextVar, copy_context -from typing import Any, Set, List, Dict, Type, Tuple, Union, Optional, TYPE_CHECKING +from typing import Any, Set, List, Dict, Type, Tuple, Union, Optional, TYPE_CHECKING, Iterable from nonebot.log import logger from nonebot.matcher import Matcher @@ -22,7 +22,7 @@ from nonebot.typing import T_State, T_StateFactory, T_Handler, T_RuleChecker from nonebot.rule import Rule, startswith, endswith, keyword, command, shell_command, ArgumentParser, regex if TYPE_CHECKING: - from nonebot.adapters import Bot, Event + from nonebot.adapters import Bot, Event, MessageSegment plugins: Dict[str, "Plugin"] = {} """ @@ -422,12 +422,15 @@ def on_command(cmd: Union[str, Tuple[str, ...]], async def _strip_cmd(bot: "Bot", event: "Event", state: T_State): message = event.get_message() - segment = message.pop(0) - new_message = message.__class__( - str(segment) - [len(state["_prefix"]["raw_command"]):].lstrip()) # type: ignore - for new_segment in reversed(new_message): - message.insert(0, new_segment) + text_processed = False + for index, segment in enumerate(message): + segment: "MessageSegment" = message.pop(index) + if segment.is_text() and not text_processed: + segment, *_ = message.__class__( + str(segment)[len(state["_prefix"]["raw_command"]):].lstrip( + )) # type: ignore + text_processed = True + message.insert(index, segment) handlers = kwargs.pop("handlers", []) handlers.insert(0, _strip_cmd) diff --git a/nonebot/rule.py b/nonebot/rule.py index d9f75a24..45146ae8 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -25,7 +25,7 @@ from nonebot.exception import ParserExit from nonebot.typing import T_State, T_RuleChecker if TYPE_CHECKING: - from nonebot.adapters import Bot, Event + from nonebot.adapters import Bot, Event, MessageSegment class Rule: @@ -137,8 +137,9 @@ class TrieRule: prefix = None suffix = None message = event.get_message() - message_seg = message[0] - if message_seg.is_text(): + message_seg: Optional["MessageSegment"] = next( + filter(lambda x: x.is_text(), message), None) + if message_seg is not None: prefix = cls.prefix.longest_prefix(str(message_seg).lstrip()) message_seg_r = message[-1] if message_seg_r.is_text(): diff --git a/tests/test_plugins/test_mirai.py b/tests/test_plugins/test_mirai.py index a5da93ae..c518290a 100644 --- a/tests/test_plugins/test_mirai.py +++ b/tests/test_plugins/test_mirai.py @@ -1,13 +1,20 @@ -from nonebot.plugin import on_message +from nonebot.plugin import on_keyword, on_command +from nonebot.rule import to_me from nonebot.adapters.mirai import Bot, MessageEvent -message_test = on_message() +message_test = on_keyword({'reply'}, rule=to_me()) @message_test.handle() async def _message(bot: Bot, event: MessageEvent): text = event.get_plaintext() - if not text: - return - reversed_text = ''.join(reversed(text)) - await bot.send(event, reversed_text, at_sender=True) + await bot.send(event, text, at_sender=True) + + +command_test = on_command('miecho') + + +@command_test.handle() +async def _echo(bot: Bot, event: MessageEvent): + text = event.get_plaintext() + await bot.send(event, text, at_sender=True) \ No newline at end of file