mirror of
				https://github.com/LiteyukiStudio/LiteyukiBot.git
				synced 2025-10-25 21:06:31 +00:00 
			
		
		
		
	🐛 fix 通道类回调函数在进程间传递时无法序列号的问题
This commit is contained in:
		
							
								
								
									
										29
									
								
								src/nonebot_plugins/liteyuki_statistics/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								src/nonebot_plugins/liteyuki_statistics/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| from nonebot.plugin import PluginMetadata | ||||
| from .stat_matchers import * | ||||
| from .stat_monitors import * | ||||
| from .stat_restful_api import * | ||||
|  | ||||
| __author__ = "snowykami" | ||||
| __plugin_meta__ = PluginMetadata( | ||||
|     name="统计信息", | ||||
|     description="统计机器人的信息,包括消息、群聊等,支持排名、图表等功能", | ||||
|     usage=( | ||||
|             "```\nstatistic message 查看统计消息\n" | ||||
|             "可选参数:\n" | ||||
|             "  -g|--group [group_id] 指定群聊\n" | ||||
|             "  -u|--user [user_id] 指定用户\n" | ||||
|             "  -d|--duration [duration] 指定时长\n" | ||||
|             "  -p|--period [period] 指定次数统计周期\n" | ||||
|             "  -b|--bot [bot_id] 指定机器人\n" | ||||
|             "命令别名:\n" | ||||
|             "  statistic|stat  message|msg|m\n" | ||||
|             "```" | ||||
|     ), | ||||
|     type="application", | ||||
|     homepage="https://github.com/snowykami/LiteyukiBot", | ||||
|     extra={ | ||||
|             "liteyuki"      : True, | ||||
|             "toggleable"    : False, | ||||
|             "default_enable": True, | ||||
|     } | ||||
| ) | ||||
							
								
								
									
										21
									
								
								src/nonebot_plugins/liteyuki_statistics/common.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								src/nonebot_plugins/liteyuki_statistics/common.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| from src.utils.base.data import Database, LiteModel | ||||
|  | ||||
|  | ||||
| class MessageEventModel(LiteModel): | ||||
|     TABLE_NAME: str = "message_event" | ||||
|     time: int = 0 | ||||
|  | ||||
|     bot_id: str = "" | ||||
|     adapter: str = "" | ||||
|  | ||||
|     user_id: str = "" | ||||
|     group_id: str = "" | ||||
|  | ||||
|     message_id: str = "" | ||||
|     message: list = [] | ||||
|     message_text: str = "" | ||||
|     message_type: str = "" | ||||
|  | ||||
|  | ||||
| msg_db = Database("data/liteyuki/msg.ldb") | ||||
| msg_db.auto_migrate(MessageEventModel()) | ||||
							
								
								
									
										172
									
								
								src/nonebot_plugins/liteyuki_statistics/data_source.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										172
									
								
								src/nonebot_plugins/liteyuki_statistics/data_source.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,172 @@ | ||||
