1
0
forked from bot/app

feat: 统一双引号

This commit is contained in:
2024-03-26 17:14:41 +08:00
parent 04fc9c3dd7
commit ecbe06a9e8
18 changed files with 131 additions and 359 deletions

View File

@ -19,14 +19,14 @@ def load_from_yaml(file: str) -> dict:
global config
nonebot.logger.debug("Loading config from %s" % file)
if not os.path.exists(file):
nonebot.logger.warning(f'Config file {file} not found, created default config, please modify it and restart')
with open(file, 'w', encoding='utf-8') as f:
nonebot.logger.warning(f"Config file {file} not found, created default config, please modify it and restart")
with open(file, "w", encoding="utf-8") as f:
yaml.dump(BasicConfig().dict(), f, default_flow_style=False)
with open(file, 'r', encoding='utf-8') as f:
with open(file, "r", encoding="utf-8") as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
config = conf
if conf is None:
nonebot.logger.warning(f'Config file {file} is empty, use default config. please modify it and restart')
nonebot.logger.warning(f"Config file {file} is empty, use default config. please modify it and restart")
conf = BasicConfig().dict()
return conf

View File

@ -226,13 +226,13 @@ class Database(BaseORMAdapter):
return_data = {}
for k, v in data.items():
if isinstance(v, LiteModel):
return_data[f'{self.FOREIGNID}{k}'] = f'{self.ID}:{v.__class__.__name__}:{self.upsert(v)}'
return_data[f"{self.FOREIGNID}{k}"] = f"{self.ID}:{v.__class__.__name__}:{self.upsert(v)}"
elif isinstance(v, list):
return_data[f'{self.LIST}{k}'] = self._flat(v)
return_data[f"{self.LIST}{k}"] = self._flat(v)
elif isinstance(v, dict):
return_data[f'{self.DICT}{k}'] = self._flat(v)
return_data[f"{self.DICT}{k}"] = self._flat(v)
elif isinstance(v, BaseIterable):
return_data[f'{self.JSON}{k}'] = self._flat(v)
return_data[f"{self.JSON}{k}"] = self._flat(v)
else:
return_data[k] = v
@ -240,7 +240,7 @@ class Database(BaseORMAdapter):
return_data = []
for v in data:
if isinstance(v, LiteModel):
return_data.append(f'{self.ID}:{v.__class__.__name__}:{self.upsert(v)}')
return_data.append(f"{self.ID}:{v.__class__.__name__}:{self.upsert(v)}")
elif isinstance(v, list):
return_data.append(self._flat(v))
elif isinstance(v, dict):
@ -250,7 +250,7 @@ class Database(BaseORMAdapter):
else:
return_data.append(v)
else:
raise ValueError('数据类型错误')
raise ValueError("数据类型错误")
return json.dumps(return_data)
@ -263,7 +263,7 @@ class Database(BaseORMAdapter):
Returns:
"""
return self.cursor.execute(f'SELECT * FROM sqlite_master WHERE type = "table" AND name = ?', (table_name,)).fetchone()
return self.cursor.execute(f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = ?", (table_name,)).fetchone()
def first(self, model: type(LiteModel), conditions, *args, default: Any = None) -> LiteModel | None:
"""查询第一条数据
@ -281,7 +281,7 @@ class Database(BaseORMAdapter):
if not self._detect_for_table(table_name):
return default
self.cursor.execute(f'SELECT * FROM {table_name} WHERE {conditions}', args)
self.cursor.execute(f"SELECT * FROM {table_name} WHERE {conditions}", args)
if row_data := self.cursor.fetchone():
data = dict(row_data)
return model(**self.convert_to_dict(data))
@ -304,9 +304,9 @@ class Database(BaseORMAdapter):
return default
if conditions:
self.cursor.execute(f'SELECT * FROM {table_name} WHERE {conditions}', args)
self.cursor.execute(f"SELECT * FROM {table_name} WHERE {conditions}", args)
else:
self.cursor.execute(f'SELECT * FROM {table_name}')
self.cursor.execute(f"SELECT * FROM {table_name}")
if row_datas := self.cursor.fetchall():
datas = [dict(row_data) for row_data in row_datas]
return [model(**self.convert_to_dict(d)) for d in datas] if datas else default
@ -327,8 +327,8 @@ class Database(BaseORMAdapter):
if not self._detect_for_table(table_name):
return
nonebot.logger.debug(f'DELETE FROM {table_name} WHERE {conditions}')
self.cursor.execute(f'DELETE FROM {table_name} WHERE {conditions}', args)
nonebot.logger.debug(f"DELETE FROM {table_name} WHERE {conditions}")
self.cursor.execute(f"DELETE FROM {table_name} WHERE {conditions}", args)
self.conn.commit()
def convert_to_dict(self, data: dict) -> dict:
@ -346,8 +346,8 @@ class Database(BaseORMAdapter):
new_d = {}
for k, v in d.items():
if k.startswith(self.FOREIGNID):
new_d[k.replace(self.FOREIGNID, '')] = load(
dict(self.cursor.execute(f'SELECT * FROM {v.split(":", 2)[1]} WHERE id = ?', (v.split(":", 2)[2],)).fetchone()))
new_d[k.replace(self.FOREIGNID, "")] = load(
dict(self.cursor.execute(f"SELECT * FROM {v.split(':', 2)[1]} WHERE id = ?", (v.split(":", 2)[2],)).fetchone()))
elif k.startswith(self.LIST):
if v == '': v = '[]'

