diff --git a/nonebot_plugin_dialectlist/__init__.py b/nonebot_plugin_dialectlist/__init__.py index fc55153..84990ec 100644 --- a/nonebot_plugin_dialectlist/__init__.py +++ b/nonebot_plugin_dialectlist/__init__.py @@ -88,11 +88,15 @@ async def _build_cache(bot: Bot, event: Event): rank_cmd = on_alconna( Alconna( "B话榜", - Args["type?", ["今日", "昨日", "本周", "上周", "本月", "上月", "年度", "历史"]][ + Args[ + "type?", + ["今日", "昨日", "本周", "上周", "本月", "上月", "年度", "历史"] + ][ "time?", str, ], Option("-g|--group_id", Args["group_id?", str]), + Option("-k|--keyword", Args["keyword?", str]), behaviors=[SameTime()], ), aliases={"废话榜"}, @@ -129,6 +133,7 @@ async def _group_message( type: Optional[str] = None, time: Optional[str] = None, group_id: Optional[str] = None, + keyword: Optional[str] = None, ): t1 = t.time() state["t1"] = t1 @@ -140,6 +145,8 @@ async def _group_message( if group_id: state["group_id"] = group_id + state["keyword"] = keyword + if not type: await rank_cmd.finish(__plugin_meta__.usage) @@ -225,8 +232,12 @@ async def handle_rank( if not id: await saa.Text("没有指定群哦").finish() + + keyword = state["keyword"] if plugin_config.counting_cache: + if keyword: + await saa.Text("已开启缓存~缓存不支持关键词查询哦").finish() t1 = t.time() raw_rank = await get_cache(start, stop, id) logger.debug(f"获取计数消息花费时间:{t.time() - t1}") @@ -242,7 +253,7 @@ async def handle_rank( time_stop=stop, exclude_id1s=plugin_config.excluded_people, ) - raw_rank = msg_counter(messages) + raw_rank = msg_counter(messages,keyword) logger.debug(f"获取计数消息花费时间:{t.time() - t1}") if not raw_rank: @@ -259,6 +270,12 @@ async def handle_rank( logger.debug(f"获取用户信息花费时间:{t.time() - t1}") string: str = "" + + if keyword: + string += f"关于{keyword}的话痨榜结果:\n" + else: + string += "话痨榜:\n" + for i in rank2: logger.debug(i.user_name) for i in range(len(rank2)): diff --git a/nonebot_plugin_dialectlist/utils.py b/nonebot_plugin_dialectlist/utils.py index 63a0e13..b0bae93 100644 --- a/nonebot_plugin_dialectlist/utils.py +++ b/nonebot_plugin_dialectlist/utils.py @@ -3,7 +3,7 @@ import httpx import asyncio import unicodedata -from typing import Dict, List +from typing import Dict, List, Optional from sqlalchemy import or_, select from sqlalchemy.sql import ColumnElement @@ -77,7 +77,7 @@ async def persist_id2group_id(ids: List[str]) -> List[str]: return [i.id2 for i in records] -def msg_counter(msg_list: List[MessageRecord]) -> Dict[str, int]: +def msg_counter(msg_list: List[MessageRecord],keyword:Optional[str]) -> Dict[str, int]: """### 计算每个人的消息量 Args: @@ -92,6 +92,9 @@ def msg_counter(msg_list: List[MessageRecord]) -> Dict[str, int]: logger.info("wow , there are {} msgs to count !!!".format(msg_len)) for i in msg_list: + if keyword: + if keyword not in i.plain_text: + continue try: lst[str(i.session_persist_id)] += 1 except KeyError: