mirror of
https://github.com/ChenXu233/nonebot_plugin_dialectlist.git
synced 2025-09-22 20:36:23 +00:00
⚡ 超级性能提升
This commit is contained in:
@ -16,7 +16,6 @@ from typing import Optional, Union
|
||||
import nonebot_plugin_saa as saa
|
||||
from arclet.alconna import ArparmaBehavior
|
||||
from arclet.alconna.arparma import Arparma
|
||||
from nonebot import on_command
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.log import logger
|
||||
from nonebot.params import Arg, Depends
|
||||
@ -31,10 +30,10 @@ from nonebot_plugin_alconna import (
|
||||
Option,
|
||||
on_alconna,
|
||||
)
|
||||
from nonebot_plugin_chatrecorder import get_message_records
|
||||
from nonebot_plugin_uninfo import Session, Uninfo, get_session
|
||||
|
||||
from .config import Config, plugin_config
|
||||
|
||||
# from .storage import build_cache, get_cache
|
||||
from .time import (
|
||||
get_datetime_fromisoformat_with_timezone,
|
||||
@ -46,8 +45,8 @@ from .utils import (
|
||||
get_rank_image,
|
||||
get_user_infos,
|
||||
got_rank,
|
||||
msg_counter,
|
||||
persist_id2user_id,
|
||||
get_user_message_counts,
|
||||
)
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
@ -80,17 +79,6 @@ def wrapper(slot: Union[int, str], content: Optional[str], context) -> str:
|
||||
return '' # pragma: no cover
|
||||
|
||||
|
||||
build_cache_cmd = on_command('build_cache', aliases={'重建缓存'}, block=True)
|
||||
|
||||
|
||||
@build_cache_cmd.handle()
|
||||
async def _build_cache(bot: Bot, event: Event):
|
||||
return
|
||||
await saa.Text('正在重建缓存,请稍等。').send(reply=True) # type: ignore
|
||||
await build_cache()
|
||||
await saa.Text('重建缓存完成。').send(reply=True)
|
||||
|
||||
|
||||
b_cmd = on_alconna(
|
||||
Alconna(
|
||||
'看看B话',
|
||||
@ -122,18 +110,16 @@ async def handle_b_cmd(
|
||||
if not gid:
|
||||
await b_cmd.finish('请指定群号。')
|
||||
|
||||
if keyword.available:
|
||||
keywords = keyword.result
|
||||
else:
|
||||
keywords = None
|
||||
_keyword = keyword.result
|
||||
|
||||
messages = await get_message_records(
|
||||
d = await get_user_message_counts(
|
||||
keyword=_keyword,
|
||||
scene_ids=[gid],
|
||||
user_ids=[id],
|
||||
types=['message'], # 排除机器人自己发的消息
|
||||
exclude_user_ids=plugin_config.excluded_people,
|
||||
)
|
||||
d = msg_counter(messages, keywords)
|
||||
|
||||
rank = got_rank(d)
|
||||
if not rank:
|
||||
await b_cmd.finish(
|
||||
@ -302,24 +288,16 @@ async def handle_rank(
|
||||
|
||||
keyword = state['keyword']
|
||||
|
||||
if plugin_config.counting_cache:
|
||||
await saa.Text("缓存暂不支持").finish()
|
||||
# if keyword:
|
||||
# await saa.Text('已开启缓存~缓存不支持关键词查询哦').finish()
|
||||
# t1 = t.time()
|
||||
# raw_rank = await get_cache(start, stop, id)
|
||||
# logger.debug(f'获取计数消息花费时间:{t.time() - t1}')
|
||||
else:
|
||||
t1 = t.time()
|
||||
messages = await get_message_records(
|
||||
scene_ids=[id],
|
||||
types=['message'], # 排除机器人自己发的消息
|
||||
time_start=start,
|
||||
time_stop=stop,
|
||||
exclude_user_ids=plugin_config.excluded_people,
|
||||
)
|
||||
raw_rank = msg_counter(messages, keyword)
|
||||
logger.debug(f'获取计数消息花费时间:{t.time() - t1}')
|
||||
t1 = t.time()
|
||||
raw_rank = await get_user_message_counts(
|
||||
keyword=keyword,
|
||||
scene_ids=[id],
|
||||
types=['message'], # 排除机器人自己发的消息
|
||||
time_start=start,
|
||||
time_stop=stop,
|
||||
exclude_user_ids=plugin_config.excluded_people,
|
||||
)
|
||||
logger.debug(f'获取计数消息花费时间:{t.time() - t1}')
|
||||
|
||||
if not raw_rank:
|
||||
await saa.Text(
|
||||
|
@ -5,24 +5,25 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class ScopedConfig(BaseModel):
|
||||
get_num: int = 5 # 获取人数数量
|
||||
font: str = "SimHei" # 字体格式
|
||||
suffix: bool = True # 是否显示后缀
|
||||
excluded_self: bool = True # 是否排除自己
|
||||
visualization: bool = True # 是否可视化
|
||||
show_text_rank: bool = True # 是否显示文本排名
|
||||
counting_cache: bool = False # 计数缓存(能够提高回复速度)
|
||||
excluded_people: List[str] = [] # 排除的人的QQ号
|
||||
use_user_info_cache: bool = False # 是否使用用户信息缓存
|
||||
aggregate_transmission: bool = False # 是否聚合转发消息
|
||||
timezone: Optional[str] = "Asia/Shanghai" # 时区,影响统计时间
|
||||
string_suffix: str = "统计花费时间:{timecost}秒" # 消息格式后缀
|
||||
template_path: str = "./template/rank_template.j2" # 模板路径
|
||||
string_format: str = "第{index}名:\n{nickname},{chatdatanum}条消息\n" # 消息格式
|
||||
get_num: int = 5 # 获取人数数量
|
||||
font: str = 'SimHei' # 字体格式
|
||||
suffix: bool = True # 是否显示后缀
|
||||
excluded_self: bool = True # 是否排除自己
|
||||
visualization: bool = True # 是否可视化
|
||||
show_text_rank: bool = True # 是否显示文本排名
|
||||
excluded_people: List[str] = [] # 排除的人的QQ号
|
||||
use_user_info_cache: bool = False # 是否使用用户信息缓存
|
||||
aggregate_transmission: bool = False # 是否聚合转发消息
|
||||
timezone: Optional[str] = 'Asia/Shanghai' # 时区,影响统计时间
|
||||
string_suffix: str = '统计花费时间:{timecost}秒' # 消息格式后缀
|
||||
template_path: str = './template/rank_template.j2' # 模板路径
|
||||
string_format: str = (
|
||||
'第{index}名:\n{nickname},{chatdatanum}条消息\n' # 消息格式
|
||||
)
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
dialectlist: ScopedConfig = ScopedConfig()
|
||||
dialectlist: ScopedConfig = ScopedConfig()
|
||||
|
||||
|
||||
global_config = get_driver().config
|
||||
|
@ -1,47 +0,0 @@
|
||||
"""update dialectlist
|
||||
|
||||
迁移 ID: fb88e4d27eb8
|
||||
父迁移: 60daff81fcdc
|
||||
创建时间: 2025-04-20 11:00:50.931679
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = 'fb88e4d27eb8'
|
||||
down_revision: str | Sequence[str] | None = '60daff81fcdc'
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade(name: str = "") -> None:
|
||||
if name:
|
||||
return
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('sessionmodel',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('bot_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('bot_type', sa.String(length=32), nullable=False),
|
||||
sa.Column('platform', sa.String(length=32), nullable=False),
|
||||
sa.Column('level', sa.Integer(), nullable=False),
|
||||
sa.Column('id1', sa.String(length=64), nullable=False),
|
||||
sa.Column('id2', sa.String(length=64), nullable=False),
|
||||
sa.Column('id3', sa.String(length=64), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_sessionmodel')),
|
||||
sa.UniqueConstraint('bot_id', 'bot_type', 'platform', 'level', 'id1', 'id2', 'id3', name='unique_session'),
|
||||
info={'bind_key': ''}
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade(name: str = "") -> None:
|
||||
if name:
|
||||
return
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('sessionmodel')
|
||||
# ### end Alembic commands ###
|
@ -1,23 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
|
||||
from nonebot_plugin_orm import Model
|
||||
from nonebot_plugin_userinfo import UserInfo
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class UserRankInfo(UserInfo):
|
||||
user_bnum: int
|
||||
user_proportion: float
|
||||
user_nickname: str
|
||||
user_index: Union[int, str]
|
||||
user_avatar_bytes: bytes
|
||||
|
||||
|
||||
class MessageCountCache(Model):
|
||||
__table_args__ = {"extend_existing": True}
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
time: Mapped[datetime]
|
||||
session_id: Mapped[int] = mapped_column(Integer, index=True)
|
||||
session_bnum: Mapped[int] = mapped_column(Integer)
|
11
nonebot_plugin_dialectlist/schema.py
Normal file
11
nonebot_plugin_dialectlist/schema.py
Normal file
@ -0,0 +1,11 @@
|
||||
from typing import Union
|
||||
|
||||
from nonebot_plugin_userinfo import UserInfo
|
||||
|
||||
|
||||
class UserRankInfo(UserInfo):
|
||||
user_bnum: int
|
||||
user_proportion: float
|
||||
user_nickname: str
|
||||
user_index: Union[int, str]
|
||||
user_avatar_bytes: bytes
|
@ -1,136 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from nonebot import get_driver
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.log import logger
|
||||
from nonebot.message import event_postprocessor
|
||||
from nonebot.params import Depends
|
||||
from nonebot_plugin_chatrecorder import get_message_records
|
||||
from nonebot_plugin_chatrecorder.utils import remove_timezone
|
||||
from nonebot_plugin_uninfo import Session, get_session
|
||||
from nonebot_plugin_uninfo.orm import SessionModel,get_session_persist_id
|
||||
from nonebot_plugin_localstore import get_data_file
|
||||
from nonebot_plugin_orm import get_session
|
||||
from sqlalchemy import delete, or_, select
|
||||
|
||||
from .config import plugin_config
|
||||
from .model import MessageCountCache
|
||||
|
||||
|
||||
async def get_cache(time_start: datetime, time_stop: datetime, group_id: str):
|
||||
async with get_session() as db_session:
|
||||
where = [or_(SessionModel.id2 == group_id)]
|
||||
statement = select(SessionModel).where(*where)
|
||||
|
||||
sessions = (await db_session.scalars(statement)).all()
|
||||
|
||||
where = [
|
||||
or_(*[MessageCountCache.session_id == session.id for session in sessions])
|
||||
]
|
||||
statement = select(MessageCountCache).where(*where)
|
||||
where.append(or_(MessageCountCache.time >= remove_timezone(time_start)))
|
||||
where.append(or_(MessageCountCache.time <= remove_timezone(time_stop)))
|
||||
statement = select(MessageCountCache).where(*where)
|
||||
|
||||
user_caches = (await db_session.scalars(statement)).all()
|
||||
raw_rank = {}
|
||||
for i in user_caches:
|
||||
raw_rank[i.session_id] = raw_rank.get(i.session_id, 0) + i.session_bnum
|
||||
return raw_rank
|
||||
|
||||
|
||||
async def build_cache():
|
||||
async with get_session() as db_session:
|
||||
await db_session.execute(delete(MessageCountCache))
|
||||
await db_session.commit()
|
||||
logger.info("先前可能存在的缓存已清空")
|
||||
messages = await get_message_records(types=["message"])
|
||||
async with get_session() as db_session:
|
||||
for msg in messages:
|
||||
msg_session_id = msg.session_persist_id
|
||||
|
||||
where = [or_(MessageCountCache.session_id == msg_session_id)]
|
||||
where.append(
|
||||
or_(
|
||||
MessageCountCache.time
|
||||
== remove_timezone(
|
||||
msg.time.replace(hour=1, minute=0, second=0, microsecond=0)
|
||||
)
|
||||
)
|
||||
)
|
||||
statement = select(MessageCountCache).where(*where)
|
||||
|
||||
user_cache = (await db_session.scalars(statement)).all()
|
||||
|
||||
if user_cache:
|
||||
user_cache[0].session_bnum += 1
|
||||
else:
|
||||
user_cache = MessageCountCache(
|
||||
session_id=msg.session_persist_id,
|
||||
time=remove_timezone(
|
||||
msg.time.replace(hour=1, minute=0, second=0, microsecond=0)
|
||||
),
|
||||
session_bnum=1,
|
||||
)
|
||||
db_session.add(user_cache)
|
||||
await db_session.commit()
|
||||
|
||||
logger.info("缓存构建完成")
|
||||
|
||||
|
||||
driver = get_driver()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
if not plugin_config.counting_cache:
|
||||
return
|
||||
f_name = get_data_file("nonebot-plugin-dialectlist", "is-pre-cached.json")
|
||||
if not os.path.exists(f_name):
|
||||
with open(f_name, "w", encoding="utf-8") as f:
|
||||
s = json.dumps({"is-pre-cached": False}, ensure_ascii=False, indent=4)
|
||||
f.write(s)
|
||||
|
||||
with open(f_name, "r", encoding="utf-8") as f:
|
||||
if json.load(f)["is-pre-cached"]:
|
||||
return
|
||||
logger.info("未检查到缓存,开始构建缓存")
|
||||
with open(f_name, "w", encoding="utf-8") as f:
|
||||
await build_cache()
|
||||
json.dump({"is-pre-cached": True}, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
@event_postprocessor
|
||||
async def _(bot: Bot, event: Event, session: Session = Depends(extract_session)):
|
||||
if not plugin_config.counting_cache:
|
||||
return
|
||||
if not session.id2:
|
||||
return
|
||||
if event.get_type() != "message":
|
||||
return
|
||||
now = datetime.now()
|
||||
now = now.replace(hour=1, minute=0, second=0, microsecond=0)
|
||||
|
||||
async with get_session() as db_session:
|
||||
session_id = await get_session_persist_id(session)
|
||||
logger.debug("session_id:" + str(session_id))
|
||||
where = [or_(MessageCountCache.session_id == session_id)]
|
||||
where.append(or_(MessageCountCache.time == remove_timezone(now)))
|
||||
statement = select(MessageCountCache).where(*where)
|
||||
user_cache = (await db_session.scalars(statement)).first()
|
||||
if user_cache:
|
||||
user_cache.session_bnum += 1
|
||||
else:
|
||||
user_cache = MessageCountCache(
|
||||
session_id=session_id,
|
||||
time=remove_timezone(now),
|
||||
session_bnum=1,
|
||||
)
|
||||
db_session.add(user_cache)
|
||||
await db_session.commit()
|
||||
logger.debug("已计入缓存")
|
||||
|
||||
|
||||
# TODO: 修复缓存储存
|
@ -1,7 +1,7 @@
|
||||
from inspect import cleandoc
|
||||
|
||||
__usage__ = cleandoc(
|
||||
"""
|
||||
"""
|
||||
快速调用:
|
||||
/今日B话榜 ————看看今天群友发了多少消息。
|
||||
|
||||
@ -18,7 +18,7 @@ __usage__ = cleandoc(
|
||||
-`/前日B话榜` ————看看前天的群友发了多少消息!
|
||||
|
||||
-`/本周B话榜` ————看看本周的群友发了多少消息!
|
||||
|
||||
|
||||
-`/上周B话榜` ————看看上周的群友发了多少消息!
|
||||
|
||||
-`/本月B话榜` ————看看这个月的群友发了多少消息!
|
||||
|
@ -2,9 +2,11 @@ import asyncio
|
||||
import os
|
||||
import re
|
||||
import unicodedata
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select, func
|
||||
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.compat import model_dump
|
||||
from nonebot.log import logger
|
||||
@ -20,10 +22,11 @@ from nonebot_plugin_uninfo.model import SceneType
|
||||
from nonebot_plugin_uninfo.orm import SessionModel, UserModel
|
||||
from nonebot_plugin_uninfo import get_session as extract_session
|
||||
from nonebot_plugin_userinfo.exception import NetworkError
|
||||
from sqlalchemy import or_, select
|
||||
from nonebot_plugin_chatrecorder.record import filter_statement
|
||||
|
||||
|
||||
from .config import plugin_config
|
||||
from .model import UserRankInfo
|
||||
from .schema import UserRankInfo
|
||||
|
||||
cache_path = get_cache_dir('nonebot_plugin_dialectlist')
|
||||
|
||||
@ -38,71 +41,35 @@ async def ensure_group(
|
||||
|
||||
async def persist_id2user_id(ids: List) -> List[str]:
|
||||
user_ids = []
|
||||
user_persist_ids = []
|
||||
if not ids:
|
||||
return user_ids
|
||||
|
||||
async with get_session() as db_session:
|
||||
for i in ids:
|
||||
session = await db_session.scalar(
|
||||
select(SessionModel).where(or_(*[SessionModel.id == i]))
|
||||
)
|
||||
if session is not None:
|
||||
user_persist_id = session.user_persist_id
|
||||
user_persist_ids.append(user_persist_id)
|
||||
for i in user_persist_ids:
|
||||
user = await db_session.scalar(
|
||||
select(UserModel).where(UserModel.id == i)
|
||||
)
|
||||
if user is not None:
|
||||
user_ids.append(user.user_id)
|
||||
statement = (
|
||||
select(UserModel.user_id)
|
||||
.join(SessionModel, UserModel.id == SessionModel.user_persist_id)
|
||||
.where(SessionModel.id.in_(ids))
|
||||
)
|
||||
result = await db_session.scalars(statement)
|
||||
user_ids = result.all()
|
||||
|
||||
return user_ids
|
||||
return list(user_ids)
|
||||
|
||||
|
||||
def msg_counter(
|
||||
msg_list: List[MessageRecord], keyword: Optional[str]
|
||||
) -> Dict[str, int]:
|
||||
"""### 计算每个人的消息量
|
||||
|
||||
Args:
|
||||
msg_list (list[MessageRecord]): 需要处理的消息列表
|
||||
|
||||
Returns:
|
||||
(dict[str,int]): 处理后的消息数量字典,键为用户,值为消息数量
|
||||
"""
|
||||
|
||||
lst: Dict[str, int] = {}
|
||||
msg_len = len(msg_list)
|
||||
logger.info('wow , there are {} msgs to count !!!'.format(msg_len))
|
||||
|
||||
for i in msg_list:
|
||||
# logger.debug(f"processing msg {i.plain_text}")
|
||||
if keyword:
|
||||
match = re.search(keyword, i.plain_text)
|
||||
if not match:
|
||||
continue
|
||||
try:
|
||||
lst[str(i.session_persist_id)] += 1
|
||||
except KeyError:
|
||||
lst[str(i.session_persist_id)] = 1
|
||||
|
||||
logger.debug(f'finish counting, result is {lst}')
|
||||
|
||||
return lst
|
||||
|
||||
|
||||
def got_rank(msg_dict: Dict[str, int]) -> List:
|
||||
def got_rank(msg_dict: Dict[int, int]) -> List[List[Any]]:
|
||||
"""### 获得排行榜
|
||||
|
||||
Args:
|
||||
msg_dict (Dict[str,int]): 要处理的字典
|
||||
msg_dict (Dict[int,int]): 要处理的字典
|
||||
|
||||
Returns:
|
||||
List[Tuple[str,int]]: 排行榜列表(已排序)
|
||||
List[Tuple[int,int]]: 排行榜列表(已排序)
|
||||
"""
|
||||
rank = []
|
||||
while len(rank) < plugin_config.get_num:
|
||||
try:
|
||||
max_key = max(msg_dict.items(), key=lambda x: x[1])
|
||||
rank.append(list(max_key))
|
||||
rank.append(tuple(max_key))
|
||||
msg_dict.pop(max_key[0])
|
||||
except ValueError:
|
||||
logger.error(
|
||||
@ -261,3 +228,42 @@ async def get_user_infos(
|
||||
rank2.append(user)
|
||||
|
||||
return rank2
|
||||
|
||||
|
||||
async def get_user_message_counts(
|
||||
keyword: Optional[str] = None, **kwargs
|
||||
) -> Dict[int, int]:
|
||||
"""获取每个用户的消息数量(直接在数据库层面统计)
|
||||
|
||||
参数:
|
||||
* ``keyword``: 可选,关键词,只统计包含该关键词的消息
|
||||
* ``**kwargs``: 筛选参数,具体查看 `filter_statement` 中的定义
|
||||
|
||||
返回值:
|
||||
* ``Dict[str, int]``: 键为user_persist_id,值为该用户的消息数量
|
||||
"""
|
||||
whereclause = filter_statement(**kwargs)
|
||||
|
||||
# 如果提供了关键词,添加关键词过滤条件
|
||||
if keyword:
|
||||
# 构造LIKE条件,类似于msg_counter函数中的正则匹配
|
||||
# 根据数据库类型不同,可能需要调整LIKE的语法
|
||||
keyword_condition = MessageRecord.plain_text.ilike(f'%{keyword}%')
|
||||
whereclause.append(keyword_condition)
|
||||
|
||||
# 使用SQL的GROUP BY和COUNT进行分组统计
|
||||
statement = (
|
||||
select(
|
||||
SessionModel.user_persist_id,
|
||||
func.count(MessageRecord.id).label('message_count'),
|
||||
)
|
||||
.select_from(MessageRecord)
|
||||
.join(SessionModel, SessionModel.id == MessageRecord.session_persist_id)
|
||||
.where(*whereclause)
|
||||
.group_by(SessionModel.user_persist_id)
|
||||
)
|
||||
|
||||
async with get_session() as db_session:
|
||||
result = await db_session.execute(statement)
|
||||
# 转换为字典格式返回
|
||||
return {user_id: count for user_id, count in result.all()}
|
||||
|
Reference in New Issue
Block a user