diff --git a/nonebot/adapters/mirai/bot.py b/nonebot/adapters/mirai/bot.py
index 0a1262c7..4f5cb196 100644
--- a/nonebot/adapters/mirai/bot.py
+++ b/nonebot/adapters/mirai/bot.py
@@ -1,8 +1,7 @@
from datetime import datetime, timedelta
-from functools import wraps
from io import BytesIO
from ipaddress import IPv4Address
-from typing import (Any, Dict, List, NoReturn, Optional, Tuple, Union)
+from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
import httpx
@@ -10,15 +9,12 @@ from nonebot.adapters import Bot as BaseBot
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
from nonebot.exception import ApiNotAvailable, RequestDenied
-from nonebot.log import logger
-from nonebot.message import handle_event
from nonebot.typing import overrides
-from nonebot.utils import escape_tag
from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage
from .message import MessageChain, MessageSegment
-from .utils import catch_network_error, argument_validation, check_tome, Log
+from .utils import Log, argument_validation, catch_network_error, process_event
class SessionManager:
@@ -212,20 +208,15 @@ class Bot(BaseBot):
async def handle_message(self, message: dict):
Log.debug(f'received message {message}')
try:
- await handle_event(
+ await process_event(
bot=self,
- event=await check_tome(
- bot=self,
- event=Event.new({
- **message,
- 'self_id': self.self_id,
- }),
- ),
+ event=Event.new({
+ **message,
+ 'self_id': self.self_id,
+ }),
)
except Exception as e:
- logger.opt(colors=True, exception=e).exception(
- 'Failed to handle message '
- f'{escape_tag(str(message))}: ')
+ Log.error(f'Failed to handle message: {message}', e)
@overrides(BaseBot)
async def call_api(self, api: str, **data) -> NoReturn:
@@ -262,10 +253,8 @@ class Bot(BaseBot):
* ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息
* ``at_sender: bool``: 是否 @ 事件主体
"""
- if isinstance(message, MessageSegment):
+ if not isinstance(message, MessageChain):
message = MessageChain(message)
- elif isinstance(message, str):
- message = MessageChain(MessageSegment.plain(message))
if isinstance(event, FriendMessage):
return await self.send_friend_message(target=event.sender.id,
message_chain=message)
diff --git a/nonebot/adapters/mirai/event/__init__.py b/nonebot/adapters/mirai/event/__init__.py
index 1cf92096..91f4b127 100644
--- a/nonebot/adapters/mirai/event/__init__.py
+++ b/nonebot/adapters/mirai/event/__init__.py
@@ -13,7 +13,7 @@ from .request import *
__all__ = [
'Event', 'GroupChatInfo', 'GroupInfo', 'PrivateChatInfo', 'UserPermission',
- 'MessageChain', 'MessageEvent', 'GroupMessage', 'FriendMessage',
+ 'MessageSource', 'MessageEvent', 'GroupMessage', 'FriendMessage',
'TempMessage', 'NoticeEvent', 'MuteEvent', 'BotMuteEvent', 'BotUnmuteEvent',
'MemberMuteEvent', 'MemberUnmuteEvent', 'BotJoinGroupEvent',
'BotLeaveEventActive', 'BotLeaveEventKick', 'MemberJoinEvent',
diff --git a/nonebot/adapters/mirai/event/message.py b/nonebot/adapters/mirai/event/message.py
index 26d534d4..5dda0857 100644
--- a/nonebot/adapters/mirai/event/message.py
+++ b/nonebot/adapters/mirai/event/message.py
@@ -1,6 +1,7 @@
-from typing import Any
+from datetime import datetime
+from typing import Any, Optional
-from pydantic import Field
+from pydantic import BaseModel, Field
from nonebot.typing import overrides
@@ -8,9 +9,15 @@ from ..message import MessageChain
from .base import Event, GroupChatInfo, PrivateChatInfo
+class MessageSource(BaseModel):
+ id: int
+ time: datetime
+
+
class MessageEvent(Event):
"""消息事件基类"""
message_chain: MessageChain = Field(alias='messageChain')
+ source: Optional[MessageSource] = None
sender: Any
@overrides(Event)
diff --git a/nonebot/adapters/mirai/message.py b/nonebot/adapters/mirai/message.py
index 26fb198c..d2a3ec39 100644
--- a/nonebot/adapters/mirai/message.py
+++ b/nonebot/adapters/mirai/message.py
@@ -44,8 +44,9 @@ class MessageSegment(BaseMessageSegment):
@overrides(BaseMessageSegment)
def __str__(self) -> str:
- if self.is_text():
- return self.data.get('text', '')
+ return self.data['text'] if self.is_text() else repr(self)
+
+ def __repr__(self) -> str:
return '[mirai:%s]' % ','.join([
self.type.value,
*map(
@@ -273,12 +274,14 @@ class MessageChain(BaseMessage):
"""
@overrides(BaseMessage)
- def __init__(self, message: Union[List[Dict[str, Any]],
- Iterable[MessageSegment], MessageSegment],
- **kwargs):
+ def __init__(self, message: Union[List[Dict[str,
+ Any]], Iterable[MessageSegment],
+ MessageSegment, str], **kwargs):
super().__init__(**kwargs)
if isinstance(message, MessageSegment):
self.append(message)
+ elif isinstance(message, str):
+ self.append(MessageSegment.plain(text=message))
elif isinstance(message, Iterable):
self.extend(self._construct(message))
else:
@@ -306,5 +309,13 @@ class MessageChain(BaseMessage):
*map(lambda segment: segment.as_dict(), self.copy()) # type: ignore
]
+ def extract_first(self, *type: MessageType) -> Optional[MessageSegment]:
+ if not len(self):
+ return None
+ first: MessageSegment = self[0]
+ if (not type) or (first.type in type):
+ return self.pop(0)
+ return None
+
def __repr__(self) -> str:
return f'<{self.__class__.__name__} {[*self.copy()]}>'
diff --git a/nonebot/adapters/mirai/utils.py b/nonebot/adapters/mirai/utils.py
index db94dfed..385bd3c6 100644
--- a/nonebot/adapters/mirai/utils.py
+++ b/nonebot/adapters/mirai/utils.py
@@ -7,10 +7,11 @@ from pydantic import Extra, ValidationError, validate_arguments
import nonebot.exception as exception
from nonebot.log import logger
+from nonebot.message import handle_event
from nonebot.utils import escape_tag, logger_wrapper
-from .event import Event, GroupMessage
-from .message import MessageSegment, MessageType
+from .event import Event, GroupMessage, MessageEvent, MessageSource
+from .message import MessageType
if TYPE_CHECKING:
from .bot import Bot
@@ -22,27 +23,26 @@ _AnyCallable = TypeVar("_AnyCallable", bound=Callable)
class Log:
@staticmethod
- def _log(level: str, message: Any, exception: Optional[Exception] = None):
+ def log(level: str, message: str, exception: Optional[Exception] = None):
logger = logger_wrapper('MIRAI')
- logger(level=level,
- message=escape_tag(str(message)),
- exception=exception)
+ message = '' + escape_tag(message) + ''
+ logger(level=level.upper(), message=message, exception=exception)
@classmethod
def info(cls, message: Any):
- cls._log('INFO', escape_tag(str(message)))
+ cls.log('INFO', str(message))
@classmethod
def debug(cls, message: Any):
- cls._log('DEBUG', escape_tag(str(message)))
+ cls.log('DEBUG', str(message))
@classmethod
def warn(cls, message: Any):
- cls._log('WARNING', escape_tag(str(message)))
+ cls.log('WARNING', str(message))
@classmethod
def error(cls, message: Any, exception: Optional[Exception] = None):
- cls._log('ERROR', escape_tag(str(message)), exception=exception)
+ cls.log('ERROR', str(message), exception=exception)
class ActionFailed(exception.ActionFailed):
@@ -124,39 +124,54 @@ def argument_validation(function: _AnyCallable) -> _AnyCallable:
return wrapper # type: ignore
-async def check_tome(bot: "Bot", event: "Event") -> "Event":
- if not isinstance(event, GroupMessage):
- return event
-
- def _is_at(event: GroupMessage) -> bool:
- for segment in event.message_chain:
- segment: MessageSegment
- if segment.type != MessageType.AT:
- continue
- if segment.data['target'] == event.self_id:
- return True
- return False
-
- def _is_nick(event: GroupMessage) -> bool:
- text = event.get_plaintext()
- if not text:
- return False
- nick_regex = '|'.join(
- {i.strip() for i in bot.config.nickname if i.strip()})
- matched = re.search(rf"^({nick_regex})([\s,,]*|$)", text, re.IGNORECASE)
- if matched is None:
- return False
- Log.info(f'User is calling me {matched.group(1)}')
- return True
-
- def _is_reply(event: GroupMessage) -> bool:
- for segment in event.message_chain:
- segment: MessageSegment
- if segment.type != MessageType.QUOTE:
- continue
- if segment.data['senderId'] == event.self_id:
- return True
- return False
-
- event.to_me = any([_is_at(event), _is_reply(event), _is_nick(event)])
+def process_source(bot: "Bot", event: MessageEvent) -> MessageEvent:
+ source = event.message_chain.extract_first(MessageType.SOURCE)
+ if source is not None:
+ event.source = MessageSource.parse_obj(source.data)
return event
+
+
+def process_at(bot: "Bot", event: GroupMessage) -> GroupMessage:
+ at = event.message_chain.extract_first(MessageType.AT)
+ if at is not None:
+ if at.data['target'] == event.self_id:
+ event.to_me = True
+ else:
+ event.message_chain.insert(0, at)
+ return event
+
+
+def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage:
+ plain = event.message_chain.extract_first(MessageType.PLAIN)
+ if plain is not None:
+ text = str(plain)
+ nick_regex = '|'.join(filter(lambda x: x, bot.config.nickname))
+ matched = re.search(rf"^({nick_regex})([\s,,]*|$)", text, re.IGNORECASE)
+ if matched is not None:
+ event.to_me = True
+ nickname = matched.group(1)
+ Log.info(f'User is calling me {nickname}')
+ plain.data['text'] = text[matched.end():]
+ event.message_chain.insert(0, plain)
+ return event
+
+
+def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage:
+ reply = event.message_chain.extract_first(MessageType.QUOTE)
+ if reply is not None:
+ if reply.data['senderId'] == event.self_id:
+ event.to_me = True
+ else:
+ event.message_chain.insert(0, reply)
+ return event
+
+
+async def process_event(bot: "Bot", event: Event) -> None:
+ if isinstance(event, MessageEvent):
+ Log.debug(event.message_chain)
+ event = process_source(bot, event)
+ if isinstance(event, GroupMessage):
+ event = process_nick(bot, event)
+ event = process_reply(bot, event)
+ event = process_at(bot, event)
+ await handle_event(bot, event)
\ No newline at end of file
diff --git a/nonebot/matcher.py b/nonebot/matcher.py
index 4c22be8f..0fda9f3d 100644
--- a/nonebot/matcher.py
+++ b/nonebot/matcher.py
@@ -418,7 +418,7 @@ class Matcher(metaclass=MatcherMeta):
"""
bot = current_bot.get()
event = current_event.get()
- return await bot.send(event=event, message=message, **kwargs)
+ await bot.send(event=event, message=message, **kwargs)
@classmethod
async def finish(cls,
diff --git a/nonebot/plugin.py b/nonebot/plugin.py
index 3270fd12..dd992f57 100644
--- a/nonebot/plugin.py
+++ b/nonebot/plugin.py
@@ -13,7 +13,7 @@ from types import ModuleType
from dataclasses import dataclass
from importlib._bootstrap import _load
from contextvars import Context, ContextVar, copy_context
-from typing import Any, Set, List, Dict, Type, Tuple, Union, Optional, TYPE_CHECKING
+from typing import Any, Set, List, Dict, Type, Tuple, Union, Optional, TYPE_CHECKING, Iterable
from nonebot.log import logger
from nonebot.matcher import Matcher
@@ -22,7 +22,7 @@ from nonebot.typing import T_State, T_StateFactory, T_Handler, T_RuleChecker
from nonebot.rule import Rule, startswith, endswith, keyword, command, shell_command, ArgumentParser, regex
if TYPE_CHECKING:
- from nonebot.adapters import Bot, Event
+ from nonebot.adapters import Bot, Event, MessageSegment
plugins: Dict[str, "Plugin"] = {}
"""
@@ -422,12 +422,15 @@ def on_command(cmd: Union[str, Tuple[str, ...]],
async def _strip_cmd(bot: "Bot", event: "Event", state: T_State):
message = event.get_message()
- segment = message.pop(0)
- new_message = message.__class__(
- str(segment)
- [len(state["_prefix"]["raw_command"]):].lstrip()) # type: ignore
- for new_segment in reversed(new_message):
- message.insert(0, new_segment)
+ text_processed = False
+ for index, segment in enumerate(message):
+ segment: "MessageSegment" = message.pop(index)
+ if segment.is_text() and not text_processed:
+ segment, *_ = message.__class__(
+ str(segment)[len(state["_prefix"]["raw_command"]):].lstrip(
+ )) # type: ignore
+ text_processed = True
+ message.insert(index, segment)
handlers = kwargs.pop("handlers", [])
handlers.insert(0, _strip_cmd)
diff --git a/nonebot/rule.py b/nonebot/rule.py
index d9f75a24..45146ae8 100644
--- a/nonebot/rule.py
+++ b/nonebot/rule.py
@@ -25,7 +25,7 @@ from nonebot.exception import ParserExit
from nonebot.typing import T_State, T_RuleChecker
if TYPE_CHECKING:
- from nonebot.adapters import Bot, Event
+ from nonebot.adapters import Bot, Event, MessageSegment
class Rule:
@@ -137,8 +137,9 @@ class TrieRule:
prefix = None
suffix = None
message = event.get_message()
- message_seg = message[0]
- if message_seg.is_text():
+ message_seg: Optional["MessageSegment"] = next(
+ filter(lambda x: x.is_text(), message), None)
+ if message_seg is not None:
prefix = cls.prefix.longest_prefix(str(message_seg).lstrip())
message_seg_r = message[-1]
if message_seg_r.is_text():
diff --git a/tests/test_plugins/test_mirai.py b/tests/test_plugins/test_mirai.py
index a5da93ae..c518290a 100644
--- a/tests/test_plugins/test_mirai.py
+++ b/tests/test_plugins/test_mirai.py
@@ -1,13 +1,20 @@
-from nonebot.plugin import on_message
+from nonebot.plugin import on_keyword, on_command
+from nonebot.rule import to_me
from nonebot.adapters.mirai import Bot, MessageEvent
-message_test = on_message()
+message_test = on_keyword({'reply'}, rule=to_me())
@message_test.handle()
async def _message(bot: Bot, event: MessageEvent):
text = event.get_plaintext()
- if not text:
- return
- reversed_text = ''.join(reversed(text))
- await bot.send(event, reversed_text, at_sender=True)
+ await bot.send(event, text, at_sender=True)
+
+
+command_test = on_command('miecho')
+
+
+@command_test.handle()
+async def _echo(bot: Bot, event: MessageEvent):
+ text = event.get_plaintext()
+ await bot.send(event, text, at_sender=True)
\ No newline at end of file