View File

@ -6,36 +6,36 @@ from liteyuki.utils.data import LiteModel, Database as DB
DATA_PATH = "data/liteyuki"
user_db = DB(os.path.join(DATA_PATH, 'users.ldb'))
group_db = DB(os.path.join(DATA_PATH, 'groups.ldb'))
plugin_db = DB(os.path.join(DATA_PATH, 'plugins.ldb'))
common_db = DB(os.path.join(DATA_PATH, 'common.ldb'))
user_db = DB(os.path.join(DATA_PATH, "users.ldb"))
group_db = DB(os.path.join(DATA_PATH, "groups.ldb"))
plugin_db = DB(os.path.join(DATA_PATH, "plugins.ldb"))
common_db = DB(os.path.join(DATA_PATH, "common.ldb"))
class User(LiteModel):
user_id: str = Field(str(), alias='user_id')
username: str = Field(str(), alias='username')
profile: dict[str, str] = Field(dict(), alias='profile')
enabled_plugins: list[str] = Field(list(), alias='enabled_plugins')
disabled_plugins: list[str] = Field(list(), alias='disabled_plugins')
user_id: str = Field(str(), alias="user_id")
username: str = Field(str(), alias="username")
profile: dict[str, str] = Field(dict(), alias="profile")
enabled_plugins: list[str] = Field(list(), alias="enabled_plugins")
disabled_plugins: list[str] = Field(list(), alias="disabled_plugins")
class GroupChat(LiteModel):
# Group是一个关键字所以这里用GroupChat
group_id: str = Field(str(), alias='group_id')
group_name: str = Field(str(), alias='group_name')
enabled_plugins: list[str] = Field([], alias='enabled_plugins')
disabled_plugins: list[str] = Field([], alias='disabled_plugins')
group_id: str = Field(str(), alias="group_id")
group_name: str = Field(str(), alias="group_name")
enabled_plugins: list[str] = Field([], alias="enabled_plugins")
disabled_plugins: list[str] = Field([], alias="disabled_plugins")
class InstalledPlugin(LiteModel):
module_name: str = Field(str(), alias='module_name')
version: str = Field(str(), alias='version')
module_name: str = Field(str(), alias="module_name")
version: str = Field(str(), alias="version")
class GlobalPlugin(LiteModel):
module_name: str = Field(str(), alias='module_name')
enabled: bool = Field(True, alias='enabled')
module_name: str = Field(str(), alias="module_name")
enabled: bool = Field(True, alias="enabled")
def auto_migrate():

View File

@ -1,20 +1,11 @@
import copy
import json
import os
import pickle
import sqlite3
import types
from types import NoneType
from collections.abc import Iterable
from pydantic import BaseModel, Field
from typing import Any
LOG_OUT = True
def log(*args, **kwargs):
if LOG_OUT:
print(*args, **kwargs)
import pydantic
from pydantic import BaseModel
class LiteModel(BaseModel):
@ -85,7 +76,12 @@ class Database:
elif model.TABLE_NAME not in table_list:
raise ValueError(f"数据模型 {model.__class__.__name__} 的表 {model.TABLE_NAME} 不存在,请先迁移")
else:
self._save(model.model_dump(by_alias=True))
if pydantic.__version__ < "1.8.2":
# 兼容pydantic 1.8.2以下版本
model_dict = model.dict(by_alias=True)
else:
model_dict = model.model_dump(by_alias=True)
self._save(model_dict)
def _save(self, obj: Any) -> Any:
# obj = copy.deepcopy(obj)

