fix: 数据库支持

This commit is contained in:
2024-03-26 21:33:40 +08:00
parent 90c9ef31a1
commit 58e603e1ad
14 changed files with 528 additions and 728 deletions

View File

@ -1,374 +1,358 @@
import json
import os
import pickle
import sqlite3
import types
from abc import ABC
from collections.abc import Iterable
import nonebot
from pydantic import BaseModel
from types import NoneType
from typing import Any
BaseIterable = list | tuple | set | dict
import nonebot
import pydantic
from pydantic import BaseModel
class LiteModel(BaseModel):
"""轻量级模型基类
类型注解统一使用Python3.9的PEP585标准如需使用泛型请使用typing模块的泛型类型
"""
TABLE_NAME: str = None
id: int = None
class BaseORMAdapter(ABC):
def __init__(self):
pass
def auto_migrate(self, *args, **kwargs):
"""自动迁移
Returns:
"""
raise NotImplementedError
def upsert(self, *args, **kwargs):
"""存储数据
Returns:
"""
raise NotImplementedError
def first(self, *args, **kwargs):
"""查询第一条数据
Returns:
"""
raise NotImplementedError
def all(self, *args, **kwargs):
"""查询所有数据
Returns:
"""
raise NotImplementedError
def delete(self, *args, **kwargs):
"""删除数据
Returns:
"""
raise NotImplementedError
def update(self, *args, **kwargs):
"""更新数据
Returns:
"""
raise NotImplementedError
def dump(self, *args, **kwargs):
if pydantic.__version__ < "1.8.2":
return self.dict(by_alias=True)
else:
return self.model_dump(by_alias=True)
class Database(BaseORMAdapter):
"""SQLiteORM适配器严禁使用`FORIEGNID`和`JSON`作为主键前缀,严禁使用`$ID:`作为字符串值前缀
Attributes:
"""
type_map = {
# default: TEXT
str : 'TEXT',
int : 'INTEGER',
float: 'REAL',
bool : 'INTEGER',
list : 'TEXT'
}
DEFAULT_VALUE = {
'TEXT' : '',
'INTEGER': 0,
'REAL' : 0.0
}
FOREIGNID = 'FOREIGNID'
JSON = 'JSON'
LIST = 'LIST'
DICT = 'DICT'
ID = '$ID'
class Database:
def __init__(self, db_name: str):
super().__init__()
if not os.path.exists(os.path.dirname(db_name)):
if os.path.dirname(db_name) != "" and not os.path.exists(os.path.dirname(db_name)):
os.makedirs(os.path.dirname(db_name))
self.db_name = db_name
self.conn = sqlite3.connect(db_name)
self.conn.row_factory = sqlite3.Row
self.cursor = self.conn.cursor()
def auto_migrate(self, *args: type(LiteModel)):
"""自动迁移,检测新模型字段和原有表字段的差异,如有差异自动增删新字段
def first(self, model: LiteModel, condition: str, *args: Any, default: Any = None) -> LiteModel | Any | None:
"""查询第一个
Args:
*args: 模型类
model: 数据模型实例
condition: 查询条件,不给定则查询所有
*args: 参数化查询参数
default: 默认值
Returns:
"""
table_name = ''
all_results = self.all(model, condition, *args)
return all_results[0] if all_results else default
def all(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> list[LiteModel | Any] | None:
"""查询所有
Args:
model: 数据模型实例
condition: 查询条件,不给定则查询所有
*args: 参数化查询参数
default: 默认值
Returns:
"""
table_name = model.TABLE_NAME
model_type = type(model)
if not table_name:
raise ValueError(f"数据模型{model_type.__name__}未提供表名")
# condition = f"WHERE {condition}"
# print(f"SELECT * FROM {table_name} {condition}", args)
# if len(args) == 0:
# results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}").fetchall()
# else:
# results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}", args).fetchall()
if condition:
results = self.cursor.execute(f"SELECT * FROM {table_name} WHERE {condition}", args).fetchall()
else:
results = self.cursor.execute(f"SELECT * FROM {table_name}").fetchall()
fields = [description[0] for description in self.cursor.description]
if not results:
return default
else:
return [model_type(**self._load(dict(zip(fields, result)))) for result in results]
def upsert(self, *args: LiteModel):
"""增/改操作
Args:
*args:
Returns:
"""
table_list = [item[0] for item in self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
for model in args:
model: type(LiteModel)
# 检测并创建表若模型未定义id字段则使用自增主键有定义的话使用id字段且id有可能为字符串
table_name = model.__name__
if 'id' in model.__annotations__ and model.__annotations__['id'] is not None:
# 如果模型定义了id字段那么使用模型的id字段
id_type = self.type_map.get(model.__annotations__['id'], 'TEXT')
self.cursor.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id {id_type} PRIMARY KEY)')
if not model.TABLE_NAME:
raise ValueError(f"数据模型 {model.__class__.__name__} 未提供表名")
elif model.TABLE_NAME not in table_list:
raise ValueError(f"数据模型 {model.__class__.__name__} 的表 {model.TABLE_NAME} 不存在,请先迁移")
else:
# 如果模型未定义id字段那么使用自增主键
self.cursor.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER PRIMARY KEY AUTOINCREMENT)')
# 获取表字段
self.cursor.execute(f'PRAGMA table_info({table_name})')
table_fields = self.cursor.fetchall()
table_fields = [field[1] for field in table_fields]
self._save(model.dump(by_alias=True))
raw_fields, raw_types = zip(*model.__annotations__.items())
# 获取模型字段若有模型则添加FOREIGNID前缀若为BaseIterable则添加JSON前缀用多行if判断
model_fields = []
model_types = []
for field, r_type in zip(raw_fields, raw_types):
if isinstance(r_type, type(LiteModel)):
model_fields.append(f'{self.FOREIGNID}{field}')
model_types.append('TEXT')
elif r_type in [list[str], list[int], list[float], list[bool], list]:
model_fields.append(f'{self.LIST}{field}')
model_types.append('TEXT')
elif r_type in [dict[str, str], dict[str, int], dict[str, float], dict[str, bool], dict]:
model_fields.append(f'{self.DICT}{field}')
model_types.append('TEXT')
elif isinstance(r_type, types.GenericAlias):
model_fields.append(f'{self.JSON}{field}')
model_types.append('TEXT')
def _save(self, obj: Any) -> Any:
# obj = copy.deepcopy(obj)
if isinstance(obj, dict):
table_name = obj.get("TABLE_NAME")
row_id = obj.get("id")
new_obj = {}
for field, value in obj.items():
if isinstance(value, self.ITERABLE_TYPE):
new_obj[self._get_stored_field_prefix(value) + field] = self._save(value) # self._save(value) # -> bytes
elif isinstance(value, self.BASIC_TYPE):
new_obj[field] = value
else:
model_fields.append(field)
model_types.append(self.type_map.get(r_type, 'TEXT'))
# 检测新字段或字段类型是否有变化,有则增删字段,已经加了前缀类型
for field_changed, type_, r_type in zip(model_fields, model_types, raw_types):
if field_changed not in table_fields:
nonebot.logger.debug(f'ALTER TABLE {table_name} ADD COLUMN {field_changed} {type_}')
self.cursor.execute(f'ALTER TABLE {table_name} ADD COLUMN {field_changed} {type_}')
# 在原有的行中添加新字段对应类型的默认值从DEFAULT_TYPE中获取
self.cursor.execute(f'UPDATE {table_name} SET {field_changed} = ? WHERE {field_changed} IS NULL', (self.DEFAULT_VALUE.get(type_, ""),))
# 检测多余字段除了id字段
for field in table_fields:
if field not in model_fields and field != 'id':
nonebot.logger.debug(f'ALTER TABLE {table_name} DROP COLUMN {field}')
self.cursor.execute(f'ALTER TABLE {table_name} DROP COLUMN {field}')
self.conn.commit()
nonebot.logger.debug(f'Table {table_name} migrated successfully')
def upsert(self, *models: LiteModel) -> int | tuple:
"""存储数据检查id字段如果有id字段则更新没有则插入
Args:
models: 数据
Returns:
id: 数据id如果有多个数据则返回id元组
"""
ids = []
for model in models:
table_name = model.__class__.__name__
if not self._detect_for_table(table_name):
raise ValueError(f'{table_name}不存在,请先迁移')
key_list = []
value_list = []
# 处理外键,添加前缀'$IDFieldName'
for field, value in model.__dict__.items():
if isinstance(value, LiteModel):
key_list.append(f'{self.FOREIGNID}{field}')
value_list.append(f'{self.ID}:{value.__class__.__name__}:{self.upsert(value)}')
elif isinstance(value, list):
key_list.append(f'{self.LIST}{field}')
value_list.append(self._flat(value))
elif isinstance(value, dict):
key_list.append(f'{self.DICT}{field}')
value_list.append(self._flat(value))
elif isinstance(value, BaseIterable):
key_list.append(f'{self.JSON}{field}')
value_list.append(self._flat(value))
else:
key_list.append(field)
value_list.append(value)
# 更新或插入数据,用?占位
nonebot.logger.debug(f'INSERT OR REPLACE INTO {table_name} ({",".join(key_list)}) VALUES ({",".join(["?" for _ in key_list])})')
self.cursor.execute(f'INSERT OR REPLACE INTO {table_name} ({",".join(key_list)}) VALUES ({",".join(["?" for _ in key_list])})', value_list)
ids.append(self.cursor.lastrowid)
self.conn.commit()
return ids[0] if len(ids) == 1 else tuple(ids)
def _flat(self, data: Iterable) -> str:
"""扁平化数据,返回扁平化对象
Args:
data: 数据,可迭代对象
Returns: json字符串
"""
if isinstance(data, dict):
return_data = {}
for k, v in data.items():
if isinstance(v, LiteModel):
return_data[f"{self.FOREIGNID}{k}"] = f"{self.ID}:{v.__class__.__name__}:{self.upsert(v)}"
elif isinstance(v, list):
return_data[f"{self.LIST}{k}"] = self._flat(v)
elif isinstance(v, dict):
return_data[f"{self.DICT}{k}"] = self._flat(v)
elif isinstance(v, BaseIterable):
return_data[f"{self.JSON}{k}"] = self._flat(v)
else:
return_data[k] = v
elif isinstance(data, list | tuple | set):
return_data = []
for v in data:
if isinstance(v, LiteModel):
return_data.append(f"{self.ID}:{v.__class__.__name__}:{self.upsert(v)}")
elif isinstance(v, list):
return_data.append(self._flat(v))
elif isinstance(v, dict):
return_data.append(self._flat(v))
elif isinstance(v, BaseIterable):
return_data.append(self._flat(v))
else:
return_data.append(v)
else:
raise ValueError("数据类型错误")
return json.dumps(return_data)
def _detect_for_table(self, table_name: str) -> bool:
"""在进行增删查改前检测表是否存在
Args:
table_name: 表名
Returns:
"""
return self.cursor.execute(f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = ?", (table_name,)).fetchone()
def first(self, model: type(LiteModel), conditions, *args, default: Any = None) -> LiteModel | None:
"""查询第一条数据
Args:
model: 模型
conditions: 查询条件
*args: 参数化查询条件参数
default: 未查询到结果默认返回值
Returns: 数据
"""
table_name = model.__name__
if not self._detect_for_table(table_name):
return default
self.cursor.execute(f"SELECT * FROM {table_name} WHERE {conditions}", args)
if row_data := self.cursor.fetchone():
data = dict(row_data)
return model(**self.convert_to_dict(data))
return default
def all(self, model: type(LiteModel), conditions=None, *args, default: Any = None) -> list[LiteModel] | None:
"""查询所有数据
Args:
model: 模型
conditions: 查询条件
*args: 参数化查询条件参数
default: 未查询到结果默认返回值
Returns: 数据
"""
table_name = model.__name__
if not self._detect_for_table(table_name):
return default
if conditions:
self.cursor.execute(f"SELECT * FROM {table_name} WHERE {conditions}", args)
else:
self.cursor.execute(f"SELECT * FROM {table_name}")
if row_datas := self.cursor.fetchall():
datas = [dict(row_data) for row_data in row_datas]
return [model(**self.convert_to_dict(d)) for d in datas] if datas else default
return default
def delete(self, model: type(LiteModel), conditions, *args):
"""删除数据
Args:
model: 模型
conditions: 查询条件
*args: 参数化查询条件参数
Returns:
"""
table_name = model.__name__
if not self._detect_for_table(table_name):
return
nonebot.logger.debug(f"DELETE FROM {table_name} WHERE {conditions}")
self.cursor.execute(f"DELETE FROM {table_name} WHERE {conditions}", args)
self.conn.commit()
def convert_to_dict(self, data: dict) -> dict:
"""将json字符串转换为字典
Args:
data: json字符串
Returns: 字典
"""
def load(d: BaseIterable) -> BaseIterable:
"""递归加载数据,去除前缀"""
if isinstance(d, dict):
new_d = {}
for k, v in d.items():
if k.startswith(self.FOREIGNID):
new_d[k.replace(self.FOREIGNID, "")] = load(
dict(self.cursor.execute(f"SELECT * FROM {v.split(':', 2)[1]} WHERE id = ?", (v.split(":", 2)[2],)).fetchone()))
elif k.startswith(self.LIST):
if v == '': v = '[]'
new_d[k.replace(self.LIST, '')] = load(json.loads(v))
elif k.startswith(self.DICT):
if v == '': v = '{}'
new_d[k.replace(self.DICT, '')] = load(json.loads(v))
elif k.startswith(self.JSON):
if v == '': v = '[]'
new_d[k.replace(self.JSON, '')] = load(json.loads(v))
else:
new_d[k] = v
elif isinstance(d, list | tuple | set):
new_d = []
for i, v in enumerate(d):
if isinstance(v, str) and v.startswith(self.ID):
new_d.append(load(dict(self.cursor.execute(f'SELECT * FROM {v.split(":", 2)[1]} WHERE id = ?', (v.split(":", 2)[2],)).fetchone())))
elif isinstance(v, BaseIterable):
new_d.append(load(v))
raise ValueError(f"数据模型{table_name}包含不支持的数据类型,字段:{field} 值:{value} 值类型:{type(value)}")
if table_name:
fields, values = [], []
for n_field, n_value in new_obj.items():
if n_field not in ["TABLE_NAME", "id"]:
fields.append(n_field)
values.append(n_value)
# 移除TABLE_NAME和id
fields = list(fields)
values = list(values)
if row_id is not None:
# 如果 _id 不为空,将 'id' 插入到字段列表的开始
fields.insert(0, 'id')
# 将 _id 插入到值列表的开始
values.insert(0, row_id)
fields = ', '.join([f'"{field}"' for field in fields])
placeholders = ', '.join('?' for _ in values)
self.cursor.execute(f"INSERT OR REPLACE INTO {table_name}({fields}) VALUES ({placeholders})", tuple(values))
self.conn.commit()
foreign_id = self.cursor.execute("SELECT last_insert_rowid()").fetchone()[0]
return f"{self.FOREIGN_KEY_PREFIX}{foreign_id}@{table_name}" # -> FOREIGN_KEY_123456@{table_name} id@{table_name}
else:
new_d = d
return new_d
return pickle.dumps(new_obj) # -> bytes
elif isinstance(obj, (list, set, tuple)):
obj_type = type(obj) # 到时候转回去
new_obj = []
for item in obj:
if isinstance(item, self.ITERABLE_TYPE):
new_obj.append(self._save(item))
elif isinstance(item, self.BASIC_TYPE):
new_obj.append(item)
else:
raise ValueError(f"数据模型包含不支持的数据类型,值:{item} 值类型:{type(item)}")
return pickle.dumps(obj_type(new_obj)) # -> bytes
else:
raise ValueError(f"数据模型包含不支持的数据类型,值:{obj} 值类型:{type(obj)}")
return load(data)
def _load(self, obj: Any) -> Any:
if isinstance(obj, dict):
new_obj = {}
for field, value in obj.items():
field: str
if field.startswith(self.BYTES_PREFIX):
new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value))
elif field.startswith(self.FOREIGN_KEY_PREFIX):
new_obj[field.replace(self.FOREIGN_KEY_PREFIX, "")] = self._load(self._get_foreign_data(value))
else:
new_obj[field] = value
return new_obj
elif isinstance(obj, (list, set, tuple)):
print(" - Load as List")
new_obj = []
for item in obj:
print(" - Loading Item", item)
if isinstance(item, bytes):
# 对bytes进行尝试解析解析失败则返回原始bytes
try:
new_obj.append(self._load(pickle.loads(item)))
except Exception as e:
new_obj.append(self._load(item))
print(" - Load as Bytes | Result:", new_obj[-1])
elif isinstance(item, str) and item.startswith(self.FOREIGN_KEY_PREFIX):
new_obj.append(self._load(self._get_foreign_data(item)))
else:
new_obj.append(self._load(item))
return new_obj
else:
return obj
def delete(self, model: LiteModel, condition: str, *args: Any, allow_empty: bool = False):
"""
删除满足条件的数据
Args:
allow_empty: 允许空条件删除整个表
model:
condition:
*args:
Returns:
"""
table_name = model.TABLE_NAME
if not table_name:
raise ValueError(f"数据模型{model.__class__.__name__}未提供表名")
if not condition and not allow_empty:
raise ValueError("删除操作必须提供条件")
self.cursor.execute(f"DELETE FROM {table_name} WHERE {condition}", args)
def auto_migrate(self, *args: LiteModel):
"""
自动迁移模型
Args:
*args: 模型类实例化对象,支持空默认值,不支持嵌套迁移
Returns:
"""
for model in args:
if not model.TABLE_NAME:
raise ValueError(f"数据模型{type(model).__name__}未提供表名")
# 若无则创建表
self.cursor.execute(
f'CREATE TABLE IF NOT EXISTS "{model.TABLE_NAME}" (id INTEGER PRIMARY KEY AUTOINCREMENT)'
)
# 获取表结构,field -> SqliteType
new_structure = {}
for n_field, n_value in model.dump(by_alias=True).items():
if n_field not in ["TABLE_NAME", "id"]:
new_structure[self._get_stored_field_prefix(n_value) + n_field] = self._get_stored_type(n_value)
# 原有的字段列表
existing_structure = dict([(column[1], column[2]) for column in self.cursor.execute(f'PRAGMA table_info({model.TABLE_NAME})').fetchall()])
# 检测缺失字段由于SQLite是动态类型所以不需要检测类型
for n_field, n_type in new_structure.items():
if n_field not in existing_structure.keys() and n_field.lower() not in ["id", "table_name"]:
print(n_type, self.DEFAULT_MAPPING.get(n_type, ''))
self.cursor.execute(
f"ALTER TABLE '{model.TABLE_NAME}' ADD COLUMN {n_field} {n_type} DEFAULT {self.DEFAULT_MAPPING.get(n_type, '')}"
)
# 检测多余字段进行删除
for e_field in existing_structure.keys():
if e_field not in new_structure.keys() and e_field.lower() not in ['id']:
self.cursor.execute(
f'ALTER TABLE "{model.TABLE_NAME}" DROP COLUMN "{e_field}"'
)
self.conn.commit()
# 已完成
def _get_stored_field_prefix(self, value) -> str:
"""根据类型获取存储字段前缀,一定在后加上字段名
* -> ""
Args:
value: 储存的值
Returns:
Sqlite3存储字段
"""
if isinstance(value, LiteModel) or isinstance(value, dict) and "TABLE_NAME" in value:
return self.FOREIGN_KEY_PREFIX
elif type(value) in self.ITERABLE_TYPE:
return self.BYTES_PREFIX
return ""
def _get_stored_type(self, value) -> str:
"""获取存储类型
Args:
value: 储存的值
Returns:
Sqlite3存储类型
"""
if isinstance(value, dict) and "TABLE_NAME" in value:
# 是一个模型字典,储存外键
return "INTEGER"
return self.TYPE_MAPPING.get(type(value), "TEXT")
def _get_foreign_data(self, foreign_value: str) -> dict:
"""
获取外键数据
Args:
foreign_value:
Returns:
"""
foreign_value = foreign_value.replace(self.FOREIGN_KEY_PREFIX, "")
table_name = foreign_value.split("@")[-1]
foreign_id = foreign_value.split("@")[0]
fields = [description[1] for description in self.cursor.execute(f"PRAGMA table_info({table_name})").fetchall()]
result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone()
return dict(zip(fields, result))
TYPE_MAPPING = {
int : "INTEGER",
float : "REAL",
str : "TEXT",
bool : "INTEGER",
bytes : "BLOB",
NoneType : "NULL",
# dict : "TEXT",
# list : "TEXT",
# tuple : "TEXT",
# set : "TEXT",
dict : "BLOB", # LITEYUKIDICT{key_name}
list : "BLOB", # LITEYUKILIST{key_name}
tuple : "BLOB", # LITEYUKITUPLE{key_name}
set : "BLOB", # LITEYUKISET{key_name}
LiteModel: "TEXT" # FOREIGN_KEY_{table_name}
}
DEFAULT_MAPPING = {
"TEXT" : "''",
"INTEGER": 0,
"REAL" : 0.0,
"BLOB" : b"",
"NULL" : None
}
# 基础类型
BASIC_TYPE = (int, float, str, bool, bytes, NoneType)
# 可序列化类型
ITERABLE_TYPE = (dict, list, tuple, set, LiteModel)
# 外键前缀
FOREIGN_KEY_PREFIX = "FOREIGN_KEY_"
# 转换为的字节前缀
BYTES_PREFIX = "PICKLE_BYTES_"
def check_sqlite_keyword(name):
sqlite_keywords = [
"ABORT", "ACTION", "ADD", "AFTER", "ALL", "ALTER", "ANALYZE", "AND", "AS", "ASC",
"ATTACH", "AUTOINCREMENT", "BEFORE", "BEGIN", "BETWEEN", "BY", "CASCADE", "CASE",
"CAST", "CHECK", "COLLATE", "COLUMN", "COMMIT", "CONFLICT", "CONSTRAINT", "CREATE",
"CROSS", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP", "DATABASE", "DEFAULT",
"DEFERRABLE", "DEFERRED", "DELETE", "DESC", "DETACH", "DISTINCT", "DROP", "EACH",
"ELSE", "END", "ESCAPE", "EXCEPT", "EXCLUSIVE", "EXISTS", "EXPLAIN", "FAIL", "FOR",
"FOREIGN", "FROM", "FULL", "GLOB", "GROUP", "HAVING", "IF", "IGNORE", "IMMEDIATE",
"IN", "INDEX", "INDEXED", "INITIALLY", "INNER", "INSERT", "INSTEAD", "INTERSECT",
"INTO", "IS", "ISNULL", "JOIN", "KEY", "LEFT", "LIKE", "LIMIT", "MATCH", "NATURAL",
"NO", "NOT", "NOTNULL", "NULL", "OF", "OFFSET", "ON", "OR", "ORDER", "OUTER", "PLAN",
"PRAGMA", "PRIMARY", "QUERY", "RAISE", "RECURSIVE", "REFERENCES", "REGEXP", "REINDEX",
"RELEASE", "RENAME", "REPLACE", "RESTRICT", "RIGHT", "ROLLBACK", "ROW", "SAVEPOINT",
"SELECT", "SET", "TABLE", "TEMP", "TEMPORARY", "THEN", "TO", "TRANSACTION", "TRIGGER",
"UNION", "UNIQUE", "UPDATE", "USING", "VACUUM", "VALUES", "VIEW", "VIRTUAL", "WHEN",
"WHERE", "WITH", "WITHOUT"
]
return True
# if name.upper() in sqlite_keywords:
# raise ValueError(f"'{name}' 是SQLite保留字不建议使用请更换名称")