mirror of
https://github.com/LiteyukiStudio/LiteyukiBot.git
synced 2025-07-27 13:20:55 +00:00
fix: 数据库支持
This commit is contained in:
@ -1,374 +1,358 @@
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import sqlite3
|
||||
import types
|
||||
from abc import ABC
|
||||
from collections.abc import Iterable
|
||||
|
||||
import nonebot
|
||||
from pydantic import BaseModel
|
||||
from types import NoneType
|
||||
from typing import Any
|
||||
|
||||
BaseIterable = list | tuple | set | dict
|
||||
import nonebot
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LiteModel(BaseModel):
|
||||
"""轻量级模型基类
|
||||
类型注解统一使用Python3.9的PEP585标准,如需使用泛型请使用typing模块的泛型类型
|
||||
"""
|
||||
TABLE_NAME: str = None
|
||||
id: int = None
|
||||
|
||||
|
||||
class BaseORMAdapter(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def auto_migrate(self, *args, **kwargs):
|
||||
"""自动迁移
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def upsert(self, *args, **kwargs):
|
||||
"""存储数据
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def first(self, *args, **kwargs):
|
||||
"""查询第一条数据
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def all(self, *args, **kwargs):
|
||||
"""查询所有数据
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
"""删除数据
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
"""更新数据
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def dump(self, *args, **kwargs):
|
||||
if pydantic.__version__ < "1.8.2":
|
||||
return self.dict(by_alias=True)
|
||||
else:
|
||||
return self.model_dump(by_alias=True)
|
||||
|
||||
|
||||
class Database(BaseORMAdapter):
|
||||
"""SQLiteORM适配器,严禁使用`FORIEGNID`和`JSON`作为主键前缀,严禁使用`$ID:`作为字符串值前缀
|
||||
|
||||
Attributes:
|
||||
|
||||
"""
|
||||
type_map = {
|
||||
# default: TEXT
|
||||
str : 'TEXT',
|
||||
int : 'INTEGER',
|
||||
float: 'REAL',
|
||||
bool : 'INTEGER',
|
||||
list : 'TEXT'
|
||||
}
|
||||
|
||||
DEFAULT_VALUE = {
|
||||
'TEXT' : '',
|
||||
'INTEGER': 0,
|
||||
'REAL' : 0.0
|
||||
}
|
||||
|
||||
FOREIGNID = 'FOREIGNID'
|
||||
JSON = 'JSON'
|
||||
LIST = 'LIST'
|
||||
DICT = 'DICT'
|
||||
ID = '$ID'
|
||||
|
||||
class Database:
|
||||
def __init__(self, db_name: str):
|
||||
super().__init__()
|
||||
if not os.path.exists(os.path.dirname(db_name)):
|
||||
|
||||
if os.path.dirname(db_name) != "" and not os.path.exists(os.path.dirname(db_name)):
|
||||
os.makedirs(os.path.dirname(db_name))
|
||||
|
||||
self.db_name = db_name
|
||||
self.conn = sqlite3.connect(db_name)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def auto_migrate(self, *args: type(LiteModel)):
|
||||
"""自动迁移,检测新模型字段和原有表字段的差异,如有差异自动增删新字段
|
||||
|
||||
def first(self, model: LiteModel, condition: str, *args: Any, default: Any = None) -> LiteModel | Any | None:
|
||||
"""查询第一个
|
||||
Args:
|
||||
*args: 模型类
|
||||
model: 数据模型实例
|
||||
condition: 查询条件,不给定则查询所有
|
||||
*args: 参数化查询参数
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
table_name = ''
|
||||
all_results = self.all(model, condition, *args)
|
||||
return all_results[0] if all_results else default
|
||||
|
||||
def all(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> list[LiteModel | Any] | None:
|
||||
"""查询所有
|
||||
Args:
|
||||
model: 数据模型实例
|
||||
condition: 查询条件,不给定则查询所有
|
||||
*args: 参数化查询参数
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
table_name = model.TABLE_NAME
|
||||
model_type = type(model)
|
||||
if not table_name:
|
||||
raise ValueError(f"数据模型{model_type.__name__}未提供表名")
|
||||
|
||||
# condition = f"WHERE {condition}"
|
||||
# print(f"SELECT * FROM {table_name} {condition}", args)
|
||||
# if len(args) == 0:
|
||||
# results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}").fetchall()
|
||||
# else:
|
||||
# results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}", args).fetchall()
|
||||
if condition:
|
||||
results = self.cursor.execute(f"SELECT * FROM {table_name} WHERE {condition}", args).fetchall()
|
||||
else:
|
||||
results = self.cursor.execute(f"SELECT * FROM {table_name}").fetchall()
|
||||
fields = [description[0] for description in self.cursor.description]
|
||||
if not results:
|
||||
return default
|
||||
else:
|
||||
return [model_type(**self._load(dict(zip(fields, result)))) for result in results]
|
||||
|
||||
def upsert(self, *args: LiteModel):
|
||||
"""增/改操作
|
||||
Args:
|
||||
*args:
|
||||
|
||||
Returns:
|
||||
"""
|
||||
table_list = [item[0] for item in self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
|
||||
for model in args:
|
||||
model: type(LiteModel)
|
||||
# 检测并创建表,若模型未定义id字段则使用自增主键,有定义的话使用id字段,且id有可能为字符串
|
||||
table_name = model.__name__
|
||||
if 'id' in model.__annotations__ and model.__annotations__['id'] is not None:
|
||||
# 如果模型定义了id字段,那么使用模型的id字段
|
||||
id_type = self.type_map.get(model.__annotations__['id'], 'TEXT')
|
||||
self.cursor.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id {id_type} PRIMARY KEY)')
|
||||
if not model.TABLE_NAME:
|
||||
raise ValueError(f"数据模型 {model.__class__.__name__} 未提供表名")
|
||||
elif model.TABLE_NAME not in table_list:
|
||||
raise ValueError(f"数据模型 {model.__class__.__name__} 的表 {model.TABLE_NAME} 不存在,请先迁移")
|
||||
else:
|
||||
# 如果模型未定义id字段,那么使用自增主键
|
||||
self.cursor.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER PRIMARY KEY AUTOINCREMENT)')
|
||||
# 获取表字段
|
||||
self.cursor.execute(f'PRAGMA table_info({table_name})')
|
||||
table_fields = self.cursor.fetchall()
|
||||
table_fields = [field[1] for field in table_fields]
|
||||
self._save(model.dump(by_alias=True))
|
||||
|
||||
raw_fields, raw_types = zip(*model.__annotations__.items())
|
||||
# 获取模型字段,若有模型则添加FOREIGNID前缀,若为BaseIterable则添加JSON前缀,用多行if判断
|
||||
model_fields = []
|
||||
model_types = []
|
||||
for field, r_type in zip(raw_fields, raw_types):
|
||||
if isinstance(r_type, type(LiteModel)):
|
||||
model_fields.append(f'{self.FOREIGNID}{field}')
|
||||
model_types.append('TEXT')
|
||||
elif r_type in [list[str], list[int], list[float], list[bool], list]:
|
||||
model_fields.append(f'{self.LIST}{field}')
|
||||
model_types.append('TEXT')
|
||||
elif r_type in [dict[str, str], dict[str, int], dict[str, float], dict[str, bool], dict]:
|
||||
model_fields.append(f'{self.DICT}{field}')
|
||||
model_types.append('TEXT')
|
||||
elif isinstance(r_type, types.GenericAlias):
|
||||
model_fields.append(f'{self.JSON}{field}')
|
||||
model_types.append('TEXT')
|
||||
def _save(self, obj: Any) -> Any:
|
||||
# obj = copy.deepcopy(obj)
|
||||
if isinstance(obj, dict):
|
||||
table_name = obj.get("TABLE_NAME")
|
||||
row_id = obj.get("id")
|
||||
new_obj = {}
|
||||
for field, value in obj.items():
|
||||
if isinstance(value, self.ITERABLE_TYPE):
|
||||
new_obj[self._get_stored_field_prefix(value) + field] = self._save(value) # self._save(value) # -> bytes
|
||||
elif isinstance(value, self.BASIC_TYPE):
|
||||
new_obj[field] = value
|
||||
else:
|
||||
model_fields.append(field)
|
||||
model_types.append(self.type_map.get(r_type, 'TEXT'))
|
||||
|
||||
# 检测新字段或字段类型是否有变化,有则增删字段,已经加了前缀类型
|
||||
for field_changed, type_, r_type in zip(model_fields, model_types, raw_types):
|
||||
if field_changed not in table_fields:
|
||||
nonebot.logger.debug(f'ALTER TABLE {table_name} ADD COLUMN {field_changed} {type_}')
|
||||
self.cursor.execute(f'ALTER TABLE {table_name} ADD COLUMN {field_changed} {type_}')
|
||||
# 在原有的行中添加新字段对应类型的默认值,从DEFAULT_TYPE中获取
|
||||
self.cursor.execute(f'UPDATE {table_name} SET {field_changed} = ? WHERE {field_changed} IS NULL', (self.DEFAULT_VALUE.get(type_, ""),))
|
||||
|
||||
# 检测多余字段,除了id字段
|
||||
for field in table_fields:
|
||||
if field not in model_fields and field != 'id':
|
||||
nonebot.logger.debug(f'ALTER TABLE {table_name} DROP COLUMN {field}')
|
||||
self.cursor.execute(f'ALTER TABLE {table_name} DROP COLUMN {field}')
|
||||
|
||||
self.conn.commit()
|
||||
nonebot.logger.debug(f'Table {table_name} migrated successfully')
|
||||
|
||||
def upsert(self, *models: LiteModel) -> int | tuple:
|
||||
"""存储数据,检查id字段,如果有id字段则更新,没有则插入
|
||||
|
||||
Args:
|
||||
models: 数据
|
||||
|
||||
Returns:
|
||||
id: 数据id,如果有多个数据则返回id元组
|
||||
"""
|
||||
|
||||
ids = []
|
||||
for model in models:
|
||||
table_name = model.__class__.__name__
|
||||
if not self._detect_for_table(table_name):
|
||||
raise ValueError(f'表{table_name}不存在,请先迁移')
|
||||
key_list = []
|
||||
value_list = []
|
||||
# 处理外键,添加前缀'$IDFieldName'
|
||||
for field, value in model.__dict__.items():
|
||||
if isinstance(value, LiteModel):
|
||||
key_list.append(f'{self.FOREIGNID}{field}')
|
||||
value_list.append(f'{self.ID}:{value.__class__.__name__}:{self.upsert(value)}')
|
||||
elif isinstance(value, list):
|
||||
key_list.append(f'{self.LIST}{field}')
|
||||
value_list.append(self._flat(value))
|
||||
elif isinstance(value, dict):
|
||||
key_list.append(f'{self.DICT}{field}')
|
||||
value_list.append(self._flat(value))
|
||||
elif isinstance(value, BaseIterable):
|
||||
key_list.append(f'{self.JSON}{field}')
|
||||
value_list.append(self._flat(value))
|
||||
else:
|
||||
key_list.append(field)
|
||||
value_list.append(value)
|
||||
# 更新或插入数据,用?占位
|
||||
nonebot.logger.debug(f'INSERT OR REPLACE INTO {table_name} ({",".join(key_list)}) VALUES ({",".join(["?" for _ in key_list])})')
|
||||
self.cursor.execute(f'INSERT OR REPLACE INTO {table_name} ({",".join(key_list)}) VALUES ({",".join(["?" for _ in key_list])})', value_list)
|
||||
|
||||
ids.append(self.cursor.lastrowid)
|
||||
self.conn.commit()
|
||||
return ids[0] if len(ids) == 1 else tuple(ids)
|
||||
|
||||
def _flat(self, data: Iterable) -> str:
|
||||
"""扁平化数据,返回扁平化对象
|
||||
|
||||
Args:
|
||||
|
||||
data: 数据,可迭代对象
|
||||
|
||||
Returns: json字符串
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
return_data = {}
|
||||
for k, v in data.items():
|
||||
if isinstance(v, LiteModel):
|
||||
return_data[f"{self.FOREIGNID}{k}"] = f"{self.ID}:{v.__class__.__name__}:{self.upsert(v)}"
|
||||
elif isinstance(v, list):
|
||||
return_data[f"{self.LIST}{k}"] = self._flat(v)
|
||||
elif isinstance(v, dict):
|
||||
return_data[f"{self.DICT}{k}"] = self._flat(v)
|
||||
elif isinstance(v, BaseIterable):
|
||||
return_data[f"{self.JSON}{k}"] = self._flat(v)
|
||||
else:
|
||||
return_data[k] = v
|
||||
|
||||
elif isinstance(data, list | tuple | set):
|
||||
return_data = []
|
||||
for v in data:
|
||||
if isinstance(v, LiteModel):
|
||||
return_data.append(f"{self.ID}:{v.__class__.__name__}:{self.upsert(v)}")
|
||||
elif isinstance(v, list):
|
||||
return_data.append(self._flat(v))
|
||||
elif isinstance(v, dict):
|
||||
return_data.append(self._flat(v))
|
||||
elif isinstance(v, BaseIterable):
|
||||
return_data.append(self._flat(v))
|
||||
else:
|
||||
return_data.append(v)
|
||||
else:
|
||||
raise ValueError("数据类型错误")
|
||||
|
||||
return json.dumps(return_data)
|
||||
|
||||
def _detect_for_table(self, table_name: str) -> bool:
|
||||
"""在进行增删查改前检测表是否存在
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return self.cursor.execute(f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = ?", (table_name,)).fetchone()
|
||||
|
||||
def first(self, model: type(LiteModel), conditions, *args, default: Any = None) -> LiteModel | None:
|
||||
"""查询第一条数据
|
||||
|
||||
Args:
|
||||
model: 模型
|
||||
conditions: 查询条件
|
||||
*args: 参数化查询条件参数
|
||||
default: 未查询到结果默认返回值
|
||||
|
||||
Returns: 数据
|
||||
"""
|
||||
table_name = model.__name__
|
||||
|
||||
if not self._detect_for_table(table_name):
|
||||
return default
|
||||
|
||||
self.cursor.execute(f"SELECT * FROM {table_name} WHERE {conditions}", args)
|
||||
if row_data := self.cursor.fetchone():
|
||||
data = dict(row_data)
|
||||
return model(**self.convert_to_dict(data))
|
||||
return default
|
||||
|
||||
def all(self, model: type(LiteModel), conditions=None, *args, default: Any = None) -> list[LiteModel] | None:
|
||||
"""查询所有数据
|
||||
|
||||
Args:
|
||||
model: 模型
|
||||
conditions: 查询条件
|
||||
*args: 参数化查询条件参数
|
||||
default: 未查询到结果默认返回值
|
||||
|
||||
Returns: 数据
|
||||
"""
|
||||
table_name = model.__name__
|
||||
|
||||
if not self._detect_for_table(table_name):
|
||||
return default
|
||||
|
||||
if conditions:
|
||||
self.cursor.execute(f"SELECT * FROM {table_name} WHERE {conditions}", args)
|
||||
else:
|
||||
self.cursor.execute(f"SELECT * FROM {table_name}")
|
||||
if row_datas := self.cursor.fetchall():
|
||||
datas = [dict(row_data) for row_data in row_datas]
|
||||
return [model(**self.convert_to_dict(d)) for d in datas] if datas else default
|
||||
return default
|
||||
|
||||
def delete(self, model: type(LiteModel), conditions, *args):
|
||||
"""删除数据
|
||||
|
||||
Args:
|
||||
model: 模型
|
||||
conditions: 查询条件
|
||||
*args: 参数化查询条件参数
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
table_name = model.__name__
|
||||
|
||||
if not self._detect_for_table(table_name):
|
||||
return
|
||||
nonebot.logger.debug(f"DELETE FROM {table_name} WHERE {conditions}")
|
||||
self.cursor.execute(f"DELETE FROM {table_name} WHERE {conditions}", args)
|
||||
self.conn.commit()
|
||||
|
||||
def convert_to_dict(self, data: dict) -> dict:
|
||||
"""将json字符串转换为字典
|
||||
|
||||
Args:
|
||||
data: json字符串
|
||||
|
||||
Returns: 字典
|
||||
"""
|
||||
|
||||
def load(d: BaseIterable) -> BaseIterable:
|
||||
"""递归加载数据,去除前缀"""
|
||||
if isinstance(d, dict):
|
||||
new_d = {}
|
||||
for k, v in d.items():
|
||||
if k.startswith(self.FOREIGNID):
|
||||
new_d[k.replace(self.FOREIGNID, "")] = load(
|
||||
dict(self.cursor.execute(f"SELECT * FROM {v.split(':', 2)[1]} WHERE id = ?", (v.split(":", 2)[2],)).fetchone()))
|
||||
|
||||
elif k.startswith(self.LIST):
|
||||
if v == '': v = '[]'
|
||||
new_d[k.replace(self.LIST, '')] = load(json.loads(v))
|
||||
elif k.startswith(self.DICT):
|
||||
if v == '': v = '{}'
|
||||
new_d[k.replace(self.DICT, '')] = load(json.loads(v))
|
||||
elif k.startswith(self.JSON):
|
||||
if v == '': v = '[]'
|
||||
new_d[k.replace(self.JSON, '')] = load(json.loads(v))
|
||||
else:
|
||||
new_d[k] = v
|
||||
elif isinstance(d, list | tuple | set):
|
||||
new_d = []
|
||||
for i, v in enumerate(d):
|
||||
if isinstance(v, str) and v.startswith(self.ID):
|
||||
new_d.append(load(dict(self.cursor.execute(f'SELECT * FROM {v.split(":", 2)[1]} WHERE id = ?', (v.split(":", 2)[2],)).fetchone())))
|
||||
elif isinstance(v, BaseIterable):
|
||||
new_d.append(load(v))
|
||||
raise ValueError(f"数据模型{table_name}包含不支持的数据类型,字段:{field} 值:{value} 值类型:{type(value)}")
|
||||
if table_name:
|
||||
fields, values = [], []
|
||||
for n_field, n_value in new_obj.items():
|
||||
if n_field not in ["TABLE_NAME", "id"]:
|
||||
fields.append(n_field)
|
||||
values.append(n_value)
|
||||
# 移除TABLE_NAME和id
|
||||
fields = list(fields)
|
||||
values = list(values)
|
||||
if row_id is not None:
|
||||
# 如果 _id 不为空,将 'id' 插入到字段列表的开始
|
||||
fields.insert(0, 'id')
|
||||
# 将 _id 插入到值列表的开始
|
||||
values.insert(0, row_id)
|
||||
fields = ', '.join([f'"{field}"' for field in fields])
|
||||
placeholders = ', '.join('?' for _ in values)
|
||||
self.cursor.execute(f"INSERT OR REPLACE INTO {table_name}({fields}) VALUES ({placeholders})", tuple(values))
|
||||
self.conn.commit()
|
||||
foreign_id = self.cursor.execute("SELECT last_insert_rowid()").fetchone()[0]
|
||||
return f"{self.FOREIGN_KEY_PREFIX}{foreign_id}@{table_name}" # -> FOREIGN_KEY_123456@{table_name} id@{table_name}
|
||||
else:
|
||||
new_d = d
|
||||
return new_d
|
||||
return pickle.dumps(new_obj) # -> bytes
|
||||
elif isinstance(obj, (list, set, tuple)):
|
||||
obj_type = type(obj) # 到时候转回去
|
||||
new_obj = []
|
||||
for item in obj:
|
||||
if isinstance(item, self.ITERABLE_TYPE):
|
||||
new_obj.append(self._save(item))
|
||||
elif isinstance(item, self.BASIC_TYPE):
|
||||
new_obj.append(item)
|
||||
else:
|
||||
raise ValueError(f"数据模型包含不支持的数据类型,值:{item} 值类型:{type(item)}")
|
||||
return pickle.dumps(obj_type(new_obj)) # -> bytes
|
||||
else:
|
||||
raise ValueError(f"数据模型包含不支持的数据类型,值:{obj} 值类型:{type(obj)}")
|
||||
|
||||
return load(data)
|
||||
def _load(self, obj: Any) -> Any:
|
||||
|
||||
if isinstance(obj, dict):
|
||||
|
||||
new_obj = {}
|
||||
|
||||
for field, value in obj.items():
|
||||
|
||||
field: str
|
||||
|
||||
if field.startswith(self.BYTES_PREFIX):
|
||||
|
||||
new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value))
|
||||
|
||||
elif field.startswith(self.FOREIGN_KEY_PREFIX):
|
||||
|
||||
new_obj[field.replace(self.FOREIGN_KEY_PREFIX, "")] = self._load(self._get_foreign_data(value))
|
||||
|
||||
else:
|
||||
new_obj[field] = value
|
||||
return new_obj
|
||||
elif isinstance(obj, (list, set, tuple)):
|
||||
|
||||
print(" - Load as List")
|
||||
|
||||
new_obj = []
|
||||
for item in obj:
|
||||
|
||||
print(" - Loading Item", item)
|
||||
|
||||
if isinstance(item, bytes):
|
||||
|
||||
# 对bytes进行尝试解析,解析失败则返回原始bytes
|
||||
try:
|
||||
new_obj.append(self._load(pickle.loads(item)))
|
||||
except Exception as e:
|
||||
new_obj.append(self._load(item))
|
||||
|
||||
print(" - Load as Bytes | Result:", new_obj[-1])
|
||||
|
||||
elif isinstance(item, str) and item.startswith(self.FOREIGN_KEY_PREFIX):
|
||||
new_obj.append(self._load(self._get_foreign_data(item)))
|
||||
else:
|
||||
new_obj.append(self._load(item))
|
||||
return new_obj
|
||||
else:
|
||||
return obj
|
||||
|
||||
def delete(self, model: LiteModel, condition: str, *args: Any, allow_empty: bool = False):
|
||||
"""
|
||||
删除满足条件的数据
|
||||
Args:
|
||||
allow_empty: 允许空条件删除整个表
|
||||
model:
|
||||
condition:
|
||||
*args:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
table_name = model.TABLE_NAME
|
||||
if not table_name:
|
||||
raise ValueError(f"数据模型{model.__class__.__name__}未提供表名")
|
||||
if not condition and not allow_empty:
|
||||
raise ValueError("删除操作必须提供条件")
|
||||
self.cursor.execute(f"DELETE FROM {table_name} WHERE {condition}", args)
|
||||
|
||||
def auto_migrate(self, *args: LiteModel):
|
||||
|
||||
"""
|
||||
自动迁移模型
|
||||
Args:
|
||||
*args: 模型类实例化对象,支持空默认值,不支持嵌套迁移
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
for model in args:
|
||||
if not model.TABLE_NAME:
|
||||
raise ValueError(f"数据模型{type(model).__name__}未提供表名")
|
||||
|
||||
# 若无则创建表
|
||||
self.cursor.execute(
|
||||
f'CREATE TABLE IF NOT EXISTS "{model.TABLE_NAME}" (id INTEGER PRIMARY KEY AUTOINCREMENT)'
|
||||
)
|
||||
|
||||
# 获取表结构,field -> SqliteType
|
||||
new_structure = {}
|
||||
for n_field, n_value in model.dump(by_alias=True).items():
|
||||
if n_field not in ["TABLE_NAME", "id"]:
|
||||
new_structure[self._get_stored_field_prefix(n_value) + n_field] = self._get_stored_type(n_value)
|
||||
|
||||
# 原有的字段列表
|
||||
existing_structure = dict([(column[1], column[2]) for column in self.cursor.execute(f'PRAGMA table_info({model.TABLE_NAME})').fetchall()])
|
||||
# 检测缺失字段,由于SQLite是动态类型,所以不需要检测类型
|
||||
for n_field, n_type in new_structure.items():
|
||||
if n_field not in existing_structure.keys() and n_field.lower() not in ["id", "table_name"]:
|
||||
print(n_type, self.DEFAULT_MAPPING.get(n_type, ''))
|
||||
self.cursor.execute(
|
||||
f"ALTER TABLE '{model.TABLE_NAME}' ADD COLUMN {n_field} {n_type} DEFAULT {self.DEFAULT_MAPPING.get(n_type, '')}"
|
||||
)
|
||||
|
||||
# 检测多余字段进行删除
|
||||
for e_field in existing_structure.keys():
|
||||
if e_field not in new_structure.keys() and e_field.lower() not in ['id']:
|
||||
self.cursor.execute(
|
||||
f'ALTER TABLE "{model.TABLE_NAME}" DROP COLUMN "{e_field}"'
|
||||
)
|
||||
self.conn.commit()
|
||||
# 已完成
|
||||
|
||||
def _get_stored_field_prefix(self, value) -> str:
|
||||
"""根据类型获取存储字段前缀,一定在后加上字段名
|
||||
* -> ""
|
||||
Args:
|
||||
value: 储存的值
|
||||
|
||||
Returns:
|
||||
Sqlite3存储字段
|
||||
"""
|
||||
|
||||
if isinstance(value, LiteModel) or isinstance(value, dict) and "TABLE_NAME" in value:
|
||||
return self.FOREIGN_KEY_PREFIX
|
||||
elif type(value) in self.ITERABLE_TYPE:
|
||||
return self.BYTES_PREFIX
|
||||
return ""
|
||||
|
||||
def _get_stored_type(self, value) -> str:
|
||||
"""获取存储类型
|
||||
|
||||
Args:
|
||||
value: 储存的值
|
||||
|
||||
Returns:
|
||||
Sqlite3存储类型
|
||||
"""
|
||||
if isinstance(value, dict) and "TABLE_NAME" in value:
|
||||
# 是一个模型字典,储存外键
|
||||
return "INTEGER"
|
||||
return self.TYPE_MAPPING.get(type(value), "TEXT")
|
||||
|
||||
def _get_foreign_data(self, foreign_value: str) -> dict:
|
||||
"""
|
||||
获取外键数据
|
||||
Args:
|
||||
foreign_value:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
foreign_value = foreign_value.replace(self.FOREIGN_KEY_PREFIX, "")
|
||||
table_name = foreign_value.split("@")[-1]
|
||||
foreign_id = foreign_value.split("@")[0]
|
||||
fields = [description[1] for description in self.cursor.execute(f"PRAGMA table_info({table_name})").fetchall()]
|
||||
result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone()
|
||||
return dict(zip(fields, result))
|
||||
|
||||
TYPE_MAPPING = {
|
||||
int : "INTEGER",
|
||||
float : "REAL",
|
||||
str : "TEXT",
|
||||
bool : "INTEGER",
|
||||
bytes : "BLOB",
|
||||
NoneType : "NULL",
|
||||
# dict : "TEXT",
|
||||
# list : "TEXT",
|
||||
# tuple : "TEXT",
|
||||
# set : "TEXT",
|
||||
|
||||
dict : "BLOB", # LITEYUKIDICT{key_name}
|
||||
list : "BLOB", # LITEYUKILIST{key_name}
|
||||
tuple : "BLOB", # LITEYUKITUPLE{key_name}
|
||||
set : "BLOB", # LITEYUKISET{key_name}
|
||||
LiteModel: "TEXT" # FOREIGN_KEY_{table_name}
|
||||
}
|
||||
DEFAULT_MAPPING = {
|
||||
"TEXT" : "''",
|
||||
"INTEGER": 0,
|
||||
"REAL" : 0.0,
|
||||
"BLOB" : b"",
|
||||
"NULL" : None
|
||||
}
|
||||
|
||||
# 基础类型
|
||||
BASIC_TYPE = (int, float, str, bool, bytes, NoneType)
|
||||
# 可序列化类型
|
||||
ITERABLE_TYPE = (dict, list, tuple, set, LiteModel)
|
||||
|
||||
# 外键前缀
|
||||
FOREIGN_KEY_PREFIX = "FOREIGN_KEY_"
|
||||
# 转换为的字节前缀
|
||||
BYTES_PREFIX = "PICKLE_BYTES_"
|
||||
|
||||
|
||||
def check_sqlite_keyword(name):
|
||||
sqlite_keywords = [
|
||||
"ABORT", "ACTION", "ADD", "AFTER", "ALL", "ALTER", "ANALYZE", "AND", "AS", "ASC",
|
||||
"ATTACH", "AUTOINCREMENT", "BEFORE", "BEGIN", "BETWEEN", "BY", "CASCADE", "CASE",
|
||||
"CAST", "CHECK", "COLLATE", "COLUMN", "COMMIT", "CONFLICT", "CONSTRAINT", "CREATE",
|
||||
"CROSS", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP", "DATABASE", "DEFAULT",
|
||||
"DEFERRABLE", "DEFERRED", "DELETE", "DESC", "DETACH", "DISTINCT", "DROP", "EACH",
|
||||
"ELSE", "END", "ESCAPE", "EXCEPT", "EXCLUSIVE", "EXISTS", "EXPLAIN", "FAIL", "FOR",
|
||||
"FOREIGN", "FROM", "FULL", "GLOB", "GROUP", "HAVING", "IF", "IGNORE", "IMMEDIATE",
|
||||
"IN", "INDEX", "INDEXED", "INITIALLY", "INNER", "INSERT", "INSTEAD", "INTERSECT",
|
||||
"INTO", "IS", "ISNULL", "JOIN", "KEY", "LEFT", "LIKE", "LIMIT", "MATCH", "NATURAL",
|
||||
"NO", "NOT", "NOTNULL", "NULL", "OF", "OFFSET", "ON", "OR", "ORDER", "OUTER", "PLAN",
|
||||
"PRAGMA", "PRIMARY", "QUERY", "RAISE", "RECURSIVE", "REFERENCES", "REGEXP", "REINDEX",
|
||||
"RELEASE", "RENAME", "REPLACE", "RESTRICT", "RIGHT", "ROLLBACK", "ROW", "SAVEPOINT",
|
||||
"SELECT", "SET", "TABLE", "TEMP", "TEMPORARY", "THEN", "TO", "TRANSACTION", "TRIGGER",
|
||||
"UNION", "UNIQUE", "UPDATE", "USING", "VACUUM", "VALUES", "VIEW", "VIRTUAL", "WHEN",
|
||||
"WHERE", "WITH", "WITHOUT"
|
||||
]
|
||||
return True
|
||||
# if name.upper() in sqlite_keywords:
|
||||
# raise ValueError(f"'{name}' 是SQLite保留字,不建议使用,请更换名称")
|
||||
|
Reference in New Issue
Block a user