| import time | ||||
| from typing import Any | ||||
|  | ||||
| from collections import Counter | ||||
|  | ||||
| from nonebot import Bot | ||||
|  | ||||
| from src.utils.message.html_tool import template2image | ||||
| from .common import MessageEventModel, msg_db | ||||
| from src.utils.base.language import Language | ||||
| from src.utils.base.resource import get_path | ||||
| from src.utils.message.string_tool import convert_seconds_to_time | ||||
| from ...utils.external.logo import get_group_icon, get_user_icon | ||||
|  | ||||
|  | ||||
| async def count_msg_by_bot_id(bot_id: str) -> int: | ||||
|     condition = " AND bot_id = ?" | ||||
|     condition_args = [bot_id] | ||||
|  | ||||
|     msg_rows = msg_db.where_all( | ||||
|         MessageEventModel(), | ||||
|         condition, | ||||
|         *condition_args | ||||
|     ) | ||||
|  | ||||
|     return len(msg_rows) | ||||
|  | ||||
|  | ||||
| async def get_stat_msg_image( | ||||
|         duration: int, | ||||
|         period: int, | ||||
|         group_id: str = None, | ||||
|         bot_id: str = None, | ||||
|         user_id: str = None, | ||||
|         ulang: Language = Language() | ||||
| ) -> bytes: | ||||
|     """ | ||||
|     获取统计消息 | ||||
|     Args: | ||||
|         user_id: | ||||
|         ulang: | ||||
|         bot_id: | ||||
|         group_id: | ||||
|         duration: 统计时间,单位秒 | ||||
|         period: 统计周期,单位秒 | ||||
|  | ||||
|     Returns: | ||||
|         tuple: [int,], [int,] 两个列表,分别为周期中心时间戳和消息数量 | ||||
|     """ | ||||
|     now = int(time.time()) | ||||
|     start_time = (now - duration) | ||||
|  | ||||
|     condition = "time > ?" | ||||
|     condition_args = [start_time] | ||||
|  | ||||
|     if group_id: | ||||
|         condition += " AND group_id = ?" | ||||
|         condition_args.append(group_id) | ||||
|     if bot_id: | ||||
|         condition += " AND bot_id = ?" | ||||
|         condition_args.append(bot_id) | ||||
|  | ||||
|     if user_id: | ||||
|         condition += " AND user_id = ?" | ||||
|         condition_args.append(user_id) | ||||
|  | ||||
|     msg_rows = msg_db.where_all( | ||||
|         MessageEventModel(), | ||||
|         condition, | ||||
|         *condition_args | ||||
|     ) | ||||
|     timestamps = [] | ||||
|     msg_count = [] | ||||
|     msg_rows.sort(key=lambda x: x.time) | ||||
|  | ||||
|     start_time = max(msg_rows[0].time, start_time) | ||||
|  | ||||
|     for i in range(start_time, now, period): | ||||
|         timestamps.append(i + period // 2) | ||||
|         msg_count.append(0) | ||||
|  | ||||
|     for msg in msg_rows: | ||||
|         period_start_time = start_time + (msg.time - start_time) // period * period | ||||
|         period_center_time = period_start_time + period // 2 | ||||
|         index = timestamps.index(period_center_time) | ||||
|         msg_count[index] += 1 | ||||
|  | ||||
|     templates = { | ||||
|             "data": [ | ||||
|                     { | ||||
|                             "name"  : ulang.get("stat.message") | ||||
|                                       + f"    Period {convert_seconds_to_time(period)}" + f"    Duration {convert_seconds_to_time(duration)}" | ||||
|                                       + (f"    Group {group_id}" if group_id else "") + (f"    Bot {bot_id}" if bot_id else "") + ( | ||||
|                                               f"    User {user_id}" if user_id else ""), | ||||
|                             "times" : timestamps, | ||||
|                             "counts": msg_count | ||||
|                     } | ||||
|             ] | ||||
|     } | ||||
|  | ||||
|     return await template2image(get_path("templates/stat_msg.html"), templates) | ||||
|  | ||||
|  | ||||
| async def get_stat_rank_image( | ||||
|         rank_type: str, | ||||
|         limit: dict[str, Any], | ||||
|         ulang: Language = Language(), | ||||
|         bot: Bot = None, | ||||
| ) -> bytes: | ||||
|     if rank_type == "user": | ||||
|         condition = "user_id != ''" | ||||
|         condition_args = [] | ||||
|     else: | ||||
|         condition = "group_id != ''" | ||||
|         condition_args = [] | ||||
|  | ||||
|     for k, v in limit.items(): | ||||
|         match k: | ||||
|             case "user_id": | ||||
|                 condition += " AND user_id = ?" | ||||
|                 condition_args.append(v) | ||||
|             case "group_id": | ||||
|                 condition += " AND group_id = ?" | ||||
|                 condition_args.append(v) | ||||
|             case "bot_id": | ||||
|                 condition += " AND bot_id = ?" | ||||
|                 condition_args.append(v) | ||||
|             case "duration": | ||||
|                 condition += " AND time > ?" | ||||
|                 condition_args.append(v) | ||||
|  | ||||
|     msg_rows = msg_db.where_all( | ||||
|         MessageEventModel(), | ||||
|         condition, | ||||
|         *condition_args | ||||
|     ) | ||||
|  | ||||
|     """ | ||||
|         { | ||||
|             name: string,   # user name or group name | ||||
|             count: int,     # message count | ||||
|             icon: string    # icon url | ||||
|         } | ||||
|     """ | ||||
|  | ||||
|     if rank_type == "user": | ||||
|         ranking_counter = Counter([msg.user_id for msg in msg_rows]) | ||||
|     else: | ||||
|         ranking_counter = Counter([msg.group_id for msg in msg_rows]) | ||||
|     sorted_data = sorted(ranking_counter.items(), key=lambda x: x[1], reverse=True) | ||||
|  | ||||
|     ranking: list[dict[str, Any]] = [ | ||||
|             { | ||||
|                     "name" : _[0], | ||||
|                     "count": _[1], | ||||
|                     "icon" : await (get_group_icon(platform="qq", group_id=_[0]) if rank_type == "group" else get_user_icon( | ||||
|                         platform="qq", user_id=_[0] | ||||
|                     )) | ||||
|             } | ||||
|             for _ in sorted_data[0:min(len(sorted_data), limit["rank"])] | ||||
|     ] | ||||
|  | ||||
|     templates = { | ||||
|             "data": | ||||
|                 { | ||||
|                         "name"   : ulang.get("stat.rank") + f"    Type {rank_type}" + f"    Limit {limit}", | ||||
|                         "ranking": ranking | ||||
|                 } | ||||
|  | ||||
|     } | ||||
|  | ||||
|     return await template2image(get_path("templates/stat_rank.html"), templates, debug=True) | ||||
							
								
								
									
										134
									
								
								src/nonebot_plugins/liteyuki_statistics/stat_matchers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								src/nonebot_plugins/liteyuki_statistics/stat_matchers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,134 @@ | ||||
| from nonebot import Bot, require | ||||
| from src.utils.message.string_tool import convert_duration, convert_time_to_seconds | ||||
| from .data_source import * | ||||
| from src.utils import event as event_utils | ||||
| from src.utils.base.language import Language | ||||
| from src.utils.base.ly_typing import T_MessageEvent | ||||
|  | ||||
| require("nonebot_plugin_alconna") | ||||
|  | ||||
| from nonebot_plugin_alconna import ( | ||||
|     UniMessage, | ||||
|     on_alconna, | ||||
|     Alconna, | ||||
|     Args, | ||||
|     Subcommand, | ||||
|     Arparma, | ||||
|     Option, | ||||
|     MultiVar | ||||
| ) | ||||
|  | ||||
| stat_msg = on_alconna( | ||||
|     Alconna( | ||||
|         "statistic", | ||||
|         Subcommand( | ||||
|             "message", | ||||
|             # Args["duration", str, "2d"]["period", str, "60s"],  # 默认为1天 | ||||
|             Option( | ||||
|                 "-d|--duration", | ||||
|                 Args["duration", str, "2d"], | ||||
|                 help_text="统计时间", | ||||
|             ), | ||||
|             Option( | ||||
|                 "-p|--period", | ||||
|                 Args["period", str, "60s"], | ||||
|                 help_text="统计周期", | ||||
|             ), | ||||
|             Option( | ||||
|                 "-b|--bot",  # 生成图表 | ||||
|                 Args["bot_id", str, "current"], | ||||
|                 help_text="是否指定机器人", | ||||
|             ), | ||||
|             Option( | ||||
|                 "-g|--group", | ||||
|                 Args["group_id", str, "current"], | ||||
|                 help_text="指定群组" | ||||
|             ), | ||||
|             Option( | ||||
|                 "-u|--user", | ||||
|                 Args["user_id", str, "current"], | ||||
|                 help_text="指定用户" | ||||
|             ), | ||||
|             alias={"msg", "m"}, | ||||
|             help_text="查看统计次数内的消息" | ||||
|         ), | ||||
|         Subcommand( | ||||
|             "rank", | ||||
|             Option( | ||||
|                 "-u|--user", | ||||
|                 help_text="以用户为指标", | ||||
|             ), | ||||
|             Option( | ||||
|                 "-g|--group", | ||||
|                 help_text="以群组为指标", | ||||
|             ), | ||||
|             Option( | ||||
|                 "-l|--limit", | ||||
|                 Args["limit", MultiVar(str)], | ||||
|                 help_text="限制参数,使用key=val格式", | ||||
|             ), | ||||
|             Option( | ||||
|                 "-d|--duration", | ||||
|                 Args["duration", str, "1d"], | ||||
|                 help_text="统计时间", | ||||
|             ), | ||||
|             Option( | ||||
|                 "-r|--rank", | ||||
|                 Args["rank", int, 20], | ||||
|                 help_text="指定排名", | ||||
|             ), | ||||
|             alias={"r"}, | ||||
|         ) | ||||
|     ), | ||||
|     aliases={"stat"} | ||||
| ) | ||||
|  | ||||
|  | ||||
| @stat_msg.assign("message") | ||||
| async def _(result: Arparma, event: T_MessageEvent, bot: Bot): | ||||
|     ulang = Language(event_utils.get_user_id(event)) | ||||
|     try: | ||||
|         duration = convert_time_to_seconds(result.other_args.get("duration", "2d"))  # 秒数 | ||||
|         period = convert_time_to_seconds(result.other_args.get("period", "1m")) | ||||
|     except BaseException as e: | ||||
|         await stat_msg.send(ulang.get("liteyuki.invalid_command", TEXT=str(e.__str__()))) | ||||
|         return | ||||
|  | ||||
|     group_id = result.other_args.get("group_id") | ||||
|     bot_id = result.other_args.get("bot_id") | ||||
|     user_id = result.other_args.get("user_id") | ||||
|  | ||||
|     if group_id in ["current", "c"]: | ||||
|         group_id = str(event_utils.get_group_id(event)) | ||||
|  | ||||
|     if group_id in ["all", "a"]: | ||||
|         group_id = "all" | ||||
|  | ||||
|     if bot_id in ["current", "c"]: | ||||
|         bot_id = str(bot.self_id) | ||||
|  | ||||
|     if user_id in ["current", "c"]: | ||||
|         user_id = str(event_utils.get_user_id(event)) | ||||
|  | ||||
|     img = await get_stat_msg_image(duration=duration, period=period, group_id=group_id, bot_id=bot_id, user_id=user_id, ulang=ulang) | ||||
|     await stat_msg.send(UniMessage.image(raw=img)) | ||||
|  | ||||
|  | ||||
| @stat_msg.assign("rank") | ||||
| async def _(result: Arparma, event: T_MessageEvent, bot: Bot): | ||||
|     ulang = Language(event_utils.get_user_id(event)) | ||||
|     rank_type = "user" | ||||
|     duration = convert_time_to_seconds(result.other_args.get("duration", "1d")) | ||||
|     if result.subcommands.get("rank").options.get("user"): | ||||
|         rank_type = "user" | ||||
|     elif result.subcommands.get("rank").options.get("group"): | ||||
|         rank_type = "group" | ||||
|  | ||||
|     limit = result.other_args.get("limit", {}) | ||||
|     if limit: | ||||
|         limit = dict([i.split("=") for i in limit]) | ||||
|     limit["duration"] = time.time() - duration  # 起始时间戳 | ||||
|     limit["rank"] = result.other_args.get("rank", 20) | ||||
|  | ||||
|     img = await get_stat_rank_image(rank_type=rank_type, limit=limit, ulang=ulang) | ||||
|     await stat_msg.send(UniMessage.image(raw=img)) | ||||
							
								
								
									
										92
									
								
								src/nonebot_plugins/liteyuki_statistics/stat_monitors.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								src/nonebot_plugins/liteyuki_statistics/stat_monitors.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,92 @@ | ||||
| import time | ||||
|  | ||||
| from nonebot import require | ||||
| from nonebot.message import event_postprocessor | ||||
|  | ||||
| from src.utils.base.data import Database, LiteModel | ||||
| from src.utils.base.ly_typing import v11, v12, satori | ||||
|  | ||||
| from src.utils.base.ly_typing import T_Bot, T_MessageEvent | ||||
|  | ||||
| from .common import MessageEventModel, msg_db | ||||
| from src.utils import event as event_utils | ||||
|  | ||||
| require("nonebot_plugin_alconna") | ||||
|  | ||||
|  | ||||
| async def general_event_monitor(bot: T_Bot, event: T_MessageEvent): | ||||
|     pass | ||||
|     # if isinstance(bot, satori.Bot): | ||||
|     #     print("POST PROCESS SATORI EVENT") | ||||
|     #     return await satori_event_monitor(bot, event) | ||||
|     # elif isinstance(bot, v11.Bot): | ||||
|     #     print("POST PROCESS V11 EVENT") | ||||
|     #     return await onebot_v11_event_monitor(bot, event) | ||||
|  | ||||
|  | ||||
| @event_postprocessor | ||||
| async def onebot_v11_event_monitor(bot: v11.Bot, event: v11.MessageEvent): | ||||
|     if event.message_type == "group": | ||||
|         event: v11.GroupMessageEvent | ||||
|         group_id = str(event.group_id) | ||||
|     else: | ||||
|         group_id = "" | ||||
|     mem = MessageEventModel( | ||||
|         time=int(time.time()), | ||||
|         bot_id=bot.self_id, | ||||
|         adapter="onebot.v11", | ||||
|         group_id=group_id, | ||||
|         user_id=str(event.user_id), | ||||
|  | ||||
|         message_id=str(event.message_id), | ||||
|  | ||||
|         message=[ms.__dict__ for ms in event.message], | ||||
|         message_text=event.raw_message, | ||||
|         message_type=event.message_type, | ||||
|     ) | ||||
|     msg_db.save(mem) | ||||
|  | ||||
|  | ||||
| @event_postprocessor | ||||
| async def onebot_v12_event_monitor(bot: v12.Bot, event: v12.MessageEvent): | ||||
|     if event.message_type == "group": | ||||
|         event: v12.GroupMessageEvent | ||||
|         group_id = str(event.group_id) | ||||
|     else: | ||||
|         group_id = "" | ||||
|     mem = MessageEventModel( | ||||
|         time=int(time.time()), | ||||
|         bot_id=bot.self_id, | ||||
|         adapter="onebot.v12", | ||||
|         group_id=group_id, | ||||
|         user_id=str(event.user_id), | ||||
|  | ||||
|         message_id=[ms.__dict__ for ms in event.message], | ||||
|  | ||||
|         message=event.message, | ||||
|         message_text=event.raw_message, | ||||
|         message_type=event.message_type, | ||||
|     ) | ||||
|     msg_db.save(mem) | ||||
|  | ||||
|  | ||||
| @event_postprocessor | ||||
| async def satori_event_monitor(bot: satori.Bot, event: satori.MessageEvent): | ||||
|     if event.guild is not None: | ||||
|         event: satori.MessageEvent | ||||
|         group_id = str(event.guild.id) | ||||
|     else: | ||||
|         group_id = "" | ||||
|  | ||||
|     mem = MessageEventModel( | ||||
|         time=int(time.time()), | ||||
|         bot_id=bot.self_id, | ||||
|         adapter="satori", | ||||
|         group_id=group_id, | ||||
|         user_id=str(event.user.id), | ||||
|         message_id=[ms.__str__() for ms in event.message], | ||||
|         message=event.message, | ||||
|         message_text=event.message.content, | ||||
|         message_type=event_utils.get_message_type(event), | ||||
|     ) | ||||
|     msg_db.save(mem) | ||||
							
								
								
									
										21
									
								
								src/nonebot_plugins/liteyuki_statistics/word_cloud/LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								src/nonebot_plugins/liteyuki_statistics/word_cloud/LICENSE
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| MIT License | ||||
|  | ||||
| Copyright (c) 2022 hemengyang | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| The above copyright notice and this permission notice shall be included in all | ||||
| copies or substantial portions of the Software. | ||||
|  | ||||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||
| SOFTWARE. | ||||
| @@ -0,0 +1,107 @@ | ||||
| import asyncio | ||||
| import concurrent.futures | ||||
| import contextlib | ||||
| import re | ||||
| from functools import partial | ||||
| from io import BytesIO | ||||
| from random import choice | ||||
| from typing import Optional | ||||
|  | ||||
| import jieba | ||||
| import jieba.analyse | ||||
| import numpy as np | ||||
| from emoji import replace_emoji | ||||
| from PIL import Image | ||||
| from wordcloud import WordCloud | ||||
|  | ||||
| from .config import global_config, plugin_config | ||||
|  | ||||
|  | ||||
| def pre_precess(msg: str) -> str: | ||||
|     """对消息进行预处理""" | ||||
|     # 去除网址 | ||||
|     # https://stackoverflow.com/a/17773849/9212748 | ||||
|     url_regex = re.compile( | ||||
|         r"(https?:\/\/(?:www\.|(?!www))[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9]\.[^\s]{2,}|www\.[a-zA-Z0-9][a-zA-Z0-9-]" | ||||
|         r"+[a-zA-Z0-9]\.[^\s]{2,}|https?:\/\/(?:www\.|(?!www))[a-zA-Z0-9]+\.[^\s]{2,}|www\.[a-zA-Z0-9]+\.[^\s]{2,})" | ||||
|     ) | ||||
|     msg = url_regex.sub("", msg) | ||||
|  | ||||
|     # 去除 \u200b | ||||
|     msg = re.sub(r"\u200b", "", msg) | ||||
|  | ||||
|     # 去除 emoji | ||||
|     # https://github.com/carpedm20/emoji | ||||
|     msg = replace_emoji(msg) | ||||
|  | ||||
|     return msg | ||||
|  | ||||
|  | ||||
| def analyse_message(msg: str) -> dict[str, float]: | ||||
|     """分析消息 | ||||
|  | ||||
|     分词,并统计词频 | ||||
|     """ | ||||
|     # 设置停用词表 | ||||
|     if plugin_config.wordcloud_stopwords_path: | ||||
|         jieba.analyse.set_stop_words(plugin_config.wordcloud_stopwords_path) | ||||
|     # 加载用户词典 | ||||
|     if plugin_config.wordcloud_userdict_path: | ||||
|         jieba.load_userdict(str(plugin_config.wordcloud_userdict_path)) | ||||
|     # 基于 TF-IDF 算法的关键词抽取 | ||||
|     # 返回所有关键词,因为设置了数量其实也只是 tags[:topK],不如交给词云库处理 | ||||
|     words = jieba.analyse.extract_tags(msg, topK=0, withWeight=True) | ||||
|     return dict(words) | ||||
|  | ||||
|  | ||||
| def get_mask(key: str): | ||||
|     """获取 mask""" | ||||
|     mask_path = plugin_config.get_mask_path(key) | ||||
|     if mask_path.exists(): | ||||
|         return np.array(Image.open(mask_path)) | ||||
|     # 如果指定 mask 文件不存在,则尝试默认 mask | ||||
|     default_mask_path = plugin_config.get_mask_path() | ||||
|     if default_mask_path.exists(): | ||||
|         return np.array(Image.open(default_mask_path)) | ||||
|  | ||||
|  | ||||
| def _get_wordcloud(messages: list[str], mask_key: str) -> Optional[bytes]: | ||||
|     # 过滤掉命令 | ||||
|     command_start = tuple(i for i in global_config.command_start if i) | ||||
|     message = " ".join(m for m in messages if not m.startswith(command_start)) | ||||
|     # 预处理 | ||||
|     message = pre_precess(message) | ||||
|     # 分析消息。分词,并统计词频 | ||||
|     frequency = analyse_message(message) | ||||
|     # 词云参数 | ||||
|     wordcloud_options = {} | ||||
|     wordcloud_options.update(plugin_config.wordcloud_options) | ||||
|     wordcloud_options.setdefault("font_path", str(plugin_config.wordcloud_font_path)) | ||||
|     wordcloud_options.setdefault("width", plugin_config.wordcloud_width) | ||||
|     wordcloud_options.setdefault("height", plugin_config.wordcloud_height) | ||||
|     wordcloud_options.setdefault( | ||||
|         "background_color", plugin_config.wordcloud_background_color | ||||
|     ) | ||||
|     # 如果 colormap 是列表,则随机选择一个 | ||||
|     colormap = ( | ||||
|         plugin_config.wordcloud_colormap | ||||
|         if isinstance(plugin_config.wordcloud_colormap, str) | ||||
|         else choice(plugin_config.wordcloud_colormap) | ||||
|     ) | ||||
|     wordcloud_options.setdefault("colormap", colormap) | ||||
|     wordcloud_options.setdefault("mask", get_mask(mask_key)) | ||||
|     with contextlib.suppress(ValueError): | ||||
|         wordcloud = WordCloud(**wordcloud_options) | ||||
|         image = wordcloud.generate_from_frequencies(frequency).to_image() | ||||
|         image_bytes = BytesIO() | ||||
|         image.save(image_bytes, format="PNG") | ||||
|         return image_bytes.getvalue() | ||||
|  | ||||
|  | ||||
| async def get_wordcloud(messages: list[str], mask_key: str) -> Optional[bytes]: | ||||
|     loop = asyncio.get_running_loop() | ||||
|     pfunc = partial(_get_wordcloud, messages, mask_key) | ||||
|     # 虽然不知道具体是哪里泄漏了,但是通过每次关闭线程池可以避免这个问题 | ||||
|     # https://github.com/he0119/nonebot-plugin-wordcloud/issues/99 | ||||
|     with concurrent.futures.ThreadPoolExecutor() as pool: | ||||
|         return await loop.run_in_executor(pool, pfunc) | ||||
		Reference in New Issue
	
	Block a user