mirror of
https://github.com/ChenXu233/nonebot_plugin_dialectlist.git
synced 2025-09-23 04:46:23 +00:00
⚡ 超级性能提升
This commit is contained in:
@ -16,7 +16,6 @@ from typing import Optional, Union
|
|||||||
import nonebot_plugin_saa as saa
|
import nonebot_plugin_saa as saa
|
||||||
from arclet.alconna import ArparmaBehavior
|
from arclet.alconna import ArparmaBehavior
|
||||||
from arclet.alconna.arparma import Arparma
|
from arclet.alconna.arparma import Arparma
|
||||||
from nonebot import on_command
|
|
||||||
from nonebot.adapters import Bot, Event
|
from nonebot.adapters import Bot, Event
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.params import Arg, Depends
|
from nonebot.params import Arg, Depends
|
||||||
@ -31,10 +30,10 @@ from nonebot_plugin_alconna import (
|
|||||||
Option,
|
Option,
|
||||||
on_alconna,
|
on_alconna,
|
||||||
)
|
)
|
||||||
from nonebot_plugin_chatrecorder import get_message_records
|
|
||||||
from nonebot_plugin_uninfo import Session, Uninfo, get_session
|
from nonebot_plugin_uninfo import Session, Uninfo, get_session
|
||||||
|
|
||||||
from .config import Config, plugin_config
|
from .config import Config, plugin_config
|
||||||
|
|
||||||
# from .storage import build_cache, get_cache
|
# from .storage import build_cache, get_cache
|
||||||
from .time import (
|
from .time import (
|
||||||
get_datetime_fromisoformat_with_timezone,
|
get_datetime_fromisoformat_with_timezone,
|
||||||
@ -46,8 +45,8 @@ from .utils import (
|
|||||||
get_rank_image,
|
get_rank_image,
|
||||||
get_user_infos,
|
get_user_infos,
|
||||||
got_rank,
|
got_rank,
|
||||||
msg_counter,
|
|
||||||
persist_id2user_id,
|
persist_id2user_id,
|
||||||
|
get_user_message_counts,
|
||||||
)
|
)
|
||||||
|
|
||||||
__plugin_meta__ = PluginMetadata(
|
__plugin_meta__ = PluginMetadata(
|
||||||
@ -80,17 +79,6 @@ def wrapper(slot: Union[int, str], content: Optional[str], context) -> str:
|
|||||||
return '' # pragma: no cover
|
return '' # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
build_cache_cmd = on_command('build_cache', aliases={'重建缓存'}, block=True)
|
|
||||||
|
|
||||||
|
|
||||||
@build_cache_cmd.handle()
|
|
||||||
async def _build_cache(bot: Bot, event: Event):
|
|
||||||
return
|
|
||||||
await saa.Text('正在重建缓存,请稍等。').send(reply=True) # type: ignore
|
|
||||||
await build_cache()
|
|
||||||
await saa.Text('重建缓存完成。').send(reply=True)
|
|
||||||
|
|
||||||
|
|
||||||
b_cmd = on_alconna(
|
b_cmd = on_alconna(
|
||||||
Alconna(
|
Alconna(
|
||||||
'看看B话',
|
'看看B话',
|
||||||
@ -122,18 +110,16 @@ async def handle_b_cmd(
|
|||||||
if not gid:
|
if not gid:
|
||||||
await b_cmd.finish('请指定群号。')
|
await b_cmd.finish('请指定群号。')
|
||||||
|
|
||||||
if keyword.available:
|
_keyword = keyword.result
|
||||||
keywords = keyword.result
|
|
||||||
else:
|
|
||||||
keywords = None
|
|
||||||
|
|
||||||
messages = await get_message_records(
|
d = await get_user_message_counts(
|
||||||
|
keyword=_keyword,
|
||||||
scene_ids=[gid],
|
scene_ids=[gid],
|
||||||
user_ids=[id],
|
user_ids=[id],
|
||||||
types=['message'], # 排除机器人自己发的消息
|
types=['message'], # 排除机器人自己发的消息
|
||||||
exclude_user_ids=plugin_config.excluded_people,
|
exclude_user_ids=plugin_config.excluded_people,
|
||||||
)
|
)
|
||||||
d = msg_counter(messages, keywords)
|
|
||||||
rank = got_rank(d)
|
rank = got_rank(d)
|
||||||
if not rank:
|
if not rank:
|
||||||
await b_cmd.finish(
|
await b_cmd.finish(
|
||||||
@ -302,23 +288,15 @@ async def handle_rank(
|
|||||||
|
|
||||||
keyword = state['keyword']
|
keyword = state['keyword']
|
||||||
|
|
||||||
if plugin_config.counting_cache:
|
|
||||||
await saa.Text("缓存暂不支持").finish()
|
|
||||||
# if keyword:
|
|
||||||
# await saa.Text('已开启缓存~缓存不支持关键词查询哦').finish()
|
|
||||||
# t1 = t.time()
|
|
||||||
# raw_rank = await get_cache(start, stop, id)
|
|
||||||
# logger.debug(f'获取计数消息花费时间:{t.time() - t1}')
|
|
||||||
else:
|
|
||||||
t1 = t.time()
|
t1 = t.time()
|
||||||
messages = await get_message_records(
|
raw_rank = await get_user_message_counts(
|
||||||
|
keyword=keyword,
|
||||||
scene_ids=[id],
|
scene_ids=[id],
|
||||||
types=['message'], # 排除机器人自己发的消息
|
types=['message'], # 排除机器人自己发的消息
|
||||||
time_start=start,
|
time_start=start,
|
||||||
time_stop=stop,
|
time_stop=stop,
|
||||||
exclude_user_ids=plugin_config.excluded_people,
|
exclude_user_ids=plugin_config.excluded_people,
|
||||||
)
|
)
|
||||||
raw_rank = msg_counter(messages, keyword)
|
|
||||||
logger.debug(f'获取计数消息花费时间:{t.time() - t1}')
|
logger.debug(f'获取计数消息花费时间:{t.time() - t1}')
|
||||||
|
|
||||||
if not raw_rank:
|
if not raw_rank:
|
||||||
|
@ -6,19 +6,20 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
class ScopedConfig(BaseModel):
|
class ScopedConfig(BaseModel):
|
||||||
get_num: int = 5 # 获取人数数量
|
get_num: int = 5 # 获取人数数量
|
||||||
font: str = "SimHei" # 字体格式
|
font: str = 'SimHei' # 字体格式
|
||||||
suffix: bool = True # 是否显示后缀
|
suffix: bool = True # 是否显示后缀
|
||||||
excluded_self: bool = True # 是否排除自己
|
excluded_self: bool = True # 是否排除自己
|
||||||
visualization: bool = True # 是否可视化
|
visualization: bool = True # 是否可视化
|
||||||
show_text_rank: bool = True # 是否显示文本排名
|
show_text_rank: bool = True # 是否显示文本排名
|
||||||
counting_cache: bool = False # 计数缓存(能够提高回复速度)
|
|
||||||
excluded_people: List[str] = [] # 排除的人的QQ号
|
excluded_people: List[str] = [] # 排除的人的QQ号
|
||||||
use_user_info_cache: bool = False # 是否使用用户信息缓存
|
use_user_info_cache: bool = False # 是否使用用户信息缓存
|
||||||
aggregate_transmission: bool = False # 是否聚合转发消息
|
aggregate_transmission: bool = False # 是否聚合转发消息
|
||||||
timezone: Optional[str] = "Asia/Shanghai" # 时区,影响统计时间
|
timezone: Optional[str] = 'Asia/Shanghai' # 时区,影响统计时间
|
||||||
string_suffix: str = "统计花费时间:{timecost}秒" # 消息格式后缀
|
string_suffix: str = '统计花费时间:{timecost}秒' # 消息格式后缀
|
||||||
template_path: str = "./template/rank_template.j2" # 模板路径
|
template_path: str = './template/rank_template.j2' # 模板路径
|
||||||
string_format: str = "第{index}名:\n{nickname},{chatdatanum}条消息\n" # 消息格式
|
string_format: str = (
|
||||||
|
'第{index}名:\n{nickname},{chatdatanum}条消息\n' # 消息格式
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
|
@ -1,47 +0,0 @@
|
|||||||
"""update dialectlist
|
|
||||||
|
|
||||||
迁移 ID: fb88e4d27eb8
|
|
||||||
父迁移: 60daff81fcdc
|
|
||||||
创建时间: 2025-04-20 11:00:50.931679
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
revision: str = 'fb88e4d27eb8'
|
|
||||||
down_revision: str | Sequence[str] | None = '60daff81fcdc'
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade(name: str = "") -> None:
|
|
||||||
if name:
|
|
||||||
return
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table('sessionmodel',
|
|
||||||
sa.Column('id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('bot_id', sa.String(length=64), nullable=False),
|
|
||||||
sa.Column('bot_type', sa.String(length=32), nullable=False),
|
|
||||||
sa.Column('platform', sa.String(length=32), nullable=False),
|
|
||||||
sa.Column('level', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('id1', sa.String(length=64), nullable=False),
|
|
||||||
sa.Column('id2', sa.String(length=64), nullable=False),
|
|
||||||
sa.Column('id3', sa.String(length=64), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_sessionmodel')),
|
|
||||||
sa.UniqueConstraint('bot_id', 'bot_type', 'platform', 'level', 'id1', 'id2', 'id3', name='unique_session'),
|
|
||||||
info={'bind_key': ''}
|
|
||||||
)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade(name: str = "") -> None:
|
|
||||||
if name:
|
|
||||||
return
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_table('sessionmodel')
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,23 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from nonebot_plugin_orm import Model
|
|
||||||
from nonebot_plugin_userinfo import UserInfo
|
|
||||||
from sqlalchemy import Integer
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
|
||||||
|
|
||||||
|
|
||||||
class UserRankInfo(UserInfo):
|
|
||||||
user_bnum: int
|
|
||||||
user_proportion: float
|
|
||||||
user_nickname: str
|
|
||||||
user_index: Union[int, str]
|
|
||||||
user_avatar_bytes: bytes
|
|
||||||
|
|
||||||
|
|
||||||
class MessageCountCache(Model):
|
|
||||||
__table_args__ = {"extend_existing": True}
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
|
||||||
time: Mapped[datetime]
|
|
||||||
session_id: Mapped[int] = mapped_column(Integer, index=True)
|
|
||||||
session_bnum: Mapped[int] = mapped_column(Integer)
|
|
11
nonebot_plugin_dialectlist/schema.py
Normal file
11
nonebot_plugin_dialectlist/schema.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from nonebot_plugin_userinfo import UserInfo
|
||||||
|
|
||||||
|
|
||||||
|
class UserRankInfo(UserInfo):
|
||||||
|
user_bnum: int
|
||||||
|
user_proportion: float
|
||||||
|
user_nickname: str
|
||||||
|
user_index: Union[int, str]
|
||||||
|
user_avatar_bytes: bytes
|
@ -1,136 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from nonebot import get_driver
|
|
||||||
from nonebot.adapters import Bot, Event
|
|
||||||
from nonebot.log import logger
|
|
||||||
from nonebot.message import event_postprocessor
|
|
||||||
from nonebot.params import Depends
|
|
||||||
from nonebot_plugin_chatrecorder import get_message_records
|
|
||||||
from nonebot_plugin_chatrecorder.utils import remove_timezone
|
|
||||||
from nonebot_plugin_uninfo import Session, get_session
|
|
||||||
from nonebot_plugin_uninfo.orm import SessionModel,get_session_persist_id
|
|
||||||
from nonebot_plugin_localstore import get_data_file
|
|
||||||
from nonebot_plugin_orm import get_session
|
|
||||||
from sqlalchemy import delete, or_, select
|
|
||||||
|
|
||||||
from .config import plugin_config
|
|
||||||
from .model import MessageCountCache
|
|
||||||
|
|
||||||
|
|
||||||
async def get_cache(time_start: datetime, time_stop: datetime, group_id: str):
|
|
||||||
async with get_session() as db_session:
|
|
||||||
where = [or_(SessionModel.id2 == group_id)]
|
|
||||||
statement = select(SessionModel).where(*where)
|
|
||||||
|
|
||||||
sessions = (await db_session.scalars(statement)).all()
|
|
||||||
|
|
||||||
where = [
|
|
||||||
or_(*[MessageCountCache.session_id == session.id for session in sessions])
|
|
||||||
]
|
|
||||||
statement = select(MessageCountCache).where(*where)
|
|
||||||
where.append(or_(MessageCountCache.time >= remove_timezone(time_start)))
|
|
||||||
where.append(or_(MessageCountCache.time <= remove_timezone(time_stop)))
|
|
||||||
statement = select(MessageCountCache).where(*where)
|
|
||||||
|
|
||||||
user_caches = (await db_session.scalars(statement)).all()
|
|
||||||
raw_rank = {}
|
|
||||||
for i in user_caches:
|
|
||||||
raw_rank[i.session_id] = raw_rank.get(i.session_id, 0) + i.session_bnum
|
|
||||||
return raw_rank
|
|
||||||
|
|
||||||
|
|
||||||
async def build_cache():
|
|
||||||
async with get_session() as db_session:
|
|
||||||
await db_session.execute(delete(MessageCountCache))
|
|
||||||
await db_session.commit()
|
|
||||||
logger.info("先前可能存在的缓存已清空")
|
|
||||||
messages = await get_message_records(types=["message"])
|
|
||||||
async with get_session() as db_session:
|
|
||||||
for msg in messages:
|
|
||||||
msg_session_id = msg.session_persist_id
|
|
||||||
|
|
||||||
where = [or_(MessageCountCache.session_id == msg_session_id)]
|
|
||||||
where.append(
|
|
||||||
or_(
|
|
||||||
MessageCountCache.time
|
|
||||||
== remove_timezone(
|
|
||||||
msg.time.replace(hour=1, minute=0, second=0, microsecond=0)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
statement = select(MessageCountCache).where(*where)
|
|
||||||
|
|
||||||
user_cache = (await db_session.scalars(statement)).all()
|
|
||||||
|
|
||||||
if user_cache:
|
|
||||||
user_cache[0].session_bnum += 1
|
|
||||||
else:
|
|
||||||
user_cache = MessageCountCache(
|
|
||||||
session_id=msg.session_persist_id,
|
|
||||||
time=remove_timezone(
|
|
||||||
msg.time.replace(hour=1, minute=0, second=0, microsecond=0)
|
|
||||||
),
|
|
||||||
session_bnum=1,
|
|
||||||
)
|
|
||||||
db_session.add(user_cache)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
logger.info("缓存构建完成")
|
|
||||||
|
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
|
|
||||||
|
|
||||||
@driver.on_startup
|
|
||||||
async def _():
|
|
||||||
if not plugin_config.counting_cache:
|
|
||||||
return
|
|
||||||
f_name = get_data_file("nonebot-plugin-dialectlist", "is-pre-cached.json")
|
|
||||||
if not os.path.exists(f_name):
|
|
||||||
with open(f_name, "w", encoding="utf-8") as f:
|
|
||||||
s = json.dumps({"is-pre-cached": False}, ensure_ascii=False, indent=4)
|
|
||||||
f.write(s)
|
|
||||||
|
|
||||||
with open(f_name, "r", encoding="utf-8") as f:
|
|
||||||
if json.load(f)["is-pre-cached"]:
|
|
||||||
return
|
|
||||||
logger.info("未检查到缓存,开始构建缓存")
|
|
||||||
with open(f_name, "w", encoding="utf-8") as f:
|
|
||||||
await build_cache()
|
|
||||||
json.dump({"is-pre-cached": True}, f, ensure_ascii=False, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
@event_postprocessor
|
|
||||||
async def _(bot: Bot, event: Event, session: Session = Depends(extract_session)):
|
|
||||||
if not plugin_config.counting_cache:
|
|
||||||
return
|
|
||||||
if not session.id2:
|
|
||||||
return
|
|
||||||
if event.get_type() != "message":
|
|
||||||
return
|
|
||||||
now = datetime.now()
|
|
||||||
now = now.replace(hour=1, minute=0, second=0, microsecond=0)
|
|
||||||
|
|
||||||
async with get_session() as db_session:
|
|
||||||
session_id = await get_session_persist_id(session)
|
|
||||||
logger.debug("session_id:" + str(session_id))
|
|
||||||
where = [or_(MessageCountCache.session_id == session_id)]
|
|
||||||
where.append(or_(MessageCountCache.time == remove_timezone(now)))
|
|
||||||
statement = select(MessageCountCache).where(*where)
|
|
||||||
user_cache = (await db_session.scalars(statement)).first()
|
|
||||||
if user_cache:
|
|
||||||
user_cache.session_bnum += 1
|
|
||||||
else:
|
|
||||||
user_cache = MessageCountCache(
|
|
||||||
session_id=session_id,
|
|
||||||
time=remove_timezone(now),
|
|
||||||
session_bnum=1,
|
|
||||||
)
|
|
||||||
db_session.add(user_cache)
|
|
||||||
await db_session.commit()
|
|
||||||
logger.debug("已计入缓存")
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: 修复缓存储存
|
|
@ -2,9 +2,11 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import unicodedata
|
import unicodedata
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
|
||||||
from nonebot.adapters import Bot, Event
|
from nonebot.adapters import Bot, Event
|
||||||
from nonebot.compat import model_dump
|
from nonebot.compat import model_dump
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
@ -20,10 +22,11 @@ from nonebot_plugin_uninfo.model import SceneType
|
|||||||
from nonebot_plugin_uninfo.orm import SessionModel, UserModel
|
from nonebot_plugin_uninfo.orm import SessionModel, UserModel
|
||||||
from nonebot_plugin_uninfo import get_session as extract_session
|
from nonebot_plugin_uninfo import get_session as extract_session
|
||||||
from nonebot_plugin_userinfo.exception import NetworkError
|
from nonebot_plugin_userinfo.exception import NetworkError
|
||||||
from sqlalchemy import or_, select
|
from nonebot_plugin_chatrecorder.record import filter_statement
|
||||||
|
|
||||||
|
|
||||||
from .config import plugin_config
|
from .config import plugin_config
|
||||||
from .model import UserRankInfo
|
from .schema import UserRankInfo
|
||||||
|
|
||||||
cache_path = get_cache_dir('nonebot_plugin_dialectlist')
|
cache_path = get_cache_dir('nonebot_plugin_dialectlist')
|
||||||
|
|
||||||
@ -38,71 +41,35 @@ async def ensure_group(
|
|||||||
|
|
||||||
async def persist_id2user_id(ids: List) -> List[str]:
|
async def persist_id2user_id(ids: List) -> List[str]:
|
||||||
user_ids = []
|
user_ids = []
|
||||||
user_persist_ids = []
|
if not ids:
|
||||||
async with get_session() as db_session:
|
|
||||||
for i in ids:
|
|
||||||
session = await db_session.scalar(
|
|
||||||
select(SessionModel).where(or_(*[SessionModel.id == i]))
|
|
||||||
)
|
|
||||||
if session is not None:
|
|
||||||
user_persist_id = session.user_persist_id
|
|
||||||
user_persist_ids.append(user_persist_id)
|
|
||||||
for i in user_persist_ids:
|
|
||||||
user = await db_session.scalar(
|
|
||||||
select(UserModel).where(UserModel.id == i)
|
|
||||||
)
|
|
||||||
if user is not None:
|
|
||||||
user_ids.append(user.user_id)
|
|
||||||
|
|
||||||
return user_ids
|
return user_ids
|
||||||
|
|
||||||
|
async with get_session() as db_session:
|
||||||
|
statement = (
|
||||||
|
select(UserModel.user_id)
|
||||||
|
.join(SessionModel, UserModel.id == SessionModel.user_persist_id)
|
||||||
|
.where(SessionModel.id.in_(ids))
|
||||||
|
)
|
||||||
|
result = await db_session.scalars(statement)
|
||||||
|
user_ids = result.all()
|
||||||
|
|
||||||
def msg_counter(
|
return list(user_ids)
|
||||||
msg_list: List[MessageRecord], keyword: Optional[str]
|
|
||||||
) -> Dict[str, int]:
|
|
||||||
"""### 计算每个人的消息量
|
|
||||||
|
|
||||||
Args:
|
|
||||||
msg_list (list[MessageRecord]): 需要处理的消息列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(dict[str,int]): 处理后的消息数量字典,键为用户,值为消息数量
|
|
||||||
"""
|
|
||||||
|
|
||||||
lst: Dict[str, int] = {}
|
|
||||||
msg_len = len(msg_list)
|
|
||||||
logger.info('wow , there are {} msgs to count !!!'.format(msg_len))
|
|
||||||
|
|
||||||
for i in msg_list:
|
|
||||||
# logger.debug(f"processing msg {i.plain_text}")
|
|
||||||
if keyword:
|
|
||||||
match = re.search(keyword, i.plain_text)
|
|
||||||
if not match:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
lst[str(i.session_persist_id)] += 1
|
|
||||||
except KeyError:
|
|
||||||
lst[str(i.session_persist_id)] = 1
|
|
||||||
|
|
||||||
logger.debug(f'finish counting, result is {lst}')
|
|
||||||
|
|
||||||
return lst
|
|
||||||
|
|
||||||
|
|
||||||
def got_rank(msg_dict: Dict[str, int]) -> List:
|
def got_rank(msg_dict: Dict[int, int]) -> List[List[Any]]:
|
||||||
"""### 获得排行榜
|
"""### 获得排行榜
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg_dict (Dict[str,int]): 要处理的字典
|
msg_dict (Dict[int,int]): 要处理的字典
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str,int]]: 排行榜列表(已排序)
|
List[Tuple[int,int]]: 排行榜列表(已排序)
|
||||||
"""
|
"""
|
||||||
rank = []
|
rank = []
|
||||||
while len(rank) < plugin_config.get_num:
|
while len(rank) < plugin_config.get_num:
|
||||||
try:
|
try:
|
||||||
max_key = max(msg_dict.items(), key=lambda x: x[1])
|
max_key = max(msg_dict.items(), key=lambda x: x[1])
|
||||||
rank.append(list(max_key))
|
rank.append(tuple(max_key))
|
||||||
msg_dict.pop(max_key[0])
|
msg_dict.pop(max_key[0])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.error(
|
logger.error(
|
||||||
@ -261,3 +228,42 @@ async def get_user_infos(
|
|||||||
rank2.append(user)
|
rank2.append(user)
|
||||||
|
|
||||||
return rank2
|
return rank2
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_message_counts(
|
||||||
|
keyword: Optional[str] = None, **kwargs
|
||||||
|
) -> Dict[int, int]:
|
||||||
|
"""获取每个用户的消息数量(直接在数据库层面统计)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
* ``keyword``: 可选,关键词,只统计包含该关键词的消息
|
||||||
|
* ``**kwargs``: 筛选参数,具体查看 `filter_statement` 中的定义
|
||||||
|
|
||||||
|
返回值:
|
||||||
|
* ``Dict[str, int]``: 键为user_persist_id,值为该用户的消息数量
|
||||||
|
"""
|
||||||
|
whereclause = filter_statement(**kwargs)
|
||||||
|
|
||||||
|
# 如果提供了关键词,添加关键词过滤条件
|
||||||
|
if keyword:
|
||||||
|
# 构造LIKE条件,类似于msg_counter函数中的正则匹配
|
||||||
|
# 根据数据库类型不同,可能需要调整LIKE的语法
|
||||||
|
keyword_condition = MessageRecord.plain_text.ilike(f'%{keyword}%')
|
||||||
|
whereclause.append(keyword_condition)
|
||||||
|
|
||||||
|
# 使用SQL的GROUP BY和COUNT进行分组统计
|
||||||
|
statement = (
|
||||||
|
select(
|
||||||
|
SessionModel.user_persist_id,
|
||||||
|
func.count(MessageRecord.id).label('message_count'),
|
||||||
|
)
|
||||||
|
.select_from(MessageRecord)
|
||||||
|
.join(SessionModel, SessionModel.id == MessageRecord.session_persist_id)
|
||||||
|
.where(*whereclause)
|
||||||
|
.group_by(SessionModel.user_persist_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
async with get_session() as db_session:
|
||||||
|
result = await db_session.execute(statement)
|
||||||
|
# 转换为字典格式返回
|
||||||
|
return {user_id: count for user_id, count in result.all()}
|
||||||
|
Reference in New Issue
Block a user