diff --git a/nonebot_plugin_dialectlist/__init__.py b/nonebot_plugin_dialectlist/__init__.py index ceb54ab..c26082e 100644 --- a/nonebot_plugin_dialectlist/__init__.py +++ b/nonebot_plugin_dialectlist/__init__.py @@ -1,6 +1,6 @@ import re import time -from typing import Dict, List, Tuple, Union +from typing import List, Tuple, Union from datetime import datetime, timedelta try: @@ -14,92 +14,16 @@ from nonebot.params import Command, CommandArg, Arg, Depends from nonebot.typing import T_State from nonebot.matcher import Matcher from nonebot.adapters import Bot -from nonebot.adapters.onebot.v11 import GroupMessageEvent,Message -from nonebot.adapters.onebot.v11.exception import ActionFailed - - +from nonebot.adapters.onebot.v11 import GroupMessageEvent, PrivateMessageEvent, Message require("nonebot_plugin_datastore") require("nonebot_plugin_chatrecorder") -from .record4dialectlist import get_message_records +require("nonebot_plugin_guild_patch") +from nonebot_plugin_guild_patch import GuildMessageEvent + +from .qqDBRecorder import get_message_records,msg_counter from .config import plugin_config - -async def msg_got_counter( - gid:int, - bot:Bot, - start_time=None, - stop_time=None, - got_num:int=10 -)->Message: - ''' - 计算出结果并返回可以直接发送的字符串 - ''' - st = time.time() - gids:List[str] = [str(gid)] - bot_id = await bot.call_api('get_login_info') - bot_id = [str(bot_id['user_id'])] - - logger.debug('loading msg form group {}'.format(gid)) - - gnl = await bot.call_api('get_group_member_list',group_id=int(gid)) - - logger.debug('group {} have number {}'.format(gid,len(gnl))) - - msg = await get_message_records( - group_ids=gids, - exclude_user_ids=bot_id, - message_type='group', - time_start=start_time, - time_stop=stop_time - ) - - lst:Dict[str,int] = {} - for i in msg: - try: - lst[i.user_id] +=1 - except KeyError: - lst[i.user_id] =1 - - logger.debug(lst) - logger.debug('group number num is '+str(len(lst))) - - ranking = [] - while len(ranking) < got_num: - try: - maxkey = max(lst, key=lst.get) # type: ignore - except ValueError: - ranking.append(None) - continue - - logger.debug('searching number {} form group {}'.format(str(maxkey),str(gid))) - try: - - t = await bot.call_api( - "get_group_member_info", - group_id=int(gid), - user_id=int(maxkey), - no_cache=True - ) - - nickname:str = t['nickname']if not t['card'] else t['card'] - ranking.append([nickname.strip(),lst.pop(maxkey)]) - - except ActionFailed as e: - - logger.warning(e) - logger.warning('number {} not exit in group {}'.format(str(maxkey),str(gid))) - lst.pop(maxkey) - - logger.debug('loaded list:\n{}'.format(ranking)) - - - out:str = '' - for i in range(got_num): - str_example = '第{}名:\n{}条消息\n'.format(i+1,str(ranking[i])[1:-1]) - out = out + str_example - out = out + '\n\n你们的职业是水群吗?————MYX\n计算花费时间:{}秒'.format(time.time()-st) - - return Message(out) +from .qqGuildJsonRecorder import get_guild_message_records def parse_datetime(key: str): @@ -157,7 +81,16 @@ rankings = on_command( ) @rankings.handle() -async def _a(event:GroupMessageEvent,state: T_State,commands: Tuple[str, ...] = Command(),args: Message = CommandArg()): +async def _group_message( + event:Union[GroupMessageEvent, GuildMessageEvent], + state: T_State,commands: Tuple[str, ...] = Command(), + args: Message = CommandArg() + ): + + if isinstance(event, GroupMessageEvent): + logger.debug('handle command from qq') + elif isinstance(event, GuildMessageEvent): + logger.debug('handle command from qqguild') dt = get_datetime_now_with_timezone() command = commands[0] @@ -204,6 +137,15 @@ async def _a(event:GroupMessageEvent,state: T_State,commands: Tuple[str, ...] = await rankings.finish("请输入正确的日期,不然我没法理解呢!") else: pass + +@rankings.handle() +async def _private_message( + event:PrivateMessageEvent, + state: T_State,commands: Tuple[str, ...] = Command(), + args: Message = CommandArg() + ): + # TODO:支持私聊的查询 + await rankings.finish('暂不支持私聊查询,今后可能会添加这一项功能') @rankings.got( "start", @@ -217,19 +159,31 @@ async def _a(event:GroupMessageEvent,state: T_State,commands: Tuple[str, ...] = ) async def handle_message( bot: Bot, - event: GroupMessageEvent, - start: datetime = Arg(), - stop: datetime = Arg() + event: Union[GroupMessageEvent, GuildMessageEvent], + stop: datetime = Arg(), + start: datetime = Arg() ): - - # 将时间转换到 UTC 时区 - msg = await msg_got_counter( - gid=event.group_id, - bot=bot, - start_time=start.astimezone(ZoneInfo("UTC")), - stop_time=stop.astimezone(ZoneInfo("UTC")) + + st = time.time() + bot_id = await bot.call_api('get_login_info') + bot_id = [str(bot_id['user_id'])] + if isinstance(event,GroupMessageEvent): + + gids:List[str] = [str(event.group_id)] + msg = await get_message_records( + group_ids=gids, + exclude_user_ids=bot_id, + message_type='group', + time_start=start.astimezone(ZoneInfo("UTC")), + time_stop=stop.astimezone(ZoneInfo("UTC")) ) + msg = await msg_counter(gid=event.group_id, bot=bot, msg=msg,got_num=plugin_config.dialectlist_get_num) + + elif isinstance(event, GuildMessageEvent): + + guild_id = event.guild_id + msg = await get_guild_message_records(guild_id=str(guild_id),bot=bot) + + msg += plugin_config.dialectlist_string_suffix_format.format(timecost=time.time()-st) await rankings.finish(msg) - - - + \ No newline at end of file diff --git a/nonebot_plugin_dialectlist/config.py b/nonebot_plugin_dialectlist/config.py index c7e54fa..517d37b 100644 --- a/nonebot_plugin_dialectlist/config.py +++ b/nonebot_plugin_dialectlist/config.py @@ -1,15 +1,18 @@ from typing import Optional - from nonebot import get_driver from pydantic import BaseModel, Extra +from pathlib import Path +import os class Config(BaseModel, extra=Extra.ignore): timezone: Optional[str] - - - + dialectlist_string_format: str = '第{index}名:\n{nickname},{chatdatanum}条消息\n' + dialectlist_string_suffix_format: str = '\n你们的职业是水群吗?————MYX\n计算花费时间:{timecost}秒' + dialectlist_path:str = os.path.dirname(__file__) + dialectlist_json_path:Path = Path(dialectlist_path)/'qqguild.json' + dialectlist_get_num:int = 10 global_config = get_driver().config plugin_config = Config.parse_obj(global_config) \ No newline at end of file diff --git a/nonebot_plugin_dialectlist/qqDBRecorder.py b/nonebot_plugin_dialectlist/qqDBRecorder.py new file mode 100644 index 0000000..96da136 --- /dev/null +++ b/nonebot_plugin_dialectlist/qqDBRecorder.py @@ -0,0 +1,148 @@ +import time + +from datetime import datetime +from sqlmodel import select, or_ +from typing_extensions import Literal +from typing import Iterable, List, Optional, Dict + +from nonebot.log import logger +from nonebot.adapters import Bot +from nonebot.adapters.onebot.v11 import Message +from nonebot.adapters.onebot.v11.exception import ActionFailed + +from nonebot_plugin_datastore import create_session + +from nonebot_plugin_chatrecorder.model import MessageRecord + +from .config import plugin_config + +async def get_message_records( + user_ids: Optional[Iterable[str]] = None, + group_ids: Optional[Iterable[str]] = None, + platforms: Optional[Iterable[str]] = None, + exclude_user_ids: Optional[Iterable[str]] = None, + exclude_group_ids: Optional[Iterable[str]] = None, + message_type: Optional[Literal['private', 'group']] = None, + time_start: Optional[datetime] = None, + time_stop: Optional[datetime] = None, +)->List[MessageRecord]: + """ + :说明: + + 获取消息记录 + + :参数: + + * ``user_ids: Optional[Iterable[str]]``: 用户列表,为空表示所有用户 + * ``group_ids: Optional[Iterable[str]]``: 群组列表,为空表示所有群组 + * ``platform: OPtional[Iterable[str]]``: 消息来源列表,为空表示所有来源 + * ``exclude_user_ids: Optional[Iterable[str]]``: 不包含的用户列表,为空表示不限制 + * ``exclude_group_ids: Optional[Iterable[str]]``: 不包含的群组列表,为空表示不限制 + * ``message_type: Optional[Literal['private', 'group']]``: 消息类型,可选值:'private' 和 'group',为空表示所有类型 + * ``time_start: Optional[datetime]``: 起始时间,UTC 时间,为空表示不限制起始时间 + * ``time_stop: Optional[datetime]``: 结束时间,UTC 时间,为空表示不限制结束时间 + + :返回值: + * ``List[MessageRecord]``:返回信息 + """ + + whereclause = [] + if user_ids: + whereclause.append( + or_(*[MessageRecord.user_id == user_id for user_id in user_ids]) # type: ignore + ) + if group_ids: + whereclause.append( + or_(*[MessageRecord.group_id == group_id for group_id in group_ids]) # type: ignore + ) + if platforms: + whereclause.append( + or_(*[MessageRecord.platform == platform for platform in platforms]) # type: ignore + ) + if exclude_user_ids: + for user_id in exclude_user_ids: + whereclause.append(MessageRecord.user_id != user_id) + if exclude_group_ids: + for group_id in exclude_group_ids: + whereclause.append(MessageRecord.group_id != group_id) + if message_type: + whereclause.append(MessageRecord.detail_type == message_type) + if time_start: + whereclause.append(MessageRecord.time >= time_start) + if time_stop: + whereclause.append(MessageRecord.time <= time_stop) + + statement = select(MessageRecord).where(*whereclause) + async with create_session() as session: + records: List[MessageRecord] = (await session.exec(statement)).all() # type: ignore + return records + + + +async def msg_counter( + gid:int, + bot:Bot, + msg:List[MessageRecord], + got_num:int=10, +)->Message: + ''' + 计算出结果并返回可以直接发送的字符串 + ''' + st = time.time() + + logger.debug('loading msg from group {}'.format(gid)) + gnl = await bot.call_api('get_group_member_list',group_id=int(gid)) + logger.debug('group {} have number {}'.format(gid,len(gnl))) + + lst:Dict[str,int] = {} + msg_len = len(msg) + for i in msg: + try: + lst[i.user_id] +=1 + except KeyError: + lst[i.user_id] =1 + + logger.debug(lst) + logger.debug('group number num is '+str(len(lst))) + + ranking = [] + while len(ranking) < got_num: + + try: + maxkey = max(lst, key=lst.get) # type: ignore + except ValueError: + ranking.append(("null",0)) + continue + + logger.debug('searching number {} from group {}'.format(str(maxkey),str(gid))) + + try: + + member_info = await bot.call_api( + "get_group_member_info", + group_id=int(gid), + user_id=int(maxkey), + no_cache=True + ) + nickname:str = member_info['nickname']if not member_info['card'] else member_info['card'] + ranking.append([nickname.strip(),lst.pop(maxkey)]) + + except ActionFailed as e: + + logger.warning(e) + logger.warning('number {} is not exit in group {}'.format(str(maxkey),str(gid))) + lst.pop(maxkey) + + logger.debug('loaded list:\n{}'.format(ranking)) + + out:str = '' + for i in range(got_num): + index = i+1 + nickname,chatdatanum = str(ranking[i]) + str_example = plugin_config.dialectlist_string_format.format(index=index,nickname=nickname,chatdatanum=chatdatanum) + out = out + str_example + + logger.debug(out) + logger.info('spent {} seconds to count from {} msg'.format(time.time()-st,msg_len)) + + return Message(out) diff --git a/nonebot_plugin_dialectlist/qqGuildJsonRecorder.py b/nonebot_plugin_dialectlist/qqGuildJsonRecorder.py new file mode 100644 index 0000000..084cfb6 --- /dev/null +++ b/nonebot_plugin_dialectlist/qqGuildJsonRecorder.py @@ -0,0 +1,65 @@ +import json +from typing import Dict + +from nonebot.log import logger +from nonebot.message import event_postprocessor +from nonebot.adapters import Bot +from nonebot.adapters.onebot.v11 import Message +from nonebot.adapters.onebot.v11.exception import ActionFailed + +from nonebot_plugin_guild_patch import GuildMessageEvent + +from .config import plugin_config + + +def update_json(updatedata:Dict): + + with open(plugin_config.dialectlist_json_path, 'w', encoding='utf-8') as f: + json.dump(updatedata, f, ensure_ascii=False, indent=4) + +def get_json()-> Dict[str,Dict]: + + if not plugin_config.dialectlist_json_path.exists(): + return {} + + with open(plugin_config.dialectlist_json_path, 'r', encoding='utf-8') as f: + data:Dict = json.load(f) + return data + + +@event_postprocessor +async def _pocesser(event:GuildMessageEvent): + + data = get_json() + try: + data[str(event.guild_id)][str(event.sender.nickname)] += 1 + except KeyError: + data[str(event.guild_id)] = {str(event.sender.nickname):1} + update_json(data) + + +async def get_guild_message_records( + guild_id:str, + bot:Bot, + got_num:int=10, +)->Message: + data = get_json() + ranking = [] + while len(ranking) < got_num: + + try: + maxkey = max(data[guild_id], key=data[guild_id].get) # type: ignore + except ValueError: + ranking.append(("null",0)) + continue + ranking.append((maxkey,data[guild_id].pop(maxkey))) + + logger.debug('loaded list:\n{}'.format(ranking)) + out:str = '' + for i in range(got_num): + index = i+1 + nickname,chatdatanum = ranking[i] + str_example = plugin_config.dialectlist_string_format.format(index=index,nickname=nickname,chatdatanum=chatdatanum) + out = out + str_example + + return Message(out) diff --git a/nonebot_plugin_dialectlist/record4dialectlist.py b/nonebot_plugin_dialectlist/record4dialectlist.py deleted file mode 100644 index 63d8ded..0000000 --- a/nonebot_plugin_dialectlist/record4dialectlist.py +++ /dev/null @@ -1,64 +0,0 @@ -from datetime import datetime -from sqlmodel import select, or_ -from typing_extensions import Literal -from typing import Iterable, List, Optional - -from nonebot_plugin_datastore import create_session - -from nonebot_plugin_chatrecorder.model import MessageRecord - - -async def get_message_records( - user_ids: Optional[Iterable[str]] = None, - group_ids: Optional[Iterable[str]] = None, - exclude_user_ids: Optional[Iterable[str]] = None, - exclude_group_ids: Optional[Iterable[str]] = None, - message_type: Optional[Literal['private', 'group']] = None, - time_start: Optional[datetime] = None, - time_stop: Optional[datetime] = None, -)->List[MessageRecord]: - """ - :说明: - - 获取消息记录 - - :参数: - - * ``user_ids: Optional[Iterable[str]]``: 用户列表,为空表示所有用户 - * ``group_ids: Optional[Iterable[str]]``: 群组列表,为空表示所有群组 - * ``exclude_user_ids: Optional[Iterable[str]]``: 不包含的用户列表,为空表示不限制 - * ``exclude_group_ids: Optional[Iterable[str]]``: 不包含的群组列表,为空表示不限制 - * ``message_type: Optional[Literal['private', 'group']]``: 消息类型,可选值:'private' 和 'group',为空表示所有类型 - * ``time_start: Optional[datetime]``: 起始时间,UTC 时间,为空表示不限制起始时间 - * ``time_stop: Optional[datetime]``: 结束时间,UTC 时间,为空表示不限制结束时间 - - :返回值: - * ``List[MessageRecord]``:返回信息 - """ - - whereclause = [] - if user_ids: - whereclause.append( - or_(*[MessageRecord.user_id == user_id for user_id in user_ids]) # type: ignore - ) - if group_ids: - whereclause.append( - or_(*[MessageRecord.group_id == group_id for group_id in group_ids]) # type: ignore - ) - if exclude_user_ids: - for user_id in exclude_user_ids: - whereclause.append(MessageRecord.user_id != user_id) - if exclude_group_ids: - for group_id in exclude_group_ids: - whereclause.append(MessageRecord.group_id != group_id) - if message_type: - whereclause.append(MessageRecord.detail_type == message_type) - if time_start: - whereclause.append(MessageRecord.time >= time_start) - if time_stop: - whereclause.append(MessageRecord.time <= time_stop) - - statement = select(MessageRecord).where(*whereclause) - async with create_session() as session: - records: List[MessageRecord] = (await session.exec(statement)).all() # type: ignore - return records \ No newline at end of file