mirror of
				https://github.com/LiteyukiStudio/LiteyukiBot.git
				synced 2025-10-25 00:06:24 +00:00 
			
		
		
		
	🔥 小型重构
This commit is contained in:
		
							
								
								
									
										86
									
								
								src/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								src/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | ||||
| import json | ||||
| import os.path | ||||
| import platform | ||||
| import sys | ||||
| import time | ||||
|  | ||||
| import nonebot | ||||
|  | ||||
| __NAME__ = "LiteyukiBot" | ||||
| __VERSION__ = "6.3.2"  # 60201 | ||||
|  | ||||
| import requests | ||||
|  | ||||
| from src.utils.base.config import load_from_yaml, config | ||||
| from src.utils.base.log import init_log | ||||
| from src.utils.base.data_manager import TempConfig, auto_migrate, common_db | ||||
| from git import Repo | ||||
|  | ||||
|  | ||||
| major, minor, patch = map(int, __VERSION__.split(".")) | ||||
| __VERSION_I__ = major * 10000 + minor * 100 + patch | ||||
|  | ||||
|  | ||||
| def register_bot(): | ||||
|     url = "https://api.liteyuki.icu/register" | ||||
|     data = { | ||||
|             "name"     : __NAME__, | ||||
|             "version"  : __VERSION__, | ||||
|             "version_i": __VERSION_I__, | ||||
|             "python"   : f"{platform.python_implementation()} {platform.python_version()}", | ||||
|             "os"       : f"{platform.system()} {platform.version()} {platform.machine()}" | ||||
|     } | ||||
|     try: | ||||
|         nonebot.logger.info("Waiting for register to Liteyuki...") | ||||
|         resp = requests.post(url, json=data, timeout=(10, 15)) | ||||
|         if resp.status_code == 200: | ||||
|             data = resp.json() | ||||
|             if liteyuki_id := data.get("liteyuki_id"): | ||||
|                 with open("data/liteyuki/liteyuki.json", "wb") as f: | ||||
|                     f.write(json.dumps(data).encode("utf-8")) | ||||
|                 nonebot.logger.success(f"Register {liteyuki_id} to Liteyuki successfully") | ||||
|             else: | ||||
|                 raise ValueError(f"Register to Liteyuki failed: {data}") | ||||
|  | ||||
|     except Exception as e: | ||||
|         nonebot.logger.warning(f"Register to Liteyuki failed, but it's no matter: {e}") | ||||
|  | ||||
|  | ||||
| def init(): | ||||
|     """ | ||||
|     初始化 | ||||
|     Returns: | ||||
|  | ||||
|     """ | ||||
|     # 检测python版本是否高于3.10 | ||||
|     auto_migrate() | ||||
|     init_log() | ||||
|     if sys.version_info < (3, 10): | ||||
|         nonebot.logger.error("Requires Python3.10+ to run, please upgrade your Python Environment.") | ||||
|         exit(1) | ||||
|  | ||||
|     try: | ||||
|         # 检测git仓库 | ||||
|         repo = Repo(".") | ||||
|     except Exception as e: | ||||
|         nonebot.logger.error(f"Failed to load git repository: {e}, please clone this project from GitHub instead of downloading the zip file.") | ||||
|  | ||||
|     temp_data: TempConfig = common_db.where_one(TempConfig(), default=TempConfig()) | ||||
|     temp_data.data["start_time"] = time.time() | ||||
|     common_db.save(temp_data) | ||||
|  | ||||
|     # 在加载完成语言后再初始化日志 | ||||
|     nonebot.logger.info("Liteyuki is initializing...") | ||||
|  | ||||
|     if not os.path.exists("data/liteyuki/liteyuki.json"): | ||||
|         register_bot() | ||||
|  | ||||
|     if not os.path.exists("pyproject.toml"): | ||||
|         with open("pyproject.toml", "w", encoding="utf-8") as f: | ||||
|             f.write("[tool.nonebot]\n") | ||||
|  | ||||
|     nonebot.logger.info( | ||||
|         f"Run Liteyuki with Python{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro} " | ||||
|         f"at {sys.executable}" | ||||
|     ) | ||||
|     nonebot.logger.info(f"{__NAME__} {__VERSION__}({__VERSION_I__}) is running") | ||||
							
								
								
									
										14
									
								
								src/utils/adapter_manager/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								src/utils/adapter_manager/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| from . import ( | ||||
|     satori, | ||||
|     onebot | ||||
| ) | ||||
|  | ||||
|  | ||||
| def init(config: dict): | ||||
|     onebot.init() | ||||
|     satori.init(config) | ||||
|  | ||||
|  | ||||
| def register(): | ||||
|     onebot.register() | ||||
|     satori.register() | ||||
							
								
								
									
										12
									
								
								src/utils/adapter_manager/onebot.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								src/utils/adapter_manager/onebot.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| import nonebot | ||||
| from nonebot.adapters.onebot import v11, v12 | ||||
|  | ||||
|  | ||||
| def init(): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| def register(): | ||||
|     driver = nonebot.get_driver() | ||||
|     driver.register_adapter(v11.Adapter) | ||||
|     driver.register_adapter(v12.Adapter) | ||||
							
								
								
									
										26
									
								
								src/utils/adapter_manager/satori.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								src/utils/adapter_manager/satori.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| import json | ||||
| import os | ||||
|  | ||||
| import nonebot | ||||
| from nonebot.adapters import satori | ||||
|  | ||||
|  | ||||
| def init(config: dict): | ||||
|     if config.get("satori", None) is None: | ||||
|         nonebot.logger.info("Satori config not found, skip Satori init.") | ||||
|         return None | ||||
|     satori_config = config.get("satori") | ||||
|     if not satori_config.get("enable", False): | ||||
|         nonebot.logger.info("Satori not enabled, skip Satori init.") | ||||
|         return None | ||||
|     if os.getenv("SATORI_CLIENTS", None) is not None: | ||||
|         nonebot.logger.info("Satori clients already set in environment variable, skip.") | ||||
|     os.environ["SATORI_CLIENTS"] = json.dumps(satori_config.get("hosts", []), ensure_ascii=False) | ||||
|     config['satori_clients'] = satori_config.get("hosts", []) | ||||
|     return | ||||
|  | ||||
|  | ||||
| def register(): | ||||
|     if os.getenv("SATORI_CLIENTS", None) is not None: | ||||
|         driver = nonebot.get_driver() | ||||
|         driver.register_adapter(satori.Adapter) | ||||
							
								
								
									
										0
									
								
								src/utils/base/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/utils/base/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										100
									
								
								src/utils/base/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								src/utils/base/config.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | ||||
| import os | ||||
| from typing import List | ||||
|  | ||||
| import nonebot | ||||
| import yaml | ||||
| from pydantic import BaseModel | ||||
|  | ||||
| from .data_manager import StoredConfig, TempConfig, common_db | ||||
| from .ly_typing import T_Bot | ||||
| from ..message.tools import random_hex_string | ||||
|  | ||||
| config = {}  # 全局配置,确保加载后读取 | ||||
|  | ||||
|  | ||||
| class SatoriNodeConfig(BaseModel): | ||||
|     host: str = "" | ||||
|     port: str = "5500" | ||||
|     path: str = "" | ||||
|     token: str = "" | ||||
|  | ||||
|  | ||||
| class SatoriConfig(BaseModel): | ||||
|     comment: str = "These features are still in development. Do not enable in production environment." | ||||
|     enable: bool = False | ||||
|     hosts: List[SatoriNodeConfig] = [SatoriNodeConfig()] | ||||
|  | ||||
|  | ||||
| class BasicConfig(BaseModel): | ||||
|     host: str = "127.0.0.1" | ||||
|     port: int = 20216 | ||||
|     superusers: list[str] = [] | ||||
|     command_start: list[str] = ["/", ""] | ||||
|     nickname: list[str] = [f"LiteyukiBot-{random_hex_string(6)}"] | ||||
|     satori: SatoriConfig = SatoriConfig() | ||||
|  | ||||
|  | ||||
| def load_from_yaml(file: str) -> dict: | ||||
|     global config | ||||
|     nonebot.logger.debug("Loading config from %s" % file) | ||||
|     if not os.path.exists(file): | ||||
|         nonebot.logger.warning(f"Config file {file} not found, created default config, please modify it and restart") | ||||
|         with open(file, "w", encoding="utf-8") as f: | ||||
|             yaml.dump(BasicConfig().dict(), f, default_flow_style=False) | ||||
|  | ||||
|     with open(file, "r", encoding="utf-8") as f: | ||||
|         conf = init_conf(yaml.load(f, Loader=yaml.FullLoader)) | ||||
|         config = conf | ||||
|         if conf is None: | ||||
|             nonebot.logger.warning(f"Config file {file} is empty, use default config. please modify it and restart") | ||||
|             conf = BasicConfig().dict() | ||||
|         return conf | ||||
|  | ||||
|  | ||||
| def get_config(key: str, default=None): | ||||
|     """获取配置项,优先级:bot > config > db > yaml""" | ||||
|     try: | ||||
|         bot = nonebot.get_bot() | ||||
|     except: | ||||
|         bot = None | ||||
|  | ||||
|     if bot is None: | ||||
|         bot_config = {} | ||||
|     else: | ||||
|         bot_config = bot.config.dict() | ||||
|  | ||||
|     if key in bot_config: | ||||
|         return bot_config[key] | ||||
|  | ||||
|     elif key in config: | ||||
|         return config[key] | ||||
|  | ||||
|     elif key in common_db.where_one(StoredConfig(), default=StoredConfig()).config: | ||||
|         return common_db.where_one(StoredConfig(), default=StoredConfig()).config[key] | ||||
|  | ||||
|     elif key in load_from_yaml("config.yml"): | ||||
|         return load_from_yaml("config.yml")[key] | ||||
|  | ||||
|     else: | ||||
|         return default | ||||
|  | ||||
|  | ||||
| def set_stored_config(key: str, value): | ||||
|     temp_config: TempConfig = common_db.where_one(TempConfig(), default=TempConfig()) | ||||
|     temp_config.data[key] = value | ||||
|     common_db.save(temp_config) | ||||
|  | ||||
|  | ||||
| def init_conf(conf: dict) -> dict: | ||||
|     """ | ||||
|     初始化配置文件,确保配置文件中的必要字段存在,且不会冲突 | ||||
|     Args: | ||||
|         conf: | ||||
|  | ||||
|     Returns: | ||||
|  | ||||
|     """ | ||||
|     # 若command_start中无"",则添加必要命令头,开启alconna_use_command_start防止冲突 | ||||
|     if "" not in conf.get("command_start", []): | ||||
|         conf["alconna_use_command_start"] = True | ||||
|     return conf | ||||
							
								
								
									
										405
									
								
								src/utils/base/data.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										405
									
								
								src/utils/base/data.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,405 @@ | ||||
