mirror of
				https://github.com/LiteyukiStudio/LiteyukiBot.git
				synced 2025-11-04 05:16:23 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			372 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			372 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						||
import pickle
 | 
						||
import sqlite3
 | 
						||
from types import NoneType
 | 
						||
from typing import Any
 | 
						||
from packaging.version import parse
 | 
						||
 | 
						||
import nonebot
 | 
						||
import pydantic
 | 
						||
from pydantic import BaseModel
 | 
						||
 | 
						||
 | 
						||
class LiteModel(BaseModel):
 | 
						||
    TABLE_NAME: str = None
 | 
						||
    id: int = None
 | 
						||
 | 
						||
    def dump(self, *args, **kwargs):
 | 
						||
        if parse(pydantic.__version__) < parse("2.0.0"):
 | 
						||
            return self.dict(*args, **kwargs)
 | 
						||
        else:
 | 
						||
            return self.model_dump(*args, **kwargs)
 | 
						||
 | 
						||
 | 
						||
class Database:
 | 
						||
    def __init__(self, db_name: str):
 | 
						||
 | 
						||
        if os.path.dirname(db_name) != "" and not os.path.exists(os.path.dirname(db_name)):
 | 
						||
            os.makedirs(os.path.dirname(db_name))
 | 
						||
 | 
						||
        self.db_name = db_name
 | 
						||
        self.conn = sqlite3.connect(db_name)
 | 
						||
        self.cursor = self.conn.cursor()
 | 
						||
 | 
						||
    def where_one(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> LiteModel | Any | None:
 | 
						||
        """查询第一个
 | 
						||
        Args:
 | 
						||
            model: 数据模型实例
 | 
						||
            condition: 查询条件,不给定则查询所有
 | 
						||
            *args: 参数化查询参数
 | 
						||
            default: 默认值
 | 
						||
 | 
						||
        Returns:
 | 
						||
 | 
						||
        """
 | 
						||
        all_results = self.where_all(model, condition, *args)
 | 
						||
        return all_results[0] if all_results else default
 | 
						||
 | 
						||
    def where_all(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> list[LiteModel | Any] | None:
 | 
						||
        """查询所有
 | 
						||
        Args:
 | 
						||
            model: 数据模型实例
 | 
						||
            condition: 查询条件,不给定则查询所有
 | 
						||
            *args: 参数化查询参数
 | 
						||
            default: 默认值
 | 
						||
 | 
						||
        Returns:
 | 
						||
 | 
						||
        """
 | 
						||
        table_name = model.TABLE_NAME
 | 
						||
        model_type = type(model)
 | 
						||
        nonebot.logger.debug(f"Selecting {model.TABLE_NAME} WHERE {condition.replace('?', '%s') % args}")
 | 
						||
        if not table_name:
 | 
						||
            raise ValueError(f"数据模型{model_type.__name__}未提供表名")
 | 
						||
 | 
						||
        # condition = f"WHERE {condition}"
 | 
						||
        # print(f"SELECT * FROM {table_name} {condition}", args)
 | 
						||
        # if len(args) == 0:
 | 
						||
        #     results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}").fetchall()
 | 
						||
        # else:
 | 
						||
        #     results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}", args).fetchall()
 | 
						||
        if condition:
 | 
						||
            results = self.cursor.execute(f"SELECT * FROM {table_name} WHERE {condition}", args).fetchall()
 | 
						||
        else:
 | 
						||
            results = self.cursor.execute(f"SELECT * FROM {table_name}").fetchall()
 | 
						||
        fields = [description[0] for description in self.cursor.description]
 | 
						||
        if not results:
 | 
						||
            return default
 | 
						||
        else:
 | 
						||
            return [model_type(**self._load(dict(zip(fields, result)))) for result in results]
 | 
						||
 | 
						||
    def save(self, *args: LiteModel):
 | 
						||
        """增/改操作
 | 
						||
        Args:
 | 
						||
            *args:
 | 
						||
        Returns:
 | 
						||
        """
 | 
						||
        table_list = [item[0] for item in self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
 | 
						||
        for model in args:
 | 
						||
            nonebot.logger.debug(f"Upserting {model}")
 | 
						||
            if not model.TABLE_NAME:
 | 
						||
                raise ValueError(f"数据模型 {model.__class__.__name__} 未提供表名")
 | 
						||
            elif model.TABLE_NAME not in table_list:
 | 
						||
                raise ValueError(f"数据模型 {model.__class__.__name__} 表 {model.TABLE_NAME} 不存在,请先迁移")
 | 
						||
            else:
 | 
						||
                self._save(model.dump(by_alias=True))
 | 
						||
 | 
						||
    def _save(self, obj: Any) -> Any:
 | 
						||
        # obj = copy.deepcopy(obj)
 | 
						||
        if isinstance(obj, dict):
 | 
						||
            table_name = obj.get("TABLE_NAME")
 | 
						||
            row_id = obj.get("id")
 | 
						||
            new_obj = {}
 | 
						||
            for field, value in obj.items():
 | 
						||
                if isinstance(value, self.ITERABLE_TYPE):
 | 
						||
                    new_obj[self._get_stored_field_prefix(value) + field] = self._save(value)  # self._save(value)  # -> bytes
 | 
						||
                elif isinstance(value, self.BASIC_TYPE):
 | 
						||
                    new_obj[field] = value
 | 
						||
                else:
 | 
						||
                    raise ValueError(f"数据模型{table_name}包含不支持的数据类型,字段:{field} 值:{value} 值类型:{type(value)}")
 | 
						||
            if table_name:
 | 
						||
                fields, values = [], []
 | 
						||
                for n_field, n_value in new_obj.items():
 | 
						||
                    if n_field not in ["TABLE_NAME", "id"]:
 | 
						||
                        fields.append(n_field)
 | 
						||
                        values.append(n_value)
 | 
						||
                # 移除TABLE_NAME和id
 | 
						||
                fields = list(fields)
 | 
						||
                values = list(values)
 | 
						||
                if row_id is not None:
 | 
						||
                    # 如果 _id 不为空,将 'id' 插入到字段列表的开始
 | 
						||
                    fields.insert(0, 'id')
 | 
						||
                    # 将 _id 插入到值列表的开始
 | 
						||
                    values.insert(0, row_id)
 | 
						||
                fields = ', '.join([f'"{field}"' for field in fields])
 | 
						||
                placeholders = ', '.join('?' for _ in values)
 | 
						||
                self.cursor.execute(f"INSERT OR REPLACE INTO {table_name}({fields}) VALUES ({placeholders})", tuple(values))
 | 
						||
                self.conn.commit()
 | 
						||
                foreign_id = self.cursor.execute("SELECT last_insert_rowid()").fetchone()[0]
 | 
						||
                return f"{self.FOREIGN_KEY_PREFIX}{foreign_id}@{table_name}"  # -> FOREIGN_KEY_123456@{table_name} id@{table_name}
 | 
						||
            else:
 | 
						||
                return pickle.dumps(new_obj)  # -> bytes
 | 
						||
        elif isinstance(obj, (list, set, tuple)):
 | 
						||
            obj_type = type(obj)  # 到时候转回去
 | 
						||
            new_obj = []
 | 
						||
            for item in obj:
 | 
						||
                if isinstance(item, self.ITERABLE_TYPE):
 | 
						||
                    new_obj.append(self._save(item))
 | 
						||
                elif isinstance(item, self.BASIC_TYPE):
 | 
						||
                    new_obj.append(item)
 | 
						||
                else:
 | 
						||
                    raise ValueError(f"数据模型包含不支持的数据类型,值:{item} 值类型:{type(item)}")
 | 
						||
            return pickle.dumps(obj_type(new_obj))  # -> bytes
 | 
						||
        else:
 | 
						||
            raise ValueError(f"数据模型包含不支持的数据类型,值:{obj} 值类型:{type(obj)}")
 | 
						||
 | 
						||
    def _load(self, obj: Any) -> Any:
 | 
						||
 | 
						||
        if isinstance(obj, dict):
 | 
						||
 | 
						||
            new_obj = {}
 | 
						||
 | 
						||
            for field, value in obj.items():
 | 
						||
 | 
						||
                field: str
 | 
						||
 | 
						||
                if field.startswith(self.BYTES_PREFIX):
 | 
						||
 | 
						||
                    new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value))
 | 
						||
 | 
						||
                elif field.startswith(self.FOREIGN_KEY_PREFIX):
 | 
						||
 | 
						||
                    new_obj[field.replace(self.FOREIGN_KEY_PREFIX, "")] = self._load(self._get_foreign_data(value))
 | 
						||
 | 
						||
                else:
 | 
						||
                    new_obj[field] = value
 | 
						||
            return new_obj
 | 
						||
        elif isinstance(obj, (list, set, tuple)):
 | 
						||
 | 
						||
            new_obj = []
 | 
						||
            for item in obj:
 | 
						||
 | 
						||
                if isinstance(item, bytes):
 | 
						||
 | 
						||
                    # 对bytes进行尝试解析,解析失败则返回原始bytes
 | 
						||
                    try:
 | 
						||
                        new_obj.append(self._load(pickle.loads(item)))
 | 
						||
                    except Exception as e:
 | 
						||
                        new_obj.append(self._load(item))
 | 
						||
 | 
						||
                elif isinstance(item, str) and item.startswith(self.FOREIGN_KEY_PREFIX):
 | 
						||
                    new_obj.append(self._load(self._get_foreign_data(item)))
 | 
						||
                else:
 | 
						||
                    new_obj.append(self._load(item))
 | 
						||
            return new_obj
 | 
						||
        else:
 | 
						||
            return obj
 | 
						||
 | 
						||
    def delete(self, model: LiteModel, condition: str, *args: Any, allow_empty: bool = False):
 | 
						||
        """
 | 
						||
        删除满足条件的数据
 | 
						||
        Args:
 | 
						||
            allow_empty: 允许空条件删除整个表
 | 
						||
            model:
 | 
						||
            condition:
 | 
						||
            *args:
 | 
						||
 | 
						||
        Returns:
 | 
						||
 | 
						||
        """
 | 
						||
        table_name = model.TABLE_NAME
 | 
						||
        nonebot.logger.debug(f"Deleting {model} WHERE {condition} {args}")
 | 
						||
        if not table_name:
 | 
						||
            raise ValueError(f"数据模型{model.__class__.__name__}未提供表名")
 | 
						||
        if model.id is not None:
 | 
						||
            condition = f"id = {model.id}"
 | 
						||
        if not condition and not allow_empty:
 | 
						||
            raise ValueError("删除操作必须提供条件")
 | 
						||
        self.cursor.execute(f"DELETE FROM {table_name} WHERE {condition}", args)
 | 
						||
        self.conn.commit()
 | 
						||
 | 
						||
    def auto_migrate(self, *args: LiteModel):
 | 
						||
 | 
						||
        """
 | 
						||
        自动迁移模型
 | 
						||
        Args:
 | 
						||
            *args: 模型类实例化对象,支持空默认值,不支持嵌套迁移
 | 
						||
 | 
						||
        Returns:
 | 
						||
 | 
						||
        """
 | 
						||
        for model in args:
 | 
						||
            if not model.TABLE_NAME:
 | 
						||
                raise ValueError(f"数据模型{type(model).__name__}未提供表名")
 | 
						||
 | 
						||
            # 若无则创建表
 | 
						||
            self.cursor.execute(
 | 
						||
                f'CREATE TABLE IF NOT EXISTS "{model.TABLE_NAME}" (id INTEGER PRIMARY KEY AUTOINCREMENT)'
 | 
						||
            )
 | 
						||
 | 
						||
            # 获取表结构,field -> SqliteType
 | 
						||
            new_structure = {}
 | 
						||
            for n_field, n_value in model.dump(by_alias=True).items():
 | 
						||
                if n_field not in ["TABLE_NAME", "id"]:
 | 
						||
                    new_structure[self._get_stored_field_prefix(n_value) + n_field] = self._get_stored_type(n_value)
 | 
						||
 | 
						||
            # 原有的字段列表
 | 
						||
            existing_structure = dict([(column[1], column[2]) for column in self.cursor.execute(f'PRAGMA table_info({model.TABLE_NAME})').fetchall()])
 | 
						||
            # 检测缺失字段,由于SQLite是动态类型,所以不需要检测类型
 | 
						||
            for n_field, n_type in new_structure.items():
 | 
						||
                if n_field not in existing_structure.keys() and n_field.lower() not in ["id", "table_name"]:
 | 
						||
                    default_value = self.DEFAULT_MAPPING.get(n_type, 'NULL')
 | 
						||
                    self.cursor.execute(
 | 
						||
                        f"ALTER TABLE '{model.TABLE_NAME}' ADD COLUMN {n_field} {n_type} DEFAULT {self.DEFAULT_MAPPING.get(n_type, default_value)}"
 | 
						||
                    )
 | 
						||
 | 
						||
            # 检测多余字段进行删除
 | 
						||
            for e_field in existing_structure.keys():
 | 
						||
                if e_field not in new_structure.keys() and e_field.lower() not in ['id']:
 | 
						||
                    self.cursor.execute(
 | 
						||
                        f'ALTER TABLE "{model.TABLE_NAME}" DROP COLUMN "{e_field}"'
 | 
						||
                    )
 | 
						||
        self.conn.commit()
 | 
						||
        # 已完成
 | 
						||
 | 
						||
    def _get_stored_field_prefix(self, value) -> str:
 | 
						||
        """根据类型获取存储字段前缀,一定在后加上字段名
 | 
						||
        * -> ""
 | 
						||
        Args:
 | 
						||
            value: 储存的值
 | 
						||
 | 
						||
        Returns:
 | 
						||
            Sqlite3存储字段
 | 
						||
        """
 | 
						||
 | 
						||
        if isinstance(value, LiteModel) or isinstance(value, dict) and "TABLE_NAME" in value:
 | 
						||
            return self.FOREIGN_KEY_PREFIX
 | 
						||
        elif type(value) in self.ITERABLE_TYPE:
 | 
						||
            return self.BYTES_PREFIX
 | 
						||
        return ""
 | 
						||
 | 
						||
    def _get_stored_type(self, value) -> str:
 | 
						||
        """获取存储类型
 | 
						||
 | 
						||
        Args:
 | 
						||
            value: 储存的值
 | 
						||
 | 
						||
        Returns:
 | 
						||
            Sqlite3存储类型
 | 
						||
        """
 | 
						||
        if isinstance(value, dict) and "TABLE_NAME" in value:
 | 
						||
            # 是一个模型字典,储存外键
 | 
						||
            return "INTEGER"
 | 
						||
        return self.TYPE_MAPPING.get(type(value), "TEXT")
 | 
						||
 | 
						||
    def _get_foreign_data(self, foreign_value: str) -> dict:
 | 
						||
        """
 | 
						||
        获取外键数据
 | 
						||
        Args:
 | 
						||
            foreign_value:
 | 
						||
 | 
						||
        Returns:
 | 
						||
 | 
						||
        """
 | 
						||
        foreign_value = foreign_value.replace(self.FOREIGN_KEY_PREFIX, "")
 | 
						||
        table_name = foreign_value.split("@")[-1]
 | 
						||
        foreign_id = foreign_value.split("@")[0]
 | 
						||
        fields = [description[1] for description in self.cursor.execute(f"PRAGMA table_info({table_name})").fetchall()]
 | 
						||
        result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone()
 | 
						||
        return dict(zip(fields, result))
 | 
						||
 | 
						||
    TYPE_MAPPING = {
 | 
						||
            int      : "INTEGER",
 | 
						||
            float    : "REAL",
 | 
						||
            str      : "TEXT",
 | 
						||
            bool     : "INTEGER",
 | 
						||
            bytes    : "BLOB",
 | 
						||
            NoneType : "NULL",
 | 
						||
            # dict     : "TEXT",
 | 
						||
            # list     : "TEXT",
 | 
						||
            # tuple    : "TEXT",
 | 
						||
            # set      : "TEXT",
 | 
						||
 | 
						||
            dict     : "BLOB",  # LITEYUKIDICT{key_name}
 | 
						||
            list     : "BLOB",  # LITEYUKILIST{key_name}
 | 
						||
            tuple    : "BLOB",  # LITEYUKITUPLE{key_name}
 | 
						||
            set      : "BLOB",  # LITEYUKISET{key_name}
 | 
						||
            LiteModel: "TEXT"  # FOREIGN_KEY_{table_name}
 | 
						||
    }
 | 
						||
    DEFAULT_MAPPING = {
 | 
						||
            "TEXT"   : "''",
 | 
						||
            "INTEGER": 0,
 | 
						||
            "REAL"   : 0.0,
 | 
						||
            "BLOB"   : None,
 | 
						||
            "NULL"   : None
 | 
						||
    }
 | 
						||
 | 
						||
    # 基础类型
 | 
						||
    BASIC_TYPE = (int, float, str, bool, bytes, NoneType)
 | 
						||
    # 可序列化类型
 | 
						||
    ITERABLE_TYPE = (dict, list, tuple, set, LiteModel)
 | 
						||
 | 
						||
    # 外键前缀
 | 
						||
    FOREIGN_KEY_PREFIX = "FOREIGN_KEY_"
 | 
						||
    # 转换为的字节前缀
 | 
						||
    BYTES_PREFIX = "PICKLE_BYTES_"
 | 
						||
 | 
						||
    # transaction tx 事务操作
 | 
						||
    def first(self, model: LiteModel) -> "Database":
 | 
						||
        pass
 | 
						||
 | 
						||
    def where(self, condition: str, *args) -> "Database":
 | 
						||
        pass
 | 
						||
 | 
						||
    def limit(self, limit: int) -> "Database":
 | 
						||
        pass
 | 
						||
 | 
						||
    def order(self, order: str) -> "Database":
 | 
						||
        pass
 | 
						||
 | 
						||
 | 
						||
def check_sqlite_keyword(name):
 | 
						||
    sqlite_keywords = [
 | 
						||
            "ABORT", "ACTION", "ADD", "AFTER", "ALL", "ALTER", "ANALYZE", "AND", "AS", "ASC",
 | 
						||
            "ATTACH", "AUTOINCREMENT", "BEFORE", "BEGIN", "BETWEEN", "BY", "CASCADE", "CASE",
 | 
						||
            "CAST", "CHECK", "COLLATE", "COLUMN", "COMMIT", "CONFLICT", "CONSTRAINT", "CREATE",
 | 
						||
            "CROSS", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP", "DATABASE", "DEFAULT",
 | 
						||
            "DEFERRABLE", "DEFERRED", "DELETE", "DESC", "DETACH", "DISTINCT", "DROP", "EACH",
 | 
						||
            "ELSE", "END", "ESCAPE", "EXCEPT", "EXCLUSIVE", "EXISTS", "EXPLAIN", "FAIL", "FOR",
 | 
						||
            "FOREIGN", "FROM", "FULL", "GLOB", "GROUP", "HAVING", "IF", "IGNORE", "IMMEDIATE",
 | 
						||
            "IN", "INDEX", "INDEXED", "INITIALLY", "INNER", "INSERT", "INSTEAD", "INTERSECT",
 | 
						||
            "INTO", "IS", "ISNULL", "JOIN", "KEY", "LEFT", "LIKE", "LIMIT", "MATCH", "NATURAL",
 | 
						||
            "NO", "NOT", "NOTNULL", "NULL", "OF", "OFFSET", "ON", "OR", "ORDER", "OUTER", "PLAN",
 | 
						||
            "PRAGMA", "PRIMARY", "QUERY", "RAISE", "RECURSIVE", "REFERENCES", "REGEXP", "REINDEX",
 | 
						||
            "RELEASE", "RENAME", "REPLACE", "RESTRICT", "RIGHT", "ROLLBACK", "ROW", "SAVEPOINT",
 | 
						||
            "SELECT", "SET", "TABLE", "TEMP", "TEMPORARY", "THEN", "TO", "TRANSACTION", "TRIGGER",
 | 
						||
            "UNION", "UNIQUE", "UPDATE", "USING", "VACUUM", "VALUES", "VIEW", "VIRTUAL", "WHEN",
 | 
						||
            "WHERE", "WITH", "WITHOUT"
 | 
						||
    ]
 | 
						||
    return True
 | 
						||
    # if name.upper() in sqlite_keywords:
 | 
						||
    #     raise ValueError(f"'{name}' 是SQLite保留字,不建议使用,请更换名称")
 |