View File

@ -30,14 +30,14 @@ def load_from_lang(file_path: str, lang_code: str = None):
"""
try:
if lang_code is None:
lang_code = os.path.basename(file_path).split('.')[0]
with open(file_path, 'r', encoding='utf-8') as file:
lang_code = os.path.basename(file_path).split(".")[0]
with open(file_path, "r", encoding="utf-8") as file:
data = {}
for line in file:
line = line.strip()
if not line or line.startswith('#'): # 空行或注释
if not line or line.startswith("#"): # 空行或注释
continue
key, value = line.split('=', 1)
key, value = line.split("=", 1)
data[key.strip()] = value.strip()
if lang_code not in _language_data:
_language_data[lang_code] = {}
@ -56,8 +56,8 @@ def load_from_json(file_path: str, lang_code: str = None):
"""
try:
if lang_code is None:
lang_code = os.path.basename(file_path).split('.')[0]
with open(file_path, 'r', encoding='utf-8') as file:
lang_code = os.path.basename(file_path).split(".")[0]
with open(file_path, "r", encoding="utf-8") as file:
data = json.load(file)
if lang_code not in _language_data:
_language_data[lang_code] = {}
@ -77,9 +77,9 @@ def load_from_dir(dir_path: str):
try:
file_path = os.path.join(dir_path, file)
if os.path.isfile(file_path):
if file.endswith('.lang'):
if file.endswith(".lang"):
load_from_lang(file_path)
elif file.endswith('.json'):
elif file.endswith(".json"):
load_from_json(file_path)
except Exception as e:
nonebot.logger.error(f"Failed to load language data from {file}: {e}")
@ -140,7 +140,7 @@ def get_user_lang(user_id: str) -> Language:
username="Unknown"
))
return Language(user.profile.get('lang', config.get("default_language", get_system_lang_code())))
return Language(user.profile.get("lang", config.get("default_language", get_system_lang_code())))
def get_system_lang_code() -> str:

View File

@ -61,11 +61,11 @@ def init_log():
show_icon = config.get("log_icon", True)
lang = Language(config.get("default_language", get_system_lang_code()))
debug = lang.get('log.debug', default="==DEBUG")
info = lang.get('log.info', default="===INFO")
success = lang.get('log.success', default="SUCCESS")
warning = lang.get('log.warning', default="WARNING")
error = lang.get('log.error', default="==ERROR")
debug = lang.get("log.debug", default="==DEBUG")
info = lang.get("log.info", default="===INFO")
success = lang.get("log.success", default="SUCCESS")
warning = lang.get("log.warning", default="WARNING")
error = lang.get("log.error", default="==ERROR")
logger.level("DEBUG", color="<blue>", icon=f"{'*️⃣' if show_icon else ''}{debug}")
logger.level("INFO", color="<white>", icon=f"{'' if show_icon else ''}{info}")

View File

@ -4,4 +4,4 @@ T_Bot = v11.Bot | v12.Bot
T_GroupMessageEvent = v11.GroupMessageEvent | v12.GroupMessageEvent
T_PrivateMessageEvent = v11.PrivateMessageEvent | v12.PrivateMessageEvent
T_MessageEvent = v11.MessageEvent | v12.MessageEvent
T_Message = v11.Message | v12.Message
T_Message = v11.Message | v12.Message

View File

