超级性能提升

This commit is contained in:
XuChenXu
2025-09-22 21:26:19 +08:00
parent eee9ac2ab3
commit 152128391c
8 changed files with 105 additions and 315 deletions

View File

@ -16,7 +16,6 @@ from typing import Optional, Union
import nonebot_plugin_saa as saa
from arclet.alconna import ArparmaBehavior
from arclet.alconna.arparma import Arparma
from nonebot import on_command
from nonebot.adapters import Bot, Event
from nonebot.log import logger
from nonebot.params import Arg, Depends
@ -31,10 +30,10 @@ from nonebot_plugin_alconna import (
Option,
on_alconna,
)
from nonebot_plugin_chatrecorder import get_message_records
from nonebot_plugin_uninfo import Session, Uninfo, get_session
from .config import Config, plugin_config
# from .storage import build_cache, get_cache
from .time import (
get_datetime_fromisoformat_with_timezone,
@ -46,8 +45,8 @@ from .utils import (
get_rank_image,
get_user_infos,
got_rank,
msg_counter,
persist_id2user_id,
get_user_message_counts,
)
__plugin_meta__ = PluginMetadata(
@ -80,17 +79,6 @@ def wrapper(slot: Union[int, str], content: Optional[str], context) -> str:
return '' # pragma: no cover
build_cache_cmd = on_command('build_cache', aliases={'重建缓存'}, block=True)
@build_cache_cmd.handle()
async def _build_cache(bot: Bot, event: Event):
return
await saa.Text('正在重建缓存,请稍等。').send(reply=True) # type: ignore
await build_cache()
await saa.Text('重建缓存完成。').send(reply=True)
b_cmd = on_alconna(
Alconna(
'看看B话',
@ -122,18 +110,16 @@ async def handle_b_cmd(
if not gid:
await b_cmd.finish('请指定群号。')
if keyword.available:
keywords = keyword.result
else:
keywords = None
_keyword = keyword.result
messages = await get_message_records(
d = await get_user_message_counts(
keyword=_keyword,
scene_ids=[gid],
user_ids=[id],
types=['message'], # 排除机器人自己发的消息
exclude_user_ids=plugin_config.excluded_people,
)
d = msg_counter(messages, keywords)
rank = got_rank(d)
if not rank:
await b_cmd.finish(
@ -302,24 +288,16 @@ async def handle_rank(
keyword = state['keyword']
if plugin_config.counting_cache:
await saa.Text("缓存暂不支持").finish()
# if keyword:
# await saa.Text('已开启缓存~缓存不支持关键词查询哦').finish()
# t1 = t.time()
# raw_rank = await get_cache(start, stop, id)
# logger.debug(f'获取计数消息花费时间:{t.time() - t1}')
else:
t1 = t.time()
messages = await get_message_records(
scene_ids=[id],
types=['message'], # 排除机器人自己发的消息
time_start=start,
time_stop=stop,
exclude_user_ids=plugin_config.excluded_people,
)
raw_rank = msg_counter(messages, keyword)
logger.debug(f'获取计数消息花费时间:{t.time() - t1}')
t1 = t.time()
raw_rank = await get_user_message_counts(
keyword=keyword,
scene_ids=[id],
types=['message'], # 排除机器人自己发的消息
time_start=start,
time_stop=stop,
exclude_user_ids=plugin_config.excluded_people,
)
logger.debug(f'获取计数消息花费时间:{t.time() - t1}')
if not raw_rank:
await saa.Text(

View File

@ -5,24 +5,25 @@ from pydantic import BaseModel
class ScopedConfig(BaseModel):
get_num: int = 5 # 获取人数数量
font: str = "SimHei" # 字体格式
suffix: bool = True # 是否显示后缀
excluded_self: bool = True # 是否排除自己
visualization: bool = True # 是否可视化
show_text_rank: bool = True # 是否显示文本排名
counting_cache: bool = False # 计数缓存(能够提高回复速度)
excluded_people: List[str] = [] # 排除的人的QQ号
use_user_info_cache: bool = False # 是否使用用户信息缓存
aggregate_transmission: bool = False # 是否聚合转发消息
timezone: Optional[str] = "Asia/Shanghai" # 时区,影响统计时间
string_suffix: str = "统计花费时间:{timecost}" # 消息格式后缀
template_path: str = "./template/rank_template.j2" # 模板路径
string_format: str = "{index}名:\n{nickname},{chatdatanum}条消息\n" # 消息格式
get_num: int = 5 # 获取人数数量
font: str = 'SimHei' # 字体格式
suffix: bool = True # 是否显示后缀
excluded_self: bool = True # 是否排除自己
visualization: bool = True # 是否可视化
show_text_rank: bool = True # 是否显示文本排名
excluded_people: List[str] = [] # 排除的人的QQ号
use_user_info_cache: bool = False # 是否使用用户信息缓存
aggregate_transmission: bool = False # 是否聚合转发消息
timezone: Optional[str] = 'Asia/Shanghai' # 时区,影响统计时间
string_suffix: str = '统计花费时间:{timecost}' # 消息格式后缀
template_path: str = './template/rank_template.j2' # 模板路径
string_format: str = (
'{index}名:\n{nickname},{chatdatanum}条消息\n' # 消息格式
)
class Config(BaseModel):
dialectlist: ScopedConfig = ScopedConfig()
dialectlist: ScopedConfig = ScopedConfig()
global_config = get_driver().config

View File

@ -1,47 +0,0 @@
"""update dialectlist
迁移 ID: fb88e4d27eb8
父迁移: 60daff81fcdc
创建时间: 2025-04-20 11:00:50.931679
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
revision: str = 'fb88e4d27eb8'
down_revision: str | Sequence[str] | None = '60daff81fcdc'
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade(name: str = "") -> None:
if name:
return
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('sessionmodel',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('bot_id', sa.String(length=64), nullable=False),
sa.Column('bot_type', sa.String(length=32), nullable=False),
sa.Column('platform', sa.String(length=32), nullable=False),
sa.Column('level', sa.Integer(), nullable=False),
sa.Column('id1', sa.String(length=64), nullable=False),
sa.Column('id2', sa.String(length=64), nullable=False),
sa.Column('id3', sa.String(length=64), nullable=False),
sa.PrimaryKeyConstraint('id', name=op.f('pk_sessionmodel')),
sa.UniqueConstraint('bot_id', 'bot_type', 'platform', 'level', 'id1', 'id2', 'id3', name='unique_session'),
info={'bind_key': ''}
)
# ### end Alembic commands ###
def downgrade(name: str = "") -> None:
if name:
return
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('sessionmodel')
# ### end Alembic commands ###

View File

@ -1,23 +0,0 @@
from datetime import datetime
from typing import Union
from nonebot_plugin_orm import Model
from nonebot_plugin_userinfo import UserInfo
from sqlalchemy import Integer
from sqlalchemy.orm import Mapped, mapped_column
class UserRankInfo(UserInfo):
user_bnum: int
user_proportion: float
user_nickname: str
user_index: Union[int, str]
user_avatar_bytes: bytes
class MessageCountCache(Model):
__table_args__ = {"extend_existing": True}
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
time: Mapped[datetime]
session_id: Mapped[int] = mapped_column(Integer, index=True)
session_bnum: Mapped[int] = mapped_column(Integer)

View File

@ -0,0 +1,11 @@
from typing import Union
from nonebot_plugin_userinfo import UserInfo
class UserRankInfo(UserInfo):
user_bnum: int
user_proportion: float
user_nickname: str
user_index: Union[int, str]
user_avatar_bytes: bytes

View File

@ -1,136 +0,0 @@
import json
import os
from datetime import datetime
from nonebot import get_driver
from nonebot.adapters import Bot, Event
from nonebot.log import logger
from nonebot.message import event_postprocessor
from nonebot.params import Depends
from nonebot_plugin_chatrecorder import get_message_records
from nonebot_plugin_chatrecorder.utils import remove_timezone
from nonebot_plugin_uninfo import Session, get_session
from nonebot_plugin_uninfo.orm import SessionModel,get_session_persist_id
from nonebot_plugin_localstore import get_data_file
from nonebot_plugin_orm import get_session
from sqlalchemy import delete, or_, select
from .config import plugin_config
from .model import MessageCountCache
async def get_cache(time_start: datetime, time_stop: datetime, group_id: str):
async with get_session() as db_session:
where = [or_(SessionModel.id2 == group_id)]
statement = select(SessionModel).where(*where)
sessions = (await db_session.scalars(statement)).all()
where = [
or_(*[MessageCountCache.session_id == session.id for session in sessions])
]
statement = select(MessageCountCache).where(*where)
where.append(or_(MessageCountCache.time >= remove_timezone(time_start)))
where.append(or_(MessageCountCache.time <= remove_timezone(time_stop)))
statement = select(MessageCountCache).where(*where)
user_caches = (await db_session.scalars(statement)).all()
raw_rank = {}
for i in user_caches:
raw_rank[i.session_id] = raw_rank.get(i.session_id, 0) + i.session_bnum
return raw_rank
async def build_cache():
async with get_session() as db_session:
await db_session.execute(delete(MessageCountCache))
await db_session.commit()
logger.info("先前可能存在的缓存已清空")
messages = await get_message_records(types=["message"])
async with get_session() as db_session:
for msg in messages:
msg_session_id = msg.session_persist_id
where = [or_(MessageCountCache.session_id == msg_session_id)]
where.append(
or_(
MessageCountCache.time
== remove_timezone(
msg.time.replace(hour=1, minute=0, second=0, microsecond=0)
)
)
)
statement = select(MessageCountCache).where(*where)
user_cache = (await db_session.scalars(statement)).all()
if user_cache:
user_cache[0].session_bnum += 1
else:
user_cache = MessageCountCache(
session_id=msg.session_persist_id,
time=remove_timezone(
msg.time.replace(hour=1, minute=0, second=0, microsecond=0)
),
session_bnum=1,
)
db_session.add(user_cache)
await db_session.commit()
logger.info("缓存构建完成")
driver = get_driver()
@driver.on_startup
async def _():
if not plugin_config.counting_cache:
return
f_name = get_data_file("nonebot-plugin-dialectlist", "is-pre-cached.json")
if not os.path.exists(f_name):
with open(f_name, "w", encoding="utf-8") as f:
s = json.dumps({"is-pre-cached": False}, ensure_ascii=False, indent=4)
f.write(s)
with open(f_name, "r", encoding="utf-8") as f:
if json.load(f)["is-pre-cached"]:
return
logger.info("未检查到缓存,开始构建缓存")
with open(f_name, "w", encoding="utf-8") as f:
await build_cache()
json.dump({"is-pre-cached": True}, f, ensure_ascii=False, indent=4)
@event_postprocessor
async def _(bot: Bot, event: Event, session: Session = Depends(extract_session)):
if not plugin_config.counting_cache:
return
if not session.id2:
return
if event.get_type() != "message":
return
now = datetime.now()
now = now.replace(hour=1, minute=0, second=0, microsecond=0)
async with get_session() as db_session:
session_id = await get_session_persist_id(session)
logger.debug("session_id:" + str(session_id))
where = [or_(MessageCountCache.session_id == session_id)]
where.append(or_(MessageCountCache.time == remove_timezone(now)))
statement = select(MessageCountCache).where(*where)
user_cache = (await db_session.scalars(statement)).first()
if user_cache:
user_cache.session_bnum += 1
else:
user_cache = MessageCountCache(
session_id=session_id,
time=remove_timezone(now),
session_bnum=1,
)
db_session.add(user_cache)
await db_session.commit()
logger.debug("已计入缓存")
# TODO: 修复缓存储存

View File

@ -1,7 +1,7 @@
from inspect import cleandoc
__usage__ = cleandoc(
"""
"""
快速调用:
/今日B话榜 ————看看今天群友发了多少消息。
@ -18,7 +18,7 @@ __usage__ = cleandoc(
-`/前日B话榜` ————看看前天的群友发了多少消息!
-`/本周B话榜` ————看看本周的群友发了多少消息!
-`/上周B话榜` ————看看上周的群友发了多少消息!
-`/本月B话榜` ————看看这个月的群友发了多少消息!

View File

@ -2,9 +2,11 @@ import asyncio
import os
import re
import unicodedata
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Any
import httpx
from sqlalchemy import select, func
from nonebot.adapters import Bot, Event
from nonebot.compat import model_dump
from nonebot.log import logger
@ -20,10 +22,11 @@ from nonebot_plugin_uninfo.model import SceneType
from nonebot_plugin_uninfo.orm import SessionModel, UserModel
from nonebot_plugin_uninfo import get_session as extract_session
from nonebot_plugin_userinfo.exception import NetworkError
from sqlalchemy import or_, select
from nonebot_plugin_chatrecorder.record import filter_statement
from .config import plugin_config
from .model import UserRankInfo
from .schema import UserRankInfo
cache_path = get_cache_dir('nonebot_plugin_dialectlist')
@ -38,71 +41,35 @@ async def ensure_group(
async def persist_id2user_id(ids: List) -> List[str]:
user_ids = []
user_persist_ids = []
if not ids:
return user_ids
async with get_session() as db_session:
for i in ids:
session = await db_session.scalar(
select(SessionModel).where(or_(*[SessionModel.id == i]))
)
if session is not None:
user_persist_id = session.user_persist_id
user_persist_ids.append(user_persist_id)
for i in user_persist_ids:
user = await db_session.scalar(
select(UserModel).where(UserModel.id == i)
)
if user is not None:
user_ids.append(user.user_id)
statement = (
select(UserModel.user_id)
.join(SessionModel, UserModel.id == SessionModel.user_persist_id)
.where(SessionModel.id.in_(ids))
)
result = await db_session.scalars(statement)
user_ids = result.all()
return user_ids
return list(user_ids)
def msg_counter(
msg_list: List[MessageRecord], keyword: Optional[str]
) -> Dict[str, int]:
"""### 计算每个人的消息量
Args:
msg_list (list[MessageRecord]): 需要处理的消息列表
Returns:
(dict[str,int]): 处理后的消息数量字典,键为用户,值为消息数量
"""
lst: Dict[str, int] = {}
msg_len = len(msg_list)
logger.info('wow , there are {} msgs to count !!!'.format(msg_len))
for i in msg_list:
# logger.debug(f"processing msg {i.plain_text}")
if keyword:
match = re.search(keyword, i.plain_text)
if not match:
continue
try:
lst[str(i.session_persist_id)] += 1
except KeyError:
lst[str(i.session_persist_id)] = 1
logger.debug(f'finish counting, result is {lst}')
return lst
def got_rank(msg_dict: Dict[str, int]) -> List:
def got_rank(msg_dict: Dict[int, int]) -> List[List[Any]]:
"""### 获得排行榜
Args:
msg_dict (Dict[str,int]): 要处理的字典
msg_dict (Dict[int,int]): 要处理的字典
Returns:
List[Tuple[str,int]]: 排行榜列表(已排序)
List[Tuple[int,int]]: 排行榜列表(已排序)
"""
rank = []
while len(rank) < plugin_config.get_num:
try:
max_key = max(msg_dict.items(), key=lambda x: x[1])
rank.append(list(max_key))
rank.append(tuple(max_key))
msg_dict.pop(max_key[0])
except ValueError:
logger.error(
@ -261,3 +228,42 @@ async def get_user_infos(
rank2.append(user)
return rank2
async def get_user_message_counts(
keyword: Optional[str] = None, **kwargs
) -> Dict[int, int]:
"""获取每个用户的消息数量(直接在数据库层面统计)
参数:
* ``keyword``: 可选,关键词,只统计包含该关键词的消息
* ``**kwargs``: 筛选参数,具体查看 `filter_statement` 中的定义
返回值:
* ``Dict[str, int]``: 键为user_persist_id值为该用户的消息数量
"""
whereclause = filter_statement(**kwargs)
# 如果提供了关键词,添加关键词过滤条件
if keyword:
# 构造LIKE条件类似于msg_counter函数中的正则匹配
# 根据数据库类型不同可能需要调整LIKE的语法
keyword_condition = MessageRecord.plain_text.ilike(f'%{keyword}%')
whereclause.append(keyword_condition)
# 使用SQL的GROUP BY和COUNT进行分组统计
statement = (
select(
SessionModel.user_persist_id,
func.count(MessageRecord.id).label('message_count'),
)
.select_from(MessageRecord)
.join(SessionModel, SessionModel.id == MessageRecord.session_persist_id)
.where(*whereclause)
.group_by(SessionModel.user_persist_id)
)
async with get_session() as db_session:
result = await db_session.execute(statement)
# 转换为字典格式返回
return {user_id: count for user_id, count in result.all()}