| import os | ||||
| import pickle | ||||
| import sqlite3 | ||||
| from types import NoneType | ||||
| from typing import Any, Callable | ||||
| from packaging.version import parse | ||||
| import inspect | ||||
| import nonebot | ||||
| import pydantic | ||||
| from pydantic import BaseModel | ||||
|  | ||||
|  | ||||
| class LiteModel(BaseModel): | ||||
|     TABLE_NAME: str = None | ||||
|     id: int = None | ||||
|  | ||||
|     def dump(self, *args, **kwargs): | ||||
|         if parse(pydantic.__version__) < parse("2.0.0"): | ||||
|             return self.dict(*args, **kwargs) | ||||
|         else: | ||||
|             return self.model_dump(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| class Database: | ||||
|     def __init__(self, db_name: str): | ||||
|  | ||||
|         if os.path.dirname(db_name) != "" and not os.path.exists(os.path.dirname(db_name)): | ||||
|             os.makedirs(os.path.dirname(db_name)) | ||||
|  | ||||
|         self.db_name = db_name | ||||
|         self.conn = sqlite3.connect(db_name) | ||||
|         self.cursor = self.conn.cursor() | ||||
|  | ||||
|         self._on_save_callbacks = [] | ||||
|  | ||||
|     def where_one(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> LiteModel | Any | None: | ||||
|         """查询第一个 | ||||
|         Args: | ||||
|             model: 数据模型实例 | ||||
|             condition: 查询条件,不给定则查询所有 | ||||
|             *args: 参数化查询参数 | ||||
|             default: 默认值 | ||||
|  | ||||
|         Returns: | ||||
|  | ||||
|         """ | ||||
|         all_results = self.where_all(model, condition, *args) | ||||
|         return all_results[0] if all_results else default | ||||
|  | ||||
|     def where_all(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> list[LiteModel | Any] | None: | ||||
|         """查询所有 | ||||
|         Args: | ||||
|             model: 数据模型实例 | ||||
|             condition: 查询条件,不给定则查询所有 | ||||
|             *args: 参数化查询参数 | ||||
|             default: 默认值 | ||||
|  | ||||
|         Returns: | ||||
|  | ||||
|         """ | ||||
|         table_name = model.TABLE_NAME | ||||
|         model_type = type(model) | ||||
|         nonebot.logger.debug(f"Selecting {model.TABLE_NAME} WHERE {condition.replace('?', '%s') % args}") | ||||
|         if not table_name: | ||||
|             raise ValueError(f"数据模型{model_type.__name__}未提供表名") | ||||
|  | ||||
|         # condition = f"WHERE {condition}" | ||||
|         # print(f"SELECT * FROM {table_name} {condition}", args) | ||||
|         # if len(args) == 0: | ||||
|         #     results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}").fetchall() | ||||
|         # else: | ||||
|         #     results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}", args).fetchall() | ||||
|         if condition: | ||||
|             results = self.cursor.execute(f"SELECT * FROM {table_name} WHERE {condition}", args).fetchall() | ||||
|         else: | ||||
|             results = self.cursor.execute(f"SELECT * FROM {table_name}").fetchall() | ||||
|         fields = [description[0] for description in self.cursor.description] | ||||
|         if not results: | ||||
|             return default | ||||
|         else: | ||||
|             return [model_type(**self._load(dict(zip(fields, result)))) for result in results] | ||||
|  | ||||
|     def save(self, *args: LiteModel): | ||||
|         """增/改操作 | ||||
|         Args: | ||||
|             *args: | ||||
|         Returns: | ||||
|         """ | ||||
|         table_list = [item[0] for item in self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()] | ||||
|         for model in args: | ||||
|             nonebot.logger.debug(f"Upserting {model}") | ||||
|             if not model.TABLE_NAME: | ||||
|                 raise ValueError(f"数据模型 {model.__class__.__name__} 未提供表名") | ||||
|             elif model.TABLE_NAME not in table_list: | ||||
|                 raise ValueError(f"数据模型 {model.__class__.__name__} 表 {model.TABLE_NAME} 不存在,请先迁移") | ||||
|             else: | ||||
|                 self._save(model.dump(by_alias=True)) | ||||
|  | ||||
|             for callback in self._on_save_callbacks: | ||||
|                 callback(model) | ||||
|  | ||||
|     def _save(self, obj: Any) -> Any: | ||||
|         # obj = copy.deepcopy(obj) | ||||
|         if isinstance(obj, dict): | ||||
|             table_name = obj.get("TABLE_NAME") | ||||
|             row_id = obj.get("id") | ||||
|             new_obj = {} | ||||
|             for field, value in obj.items(): | ||||
|                 if isinstance(value, self.ITERABLE_TYPE): | ||||
|                     new_obj[self._get_stored_field_prefix(value) + field] = self._save(value)  # self._save(value)  # -> bytes | ||||
|                 elif isinstance(value, self.BASIC_TYPE): | ||||
|                     new_obj[field] = value | ||||
|                 else: | ||||
|                     raise ValueError(f"数据模型{table_name}包含不支持的数据类型,字段:{field} 值:{value} 值类型:{type(value)}") | ||||
|             if table_name: | ||||
|                 fields, values = [], [] | ||||
|                 for n_field, n_value in new_obj.items(): | ||||
|                     if n_field not in ["TABLE_NAME", "id"]: | ||||
|                         fields.append(n_field) | ||||
|                         values.append(n_value) | ||||
|                 # 移除TABLE_NAME和id | ||||
|                 fields = list(fields) | ||||
|                 values = list(values) | ||||
|                 if row_id is not None: | ||||
|                     # 如果 _id 不为空,将 'id' 插入到字段列表的开始 | ||||
|                     fields.insert(0, 'id') | ||||
|                     # 将 _id 插入到值列表的开始 | ||||
|                     values.insert(0, row_id) | ||||
|                 fields = ', '.join([f'"{field}"' for field in fields]) | ||||
|                 placeholders = ', '.join('?' for _ in values) | ||||
|                 self.cursor.execute(f"INSERT OR REPLACE INTO {table_name}({fields}) VALUES ({placeholders})", tuple(values)) | ||||
|                 self.conn.commit() | ||||
|                 foreign_id = self.cursor.execute("SELECT last_insert_rowid()").fetchone()[0] | ||||
|                 return f"{self.FOREIGN_KEY_PREFIX}{foreign_id}@{table_name}"  # -> FOREIGN_KEY_123456@{table_name} id@{table_name} | ||||
|             else: | ||||
|                 return pickle.dumps(new_obj)  # -> bytes | ||||
|         elif isinstance(obj, (list, set, tuple)): | ||||
|             obj_type = type(obj)  # 到时候转回去 | ||||
|             new_obj = [] | ||||
|             for item in obj: | ||||
|                 if isinstance(item, self.ITERABLE_TYPE): | ||||
|                     new_obj.append(self._save(item)) | ||||
|                 elif isinstance(item, self.BASIC_TYPE): | ||||
|                     new_obj.append(item) | ||||
|                 else: | ||||
|                     raise ValueError(f"数据模型包含不支持的数据类型,值:{item} 值类型:{type(item)}") | ||||
|             return pickle.dumps(obj_type(new_obj))  # -> bytes | ||||
|         else: | ||||
|             raise ValueError(f"数据模型包含不支持的数据类型,值:{obj} 值类型:{type(obj)}") | ||||
|  | ||||
|     def _load(self, obj: Any) -> Any: | ||||
|  | ||||
|         if isinstance(obj, dict): | ||||
|  | ||||
|             new_obj = {} | ||||
|  | ||||
|             for field, value in obj.items(): | ||||
|  | ||||
|                 field: str | ||||
|  | ||||
|                 if field.startswith(self.BYTES_PREFIX): | ||||
|                     if isinstance(value, bytes): | ||||
|                         new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value)) | ||||
|                     else:  # 从value字段可能为None,fix at 2024/6/13 | ||||
|                         pass | ||||
|                         # 暂时不作处理,后面再修 | ||||
|  | ||||
|                 elif field.startswith(self.FOREIGN_KEY_PREFIX): | ||||
|  | ||||
|                     new_obj[field.replace(self.FOREIGN_KEY_PREFIX, "")] = self._load(self._get_foreign_data(value)) | ||||
|  | ||||
|                 else: | ||||
|                     new_obj[field] = value | ||||
|             return new_obj | ||||
|         elif isinstance(obj, (list, set, tuple)): | ||||
|  | ||||
|             new_obj = [] | ||||
|             for item in obj: | ||||
|  | ||||
|                 if isinstance(item, bytes): | ||||
|  | ||||
|                     # 对bytes进行尝试解析,解析失败则返回原始bytes | ||||
|                     try: | ||||
|                         new_obj.append(self._load(pickle.loads(item))) | ||||
|                     except Exception as e: | ||||
|                         new_obj.append(self._load(item)) | ||||
|  | ||||
|                 elif isinstance(item, str) and item.startswith(self.FOREIGN_KEY_PREFIX): | ||||
|                     new_obj.append(self._load(self._get_foreign_data(item))) | ||||
|                 else: | ||||
|                     new_obj.append(self._load(item)) | ||||
|             return new_obj | ||||
|         else: | ||||
|             return obj | ||||
|  | ||||
|     def delete(self, model: LiteModel, condition: str, *args: Any, allow_empty: bool = False): | ||||
|         """ | ||||
|         删除满足条件的数据 | ||||
|         Args: | ||||
|             allow_empty: 允许空条件删除整个表 | ||||
|             model: | ||||
|             condition: | ||||
|             *args: | ||||
|  | ||||
|         Returns: | ||||
|  | ||||
|         """ | ||||
|         table_name = model.TABLE_NAME | ||||
|         nonebot.logger.debug(f"Deleting {model} WHERE {condition} {args}") | ||||
|         if not table_name: | ||||
|             raise ValueError(f"数据模型{model.__class__.__name__}未提供表名") | ||||
|         if model.id is not None: | ||||
|             condition = f"id = {model.id}" | ||||
|         if not condition and not allow_empty: | ||||
|             raise ValueError("删除操作必须提供条件") | ||||
|         self.cursor.execute(f"DELETE FROM {table_name} WHERE {condition}", args) | ||||
|         self.conn.commit() | ||||
|  | ||||
|     def auto_migrate(self, *args: LiteModel): | ||||
|  | ||||
|         """ | ||||
|         自动迁移模型 | ||||
|         Args: | ||||
|             *args: 模型类实例化对象,支持空默认值,不支持嵌套迁移 | ||||
|  | ||||
|         Returns: | ||||
|  | ||||
|         """ | ||||
|         for model in args: | ||||
|             if not model.TABLE_NAME: | ||||
|                 raise ValueError(f"数据模型{type(model).__name__}未提供表名") | ||||
|  | ||||
|             # 若无则创建表 | ||||
|             self.cursor.execute( | ||||
|                 f'CREATE TABLE IF NOT EXISTS "{model.TABLE_NAME}" (id INTEGER PRIMARY KEY AUTOINCREMENT)' | ||||
|             ) | ||||
|  | ||||
|             # 获取表结构,field -> SqliteType | ||||
|             new_structure = {} | ||||
|             for n_field, n_value in model.dump(by_alias=True).items(): | ||||
|                 if n_field not in ["TABLE_NAME", "id"]: | ||||
|                     new_structure[self._get_stored_field_prefix(n_value) + n_field] = self._get_stored_type(n_value) | ||||
|  | ||||
|             # 原有的字段列表 | ||||
|             existing_structure = dict([(column[1], column[2]) for column in self.cursor.execute(f'PRAGMA table_info({model.TABLE_NAME})').fetchall()]) | ||||
|             # 检测缺失字段,由于SQLite是动态类型,所以不需要检测类型 | ||||
|             for n_field, n_type in new_structure.items(): | ||||
|                 if n_field not in existing_structure.keys() and n_field.lower() not in ["id", "table_name"]: | ||||
|                     default_value = self.DEFAULT_MAPPING.get(n_type, 'NULL') | ||||
|                     self.cursor.execute( | ||||
|                         f"ALTER TABLE '{model.TABLE_NAME}' ADD COLUMN {n_field} {n_type} DEFAULT {self.DEFAULT_MAPPING.get(n_type, default_value)}" | ||||
|                     ) | ||||
|  | ||||
|             # 检测多余字段进行删除 | ||||
|             for e_field in existing_structure.keys(): | ||||
|                 if e_field not in new_structure.keys() and e_field.lower() not in ['id']: | ||||
|                     self.cursor.execute( | ||||
|                         f'ALTER TABLE "{model.TABLE_NAME}" DROP COLUMN "{e_field}"' | ||||
|                     ) | ||||
|         self.conn.commit() | ||||
|         # 已完成 | ||||
|  | ||||
|     def _get_stored_field_prefix(self, value) -> str: | ||||
|         """根据类型获取存储字段前缀,一定在后加上字段名 | ||||
|         * -> "" | ||||
|         Args: | ||||
|             value: 储存的值 | ||||
|  | ||||
|         Returns: | ||||
|             Sqlite3存储字段 | ||||
|         """ | ||||
|  | ||||
|         if isinstance(value, LiteModel) or isinstance(value, dict) and "TABLE_NAME" in value: | ||||
|             return self.FOREIGN_KEY_PREFIX | ||||
|         elif type(value) in self.ITERABLE_TYPE: | ||||
|             return self.BYTES_PREFIX | ||||
|         return "" | ||||
|  | ||||
|     def _get_stored_type(self, value) -> str: | ||||
|         """获取存储类型 | ||||
|  | ||||
|         Args: | ||||
|             value: 储存的值 | ||||
|  | ||||
|         Returns: | ||||
|             Sqlite3存储类型 | ||||
|         """ | ||||
|         if isinstance(value, dict) and "TABLE_NAME" in value: | ||||
|             # 是一个模型字典,储存外键 | ||||
|             return "INTEGER" | ||||
|         return self.TYPE_MAPPING.get(type(value), "TEXT") | ||||
|  | ||||
|     def _get_foreign_data(self, foreign_value: str) -> dict: | ||||
|         """ | ||||
|         获取外键数据 | ||||
|         Args: | ||||
|             foreign_value: | ||||
|  | ||||
|         Returns: | ||||
|  | ||||
|         """ | ||||
|         foreign_value = foreign_value.replace(self.FOREIGN_KEY_PREFIX, "") | ||||
|         table_name = foreign_value.split("@")[-1] | ||||
|         foreign_id = foreign_value.split("@")[0] | ||||
|         fields = [description[1] for description in self.cursor.execute(f"PRAGMA table_info({table_name})").fetchall()] | ||||
|         result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone() | ||||
|         return dict(zip(fields, result)) | ||||
|  | ||||
|     def on_save(self, func: Callable[[LiteModel | Any], None]): | ||||
|         """ | ||||
|         装饰一个可调用对象使其在储存数据模型时被调用 | ||||
|         Args: | ||||
|             func: | ||||
|         Returns: | ||||
|         """ | ||||
|  | ||||
|         def wrapper(model): | ||||
|             # 检查被装饰函数声明的model类型和传入的model类型是否一致 | ||||
|             sign = inspect.signature(func) | ||||
|             if param := sign.parameters.get("model"): | ||||
|                 if isinstance(model, param.annotation): | ||||
|                     pass | ||||
|                 else: | ||||
|                     return | ||||
|             else: | ||||
|                 return | ||||
|             result = func(model) | ||||
|             for callback in self._on_save_callbacks: | ||||
|                 callback(result) | ||||
|             return result | ||||
|  | ||||
|         self._on_save_callbacks.append(wrapper) | ||||
|         return wrapper | ||||
|  | ||||
|     TYPE_MAPPING = { | ||||
|             int      : "INTEGER", | ||||
|             float    : "REAL", | ||||
|             str      : "TEXT", | ||||
|             bool     : "INTEGER", | ||||
|             bytes    : "BLOB", | ||||
|             NoneType : "NULL", | ||||
|             # dict     : "TEXT", | ||||
|             # list     : "TEXT", | ||||
|             # tuple    : "TEXT", | ||||
|             # set      : "TEXT", | ||||
|  | ||||
|             dict     : "BLOB",  # LITEYUKIDICT{key_name} | ||||
|             list     : "BLOB",  # LITEYUKILIST{key_name} | ||||
|             tuple    : "BLOB",  # LITEYUKITUPLE{key_name} | ||||
|             set      : "BLOB",  # LITEYUKISET{key_name} | ||||
|             LiteModel: "TEXT"  # FOREIGN_KEY_{table_name} | ||||
|     } | ||||
|     DEFAULT_MAPPING = { | ||||
|             "TEXT"   : "''", | ||||
|             "INTEGER": 0, | ||||
|             "REAL"   : 0.0, | ||||
|             "BLOB"   : None, | ||||
|             "NULL"   : None | ||||
|     } | ||||
|  | ||||
|     # 基础类型 | ||||
|     BASIC_TYPE = (int, float, str, bool, bytes, NoneType) | ||||
|     # 可序列化类型 | ||||
|     ITERABLE_TYPE = (dict, list, tuple, set, LiteModel) | ||||
|  | ||||
|     # 外键前缀 | ||||
|     FOREIGN_KEY_PREFIX = "FOREIGN_KEY_" | ||||
|     # 转换为的字节前缀 | ||||
|     BYTES_PREFIX = "PICKLE_BYTES_" | ||||
|  | ||||
|     # transaction tx 事务操作 | ||||
|     def first(self, model: LiteModel) -> "Database": | ||||
|         pass | ||||
|  | ||||
|     def where(self, condition: str, *args) -> "Database": | ||||
|         pass | ||||
|  | ||||
|     def limit(self, limit: int) -> "Database": | ||||
|         pass | ||||
|  | ||||
|     def order(self, order: str) -> "Database": | ||||
|         pass | ||||
|  | ||||
|  | ||||
| def check_sqlite_keyword(name): | ||||
|     sqlite_keywords = [ | ||||
|             "ABORT", "ACTION", "ADD", "AFTER", "ALL", "ALTER", "ANALYZE", "AND", "AS", "ASC", | ||||
|             "ATTACH", "AUTOINCREMENT", "BEFORE", "BEGIN", "BETWEEN", "BY", "CASCADE", "CASE", | ||||
|             "CAST", "CHECK", "COLLATE", "COLUMN", "COMMIT", "CONFLICT", "CONSTRAINT", "CREATE", | ||||
|             "CROSS", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP", "DATABASE", "DEFAULT", | ||||
|             "DEFERRABLE", "DEFERRED", "DELETE", "DESC", "DETACH", "DISTINCT", "DROP", "EACH", | ||||
|             "ELSE", "END", "ESCAPE", "EXCEPT", "EXCLUSIVE", "EXISTS", "EXPLAIN", "FAIL", "FOR", | ||||
|             "FOREIGN", "FROM", "FULL", "GLOB", "GROUP", "HAVING", "IF", "IGNORE", "IMMEDIATE", | ||||
|             "IN", "INDEX", "INDEXED", "INITIALLY", "INNER", "INSERT", "INSTEAD", "INTERSECT", | ||||
|             "INTO", "IS", "ISNULL", "JOIN", "KEY", "LEFT", "LIKE", "LIMIT", "MATCH", "NATURAL", | ||||
|             "NO", "NOT", "NOTNULL", "NULL", "OF", "OFFSET", "ON", "OR", "ORDER", "OUTER", "PLAN", | ||||
|             "PRAGMA", "PRIMARY", "QUERY", "RAISE", "RECURSIVE", "REFERENCES", "REGEXP", "REINDEX", | ||||
|             "RELEASE", "RENAME", "REPLACE", "RESTRICT", "RIGHT", "ROLLBACK", "ROW", "SAVEPOINT", | ||||
|             "SELECT", "SET", "TABLE", "TEMP", "TEMPORARY", "THEN", "TO", "TRANSACTION", "TRIGGER", | ||||
|             "UNION", "UNIQUE", "UPDATE", "USING", "VACUUM", "VALUES", "VIEW", "VIRTUAL", "WHEN", | ||||
|             "WHERE", "WITH", "WITHOUT" | ||||
|     ] | ||||
|     return True | ||||
|     # if name.upper() in sqlite_keywords: | ||||
|     #     raise ValueError(f"'{name}' 是SQLite保留字,不建议使用,请更换名称") | ||||
							
								
								
									
										96
									
								
								src/utils/base/data_manager.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										96
									
								
								src/utils/base/data_manager.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,96 @@ | ||||
| import os | ||||
|  | ||||
| from pydantic import Field | ||||
|  | ||||
| from .data import Database, LiteModel, Database | ||||
|  | ||||
| DATA_PATH = "data/liteyuki" | ||||
|  | ||||
| user_db: Database = Database(os.path.join(DATA_PATH, "users.ldb")) | ||||
| group_db: Database = Database(os.path.join(DATA_PATH, "groups.ldb")) | ||||
| plugin_db: Database = Database(os.path.join(DATA_PATH, "plugins.ldb")) | ||||
| common_db: Database = Database(os.path.join(DATA_PATH, "common.ldb")) | ||||
|  | ||||
| # 内存数据库,临时用于存储数据 | ||||
| memory_database = { | ||||
|  | ||||
| } | ||||
|  | ||||
|  | ||||
| class User(LiteModel): | ||||
|     TABLE_NAME: str = "user" | ||||
|     user_id: str = Field(str(), alias="user_id") | ||||
|     username: str = Field(str(), alias="username") | ||||
|     profile: dict[str, str] = Field(dict(), alias="profile") | ||||
|     enabled_plugins: list[str] = Field(list(), alias="enabled_plugins") | ||||
|     disabled_plugins: list[str] = Field(list(), alias="disabled_plugins") | ||||
|  | ||||
|  | ||||
| class Group(LiteModel): | ||||
|     TABLE_NAME: str = "group_chat" | ||||
|     # Group是一个关键字,所以这里用GroupChat | ||||
|     group_id: str = Field(str(), alias="group_id") | ||||
|     group_name: str = Field(str(), alias="group_name") | ||||
|     enabled_plugins: list[str] = Field([], alias="enabled_plugins") | ||||
|     disabled_plugins: list[str] = Field([], alias="disabled_plugins") | ||||
|     enable: bool = Field(True, alias="enable")  # 群聊全局机器人是否启用 | ||||
|     config: dict = Field({}, alias="config") | ||||
|  | ||||
|  | ||||
| class InstalledPlugin(LiteModel): | ||||
|     TABLE_NAME: str = "installed_plugin" | ||||
|     module_name: str = Field(str(), alias="module_name") | ||||
|     version: str = Field(str(), alias="version") | ||||
|  | ||||
|  | ||||
| class GlobalPlugin(LiteModel): | ||||
|     TABLE_NAME: str = "global_plugin" | ||||
|     liteyuki: bool = Field(True, alias="liteyuki")  # 是否为LiteYuki插件 | ||||
|     module_name: str = Field(str(), alias="module_name") | ||||
|     enabled: bool = Field(True, alias="enabled") | ||||
|  | ||||
|  | ||||
| class StoredConfig(LiteModel): | ||||
|     TABLE_NAME: str = "stored_config" | ||||
|     config: dict = {} | ||||
|  | ||||
|  | ||||
| class TempConfig(LiteModel): | ||||
|     """储存临时键值对的表""" | ||||
|     TABLE_NAME: str = "temp_data" | ||||
|     data: dict = {} | ||||
|  | ||||
|  | ||||
| def auto_migrate(): | ||||
|     user_db.auto_migrate(User()) | ||||
|     group_db.auto_migrate(Group()) | ||||
|     plugin_db.auto_migrate(InstalledPlugin(), GlobalPlugin()) | ||||
|     common_db.auto_migrate(GlobalPlugin(), StoredConfig(), TempConfig()) | ||||
|  | ||||
|  | ||||
| def set_memory_data(key: str, value) -> None: | ||||
|     """ | ||||
|     设置内存数据库的数据,类似于redis | ||||
|     Args: | ||||
|         key: | ||||
|         value: | ||||
|  | ||||
|     Returns: | ||||
|  | ||||
|     """ | ||||
|     return memory_database.update({ | ||||
|             key: value | ||||
|     }) | ||||
|  | ||||
|  | ||||
| def get_memory_data(key: str, default=None) -> any: | ||||
|     """ | ||||
|     获取内存数据库的数据,类似于redis | ||||
|     Args: | ||||
|         key: | ||||
|         default: | ||||
|  | ||||
|     Returns: | ||||
|  | ||||
|     """ | ||||
|     return memory_database.get(key, default) | ||||
							
								
								
									
										216
									
								
								src/utils/base/language.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										216
									
								
								src/utils/base/language.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,216 @@ | ||||
| """ | ||||
| 语言模块,添加对多语言的支持 | ||||
| """ | ||||
|  | ||||
| import json | ||||
| import locale | ||||
| import os | ||||
| from typing import Any, overload | ||||
|  | ||||
| import nonebot | ||||
|  | ||||
| from .config import config, get_config | ||||
| from .data_manager import User, user_db | ||||
|  | ||||
| _language_data = { | ||||
|         "en": { | ||||
|                 "name": "English", | ||||
|         } | ||||
| } | ||||
|  | ||||
| _user_lang = { | ||||
|         "user_id": "zh-CN" | ||||
| } | ||||
|  | ||||
|  | ||||
| def load_from_lang(file_path: str, lang_code: str = None): | ||||
|     """ | ||||
|     从lang文件中加载语言数据,用于简单的文本键值对 | ||||
|  | ||||
|     Args: | ||||
|         file_path: lang文件路径 | ||||
|         lang_code: 语言代码,如果为None则从文件名中获取 | ||||
|     """ | ||||
|     try: | ||||
|         if lang_code is None: | ||||
|             lang_code = os.path.basename(file_path).split(".")[0] | ||||
|         with open(file_path, "r", encoding="utf-8") as file: | ||||
|             data = {} | ||||
|             for line in file: | ||||
|                 line = line.strip() | ||||
|                 if not line or line.startswith("#"):  # 空行或注释 | ||||
|                     continue | ||||
|                 key, value = line.split("=", 1) | ||||
|                 data[key.strip()] = value.strip() | ||||
|             if lang_code not in _language_data: | ||||
|                 _language_data[lang_code] = {} | ||||
|             _language_data[lang_code].update(data) | ||||
|         nonebot.logger.debug(f"Loaded language data from {file_path}") | ||||
|     except Exception as e: | ||||
|         nonebot.logger.error(f"Failed to load language data from {file_path}: {e}") | ||||
|  | ||||
|  | ||||
| def load_from_json(file_path: str, lang_code: str = None): | ||||
|     """ | ||||
|     从json文件中加载语言数据,可以定义一些变量 | ||||
|  | ||||
|     Args: | ||||
|         lang_code: 语言代码,如果为None则从文件名中获取 | ||||
|         file_path: json文件路径 | ||||
|     """ | ||||
|     try: | ||||
|         if lang_code is None: | ||||
|             lang_code = os.path.basename(file_path).split(".")[0] | ||||
|         with open(file_path, "r", encoding="utf-8") as file: | ||||
|             data = json.load(file) | ||||
|             if lang_code not in _language_data: | ||||
|                 _language_data[lang_code] = {} | ||||
|             _language_data[lang_code].update(data) | ||||
|         nonebot.logger.debug(f"Loaded language data from {file_path}") | ||||
|     except Exception as e: | ||||
|         nonebot.logger.error(f"Failed to load language data from {file_path}: {e}") | ||||
|  | ||||
|  | ||||
| def load_from_dir(dir_path: str): | ||||
|     """ | ||||
|     从目录中加载语言数据 | ||||
|  | ||||
|     Args: | ||||
|         dir_path: 目录路径 | ||||
|     """ | ||||
|     for file in os.listdir(dir_path): | ||||
|         try: | ||||
|             file_path = os.path.join(dir_path, file) | ||||
|             if os.path.isfile(file_path): | ||||
|                 if file.endswith(".lang"): | ||||
|                     load_from_lang(file_path) | ||||
|                 elif file.endswith(".json"): | ||||
|                     load_from_json(file_path) | ||||
|         except Exception as e: | ||||
|             nonebot.logger.error(f"Failed to load language data from {file}: {e}") | ||||
|             continue | ||||
|  | ||||
|  | ||||
| def load_from_dict(data: dict, lang_code: str): | ||||
|     """ | ||||
|     从字典中加载语言数据 | ||||
|  | ||||
|     Args: | ||||
|         lang_code: 语言代码 | ||||
|         data: 字典数据 | ||||
|     """ | ||||
|     if lang_code not in _language_data: | ||||
|         _language_data[lang_code] = {} | ||||
|     _language_data[lang_code].update(data) | ||||
|  | ||||
|  | ||||
| class Language: | ||||
|     # 三重fallback | ||||
|     # 用户语言 > 默认语言/系统语言 > zh-CN | ||||
|     def __init__(self, lang_code: str = None, fallback_lang_code: str = None): | ||||
|         self.lang_code = lang_code | ||||
|  | ||||
|         if self.lang_code is None: | ||||
|             self.lang_code = get_default_lang_code() | ||||
|  | ||||
|         self.fallback_lang_code = fallback_lang_code | ||||
|         if self.fallback_lang_code is None: | ||||
|             self.fallback_lang_code = config.get("default_language", get_system_lang_code()) | ||||
|  | ||||
|     def get(self, item: str, *args, **kwargs) -> str | Any: | ||||
|         """ | ||||
|         获取当前语言文本,kwargs中的default参数为默认文本 | ||||
|         Args: | ||||
|             item:   文本键 | ||||
|             *args:  格式化参数 | ||||
|             **kwargs: 格式化参数 | ||||
|  | ||||
|         Returns: | ||||
|             str: 当前语言的文本 | ||||
|  | ||||
|         """ | ||||
|         default = kwargs.pop("default", None) | ||||
|         fallback = (self.lang_code, self.fallback_lang_code, "zh-CN") | ||||
|  | ||||
|         for lang_code in fallback: | ||||
|             if lang_code in _language_data and item in _language_data[lang_code]: | ||||
|                 trans: str = _language_data[lang_code][item] | ||||
|                 try: | ||||
|                     return trans.format(*args, **kwargs) | ||||
|                 except Exception as e: | ||||
|                     nonebot.logger.warning(f"Failed to format language data: {e}") | ||||
|                     return trans | ||||
|         return default or item | ||||
|  | ||||
|     def get_many(self, *args: str, **kwargs) -> dict[str, str]: | ||||
|         """ | ||||
|         获取多个文本 | ||||
|         Args: | ||||
|             *args:  文本键 | ||||
|             **kwargs: 文本键和默认文本 | ||||
|  | ||||
|         Returns: | ||||
|             dict: 多个文本 | ||||
|         """ | ||||
|         args_data = {item: self.get(item) for item in args} | ||||
|         kwargs_data = {item: self.get(item, default=default) for item, default in kwargs.items()} | ||||
|         args_data.update(kwargs_data) | ||||
|         return args_data | ||||
|  | ||||
|  | ||||
| def change_user_lang(user_id: str, lang_code: str): | ||||
|     """ | ||||
|     修改用户的语言,同时储存到数据库和内存中 | ||||
|     """ | ||||
|     user = user_db.where_one(User(), "user_id = ?", user_id, default=User(user_id=user_id)) | ||||
|     user.profile["lang"] = lang_code | ||||
|     user_db.save(user) | ||||
|     _user_lang[user_id] = lang_code | ||||
|  | ||||
|  | ||||
| def get_user_lang(user_id: str) -> Language: | ||||
|     """ | ||||
|     获取用户的语言实例,优先从内存中获取 | ||||
|     """ | ||||
|     user_id = str(user_id) | ||||
|  | ||||
|     if user_id not in _user_lang: | ||||
|         nonebot.logger.debug(f"Loading user language for {user_id}") | ||||
|         user = user_db.where_one( | ||||
|             User(), "user_id = ?", user_id, default=User( | ||||
|                 user_id=user_id, | ||||
|                 username="Unknown" | ||||
|             ) | ||||
|         ) | ||||
|         lang_code = user.profile.get("lang", get_default_lang_code()) | ||||
|         _user_lang[user_id] = lang_code | ||||
|  | ||||
|     return Language(_user_lang[user_id]) | ||||
|  | ||||
|  | ||||
| def get_system_lang_code() -> str: | ||||
|     """ | ||||
|     获取系统语言代码 | ||||
|     """ | ||||
|     return locale.getdefaultlocale()[0].replace('_', '-') | ||||
|  | ||||
|  | ||||
| def get_default_lang_code() -> str: | ||||
|     """ | ||||
|     获取默认语言代码,若没有设置则使用系统语言 | ||||
|     Returns: | ||||
|  | ||||
|     """ | ||||
|     return get_config("default_language", default=get_system_lang_code()) | ||||
|  | ||||
|  | ||||
| def get_all_lang() -> dict[str, str]: | ||||
|     """ | ||||
|     获取所有语言 | ||||
|     Returns | ||||
|         {'en': 'English'} | ||||
|     """ | ||||
|     d = {} | ||||
|     for key in _language_data: | ||||
|         d[key] = _language_data[key].get("language.name", key) | ||||
|     return d | ||||
							
								
								
									
										79
									
								
								src/utils/base/log.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								src/utils/base/log.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | ||||
| import sys | ||||
| import loguru | ||||
| from typing import TYPE_CHECKING | ||||
| from .config import load_from_yaml | ||||
| from .language import Language, get_default_lang_code | ||||
|  | ||||
| logger = loguru.logger | ||||
| if TYPE_CHECKING: | ||||
|     # avoid sphinx autodoc resolve annotation failed | ||||
|     # because loguru module do not have `Logger` class actually | ||||
|     from loguru import Record | ||||
|  | ||||
|  | ||||
| def default_filter(record: "Record"): | ||||
|     """默认的日志过滤器,根据 `config.log_level` 配置改变日志等级。""" | ||||
|     log_level = record["extra"].get("nonebot_log_level", "INFO") | ||||
|     levelno = logger.level(log_level).no if isinstance(log_level, str) else log_level | ||||
|     return record["level"].no >= levelno | ||||
|  | ||||
|  | ||||
| # DEBUG日志格式 | ||||
| debug_format: str = ( | ||||
|         "<c>{time:YYYY-MM-DD HH:mm:ss}</c> " | ||||
|         "<lvl>[{level.icon}]</lvl> " | ||||
|         "<c><{name}.{module}.{function}:{line}></c> " | ||||
|         "{message}" | ||||
| ) | ||||
|  | ||||
| # 默认日志格式 | ||||
| default_format: str = ( | ||||
|         "<c>{time:MM-DD HH:mm:ss}</c> " | ||||
|         "<lvl>[{level.icon}]</lvl> " | ||||
|         "<c><{name}></c> " | ||||
|         "{message}" | ||||
| ) | ||||
|  | ||||
|  | ||||
| def get_format(level: str) -> str: | ||||
|     if level == "DEBUG": | ||||
|         return debug_format | ||||
|     else: | ||||
|         return default_format | ||||
|  | ||||
|  | ||||
| logger = loguru.logger.bind() | ||||
|  | ||||
|  | ||||
| def init_log(): | ||||
|     """ | ||||
|     在语言加载完成后执行 | ||||
|     Returns: | ||||
|  | ||||
|     """ | ||||
|     global logger | ||||
|  | ||||
|     config = load_from_yaml("config.yml") | ||||
|  | ||||
|     logger.remove() | ||||
|     logger.add( | ||||
|         sys.stdout, | ||||
|         level=0, | ||||
|         diagnose=False, | ||||
|         filter=default_filter, | ||||
|         format=get_format(config.get("log_level", "INFO")), | ||||
|     ) | ||||
|     show_icon = config.get("log_icon", True) | ||||
|     lang = Language(get_default_lang_code()) | ||||
|  | ||||
|     debug = lang.get("log.debug", default="==DEBUG") | ||||
|     info = lang.get("log.info", default="===INFO") | ||||
|     success = lang.get("log.success", default="SUCCESS") | ||||
|     warning = lang.get("log.warning", default="WARNING") | ||||
|     error = lang.get("log.error", default="==ERROR") | ||||
|  | ||||
|     logger.level("DEBUG", color="<blue>", icon=f"{'🐛' if show_icon else ''}{debug}") | ||||
|     logger.level("INFO", color="<normal>", icon=f"{'ℹ️' if show_icon else ''}{info}") | ||||
|     logger.level("SUCCESS", color="<green>", icon=f"{'✅' if show_icon else ''}{success}") | ||||
|     logger.level("WARNING", color="<yellow>", icon=f"{'⚠️' if show_icon else ''}{warning}") | ||||
|     logger.level("ERROR", color="<red>", icon=f"{'⭕' if show_icon else ''}{error}") | ||||
							
								
								
									
										90
									
								
								src/utils/base/ly_api.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								src/utils/base/ly_api.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,90 @@ | ||||
| import json | ||||
| import os.path | ||||
| import platform | ||||
|  | ||||
| import aiohttp | ||||
| import nonebot | ||||
| import psutil | ||||
| import requests | ||||
| from aiohttp import FormData | ||||
|  | ||||
| from .. import __VERSION_I__, __VERSION__, __NAME__ | ||||
| from .config import load_from_yaml | ||||
|  | ||||
|  | ||||
| class LiteyukiAPI: | ||||
|     def __init__(self): | ||||
|         self.liteyuki_id = None | ||||
|         if os.path.exists("data/liteyuki/liteyuki.json"): | ||||
|             with open("data/liteyuki/liteyuki.json", "rb") as f: | ||||
|                 self.data = json.loads(f.read()) | ||||
|                 self.liteyuki_id = self.data.get("liteyuki_id") | ||||
|         self.report = load_from_yaml("config.yml").get("auto_report", True) | ||||
|         if self.report: | ||||
|             nonebot.logger.info("Auto bug report is enabled") | ||||
|  | ||||
|     @property | ||||
|     def device_info(self) -> dict: | ||||
|         """ | ||||
|         获取设备信息 | ||||
|         Returns: | ||||
|  | ||||
|         """ | ||||
|         return { | ||||
|                 "name"        : __NAME__, | ||||
|                 "version"     : __VERSION__, | ||||
|                 "version_i"   : __VERSION_I__, | ||||
|                 "python"      : f"{platform.python_implementation()} {platform.python_version()}", | ||||
|                 "os"          : f"{platform.system()} {platform.version()} {platform.machine()}", | ||||
|                 "cpu"         : f"{psutil.cpu_count(logical=False)}c{psutil.cpu_count()}t{psutil.cpu_freq().current}MHz", | ||||
|                 "memory_total": f"{psutil.virtual_memory().total / 1024 / 1024 / 1024:.2f}GB", | ||||
|                 "memory_used" : f"{psutil.virtual_memory().used / 1024 / 1024 / 1024:.2f}GB", | ||||
|                 "memory_bot"  : f"{psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024:.2f}MB", | ||||
|                 "disk"        : f"{psutil.disk_usage('/').total / 1024 / 1024 / 1024:.2f}GB" | ||||
|         } | ||||
|  | ||||
|     def bug_report(self, content: str): | ||||
|         """ | ||||
|         提交bug报告 | ||||
|         Args: | ||||
|             content: | ||||
|  | ||||
|         Returns: | ||||
|  | ||||
|         """ | ||||
|         if self.report: | ||||
|             nonebot.logger.warning(f"Reporting bug...: {content}") | ||||
|             url = "https://api.liteyuki.icu/bug_report" | ||||
|             data = { | ||||
|                     "liteyuki_id": self.liteyuki_id, | ||||
|                     "content"    : content, | ||||
|                     "device_info": self.device_info | ||||
|             } | ||||
|             resp = requests.post(url, json=data) | ||||
|             if resp.status_code == 200: | ||||
|                 nonebot.logger.success(f"Bug report sent successfully, report_id: {resp.json().get('report_id')}") | ||||
|             else: | ||||
|                 nonebot.logger.error(f"Bug report failed: {resp.text}") | ||||
|         else: | ||||
|             nonebot.logger.warning(f"Bug report is disabled: {content}") | ||||
|  | ||||
|     async def heartbeat_report(self): | ||||
|         """ | ||||
|         提交心跳,预留接口 | ||||
|         Returns: | ||||
|  | ||||
|         """ | ||||
|         url = "https://api.liteyuki.icu/heartbeat" | ||||
|         data = { | ||||
|                 "liteyuki_id": self.liteyuki_id, | ||||
|                 "version": __VERSION__, | ||||
|         } | ||||
|         async with aiohttp.ClientSession() as session: | ||||
|             async with session.post(url, json=data) as resp: | ||||
|                 if resp.status == 200: | ||||
|                     nonebot.logger.success("Heartbeat sent successfully") | ||||
|                 else: | ||||
|                     nonebot.logger.error(f"Heartbeat failed: {await resp.text()}") | ||||
|  | ||||
|  | ||||
| liteyuki_api = LiteyukiAPI() | ||||
							
								
								
									
										197
									
								
								src/utils/base/ly_function.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										197
									
								
								src/utils/base/ly_function.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,197 @@ | ||||
| """ | ||||
| liteyuki function是一种类似于mcfunction的函数,用于在liteyuki中实现一些功能,例如自定义指令等,也可与Python函数绑定 | ||||
| 使用 /function function_name *args **kwargs来调用 | ||||
| 例如 /function test/hello user_id=123456 | ||||
| 可以用于一些轻量级插件的编写,无需Python代码 | ||||
| SnowyKami | ||||
| """ | ||||
| import asyncio | ||||
| import functools | ||||
| # cmd *args **kwargs | ||||
| # api api_name **kwargs | ||||
| import os | ||||
| from typing import Any, Awaitable, Callable, Coroutine | ||||
|  | ||||
| import nonebot | ||||
| from nonebot import Bot | ||||
| from nonebot.adapters.satori import bot | ||||
| from nonebot.internal.matcher import Matcher | ||||
|  | ||||
| ly_function_extensions = ( | ||||
|         "lyf", | ||||
|         "lyfunction", | ||||
|         "mcfunction" | ||||
| ) | ||||
|  | ||||
| loaded_functions = dict() | ||||
|  | ||||
|  | ||||
| class LiteyukiFunction: | ||||
|     def __init__(self, name: str): | ||||
|         self.name = name | ||||
|         self.functions: list[str] = list() | ||||
|         self.bot: Bot = None | ||||
|         self.kwargs_data = dict() | ||||
|         self.args_data = list() | ||||
|         self.matcher: Matcher = None | ||||
|         self.end = False | ||||
|  | ||||
|         self.sub_tasks: list[asyncio.Task] = list() | ||||
|  | ||||
|     async def __call__(self, *args, **kwargs): | ||||
|         self.kwargs_data.update(kwargs) | ||||
|         self.args_data = list(set(self.args_data + list(args))) | ||||
|         for i, cmd in enumerate(self.functions): | ||||
|             r = await self.execute_line(cmd, i, *args, **kwargs) | ||||
|             if r == 0: | ||||
|                 msg = f"End function {self.name} by line {i}" | ||||
|                 nonebot.logger.debug(msg) | ||||
|                 for task in self.sub_tasks: | ||||
|                     task.cancel(msg) | ||||
|                 return | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"LiteyukiFunction({self.name})" | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return self.__str__() | ||||
|  | ||||
|     async def execute_line(self, cmd: str, line: int = 0, *args, **kwargs) -> Any: | ||||
|         """ | ||||
|         解析一行轻雪函数 | ||||
|         Args: | ||||
|             cmd: 命令 | ||||
|             line: 行数 | ||||
|         Returns: | ||||
|         """ | ||||
|  | ||||
|         try: | ||||
|             if "${" in cmd: | ||||
|                 # 此种情况下,{}内容不用管,只对${}内的内容进行format | ||||
|                 for i in range(len(cmd) - 1): | ||||
|                     if cmd[i] == "$" and cmd[i + 1] == "{": | ||||
|                         end = cmd.find("}", i) | ||||
|                         key = cmd[i + 2:end] | ||||
|                         cmd = cmd.replace(f"${{{key}}}", str(self.kwargs_data.get(key, ""))) | ||||
|             else: | ||||
|                 cmd = cmd.format(*self.args_data, **self.kwargs_data) | ||||
|         except Exception as e: | ||||
|             pass | ||||
|  | ||||
|         no_head = cmd.split(" ", 1)[1] if len(cmd.split(" ")) > 1 else "" | ||||
|         try: | ||||
|             head, cmd_args, cmd_kwargs = self.get_args(cmd) | ||||
|         except Exception as e: | ||||
|             error_msg = f"Parsing error in {self.name} at line {line}: {e}" | ||||
|             nonebot.logger.error(error_msg) | ||||
|             await self.matcher.send(error_msg) | ||||
|             return | ||||
|  | ||||
|         if head == "var": | ||||
|             # 变量定义 | ||||
|             self.kwargs_data.update(cmd_kwargs) | ||||
|  | ||||
|         elif head == "cmd": | ||||
|             # 在当前计算机上执行命令 | ||||
|             os.system(no_head) | ||||
|  | ||||
|         elif head == "api": | ||||
|             # 调用Bot API 需要Bot实例 | ||||
|             await self.bot.call_api(cmd_args[1], **cmd_kwargs) | ||||
|  | ||||
|         elif head == "function": | ||||
|             # 调用轻雪函数 | ||||
|             func = get_function(cmd_args[1]) | ||||
|             func.bot = self.bot | ||||
|             func.matcher = self.matcher | ||||
|             await func(*cmd_args[2:], **cmd_kwargs) | ||||
|  | ||||
|         elif head == "sleep": | ||||
|             # 等待一段时间 | ||||
|             await asyncio.sleep(float(cmd_args[1])) | ||||
|  | ||||
|         elif head == "nohup": | ||||
|             # 挂起运行 | ||||
|             task = asyncio.create_task(self.execute_line(no_head)) | ||||
|             self.sub_tasks.append(task) | ||||
|  | ||||
|         elif head == "end": | ||||
|             # 结束所有函数 | ||||
|             self.end = True | ||||
|             return 0 | ||||
|  | ||||
|  | ||||
|         elif head == "await": | ||||
|             # 等待所有协程执行完毕 | ||||
|             await asyncio.gather(*self.sub_tasks) | ||||
|  | ||||
|     def get_args(self, line: str) -> tuple[str, tuple[str, ...], dict[str, Any]]: | ||||
|         """ | ||||
|         获取参数 | ||||
|         Args: | ||||
|             line: 命令 | ||||
|         Returns: | ||||
|             命令头 参数 关键字 | ||||
|         """ | ||||
|         line = line.replace("\\=", "EQUAL_SIGN") | ||||
|         head = "" | ||||
|         args = list() | ||||
|         kwargs = dict() | ||||
|         for i, arg in enumerate(line.split(" ")): | ||||
|             if "=" in arg: | ||||
|                 key, value = arg.split("=", 1) | ||||
|                 value = value.replace("EQUAL_SIGN", "=") | ||||
|                 try: | ||||
|                     value = eval(value) | ||||
|                 except: | ||||
|                     value = self.kwargs_data.get(value, value) | ||||
|                 kwargs[key] = value | ||||
|             else: | ||||
|                 if i == 0: | ||||
|                     head = arg | ||||
|                 args.append(arg) | ||||
|         return head, tuple(args), kwargs | ||||
|  | ||||
|  | ||||
| def get_function(name: str) -> LiteyukiFunction | None: | ||||
|     """ | ||||
|     获取一个轻雪函数 | ||||
|     Args: | ||||
|         name: 函数名 | ||||
|     Returns: | ||||
|     """ | ||||
|     return loaded_functions.get(name) | ||||
|  | ||||
|  | ||||
| def load_from_dir(path: str): | ||||
|     """ | ||||
|     从目录及其子目录中递归加载所有轻雪函数,类似mcfunction | ||||
|  | ||||
|     Args: | ||||
|         path: 目录路径 | ||||
|     """ | ||||
|     for f in os.listdir(path): | ||||
|         f = os.path.join(path, f) | ||||
|         if os.path.isfile(f): | ||||
|             if f.endswith(ly_function_extensions): | ||||
|                 load_from_file(f) | ||||
|         if os.path.isdir(f): | ||||
|             load_from_dir(f) | ||||
|  | ||||
|  | ||||
| def load_from_file(path: str): | ||||
|     """ | ||||
|     从文件中加载轻雪函数 | ||||
|     Args: | ||||
|         path: | ||||
|     Returns: | ||||
|     """ | ||||
|     with open(path, "r", encoding="utf-8") as f: | ||||
|         name = ".".join(os.path.basename(path).split(".")[:-1]) | ||||
|         func = LiteyukiFunction(name) | ||||
|         for i, line in enumerate(f.read().split("\n")): | ||||
|             if line.startswith("#") or line.strip() == "": | ||||
|                 continue | ||||
|             func.functions.append(line) | ||||
|         loaded_functions[name] = func | ||||
|         nonebot.logger.debug(f"Loaded function {name}") | ||||
							
								
								
									
										8
									
								
								src/utils/base/ly_typing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								src/utils/base/ly_typing.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | ||||
| from nonebot.adapters.onebot import v11, v12 | ||||
| from nonebot.adapters import satori | ||||
|  | ||||
| T_Bot = v11.Bot | v12.Bot | satori.Bot | ||||
| T_GroupMessageEvent = v11.GroupMessageEvent | v12.GroupMessageEvent | ||||
| T_PrivateMessageEvent = v11.PrivateMessageEvent | v12.PrivateMessageEvent | ||||
| T_MessageEvent = v11.MessageEvent | v12.MessageEvent | satori.MessageEvent | ||||
| T_Message = v11.Message | v12.Message | satori.Message | ||||
							
								
								
									
										5
									
								
								src/utils/base/permission.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								src/utils/base/permission.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| from nonebot.adapters.onebot import v11 | ||||
|  | ||||
| GROUP_ADMIN = v11.GROUP_ADMIN | ||||
| GROUP_OWNER = v11.GROUP_OWNER | ||||
|  | ||||
							
								
								
									
										61
									
								
								src/utils/base/reloader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								src/utils/base/reloader.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | ||||
| import threading | ||||
| from multiprocessing import get_context | ||||
|  | ||||
| import nonebot | ||||
| from nonebot import logger | ||||
|  | ||||
| reboot_grace_time_limit: int = 20 | ||||
|  | ||||
| _nb_run = nonebot.run | ||||
|  | ||||
|  | ||||
| class Reloader: | ||||
|     event: threading.Event = None | ||||
|  | ||||
|     @classmethod | ||||
|     def reload(cls, delay: int = 0): | ||||
|         if cls.event is None: | ||||
|             raise RuntimeError() | ||||
|         if delay > 0: | ||||
|             threading.Timer(delay, function=cls.event.set).start() | ||||
|             return | ||||
|         cls.event.set() | ||||
|  | ||||
|  | ||||
| def _run(ev: threading.Event, *args, **kwargs): | ||||
|     Reloader.event = ev | ||||
|     _nb_run(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| def run(*args, **kwargs): | ||||
|     should_exit = False | ||||
|     ctx = get_context("spawn") | ||||
|     while not should_exit: | ||||
|         event = ctx.Event() | ||||
|         process = ctx.Process( | ||||
|             target=_run, | ||||
|             args=( | ||||
|                     event, | ||||
|                     *args, | ||||
|             ), | ||||
|             kwargs=kwargs, | ||||
|         ) | ||||
|         process.start() | ||||
|         while not should_exit: | ||||
|             if event.wait(1): | ||||
|                 logger.info("Receive reboot event") | ||||
|                 process.terminate() | ||||
|                 process.join(reboot_grace_time_limit) | ||||
|                 if process.is_alive(): | ||||
|                     logger.warning( | ||||
|                         f"Cannot shutdown gracefully in {reboot_grace_time_limit} second, force kill process." | ||||
|                     ) | ||||
|                     process.kill() | ||||
|                 break | ||||
|             elif process.is_alive(): | ||||
|                 continue | ||||
|             else: | ||||
|                 should_exit = True | ||||
|  | ||||
|  | ||||
| nonebot.run = run | ||||
							
								
								
									
										260
									
								
								src/utils/base/resource.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										260
									
								
								src/utils/base/resource.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,260 @@ | ||||
| import json | ||||
| import os | ||||
| import shutil | ||||
| import zipfile | ||||
| from typing import Any | ||||
|  | ||||
| import nonebot | ||||
| import yaml | ||||
|  | ||||
| from .data import LiteModel | ||||
| from .language import Language, get_default_lang_code | ||||
| from .ly_function import loaded_functions | ||||
|  | ||||
| _loaded_resource_packs: list["ResourceMetadata"] = []  # 按照加载顺序排序 | ||||
| temp_resource_root = "data/liteyuki/resources" | ||||
| temp_extract_root = "data/liteyuki/temp" | ||||
| lang = Language(get_default_lang_code()) | ||||
|  | ||||
|  | ||||
| class ResourceMetadata(LiteModel): | ||||
|     name: str = "Unknown" | ||||
|     version: str = "0.0.1" | ||||
|     description: str = "Unknown" | ||||
|     path: str = "" | ||||
|     folder: str = "" | ||||
|  | ||||
|  | ||||
| def load_resource_from_dir(path: str): | ||||
|     """ | ||||
|     把资源包按照文件相对路径复制到运行临时文件夹data/liteyuki/resources | ||||
|     Args: | ||||
|         path:  资源文件夹 | ||||
|     Returns: | ||||
|     """ | ||||
|     if os.path.exists(os.path.join(path, "metadata.yml")): | ||||
|         with open(os.path.join(path, "metadata.yml"), "r", encoding="utf-8") as f: | ||||
|             metadata = yaml.safe_load(f) | ||||
|     elif os.path.isfile(path) and path.endswith(".zip"): | ||||
|         # zip文件 | ||||
|         # 临时解压并读取metadata.yml | ||||
|         with zipfile.ZipFile(path, "r") as zip_ref: | ||||
|             # 解压至临时目录 data/liteyuki/temp/{pack_name}.zip | ||||
|             zip_ref.extractall(os.path.join(temp_extract_root, os.path.basename(path))) | ||||
|             with zip_ref.open("metadata.yml") as f: | ||||
|                 metadata = yaml.safe_load(f) | ||||
|         path = os.path.join(temp_extract_root, os.path.basename(path)) | ||||
|     else: | ||||
|         # 没有metadata.yml文件,不是一个资源包 | ||||
|         return | ||||
|     for root, dirs, files in os.walk(path): | ||||
|         for file in files: | ||||
|             relative_path = os.path.relpath(os.path.join(root, file), path) | ||||
|             copy_file(os.path.join(root, file), os.path.join(temp_resource_root, relative_path)) | ||||
|     metadata["path"] = path | ||||
|     metadata["folder"] = os.path.basename(path) | ||||
|  | ||||
|     if os.path.exists(os.path.join(path, "lang")): | ||||
|         # 加载语言 | ||||
|         from src.utils.base.language import load_from_dir | ||||
|         load_from_dir(os.path.join(path, "lang")) | ||||
|  | ||||
|     if os.path.exists(os.path.join(path, "functions")): | ||||
|         # 加载功能 | ||||
|         from src.utils.base.ly_function import load_from_dir | ||||
|         load_from_dir(os.path.join(path, "functions")) | ||||
|  | ||||
|     if os.path.exists(os.path.join(path, "word_bank")): | ||||
|         # 加载词库 | ||||
|         from src.utils.base.word_bank import load_from_dir | ||||
|         load_from_dir(os.path.join(path, "word_bank")) | ||||
|  | ||||
|     _loaded_resource_packs.insert(0, ResourceMetadata(**metadata)) | ||||
|  | ||||
|  | ||||
| def get_path(path: str, abs_path: bool = True, default: Any = None, debug: bool = False) -> str | Any: | ||||
|     """ | ||||
|     获取资源包中的文件 | ||||
|     Args: | ||||
|         debug: 启用调试,每次都会先重载资源 | ||||
|         abs_path: 是否返回绝对路径 | ||||
|         default: 默认 | ||||
|         path: 文件相对路径 | ||||
|     Returns: 文件绝对路径 | ||||
|     """ | ||||
|     if debug: | ||||
|         nonebot.logger.debug("Enable resource debug, Reloading resources") | ||||
|         load_resources() | ||||
|     resource_relative_path = os.path.join(temp_resource_root, path) | ||||
|     if os.path.exists(resource_relative_path): | ||||
|         return os.path.abspath(resource_relative_path) if abs_path else resource_relative_path | ||||
|     else: | ||||
|         return default | ||||
|  | ||||
|  | ||||
| def get_files(path: str, abs_path: bool = False) -> list[str]: | ||||
|     """ | ||||
|     获取资源包中一个文件夹的所有文件 | ||||
|     Args: | ||||
|         abs_path: | ||||
|         path: 文件夹相对路径 | ||||
|     Returns: 文件绝对路径 | ||||
|     """ | ||||
|     resource_relative_path = os.path.join(temp_resource_root, path) | ||||
|     if os.path.exists(resource_relative_path): | ||||
|         return [os.path.abspath(os.path.join(resource_relative_path, file)) if abs_path else os.path.join(resource_relative_path, file) for file in | ||||
|                 os.listdir(resource_relative_path)] | ||||
|     else: | ||||
|         return [] | ||||
|  | ||||
|  | ||||
| def get_loaded_resource_packs() -> list[ResourceMetadata]: | ||||
|     """ | ||||
|     获取已加载的资源包,优先级从前到后 | ||||
|     Returns: 资源包列表 | ||||
|     """ | ||||
|     return _loaded_resource_packs | ||||
|  | ||||
|  | ||||
| def copy_file(src, dst): | ||||
|     # 获取目标文件的目录 | ||||
|     dst_dir = os.path.dirname(dst) | ||||
|     # 如果目标目录不存在,创建它 | ||||
|     if not os.path.exists(dst_dir): | ||||
|         os.makedirs(dst_dir) | ||||
|     # 复制文件 | ||||
|     shutil.copy(src, dst) | ||||
|  | ||||
|  | ||||
| def load_resources(): | ||||
|     """用于外部主程序调用的资源加载函数 | ||||
|     Returns: | ||||
|     """ | ||||
|     # 加载默认资源和语言 | ||||
|     # 清空临时资源包路径data/liteyuki/resources | ||||
|     _loaded_resource_packs.clear() | ||||
|     loaded_functions.clear() | ||||
|     if os.path.exists(temp_resource_root): | ||||
|         shutil.rmtree(temp_resource_root) | ||||
|     os.makedirs(temp_resource_root, exist_ok=True) | ||||
|  | ||||
|     # 加载内置资源 | ||||
|     standard_resources_path = "src/resources" | ||||
|     for resource_dir in os.listdir(standard_resources_path): | ||||
|         load_resource_from_dir(os.path.join(standard_resources_path, resource_dir)) | ||||
|  | ||||
|     # 加载其他资源包 | ||||
|     if not os.path.exists("resources"): | ||||
|         os.makedirs("resources", exist_ok=True) | ||||
|  | ||||
|     if not os.path.exists("resources/index.json"): | ||||
|         json.dump([], open("resources/index.json", "w", encoding="utf-8")) | ||||
|  | ||||
|     resource_index: list[str] = json.load(open("resources/index.json", "r", encoding="utf-8")) | ||||
|     resource_index.reverse()  # 优先级高的后加载,但是排在前面 | ||||
|     for resource in resource_index: | ||||
|         load_resource_from_dir(os.path.join("resources", resource)) | ||||
|  | ||||
|  | ||||
| def check_status(name: str) -> bool: | ||||
|     """ | ||||
|     检查资源包是否已加载 | ||||
|     Args: | ||||
|         name: 资源包名称,文件夹名 | ||||
|     Returns: 是否已加载 | ||||
|     """ | ||||
|     return name in [rp.folder for rp in get_loaded_resource_packs()] | ||||
|  | ||||
|  | ||||
| def check_exist(name: str) -> bool: | ||||
|     """ | ||||
|     检查资源包文件夹是否存在于resources文件夹 | ||||
|     Args: | ||||
|         name: 资源包名称,文件夹名 | ||||
|     Returns: 是否存在 | ||||
|     """ | ||||
|     path = os.path.join("resources", name) | ||||
|     return os.path.exists(os.path.join(path, "metadata.yml")) or (os.path.isfile(path) and name.endswith(".zip")) | ||||
|  | ||||
|  | ||||
| def add_resource_pack(name: str) -> bool: | ||||
|     """ | ||||
|     添加资源包,该操作仅修改index.json文件,不会加载资源包,要生效请重载资源 | ||||
|     Args: | ||||
|         name: 资源包名称,文件夹名 | ||||
|     Returns: | ||||
|     """ | ||||
|     if check_exist(name): | ||||
|         old_index: list[str] = json.load(open("resources/index.json", "r", encoding="utf-8")) | ||||
|         if name not in old_index: | ||||
|             old_index.append(name) | ||||
|             json.dump(old_index, open("resources/index.json", "w", encoding="utf-8")) | ||||
|             load_resource_from_dir(os.path.join("resources", name)) | ||||
|             return True | ||||
|         else: | ||||
|             nonebot.logger.warning(lang.get("liteyuki.resource_loaded", name=name)) | ||||
|             return False | ||||
|     else: | ||||
|         nonebot.logger.warning(lang.get("liteyuki.resource_not_exist", name=name)) | ||||
|         return False | ||||
|  | ||||
|  | ||||
| def remove_resource_pack(name: str) -> bool: | ||||
|     """ | ||||
|     移除资源包,该操作仅修改加载索引,要生效请重载资源 | ||||
|     Args: | ||||
|         name: 资源包名称,文件夹名 | ||||
|     Returns: | ||||
|     """ | ||||
|     if check_exist(name): | ||||
|         old_index: list[str] = json.load(open("resources/index.json", "r", encoding="utf-8")) | ||||
|         if name in old_index: | ||||
|             old_index.remove(name) | ||||
|             json.dump(old_index, open("resources/index.json", "w", encoding="utf-8")) | ||||
|             return True | ||||
|         else: | ||||
|             nonebot.logger.warning(lang.get("liteyuki.resource_not_loaded", name=name)) | ||||
|             return False | ||||
|     else: | ||||
|         nonebot.logger.warning(lang.get("liteyuki.resource_not_exist", name=name)) | ||||
|         return False | ||||
|  | ||||
|  | ||||
| def change_priority(name: str, delta: int) -> bool: | ||||
|     """ | ||||
|     修改资源包优先级 | ||||
|     Args: | ||||
|         name: 资源包名称,文件夹名 | ||||
|         delta: 优先级变化,正数表示后移,负数表示前移,0表示移到最前 | ||||
|     Returns: | ||||
|     """ | ||||
|     # 正数表示前移,负数表示后移 | ||||
|     old_resource_list: list[str] = json.load(open("resources/index.json", "r", encoding="utf-8")) | ||||
|     new_resource_list = old_resource_list.copy() | ||||
|     if name in old_resource_list: | ||||
|         index = old_resource_list.index(name) | ||||
|         if 0 <= index + delta < len(old_resource_list): | ||||
|             new_index = index + delta | ||||
|             new_resource_list.remove(name) | ||||
|             new_resource_list.insert(new_index, name) | ||||
|             json.dump(new_resource_list, open("resources/index.json", "w", encoding="utf-8")) | ||||
|             return True | ||||
|         else: | ||||
|             nonebot.logger.warning("Priority change failed, out of range") | ||||
|             return False | ||||
|     else: | ||||
|         nonebot.logger.debug("Priority change failed, resource not loaded") | ||||
|         return False | ||||
|  | ||||
|  | ||||
| def get_resource_metadata(name: str) -> ResourceMetadata: | ||||
|     """ | ||||
|     获取资源包元数据 | ||||
|     Args: | ||||
|         name: 资源包名称,文件夹名 | ||||
|     Returns: | ||||
|     """ | ||||
|     for rp in get_loaded_resource_packs(): | ||||
|         if rp.folder == name: | ||||
|             return rp | ||||
|     return ResourceMetadata() | ||||
							
								
								
									
										57
									
								
								src/utils/base/word_bank.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								src/utils/base/word_bank.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | ||||
| import json | ||||
| import os | ||||
| import random | ||||
| from typing import Iterable | ||||
|  | ||||
| import nonebot | ||||
|  | ||||
| word_bank: dict[str, set[str]] = {} | ||||
|  | ||||
|  | ||||
| def load_from_file(file_path: str): | ||||
|     """ | ||||
|     从json文件中加载词库 | ||||
|  | ||||
|     Args: | ||||
|         file_path: 文件路径 | ||||
|     """ | ||||
|     with open(file_path, "r", encoding="utf-8") as file: | ||||
|         data = json.load(file) | ||||
|         for key, value_list in data.items(): | ||||
|             if key not in word_bank: | ||||
|                 word_bank[key] = set() | ||||
|             word_bank[key].update(value_list) | ||||
|  | ||||
|     nonebot.logger.debug(f"Loaded word bank from {file_path}") | ||||
|  | ||||
|  | ||||
| def load_from_dir(dir_path: str): | ||||
|     """ | ||||
|     从目录中加载词库 | ||||
|  | ||||
|     Args: | ||||
|         dir_path: 目录路径 | ||||
|     """ | ||||
|     for file in os.listdir(dir_path): | ||||
|         try: | ||||
|             file_path = os.path.join(dir_path, file) | ||||
|             if os.path.isfile(file_path): | ||||
|                 if file.endswith(".json"): | ||||
|                     load_from_file(file_path) | ||||
|         except Exception as e: | ||||
|             nonebot.logger.error(f"Failed to load language data from {file}: {e}") | ||||
|             continue | ||||
|  | ||||
|  | ||||
| def get_reply(kws: Iterable[str]) -> str | None: | ||||
|     """ | ||||
|     获取回复 | ||||
|     Args: | ||||
|         kws: 关键词 | ||||
|     Returns: | ||||
|     """ | ||||
|     for kw in kws: | ||||
|         if kw in word_bank: | ||||
|             return random.choice(list(word_bank[kw])) | ||||
|  | ||||
|     return None | ||||
							
								
								
									
										1
									
								
								src/utils/canvas/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								src/utils/canvas/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| from PIL import Image, ImageDraw, ImageFont | ||||
							
								
								
									
										6
									
								
								src/utils/driver_manager/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								src/utils/driver_manager/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| from .auto_set_env import auto_set_env | ||||
|  | ||||
|  | ||||
| def init(config: dict): | ||||
|     auto_set_env(config) | ||||
|     return | ||||
							
								
								
									
										21
									
								
								src/utils/driver_manager/auto_set_env.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								src/utils/driver_manager/auto_set_env.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| import os | ||||
|  | ||||
| import dotenv | ||||
| import nonebot | ||||
|  | ||||
| from .defines import * | ||||
|  | ||||
|  | ||||
| def auto_set_env(config: dict): | ||||
|     dotenv.load_dotenv(".env") | ||||
|     if os.getenv("DRIVER", None) is not None: | ||||
|         print(os.getenv("DRIVER")) | ||||
|         nonebot.logger.info("Driver already set in environment variable, skip auto configure.") | ||||
|         return | ||||
|     if config.get("satori", {'enable': False}).get("enable", False): | ||||
|         os.environ["DRIVER"] = get_driver_string(ASGI_DRIVER, HTTPX_DRIVER, WEBSOCKETS_DRIVER) | ||||
|         nonebot.logger.info("Enable Satori, set driver to ASGI+HTTPX+WEBSOCKETS") | ||||
|     else: | ||||
|         os.environ["DRIVER"] = get_driver_string(ASGI_DRIVER) | ||||
|         nonebot.logger.info("Disable Satori, set driver to ASGI") | ||||
|     return | ||||
							
								
								
									
										17
									
								
								src/utils/driver_manager/defines.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								src/utils/driver_manager/defines.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | ||||
| ASGI_DRIVER = "~fastapi" | ||||
| HTTPX_DRIVER = "~httpx" | ||||
| WEBSOCKETS_DRIVER = "~websockets" | ||||
|  | ||||
|  | ||||
| def get_driver_string(*argv): | ||||
|     output_string = "" | ||||
|     if ASGI_DRIVER in argv: | ||||
|         output_string += ASGI_DRIVER | ||||
|     for arg in argv: | ||||
|         if arg != ASGI_DRIVER: | ||||
|             output_string = f"{output_string}+{arg}" | ||||
|     return output_string | ||||
|  | ||||
|  | ||||
| def get_driver_full_string(*argv): | ||||
|     return f"DRIVER={get_driver_string(argv)}" | ||||
							
								
								
									
										1
									
								
								src/utils/event/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								src/utils/event/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| from .get_info import * | ||||
							
								
								
									
										24
									
								
								src/utils/event/get_info.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								src/utils/event/get_info.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | ||||
| from nonebot.adapters import satori | ||||
|  | ||||
| from src.utils.base.ly_typing import T_MessageEvent | ||||
|  | ||||
|  | ||||
| def get_user_id(event: T_MessageEvent): | ||||
|     if isinstance(event, satori.event.Event): | ||||
|         return event.user.id | ||||
|     else: | ||||
|         return event.user_id | ||||
|  | ||||
|  | ||||
| def get_group_id(event: T_MessageEvent): | ||||
|     if isinstance(event, satori.event.Event): | ||||
|         return event.guild.id | ||||
|     else: | ||||
|         return event.group_id | ||||
|  | ||||
|  | ||||
| def get_message_type(event: T_MessageEvent) -> str: | ||||
|     if isinstance(event, satori.event.Event): | ||||
|         return "private" if event.guild is None else "group" | ||||
|     else: | ||||
|         return event.message_type | ||||
							
								
								
									
										0
									
								
								src/utils/external/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/utils/external/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
								
								
									
										40
									
								
								src/utils/external/logo.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								src/utils/external/logo.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | ||||
| async def get_user_icon(platform: str, user_id: str) -> str: | ||||
|     """ | ||||
|     获取用户头像 | ||||
|     Args: | ||||
|         platform: qq, telegram, discord... | ||||
|         user_id: 1234567890 | ||||
|  | ||||
|     Returns: | ||||
|         str: 头像链接 | ||||
|     """ | ||||
|     match platform: | ||||
|         case "qq": | ||||
|             return f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=640" | ||||
|         case "telegram": | ||||
|             return f"https://t.me/i/userpic/320/{user_id}.jpg" | ||||
|         case "discord": | ||||
|             return f"https://cdn.discordapp.com/avatars/{user_id}/" | ||||
|         case _: | ||||
|             return "" | ||||
|  | ||||
|  | ||||
| async def get_group_icon(platform: str, group_id: str) -> str: | ||||
|     """ | ||||
|     获取群组头像 | ||||
|     Args: | ||||
|         platform: qq, telegram, discord... | ||||
|         group_id: 1234567890 | ||||
|  | ||||
|     Returns: | ||||
|         str: 头像链接 | ||||
|     """ | ||||
|     match platform: | ||||
|         case "qq": | ||||
|             return f"http://p.qlogo.cn/gh/{group_id}/{group_id}/640" | ||||
|         case "telegram": | ||||
|             return f"https://t.me/c/{group_id}/" | ||||
|         case "discord": | ||||
|             return f"https://cdn.discordapp.com/icons/{group_id}/" | ||||
|         case _: | ||||
|             return "" | ||||
							
								
								
									
										0
									
								
								src/utils/message/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/utils/message/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										113
									
								
								src/utils/message/html_tool.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								src/utils/message/html_tool.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,113 @@ | ||||
| import os.path | ||||
| import time | ||||
| from os import getcwd | ||||
|  | ||||
| import aiofiles | ||||
| import nonebot | ||||
| from nonebot_plugin_htmlrender import * | ||||
| from .tools import random_hex_string | ||||
|  | ||||
|  | ||||
| async def html2image( | ||||
|         html: str, | ||||
|         wait: int = 0, | ||||
| ): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| async def template2html( | ||||
|         template: str, | ||||
|         templates: dict, | ||||
| ) -> str: | ||||
|     """ | ||||
|     Args: | ||||
|         template: str: 模板文件 | ||||
|         **templates: dict: 模板参数 | ||||
|     Returns: | ||||
|         HTML 正文 | ||||
|     """ | ||||
|     template_path = os.path.dirname(template) | ||||
|     template_name = os.path.basename(template) | ||||
|     return await template_to_html(template_path, template_name, **templates) | ||||
|  | ||||
|  | ||||
| async def template2image( | ||||
|         template: str, | ||||
|         templates: dict, | ||||
|         pages=None, | ||||
|         wait: int = 0, | ||||
|         scale_factor: float = 1, | ||||
|         debug: bool = False, | ||||
| ) -> bytes: | ||||
|     """ | ||||
|     template -> html -> image | ||||
|     Args: | ||||
|         debug: 输入渲染好的 html | ||||
|         wait: 等待时间,单位秒 | ||||
|         pages: 页面参数 | ||||
|         template: str: 模板文件 | ||||
|         templates: dict: 模板参数 | ||||
|         scale_factor: 缩放因子,越高越清晰 | ||||
|     Returns: | ||||
|         图片二进制数据 | ||||
|     """ | ||||
|     if pages is None: | ||||
|         pages = { | ||||
|                 "viewport": { | ||||
|                         "width" : 1080, | ||||
|                         "height": 10 | ||||
|                 }, | ||||
|                 "base_url": f"file://{getcwd()}", | ||||
|         } | ||||
|     template_path = os.path.dirname(template) | ||||
|     template_name = os.path.basename(template) | ||||
|  | ||||
|     if debug: | ||||
|         # 重载资源 | ||||
|         raw_html = await template_to_html( | ||||
|             template_name=template_name, | ||||
|             template_path=template_path, | ||||
|             **templates, | ||||
|         ) | ||||
|         random_file_name = f"debug-{random_hex_string(6)}.html" | ||||
|         async with aiofiles.open(os.path.join(template_path, random_file_name), "w", encoding="utf-8") as f: | ||||
|             await f.write(raw_html) | ||||
|         nonebot.logger.info("Debug HTML: %s" % f"{random_file_name}") | ||||
|  | ||||
|     return await template_to_pic( | ||||
|         template_name=template_name, | ||||
|         template_path=template_path, | ||||
|         templates=templates, | ||||
|         pages=pages, | ||||
|         wait=wait, | ||||
|         device_scale_factor=scale_factor, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def url2image( | ||||
|         url: str, | ||||
|         wait: int = 0, | ||||
|         scale_factor: float = 1, | ||||
|         type: str = "png", | ||||
|         quality: int = 100, | ||||
|         **kwargs | ||||
| ) -> bytes: | ||||
|     """ | ||||
|     Args: | ||||
|         quality: | ||||
|         type: | ||||
|         url: str: URL | ||||
|         wait: int: 等待时间 | ||||
|         scale_factor: float: 缩放因子 | ||||
|         **kwargs: page 参数 | ||||
|     Returns: | ||||
|         图片二进制数据 | ||||
|     """ | ||||
|     async with get_new_page(scale_factor) as page: | ||||
|         await page.goto(url) | ||||
|         await page.wait_for_timeout(wait) | ||||
|         return await page.screenshot( | ||||
|             full_page=True, | ||||
|             type=type, | ||||
|             quality=quality | ||||
|         ) | ||||
							
								
								
									
										209
									
								
								src/utils/message/markdown.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										209
									
								
								src/utils/message/markdown.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,209 @@ | ||||
| import base64 | ||||
| from io import BytesIO | ||||
| from urllib.parse import quote | ||||
|  | ||||
| import aiohttp | ||||
| from PIL import Image | ||||
|  | ||||
| from ..base.config import get_config | ||||
| from ..base.data import LiteModel | ||||
| from ..base.ly_typing import T_Bot | ||||
|  | ||||
|  | ||||
| def escape_md(text: str) -> str: | ||||
|     """ | ||||
|     转义Markdown特殊字符 | ||||
|     Args: | ||||
|         text: str: 文本 | ||||
|  | ||||
|     Returns: | ||||
|         str: 转义后文本 | ||||
|     """ | ||||
|     spacial_chars = r"\`*_{}[]()#+-.!" | ||||
|     for char in spacial_chars: | ||||
|         text = text.replace(char, "\\\\" + char) | ||||
|     return text.replace("\n", r"\n").replace('"', r'\\\"') | ||||
|  | ||||
|  | ||||
| def escape_decorator(func): | ||||
|     def wrapper(text: str): | ||||
|         return func(escape_md(text)) | ||||
|  | ||||
|     return wrapper | ||||
|  | ||||
|  | ||||
| def compile_md(comps: list[str]) -> str: | ||||
|     """ | ||||
|     合成Markdown文本 | ||||
|     Args: | ||||
|         comps: list[str]: 组件列表 | ||||
|  | ||||
|     Returns: | ||||
|         str: 编译后文本 | ||||
|     """ | ||||
|     return "".join(comps) | ||||
|  | ||||
|  | ||||
| class MarkdownComponent: | ||||
|     @staticmethod | ||||
|     def heading(text: str, level: int = 1) -> str: | ||||
|         """标题""" | ||||
|         assert 1 <= level <= 6, "标题级别应在 1-6 之间" | ||||
|         return f"{'#' * level} {text}\n" | ||||
|  | ||||
|     @staticmethod | ||||
|     def bold(text: str) -> str: | ||||
|         """粗体""" | ||||
|         return f"**{text}**" | ||||
|  | ||||
|     @staticmethod | ||||
|     def italic(text: str) -> str: | ||||
|         """斜体""" | ||||
|         return f"*{text}*" | ||||
|  | ||||
|     @staticmethod | ||||
|     def strike(text: str) -> str: | ||||
|         """删除线""" | ||||
|         return f"~~{text}~~" | ||||
|  | ||||
|     @staticmethod | ||||
|     def code(text: str) -> str: | ||||
|         """行内代码""" | ||||
|         return f"`{text}`" | ||||
|  | ||||
|     @staticmethod | ||||
|     def code_block(text: str, language: str = "") -> str: | ||||
|         """代码块""" | ||||
|         return f"```{language}\n{text}\n```\n" | ||||
|  | ||||
|     @staticmethod | ||||
|     def quote(text: str) -> str: | ||||
|         """引用""" | ||||
|         return f"> {text}\n\n" | ||||
|  | ||||
|     @staticmethod | ||||
|     def link(text: str, url: str, symbol: bool = True) -> str: | ||||
|         """ | ||||
|         链接 | ||||
|  | ||||
|         Args: | ||||
|             text: 链接文本 | ||||
|             url: 链接地址 | ||||
|             symbol: 是否显示链接图标, mqqapi请使用False | ||||
|         """ | ||||
|         return f"[{'🔗' if symbol else ''}{text}]({url})" | ||||
|  | ||||
|     @staticmethod | ||||
|     def image(url: str, *, size: tuple[int, int]) -> str: | ||||
|         """ | ||||
|         图片,本地图片不建议直接使用 | ||||
|         Args: | ||||
|             url: 图片链接 | ||||
|             size: 图片大小 | ||||
|  | ||||
|         Returns: | ||||
|             markdown格式的图片 | ||||
|         """ | ||||
|         return f"![image #{size[0]}px #{size[1]}px]({url})" | ||||
|  | ||||
|     @staticmethod | ||||
|     async def auto_image(image: str | bytes, bot: T_Bot) -> str: | ||||
|         """ | ||||
|         自动获取图片大小 | ||||
|         Args: | ||||
|             image: 本地图片路径 | 图片url http/file | 图片bytes | ||||
|             bot: bot对象,用于上传图片到图床 | ||||
|  | ||||
|         Returns: | ||||
|             markdown格式的图片 | ||||
|         """ | ||||
|         if isinstance(image, bytes): | ||||
|             # 传入为二进制图片 | ||||
|             image_obj = Image.open(BytesIO(image)) | ||||
|             base64_string = base64.b64encode(image_obj.tobytes()).decode("utf-8") | ||||
|             url = await bot.call_api("upload_image", file=f"base64://{base64_string}") | ||||
|             size = image_obj.size | ||||
|         elif isinstance(image, str): | ||||
|             # 传入链接或本地路径 | ||||
|             if image.startswith("http"): | ||||
|                 # 网络请求 | ||||
|                 async with aiohttp.ClientSession() as session: | ||||
|                     async with session.get(image) as resp: | ||||
|                         image_data = await resp.read() | ||||
|                 url = image | ||||
|                 size = Image.open(BytesIO(image_data)).size | ||||
|  | ||||
|             else: | ||||
|                 # 本地路径/file:// | ||||
|                 image_obj = Image.open(image.replace("file://", "")) | ||||
|                 base64_string = base64.b64encode(image_obj.tobytes()).decode("utf-8") | ||||
|                 url = await bot.call_api("upload_image", file=f"base64://{base64_string}") | ||||
|                 size = image_obj.size | ||||
|         else: | ||||
|             raise ValueError("图片类型错误") | ||||
|  | ||||
|         return MarkdownComponent.image(url, size=size) | ||||
|  | ||||
|     @staticmethod | ||||
|     def table(data: list[list[any]]) -> str: | ||||
|         """ | ||||
|         表格 | ||||
|         Args: | ||||
|             data: 表格数据,二维列表 | ||||
|         Returns: | ||||
|             markdown格式的表格 | ||||
|         """ | ||||
|         # 表头 | ||||
|         table = "|".join(map(str, data[0])) + "\n" | ||||
|         table += "|".join([":-:" for _ in range(len(data[0]))]) + "\n" | ||||
|         # 表内容 | ||||
|         for row in data[1:]: | ||||
|             table += "|".join(map(str, row)) + "\n" | ||||
|         return table | ||||
|  | ||||
|     @staticmethod | ||||
|     def paragraph(text: str) -> str: | ||||
|         """ | ||||
|         段落 | ||||
|         Args: | ||||
|             text: 段落内容 | ||||
|         Returns: | ||||
|             markdown格式的段落 | ||||
|         """ | ||||
|         return f"{text}\n" | ||||
|  | ||||
|  | ||||
| class Mqqapi: | ||||
|     @staticmethod | ||||
|     @escape_decorator | ||||
|     def cmd(text: str, cmd: str, enter: bool = True, reply: bool = False, use_cmd_start: bool = True) -> str: | ||||
|         """ | ||||
|         生成点击回调文本 | ||||
|         Args: | ||||
|             text: 显示内容 | ||||
|             cmd: 命令 | ||||
|             enter: 是否自动发送 | ||||
|             reply: 是否回复 | ||||
|             use_cmd_start: 是否使用配置的命令前缀 | ||||
|  | ||||
|         Returns: | ||||
|             [text](mqqapi://)   markdown格式的可点击回调文本,类似于链接 | ||||
|         """ | ||||
|  | ||||
|         if use_cmd_start: | ||||
|             command_start = get_config("command_start", []) | ||||
|             if command_start: | ||||
|                 # 若命令前缀不为空,则使用配置的第一个命令前缀 | ||||
|                 cmd = f"{command_start[0]}{cmd}" | ||||
|         return f"[{text}](mqqapi://aio/inlinecmd?command={quote(cmd)}&reply={str(reply).lower()}&enter={str(enter).lower()})" | ||||
|  | ||||
|  | ||||
| class RenderData(LiteModel): | ||||
|     label: str | ||||
|     visited_label: str | ||||
|     style: int | ||||
|  | ||||
|  | ||||
| class Button(LiteModel): | ||||
|     id: int | ||||
|     render_data: RenderData | ||||
							
								
								
									
										297
									
								
								src/utils/message/message.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										297
									
								
								src/utils/message/message.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,297 @@ | ||||
| import base64 | ||||
| import io | ||||
| from urllib.parse import quote | ||||
|  | ||||
| import aiofiles | ||||
| from PIL import Image | ||||
| import aiohttp | ||||
| import nonebot | ||||
| from nonebot import require | ||||
| from nonebot.adapters import satori | ||||
| from nonebot.adapters.onebot import v11 | ||||
| from typing import Any, Type | ||||
|  | ||||
| from nonebot.internal.adapter import MessageSegment | ||||
| from nonebot.internal.adapter.message import TM | ||||
|  | ||||
| from .. import load_from_yaml | ||||
| from ..base.ly_typing import T_Bot, T_Message, T_MessageEvent | ||||
|  | ||||
| require("nonebot_plugin_htmlrender") | ||||
| from nonebot_plugin_htmlrender import md_to_pic | ||||
|  | ||||
| config = load_from_yaml("config.yml") | ||||
|  | ||||
| can_send_markdown = {}  # 用于存储机器人是否支持发送markdown消息,id->bool | ||||
|  | ||||
|  | ||||
| class TencentBannedMarkdownError(BaseException): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| async def broadcast_to_superusers(message: str | T_Message, markdown: bool = False): | ||||
|     """广播消息给超级用户""" | ||||
|     for bot in nonebot.get_bots().values(): | ||||
|         for user_id in config.get("superusers", []): | ||||
|             if markdown: | ||||
|                 await MarkdownMessage.send_md(message, bot, message_type="private", session_id=user_id) | ||||
|             else: | ||||
|                 await bot.send_private_msg(user_id=user_id, message=message) | ||||
|  | ||||
|  | ||||
| class MarkdownMessage: | ||||
|     @staticmethod | ||||
|     async def send_md( | ||||
|             markdown: str, | ||||
|             bot: T_Bot, *, | ||||
|             message_type: str = None, | ||||
|             session_id: str | int = None, | ||||
|             event: T_MessageEvent = None, | ||||
|             retry_as_image: bool = True, | ||||
|             **kwargs | ||||
|     ) -> dict[str, Any] | None: | ||||
|         """ | ||||
|         发送Markdown消息,支持自动转为图片发送 | ||||
|         Args: | ||||
|             markdown: | ||||
|             bot: | ||||
|             message_type: | ||||
|             session_id: | ||||
|             event: | ||||
|             retry_as_image: 发送失败后是否尝试以图片形式发送,否则失败返回None | ||||
|             **kwargs: | ||||
|  | ||||
|         Returns: | ||||
|  | ||||
|         """ | ||||
|         formatted_md = v11.unescape(markdown).replace("\n", r"\n").replace('"', r'\\\"') | ||||
|         if event is not None and message_type is None: | ||||
|             if isinstance(event, satori.event.Event): | ||||
|                 message_type = "private" if event.guild is None else "group" | ||||
|                 group_id = event.guild.id if event.guild is not None else None | ||||
|             else: | ||||
|                 assert event is not None | ||||
|                 message_type = event.message_type | ||||
|                 group_id = event.group_id if message_type == "group" else None | ||||
|             user_id = event.user.id if isinstance(event, satori.event.Event) else event.user_id | ||||
|             session_id = user_id if message_type == "private" else group_id | ||||
|         else: | ||||
|             pass | ||||
|         try: | ||||
|             raise TencentBannedMarkdownError("Tencent banned markdown") | ||||
|             forward_id = await bot.call_api( | ||||
|                 "send_private_forward_msg", | ||||
|                 messages=[ | ||||
|                     { | ||||
|                         "type": "node", | ||||
|                         "data": { | ||||
|                             "content": [ | ||||
|                                 { | ||||
|                                     "data": { | ||||
|                                         "content": "{\"content\":\"%s\"}" % formatted_md, | ||||
|                                     }, | ||||
|                                     "type": "markdown" | ||||
|                                 } | ||||
|                             ], | ||||
|                             "name": "[]", | ||||
|                             "uin": bot.self_id | ||||
|                         } | ||||
|                     } | ||||
|                 ], | ||||
|                 user_id=bot.self_id | ||||
|  | ||||
|             ) | ||||
|             data = await bot.send_msg( | ||||
|                 user_id=session_id, | ||||
|                 group_id=session_id, | ||||
|                 message_type=message_type, | ||||
|                 message=[ | ||||
|                     { | ||||
|                         "type": "longmsg", | ||||
|                         "data": { | ||||
|                             "id": forward_id | ||||
|                         } | ||||
|                     }, | ||||
|                 ], | ||||
|                 **kwargs | ||||
|             ) | ||||
|         except BaseException as e: | ||||
|             nonebot.logger.error(f"send markdown error, retry as image: {e}") | ||||
|             # 发送失败,渲染为图片发送 | ||||
|             # if not retry_as_image: | ||||
|             #     return None | ||||
|  | ||||
|             plain_markdown = markdown.replace("[🔗", "[") | ||||
|             md_image_bytes = await md_to_pic( | ||||
|                 md=plain_markdown, | ||||
|                 width=540, | ||||
|                 device_scale_factor=4 | ||||
|             ) | ||||
|             if isinstance(bot, satori.Bot): | ||||
|                 msg_seg = satori.MessageSegment.image(raw=md_image_bytes,mime="image/png") | ||||
|                 data = await bot.send( | ||||
|                     event=event, | ||||
|                     message=msg_seg | ||||
|                 ) | ||||
|             else: | ||||
|                 data = await bot.send_msg( | ||||
|                     message_type=message_type, | ||||
|                     group_id=session_id, | ||||
|                     user_id=session_id, | ||||
|                     message=v11.MessageSegment.image(md_image_bytes), | ||||
|                 ) | ||||
|         return data | ||||
|  | ||||
|     @staticmethod | ||||
|     async def send_image( | ||||
|             image: bytes | str, | ||||
|             bot: T_Bot, *, | ||||
|             message_type: str = None, | ||||
|             session_id: str | int = None, | ||||
|             event: T_MessageEvent = None, | ||||
|             **kwargs | ||||
|     ) -> dict: | ||||
|         """ | ||||
|         发送单张装逼大图 | ||||
|         Args: | ||||
|             image: 图片字节流或图片本地路径,链接请使用Markdown.image_async方法获取后通过send_md发送 | ||||
|             bot: bot instance | ||||
|             message_type: message type | ||||
|             session_id: session id | ||||
|             event: event | ||||
|             kwargs: other arguments | ||||
|         Returns: | ||||
|             dict: response data | ||||
|  | ||||
|         """ | ||||
|         if isinstance(image, str): | ||||
|             async with aiofiles.open(image, "rb") as f: | ||||
|                 image = await f.read() | ||||
|         method = 2 | ||||
|         # 1.轻雪图床方案 | ||||
|         # if method == 1: | ||||
|         #     image_url = await liteyuki_api.upload_image(image) | ||||
|         #     image_size = Image.open(io.BytesIO(image)).size | ||||
|         #     image_md = Markdown.image(image_url, image_size) | ||||
|         #     data = await Markdown.send_md(image_md, bot, message_type=message_type, session_id=session_id, event=event, | ||||
|         #                                   retry_as_image=False, | ||||
|         #                                   **kwargs) | ||||
|  | ||||
|         # Lagrange.OneBot方案 | ||||
|         if method == 2: | ||||
|             base64_string = base64.b64encode(image).decode("utf-8") | ||||
|             data = await bot.call_api("upload_image", file=f"base64://{base64_string}") | ||||
|             await MarkdownMessage.send_md(MarkdownMessage.image(data, Image.open(io.BytesIO(image)).size), bot, | ||||
|                                           event=event, message_type=message_type, | ||||
|                                           session_id=session_id, **kwargs) | ||||
|  | ||||
|         # 其他实现端方案 | ||||
|         else: | ||||
|             image_message_id = (await bot.send_private_msg( | ||||
|                 user_id=bot.self_id, | ||||
|                 message=[ | ||||
|                     v11.MessageSegment.image(file=image) | ||||
|                 ] | ||||
|             ))["message_id"] | ||||
|             image_url = (await bot.get_msg(message_id=image_message_id))["message"][0]["data"]["url"] | ||||
|             image_size = Image.open(io.BytesIO(image)).size | ||||
|             image_md = MarkdownMessage.image(image_url, image_size) | ||||
|             return await MarkdownMessage.send_md(image_md, bot, message_type=message_type, session_id=session_id, | ||||
|                                                  event=event, **kwargs) | ||||
|  | ||||
|         if data is None: | ||||
|             data = await bot.send_msg( | ||||
|                 message_type=message_type, | ||||
|                 group_id=session_id, | ||||
|                 user_id=session_id, | ||||
|                 message=v11.MessageSegment.image(image), | ||||
|                 **kwargs | ||||
|             ) | ||||
|         return data | ||||
|  | ||||
|     @staticmethod | ||||
|     async def get_image_url(image: bytes | str, bot: T_Bot) -> str: | ||||
|         """把图片上传到图床,返回链接 | ||||
|         Args: | ||||
|             bot: 发送的bot | ||||
|             image: 图片字节流或图片本地路径 | ||||
|         Returns: | ||||
|         """ | ||||
|         # 等林文轩修好Lagrange.OneBot再说 | ||||
|  | ||||
|     @staticmethod | ||||
|     def btn_cmd(name: str, cmd: str, reply: bool = False, enter: bool = True) -> str: | ||||
|         """生成点击回调按钮 | ||||
|         Args: | ||||
|             name: 按钮显示内容 | ||||
|             cmd: 发送的命令,已在函数内url编码,不需要再次编码 | ||||
|             reply: 是否以回复的方式发送消息 | ||||
|             enter: 自动发送消息则为True,否则填充到输入框 | ||||
|  | ||||
|         Returns: | ||||
|             markdown格式的可点击回调按钮 | ||||
|  | ||||
|         """ | ||||
|         if "" not in config.get("command_start", ["/"]) and config.get("alconna_use_command_start", False): | ||||
|             cmd = f"{config['command_start'][0]}{cmd}" | ||||
|         return f"[{name}](mqqapi://aio/inlinecmd?command={quote(cmd)}&reply={str(reply).lower()}&enter={str(enter).lower()})" | ||||
|  | ||||
|     @staticmethod | ||||
|     def btn_link(name: str, url: str) -> str: | ||||
|         """生成点击链接按钮 | ||||
|         Args: | ||||
|             name: 链接显示内容 | ||||
|             url: 链接地址 | ||||
|  | ||||
|         Returns: | ||||
|             markdown格式的链接 | ||||
|  | ||||
|         """ | ||||
|         return f"[🔗{name}]({url})" | ||||
|  | ||||
|     @staticmethod | ||||
|     def image(url: str, size: tuple[int, int]) -> str: | ||||
|         """构建图片链接 | ||||
|         Args: | ||||
|             size: | ||||
|             url: 图片链接 | ||||
|  | ||||
|         Returns: | ||||
|             markdown格式的图片 | ||||
|  | ||||
|         """ | ||||
|         return f"![image #{size[0]}px #{size[1]}px]({url})" | ||||
|  | ||||
|     @staticmethod | ||||
|     async def image_async(url: str) -> str: | ||||
|         """获取图片,自动请求获取大小,异步 | ||||
|         Args: | ||||
|             url: 图片链接 | ||||
|  | ||||
|         Returns: | ||||
|             图片Markdown语法:  | ||||
|  | ||||
|         """ | ||||
|         try: | ||||
|             async with aiohttp.ClientSession() as session: | ||||
|                 async with session.get(url) as resp: | ||||
|                     image = Image.open(io.BytesIO(await resp.read())) | ||||
|                     return MarkdownMessage.image(url, image.size) | ||||
|         except Exception as e: | ||||
|             nonebot.logger.error(f"get image error: {e}") | ||||
|             return "[Image Error]" | ||||
|  | ||||
|     @staticmethod | ||||
|     def escape(text: str) -> str: | ||||
|         """转义特殊字符 | ||||
|         Args: | ||||
|             text: 需要转义的文本,请勿直接把整个markdown文本传入,否则会转义掉所有字符 | ||||
|  | ||||
|         Returns: | ||||
|             转义后的文本 | ||||
|  | ||||
|         """ | ||||
|         chars = "*[]()~_`>#+=|{}.!" | ||||
|         for char in chars: | ||||
|             text = text.replace(char, f"\\\\{char}") | ||||
|         return text | ||||
							
								
								
									
										101
									
								
								src/utils/message/string_tool.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								src/utils/message/string_tool.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,101 @@ | ||||
| import nonebot | ||||
|  | ||||
|  | ||||
| def convert_duration(text: str, default) -> float: | ||||
|     """ | ||||
|     转换自然语言时间为秒数 | ||||
|     Args: | ||||
|         text: 1d2h3m | ||||
|         default: 出错时返回 | ||||
|  | ||||
|     Returns: | ||||
|         float: 总秒数 | ||||
|     """ | ||||
|     units = { | ||||
|             "d" : 86400, | ||||
|             "h" : 3600, | ||||
|             "m" : 60, | ||||
|             "s" : 1, | ||||
|             "ms": 0.001 | ||||
|     } | ||||
|  | ||||
|     duration = 0 | ||||
|     current_number = '' | ||||
|     current_unit = '' | ||||
|     try: | ||||
|         for char in text: | ||||
|             if char.isdigit(): | ||||
|                 current_number += char | ||||
|             else: | ||||
|                 if current_number: | ||||
|                     duration += int(current_number) * units[current_unit] | ||||
|                     current_number = '' | ||||
|                 if char in units: | ||||
|                     current_unit = char | ||||
|                 else: | ||||
|                     current_unit = '' | ||||
|  | ||||
|         if current_number: | ||||
|             duration += int(current_number) * units[current_unit] | ||||
|  | ||||
|         return duration | ||||
|  | ||||
|     except BaseException as e: | ||||
|         nonebot.logger.info(f"convert_duration error: {e}") | ||||
|         return default | ||||
|  | ||||
|  | ||||
| def convert_time_to_seconds(time_str): | ||||
|     """转换自然语言时长为秒数 | ||||
|     Args: | ||||
|         time_str: 1d2m3s | ||||
|  | ||||
|     Returns: | ||||
|  | ||||
|     """ | ||||
|     seconds = 0 | ||||
|     current_number = '' | ||||
|  | ||||
|     for char in time_str: | ||||
|         if char.isdigit() or char == '.': | ||||
|             current_number += char | ||||
|         elif char == 'd': | ||||
|             seconds += float(current_number) * 24 * 60 * 60 | ||||
|             current_number = '' | ||||
|         elif char == 'h': | ||||
|             seconds += float(current_number) * 60 * 60 | ||||
|             current_number = '' | ||||
|         elif char == 'm': | ||||
|             seconds += float(current_number) * 60 | ||||
|             current_number = '' | ||||
|         elif char == 's': | ||||
|             seconds += float(current_number) | ||||
|             current_number = '' | ||||
|  | ||||
|     return int(seconds) | ||||
|  | ||||
|  | ||||
| def convert_seconds_to_time(seconds): | ||||
|     """转换秒数为自然语言时长 | ||||
|     Args: | ||||
|         seconds: 10000 | ||||
|  | ||||
|     Returns: | ||||
|  | ||||
|     """ | ||||
|     d = seconds // (24 * 60 * 60) | ||||
|     h = (seconds % (24 * 60 * 60)) // (60 * 60) | ||||
|     m = (seconds % (60 * 60)) // 60 | ||||
|     s = seconds % 60 | ||||
|  | ||||
|     # 若值为0则不显示 | ||||
|     time_str = '' | ||||
|     if d: | ||||
|         time_str += f"{d}d" | ||||
|     if h: | ||||
|         time_str += f"{h}h" | ||||
|     if m: | ||||
|         time_str += f"{m}m" | ||||
|     if not time_str: | ||||
|         time_str = f"{s}s" | ||||
|     return time_str | ||||
							
								
								
									
										99
									
								
								src/utils/message/tools.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								src/utils/message/tools.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,99 @@ | ||||
| import random | ||||
| from importlib.metadata import PackageNotFoundError, version | ||||
|  | ||||
|  | ||||
| def clamp(value: float, min_value: float, max_value: float) -> float | int: | ||||
|     """将值限制在最小值和最大值之间 | ||||
|  | ||||
|     Args: | ||||
|         value (float): 要限制的值 | ||||
|         min_value (float): 最小值 | ||||
|         max_value (float): 最大值 | ||||
|  | ||||
|     Returns: | ||||
|         float: 限制后的值 | ||||
|     """ | ||||
|     return max(min(value, max_value), min_value) | ||||
|  | ||||
|  | ||||
| def convert_size(size: int, precision: int = 2, add_unit: bool = True, suffix: str = " XiB") -> str | float: | ||||
|     """把字节数转换为人类可读的字符串,计算正负 | ||||
|  | ||||
|     Args: | ||||
|  | ||||
|         add_unit:  是否添加单位,False后则suffix无效 | ||||
|         suffix: XiB或XB | ||||
|         precision: 浮点数的小数点位数 | ||||
|         size (int): 字节数 | ||||
|  | ||||
|     Returns: | ||||
|  | ||||
|         str: The human-readable string, e.g. "1.23 GB". | ||||
|     """ | ||||
|     is_negative = size < 0 | ||||
|     size = abs(size) | ||||
|     for unit in ("", "K", "M", "G", "T", "P", "E", "Z"): | ||||
|         if size < 1024: | ||||
|             break | ||||
|         size /= 1024 | ||||
|     if is_negative: | ||||
|         size = -size | ||||
|     if add_unit: | ||||
|         return f"{size:.{precision}f}{suffix.replace('X', unit)}" | ||||
|     else: | ||||
|         return size | ||||
|  | ||||
|  | ||||
| def keywords_in_text(keywords: list[str], text: str, all_matched: bool) -> bool: | ||||
|     """ | ||||
|     检查关键词是否在文本中 | ||||
|     Args: | ||||
|         keywords: 关键词列表 | ||||
|         text: 文本 | ||||
|         all_matched: 是否需要全部匹配 | ||||
|  | ||||
|     Returns: | ||||
|  | ||||
|     """ | ||||
|     if all_matched: | ||||
|         for keyword in keywords: | ||||
|             if keyword not in text: | ||||
|                 return False | ||||
|         return True | ||||
|     else: | ||||
|         for keyword in keywords: | ||||
|             if keyword in text: | ||||
|                 return True | ||||
|         return False | ||||
|  | ||||
|  | ||||
| def check_for_package(package_name: str) -> bool: | ||||
|     try: | ||||
|         version(package_name) | ||||
|         return True | ||||
|     except PackageNotFoundError: | ||||
|         return False | ||||
|  | ||||
|  | ||||
| def random_ascii_string(length: int) -> str: | ||||
|     """ | ||||
|     生成随机ASCII字符串 | ||||
|     Args: | ||||
|         length: | ||||
|  | ||||
|     Returns: | ||||
|  | ||||
|     """ | ||||
|     return "".join([chr(random.randint(33, 126)) for _ in range(length)]) | ||||
|  | ||||
|  | ||||
| def random_hex_string(length: int) -> str: | ||||
|     """ | ||||
|     生成随机十六进制字符串 | ||||
|     Args: | ||||
|         length: | ||||
|  | ||||
|     Returns: | ||||
|  | ||||
|     """ | ||||
|     return "".join([random.choice("0123456789abcdef") for _ in range(length)]) | ||||
							
								
								
									
										0
									
								
								src/utils/message/union.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/utils/message/union.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								src/utils/nb/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/utils/nb/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										15
									
								
								src/utils/network/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								src/utils/network/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| from aiohttp import ClientSession | ||||
|  | ||||
|  | ||||
| async def simple_get(url: str) -> str: | ||||
|     """ | ||||
|     简单异步get请求 | ||||
|     Args: | ||||
|         url: | ||||
|  | ||||
|     Returns: | ||||
|  | ||||
|     """ | ||||
|     async with ClientSession() as session: | ||||
|         async with session.get(url) as resp: | ||||
|             return await resp.text() | ||||
							
								
								
									
										3
									
								
								src/utils/satori_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								src/utils/satori_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| from .user_info import user_infos | ||||
| from .count_friends import count_friends | ||||
| from .count_groups import count_groups | ||||
							
								
								
									
										13
									
								
								src/utils/satori_utils/count_friends.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								src/utils/satori_utils/count_friends.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | ||||
| from nonebot.adapters import satori | ||||
|  | ||||
|  | ||||
| async def count_friends(bot: satori.Bot) -> int: | ||||
|     cnt: int = 0 | ||||
|  | ||||
|     friend_response = await bot.friend_list() | ||||
|     while friend_response.next is not None: | ||||
|         cnt += len(friend_response.data) | ||||
|         friend_response = await bot.friend_list(next_token=friend_response.next) | ||||
|  | ||||
|     cnt += len(friend_response.data) | ||||
|     return cnt - 1 | ||||
							
								
								
									
										13
									
								
								src/utils/satori_utils/count_groups.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								src/utils/satori_utils/count_groups.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | ||||
| from nonebot.adapters import satori | ||||
|  | ||||
|  | ||||
| async def count_groups(bot: satori.Bot) -> int: | ||||
|     cnt: int = 0 | ||||
|  | ||||
|     group_response = await bot.guild_list() | ||||
|     while group_response.next is not None: | ||||
|         cnt += len(group_response.data) | ||||
|         group_response = await bot.friend_list(next_token=group_response.next) | ||||
|  | ||||
|     cnt += len(group_response.data) | ||||
|     return cnt - 1 | ||||
							
								
								
									
										64
									
								
								src/utils/satori_utils/user_info.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								src/utils/satori_utils/user_info.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,64 @@ | ||||
| import nonebot | ||||
|  | ||||
| from nonebot.adapters import satori | ||||
| from nonebot.adapters.satori.models import User | ||||
|  | ||||
|  | ||||
| class UserInfo: | ||||
|     user_infos: dict = {} | ||||
|  | ||||
|     async def load_friends(self, bot: satori.Bot): | ||||
|         nonebot.logger.info("Update user info from friends") | ||||
|         friend_response = await bot.friend_list() | ||||
|         while friend_response.next is not None: | ||||
|             for i in friend_response.data: | ||||
|                 i: User = i | ||||
|                 self.user_infos[str(i.id)] = i | ||||
|             friend_response = await bot.friend_list(next_token=friend_response.next) | ||||
|  | ||||
|         for i in friend_response.data: | ||||
|             i: User = i | ||||
|             self.user_infos[str(i.id)] = i | ||||
|  | ||||
|         nonebot.logger.info("Finish update user info") | ||||
|  | ||||
|     async def get(self, uid: int | str) -> User | None: | ||||
|         try: | ||||
|             return self.user_infos[str(uid)] | ||||
|         except KeyError: | ||||
|             return None | ||||
|  | ||||
|     async def put(self, user: User) -> bool: | ||||
|         """ | ||||
|         向用户信息数据库中添加/修改一项,返回值仅代表数据是否变更,不代表操作是否成功 | ||||
|         Args: | ||||
|             user: 要加入数据库的用户 | ||||
|  | ||||
|         Returns: 当数据库中用户信息发生变化时返回 True, 否则返回 False | ||||
|  | ||||
|         """ | ||||
|         try: | ||||
|             old_user: User = self.user_infos[str(user.id)] | ||||
|             attr_edited = False | ||||
|             if user.name is not None: | ||||
|                 if old_user.name != user.name: | ||||
|                     attr_edited = True | ||||
|                     self.user_infos[str(user.id)].name = user.name | ||||
|             if user.nick is not None: | ||||
|                 if old_user.nick != user.nick: | ||||
|                     attr_edited = True | ||||
|                     self.user_infos[str(user.id)].nick = user.nick | ||||
|             if user.avatar is not None: | ||||
|                 if old_user.avatar != user.avatar: | ||||
|                     attr_edited = True | ||||
|                     self.user_infos[str(user.id)].avatar = user.avatar | ||||
|             return attr_edited | ||||
|         except KeyError: | ||||
|             self.user_infos[str(user.id)] = user | ||||
|             return True | ||||
|  | ||||
|     def __init__(self): | ||||
|         pass | ||||
|  | ||||
|  | ||||
| user_infos = UserInfo() | ||||
		Reference in New Issue
	
	Block a user