@ -1,14 +1,16 @@
from urllib.parse import quote
import nonebot
from nonebot.adapters.onebot import v11, v12
from typing import Any
from .tools import de_escape, encode_url
from .tools import encode_url
from .ly_typing import T_Bot, T_MessageEvent
async def send_markdown(markdown: str, bot: T_Bot, *, message_type: str = None, session_id: str | int = None, event: T_MessageEvent = None, **kwargs) -> dict[
str, Any]:
formatted_md = de_escape(markdown).replace("\n", r"\n").replace("\"", r'\\\"')
str, Any]:
formatted_md = v11.unescape(markdown).replace("\n", r"\n").replace("\"", r'\\\"')
if event is not None and message_type is None:
message_type = event.message_type
session_id = event.user_id if event.message_type == "private" else event.group_id
@ -89,7 +91,7 @@ class Markdown:
markdown格式的可点击回调按钮
"""
return f"[{name}](mqqapi://aio/inlinecmd?command={encode_url(cmd)}&reply={str(reply).lower()}&enter={str(enter).lower()})"
return f"[{name}](mqqapi://aio/inlinecmd?command={quote(cmd)}&reply={str(reply).lower()}&enter={str(enter).lower()})"
@staticmethod
def link(name: str, url: str) -> str:

View File

@ -1,207 +0,0 @@
import os
import pickle
import sqlite3
from types import NoneType
from typing import Any
import nonebot
from pydantic import BaseModel, Field
class LiteModel(BaseModel):
"""轻量级模型基类
类型注解统一使用Python3.9的PEP585标准如需使用泛型请使用typing模块的泛型类型
不允许使用id, table_name以及其他SQLite关键字作为字段名不允许使用JSON和ID必须指定默认值且默认值类型必须与字段类型一致
"""
__ID__: int = Field(None, alias='id')
__TABLE_NAME__: str = Field(None, alias='table_name')
class Database:
TYPE_MAPPING = {
int : "INTEGER",
float : "REAL",
str : "TEXT",
bool : "INTEGER",
bytes : "BLOB",
NoneType: "NULL",
dict : "BLOB", # LITEYUKIDICT{key_name}
list : "BLOB", # LITEYUKILIST{key_name}
tuple : "BLOB", # LITEYUKITUPLE{key_name}
set : "BLOB", # LITEYUKISET{key_name}
}
# 基础类型
BASIC_TYPE = [int, float, str, bool, bytes, NoneType]
# 可序列化类型
ITERABLE_TYPE = [dict, list, tuple, set]
LITEYUKI = "LITEYUKI"
# 字段前缀映射,默认基础类型为""
FIELD_PREFIX_MAPPING = {
dict : f"{LITEYUKI}DICT",
list : f"{LITEYUKI}LIST",
tuple : f"{LITEYUKI}TUPLE",
set : f"{LITEYUKI}SET",
type(LiteModel): f"{LITEYUKI}MODEL"
}
def __init__(self, db_name: str):
if not os.path.exists(os.path.dirname(db_name)):
os.makedirs(os.path.dirname(db_name))
self.conn = sqlite3.connect(db_name) # 连接对象
self.conn.row_factory = sqlite3.Row # 以字典形式返回查询结果
self.cursor = self.conn.cursor() # 游标对象
def auto_migrate(self, *args: LiteModel):
"""
自动迁移模型
Args:
*args: 模型类实例化对象,支持空默认值,不支持嵌套迁移
Returns:
"""
for model in args:
if not model.__TABLE_NAME__:
raise ValueError(f"数据模型{model.__class__.__name__}未提供表名")
# 若无则创建表
self.cursor.execute(
f'CREATE TABLE IF NOT EXISTS {model.__TABLE_NAME__} (id INTEGER PRIMARY KEY AUTOINCREMENT)'
)
# 获取表结构
new_fields, new_stored_types = (
zip(
*[(self._get_stored_field_prefix(model.__getattribute__(field)) + field, self._get_stored_type(model.__getattribute__(field)))
for field in model.__annotations__]
)
)
# 原有的字段列表
existing_fields = self.cursor.execute(f'PRAGMA table_info({model.__TABLE_NAME__})').fetchall()
existing_types = [field['name'] for field in existing_fields]
# 检测缺失字段由于SQLite是动态类型所以不需要检测类型
for n_field, n_type in zip(new_fields, new_stored_types):
if n_field not in existing_types:
nonebot.logger.debug(f'ALTER TABLE {model.__TABLE_NAME__} ADD COLUMN {n_field} {n_type}')
self.cursor.execute(
f'ALTER TABLE {model.__TABLE_NAME__} ADD COLUMN {n_field} {n_type}'
)
# 检测多余字段进行删除
for e_field in existing_types:
if e_field not in new_fields and e_field not in ['id']:
nonebot.logger.debug(f'ALTER TABLE {model.__TABLE_NAME__} DROP COLUMN {e_field}')
self.cursor.execute(
f'ALTER TABLE {model.__TABLE_NAME__} DROP COLUMN {e_field}'
)
self.conn.commit()
def save(self, *args: LiteModel) -> [int | tuple[int, ...]]:
"""
保存或更新模型
Args:
*args: 模型类实例化对象,支持空默认值,不支持嵌套迁移
Returns:
"""
ids = []
for model in args:
if not model.__TABLE_NAME__:
raise ValueError(f"数据模型{model.__class__.__name__}未提供表名")
if not self.cursor.execute(f'PRAGMA table_info({model.__TABLE_NAME__})').fetchall():
raise ValueError(f"数据表{model.__TABLE_NAME__}不存在,请先迁移{model.__class__.__name__}模型")
stored_fields, stored_values = [], []
for r_field in model.__annotations__:
r_value = model.__getattribute__(r_field)
stored_fields.append(self._get_stored_field_prefix(r_value) + r_field)
if type(r_value) in Database.BASIC_TYPE:
# int str float bool bytes NoneType
stored_values.append(r_value)
elif type(r_value) in Database.ITERABLE_TYPE:
# dict list tuple set
stored_values.append(pickle.dumps(self._flat_save(r_value)))
elif isinstance(r_value, LiteModel):
# LiteModel TABLE_NAME:ID
stored_values.append(f"{r_value.__TABLE_NAME__}:{self.save(r_value)}")
else:
raise ValueError(f"不支持的数据类型{type(r_value)}")
nonebot.logger.debug(f"INSERT OR REPLACE INTO {model.__TABLE_NAME__} ({','.join(stored_fields)}) VALUES ({','.join([_ for _ in stored_values])})")
self.cursor.execute(
f"INSERT OR REPLACE INTO {model.__TABLE_NAME__} ({','.join(stored_fields)}) VALUES ({','.join(['?' for _ in stored_values])})",
stored_values
)
ids.append(self.cursor.lastrowid)
self.conn.commit()
return tuple(ids) if len(ids) > 1 else ids[0]
# 检测id字段是否有1有则更新无则插入
def _flat_save(self, obj) -> Any:
"""扁平化存储
Args:
obj: 需要存储的对象
Returns:
存储的字节流
"""
# TODO 递归扁平化存储
if type(obj) in Database.ITERABLE_TYPE:
for i, item in enumerate(obj) if type(obj) in [list, tuple, set] else obj.items():
if type(item) in Database.BASIC_TYPE:
continue
elif type(item) in Database.ITERABLE_TYPE:
obj[i] = pickle.dumps(self._flat_save(item))
elif isinstance(item, LiteModel):
obj[i] = f"{item.__TABLE_NAME__}:{self.save(item)}"
else:
raise ValueError(f"不支持的数据类型{type(item)}")
else:
raise ValueError(f"不支持的数据类型{type(obj)}")
@staticmethod
def _get_stored_field_prefix(value) -> str:
"""获取存储字段前缀,一定在后加上字段名
LiteModel -> LITEYUKIID
dict -> LITEYUKIDICT
list -> LITEYUKILIST
tuple -> LITEYUKITUPLE
set -> LITEYUKISET
* -> ""
Args:
value: 储存的值
Returns:
Sqlite3存储字段
"""
return Database.FIELD_PREFIX_MAPPING.get(type(value), "")
@staticmethod
def _get_stored_type(value) -> str:
"""获取存储类型
Args:
value: 储存的值
Returns:
Sqlite3存储类型
"""
return Database.TYPE_MAPPING.get(type(value), "TEXT")

View File

@ -35,23 +35,6 @@ def convert_size(size: int, precision: int = 2, add_unit: bool = True, suffix: s
return f"{size:.{precision}f}"
def de_escape(text: str) -> str:
str_map = {
"&#91;": "[",
"&#93;": "]",
"&amp;": "&",
"&#44;": ",",
}
for k, v in str_map.items():
text = text.replace(k, v)
return text
def encode_url(text: str) -> str:
return quote(text, safe="")
def keywords_in_text(keywords: list[str], text: str, all_matched: bool) -> bool:
"""
检查关键词是否在文本中