mirror of
https://github.com/LiteyukiStudio/LiteyukiBot.git
synced 2025-09-05 19:26:24 +00:00
✨ 分离magicoca
和croterline
This commit is contained in:
@ -1,53 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/8/11 下午5:24
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : __init__.py.py
|
||||
@Software: PyCharm
|
||||
"""
|
||||
|
||||
import nonebot
|
||||
from liteyuki.utils import IS_MAIN_PROCESS
|
||||
from liteyuki.plugin import PluginMetadata, PluginType
|
||||
from .nb_utils import adapter_manager, driver_manager # type: ignore
|
||||
from liteyuki.log import logger
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="NoneBot2启动器",
|
||||
type=PluginType.APPLICATION,
|
||||
)
|
||||
|
||||
|
||||
def nb_run(*args, **kwargs):
|
||||
"""
|
||||
初始化NoneBot并运行在子进程
|
||||
Args:
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
"""
|
||||
# 给子进程传递通道对象
|
||||
kwargs.update(kwargs.get("nonebot", {})) # nonebot配置优先
|
||||
nonebot.init(**kwargs)
|
||||
|
||||
driver_manager.init(config=kwargs)
|
||||
adapter_manager.init(kwargs)
|
||||
adapter_manager.register()
|
||||
|
||||
try:
|
||||
# nonebot.load_plugin("nonebot-plugin-lnpm") # 尝试加载轻雪NoneBot插件加载器(Nonebot插件)
|
||||
nonebot.load_plugin("src.liteyuki_main") # 尝试加载轻雪主插件(Nonebot插件)
|
||||
except Exception as e:
|
||||
pass
|
||||
nonebot.run()
|
||||
|
||||
|
||||
if IS_MAIN_PROCESS:
|
||||
from liteyuki import get_bot
|
||||
from .dev_reloader import *
|
||||
|
||||
liteyuki = get_bot()
|
||||
liteyuki.process_manager.add_target(name="nonebot", target=nb_run, args=(), kwargs=liteyuki.config)
|
@ -1,24 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
NoneBot 开发环境重载监视器
|
||||
"""
|
||||
import os.path
|
||||
|
||||
from liteyuki.dev import observer
|
||||
from liteyuki import get_bot, logger
|
||||
from liteyuki.utils import IS_MAIN_PROCESS
|
||||
from watchdog.events import FileSystemEvent
|
||||
|
||||
|
||||
liteyuki = get_bot()
|
||||
|
||||
exclude_extensions = (".pyc", ".pyo")
|
||||
|
||||
|
||||
@observer.on_file_system_event(
|
||||
directories=("src/nonebot_plugins",),
|
||||
event_filter=lambda event: not event.src_path.endswith(exclude_extensions) and ("__pycache__" not in event.src_path ) and os.path.isfile(event.src_path)
|
||||
)
|
||||
def restart_nonebot_process(event: FileSystemEvent):
|
||||
logger.debug(f"File {event.src_path} changed, reloading nonebot...")
|
||||
liteyuki.restart_process("nonebot")
|
33
src/liteyuki_plugins/nonebot/__init__.py
Normal file
33
src/liteyuki_plugins/nonebot/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
import os.path
|
||||
from pathlib import Path
|
||||
|
||||
import nonebot
|
||||
from croterline.utils import IsMainProcess
|
||||
|
||||
from liteyuki import get_bot
|
||||
from liteyuki.core import sub_process_manager
|
||||
from liteyuki.plugin import PluginMetadata, PluginType
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="NoneBot2启动器",
|
||||
type=PluginType.APPLICATION,
|
||||
)
|
||||
|
||||
|
||||
def nb_run(*args, **kwargs):
|
||||
nonebot.init(**kwargs)
|
||||
|
||||
from .nb_utils import driver_manager, adapter_manager
|
||||
driver_manager.init(config=kwargs)
|
||||
adapter_manager.init(kwargs)
|
||||
adapter_manager.register()
|
||||
nonebot.load_plugin(Path(os.path.dirname(__file__)) / "np_main")
|
||||
nonebot.run()
|
||||
|
||||
|
||||
if IsMainProcess:
|
||||
from .dev_reloader import *
|
||||
bot = get_bot()
|
||||
sub_process_manager.add(
|
||||
name="nonebot", func=nb_run, **bot.config.get("nonebot", {})
|
||||
)
|
@ -10,7 +10,7 @@ from .common import MessageEventModel, msg_db
|
||||
from src.utils.base.language import Language
|
||||
from src.utils.base.resource import get_path
|
||||
from src.utils.message.string_tool import convert_seconds_to_time
|
||||
from ...utils.external.logo import get_group_icon, get_user_icon
|
||||
from src.utils.external.logo import get_group_icon, get_user_icon
|
||||
|
||||
|
||||
async def count_msg_by_bot_id(bot_id: str) -> int:
|
@ -15,7 +15,7 @@ __plugin_meta__ = PluginMetadata(
|
||||
}
|
||||
)
|
||||
|
||||
from ...utils.base.data_manager import set_memory_data
|
||||
from src.utils.base.data_manager import set_memory_data
|
||||
|
||||
driver = get_driver()
|
||||
|
@ -3,8 +3,8 @@ import aiohttp
|
||||
from .qw_models import *
|
||||
import httpx
|
||||
|
||||
from ...utils.base.data_manager import get_memory_data
|
||||
from ...utils.base.language import Language
|
||||
from src.utils.base.data_manager import get_memory_data
|
||||
from src.utils.base.language import Language
|
||||
|
||||
dev_url = "https://devapi.qweather.com/" # 开发HBa
|
||||
com_url = "https://api.qweather.com/" # 正式环境
|
@ -1,20 +1,20 @@
|
||||
from nonebot.plugin import PluginMetadata
|
||||
|
||||
from .core import *
|
||||
from .loader import *
|
||||
__author__ = "snowykami"
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="轻雪核心插件",
|
||||
description="轻雪主程序插件,包含了许多初始化的功能",
|
||||
usage="",
|
||||
homepage="https://github.com/snowykami/LiteyukiBot",
|
||||
extra={
|
||||
"liteyuki" : True,
|
||||
"toggleable": False,
|
||||
}
|
||||
)
|
||||
|
||||
from ..utils.base.language import Language, get_default_lang_code
|
||||
|
||||
sys_lang = Language(get_default_lang_code())
|
||||
from nonebot.plugin import PluginMetadata
|
||||
|
||||
from .core import *
|
||||
from .loader import *
|
||||
__author__ = "snowykami"
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="轻雪核心插件",
|
||||
description="轻雪主程序插件,包含了许多初始化的功能",
|
||||
usage="",
|
||||
homepage="https://github.com/snowykami/LiteyukiBot",
|
||||
extra={
|
||||
"liteyuki" : True,
|
||||
"toggleable": False,
|
||||
}
|
||||
)
|
||||
|
||||
from src.utils.base.language import Language, get_default_lang_code
|
||||
|
||||
sys_lang = Language(get_default_lang_code())
|
||||
nonebot.logger.info(sys_lang.get("main.current_language", LANG=sys_lang.get("language.name")))
|
@ -1,47 +1,47 @@
|
||||
import nonebot
|
||||
from git import Repo
|
||||
|
||||
from src.utils.base.config import get_config
|
||||
|
||||
remote_urls = [
|
||||
"https://github.com/LiteyukiStudio/LiteyukiBot.git",
|
||||
"https://gitee.com/snowykami/LiteyukiBot.git"
|
||||
]
|
||||
|
||||
|
||||
def detect_update() -> bool:
|
||||
# 对每个远程仓库进行检查,只要有一个仓库有更新,就返回True
|
||||
for remote_url in remote_urls:
|
||||
repo = Repo(".")
|
||||
repo.remotes.origin.set_url(remote_url)
|
||||
repo.remotes.origin.fetch()
|
||||
if repo.head.commit != repo.commit('origin/main'):
|
||||
return True
|
||||
|
||||
|
||||
def update_liteyuki() -> tuple[bool, str]:
|
||||
"""更新轻雪
|
||||
:return: 是否更新成功,更新变动"""
|
||||
|
||||
if get_config("allow_update", True):
|
||||
new_commit_detected = detect_update()
|
||||
if new_commit_detected:
|
||||
repo = Repo(".")
|
||||
logs = ""
|
||||
# 对每个远程仓库进行更新
|
||||
for remote_url in remote_urls:
|
||||
try:
|
||||
logs += f"\nremote: {remote_url}"
|
||||
repo.remotes.origin.set_url(remote_url)
|
||||
repo.remotes.origin.pull()
|
||||
diffs = repo.head.commit.diff("origin/main")
|
||||
for diff in diffs.iter_change_type('M'):
|
||||
logs += f"\n{diff.a_path}"
|
||||
return True, logs
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
return False, "Nothing Changed"
|
||||
|
||||
else:
|
||||
raise PermissionError("Update is not allowed.")
|
||||
import nonebot
|
||||
from git import Repo
|
||||
|
||||
from src.utils.base.config import get_config
|
||||
|
||||
remote_urls = [
|
||||
"https://github.com/LiteyukiStudio/LiteyukiBot.git",
|
||||
"https://gitee.com/snowykami/LiteyukiBot.git"
|
||||
]
|
||||
|
||||
|
||||
def detect_update() -> bool:
|
||||
# 对每个远程仓库进行检查,只要有一个仓库有更新,就返回True
|
||||
for remote_url in remote_urls:
|
||||
repo = Repo(".")
|
||||
repo.remotes.origin.set_url(remote_url)
|
||||
repo.remotes.origin.fetch()
|
||||
if repo.head.commit != repo.commit('origin/main'):
|
||||
return True
|
||||
|
||||
|
||||
def update_liteyuki() -> tuple[bool, str]:
|
||||
"""更新轻雪
|
||||
:return: 是否更新成功,更新变动"""
|
||||
|
||||
if get_config("allow_update", True):
|
||||
new_commit_detected = detect_update()
|
||||
if new_commit_detected:
|
||||
repo = Repo(".")
|
||||
logs = ""
|
||||
# 对每个远程仓库进行更新
|
||||
for remote_url in remote_urls:
|
||||
try:
|
||||
logs += f"\nremote: {remote_url}"
|
||||
repo.remotes.origin.set_url(remote_url)
|
||||
repo.remotes.origin.pull()
|
||||
diffs = repo.head.commit.diff("origin/main")
|
||||
for diff in diffs.iter_change_type('M'):
|
||||
logs += f"\n{diff.a_path}"
|
||||
return True, logs
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
return False, "Nothing Changed"
|
||||
|
||||
else:
|
||||
raise PermissionError("Update is not allowed.")
|
@ -1,301 +1,301 @@
|
||||
import time
|
||||
from typing import AnyStr
|
||||
|
||||
import time
|
||||
from typing import AnyStr
|
||||
|
||||
import nonebot
|
||||
import pip
|
||||
from nonebot import get_driver, require
|
||||
from nonebot.adapters import onebot, satori
|
||||
from nonebot.adapters.onebot.v11 import Message, unescape
|
||||
from nonebot.internal.matcher import Matcher
|
||||
from nonebot.permission import SUPERUSER
|
||||
|
||||
# from src.liteyuki.core import Reloader
|
||||
from src.utils import event as event_utils, satori_utils
|
||||
from src.utils.base.config import get_config
|
||||
from src.utils.base.data_manager import TempConfig, common_db
|
||||
from src.utils.base.language import get_user_lang
|
||||
from src.utils.base.ly_typing import T_Bot, T_MessageEvent
|
||||
from src.utils.message.message import MarkdownMessage as md, broadcast_to_superusers
|
||||
from .api import update_liteyuki # type: ignore
|
||||
from ..utils.base import reload # type: ignore
|
||||
from ..utils.base.ly_function import get_function # type: ignore
|
||||
from ..utils.message.html_tool import md_to_pic
|
||||
|
||||
require("nonebot_plugin_alconna")
|
||||
require("nonebot_plugin_apscheduler")
|
||||
from nonebot_plugin_alconna import UniMessage, on_alconna, Alconna, Args, Arparma, MultiVar
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
|
||||
|
||||
driver = get_driver()
|
||||
|
||||
|
||||
@on_alconna(
|
||||
command=Alconna(
|
||||
"liteecho",
|
||||
Args["text", str, ""],
|
||||
),
|
||||
permission=SUPERUSER
|
||||
).handle()
|
||||
# Satori OK
|
||||
async def _(bot: T_Bot, matcher: Matcher, result: Arparma):
|
||||
if text := result.main_args.get("text"):
|
||||
await matcher.finish(Message(unescape(text)))
|
||||
else:
|
||||
await matcher.finish(f"Hello, Liteyuki!\nBot {bot.self_id}")
|
||||
|
||||
|
||||
@on_alconna(
|
||||
aliases={"更新轻雪"},
|
||||
command=Alconna(
|
||||
"update-liteyuki"
|
||||
),
|
||||
permission=SUPERUSER
|
||||
).handle()
|
||||
# Satori OK
|
||||
async def _(bot: T_Bot, event: T_MessageEvent, matcher: Matcher):
|
||||
# 使用git pull更新
|
||||
|
||||
ulang = get_user_lang(str(event.user.id if isinstance(event, satori.event.Event) else event.user_id))
|
||||
success, logs = update_liteyuki()
|
||||
reply = "Liteyuki updated!\n"
|
||||
reply += f"```\n{logs}\n```\n"
|
||||
btn_restart = md.btn_cmd(ulang.get("liteyuki.restart_now"), "reload-liteyuki")
|
||||
pip.main(["install", "-r", "requirements.txt"])
|
||||
reply += f"{ulang.get('liteyuki.update_restart', RESTART=btn_restart)}"
|
||||
# await md.send_md(reply, bot)
|
||||
img_bytes = await md_to_pic(reply)
|
||||
await UniMessage.send(UniMessage.image(raw=img_bytes))
|
||||
|
||||
|
||||
@on_alconna(
|
||||
aliases={"重启轻雪"},
|
||||
command=Alconna(
|
||||
"reload-liteyuki"
|
||||
),
|
||||
permission=SUPERUSER
|
||||
).handle()
|
||||
# Satori OK
|
||||
async def _(matcher: Matcher, bot: T_Bot, event: T_MessageEvent):
|
||||
await matcher.send("Liteyuki reloading")
|
||||
temp_data = common_db.where_one(TempConfig(), default=TempConfig())
|
||||
|
||||
temp_data.data.update(
|
||||
{
|
||||
"reload" : True,
|
||||
"reload_time" : time.time(),
|
||||
"reload_bot_id" : bot.self_id,
|
||||
"reload_session_type": event_utils.get_message_type(event),
|
||||
"reload_session_id" : (event.group_id if event.message_type == "group" else event.user_id)
|
||||
if not isinstance(event, satori.event.Event) else event.chan_active.id,
|
||||
"delta_time" : 0
|
||||
}
|
||||
)
|
||||
|
||||
common_db.save(temp_data)
|
||||
reload()
|
||||
|
||||
|
||||
@on_alconna(
|
||||
command=Alconna(
|
||||
"liteyuki-docs",
|
||||
),
|
||||
aliases={"轻雪文档"},
|
||||
).handle()
|
||||
# Satori OK
|
||||
async def _(matcher: Matcher):
|
||||
await matcher.finish("https://bot.liteyuki.icu/")
|
||||
|
||||
|
||||
@on_alconna(
|
||||
command=Alconna(
|
||||
"/function",
|
||||
Args["function", str]["args", MultiVar(str), ()],
|
||||
),
|
||||
permission=SUPERUSER
|
||||
).handle()
|
||||
async def _(result: Arparma, bot: T_Bot, event: T_MessageEvent, matcher: Matcher):
|
||||
"""
|
||||
调用轻雪函数
|
||||
Args:
|
||||
result:
|
||||
bot:
|
||||
event:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
function_name = result.main_args.get("function")
|
||||
args: tuple[str] = result.main_args.get("args", ())
|
||||
_args = []
|
||||
_kwargs = {
|
||||
"USER_ID" : str(event.user_id),
|
||||
"GROUP_ID": str(event.group_id) if event.message_type == "group" else "0",
|
||||
"BOT_ID" : str(bot.self_id)
|
||||
}
|
||||
|
||||
for arg in args:
|
||||
arg = arg.replace("\\=", "EQUAL_SIGN")
|
||||
if "=" in arg:
|
||||
key, value = arg.split("=", 1)
|
||||
value = unescape(value.replace("EQUAL_SIGN", "="))
|
||||
try:
|
||||
value = eval(value)
|
||||
except:
|
||||
value = value
|
||||
_kwargs[key] = value
|
||||
else:
|
||||
_args.append(arg.replace("EQUAL_SIGN", "="))
|
||||
|
||||
ly_func = get_function(function_name)
|
||||
ly_func.bot = bot if "BOT_ID" not in _kwargs else nonebot.get_bot(_kwargs["BOT_ID"])
|
||||
ly_func.matcher = matcher
|
||||
|
||||
await ly_func(*tuple(_args), **_kwargs)
|
||||
|
||||
|
||||
@on_alconna(
|
||||
command=Alconna(
|
||||
"/api",
|
||||
Args["api", str]["args", MultiVar(AnyStr), ()],
|
||||
),
|
||||
permission=SUPERUSER
|
||||
).handle()
|
||||
async def _(result: Arparma, bot: T_Bot, event: T_MessageEvent, matcher: Matcher):
|
||||
"""
|
||||
调用API
|
||||
Args:
|
||||
result:
|
||||
bot:
|
||||
event:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
api_name = result.main_args.get("api")
|
||||
args: tuple[str] = result.main_args.get("args", ()) # 类似于url参数,但每个参数间用空格分隔,空格是%20
|
||||
args_dict = {}
|
||||
|
||||
for arg in args:
|
||||
key, value = arg.split("=", 1)
|
||||
|
||||
args_dict[key] = unescape(value.replace("%20", " "))
|
||||
|
||||
if api_name in need_user_id and "user_id" not in args_dict:
|
||||
args_dict["user_id"] = str(event.user_id)
|
||||
if api_name in need_group_id and "group_id" not in args_dict and event.message_type == "group":
|
||||
args_dict["group_id"] = str(event.group_id)
|
||||
|
||||
if "message" in args_dict:
|
||||
args_dict["message"] = Message(eval(args_dict["message"]))
|
||||
|
||||
if "messages" in args_dict:
|
||||
args_dict["messages"] = Message(eval(args_dict["messages"]))
|
||||
|
||||
try:
|
||||
result = await bot.call_api(api_name, **args_dict)
|
||||
except Exception as e:
|
||||
result = str(e)
|
||||
|
||||
args_show = "\n".join("- %s: %s" % (k, v) for k, v in args_dict.items())
|
||||
await matcher.finish(f"API: {api_name}\n\nArgs: \n{args_show}\n\nResult: {result}")
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def on_startup():
|
||||
temp_data = common_db.where_one(TempConfig(), default=TempConfig())
|
||||
# 储存重启信息
|
||||
if temp_data.data.get("reload", False):
|
||||
delta_time = time.time() - temp_data.data.get("reload_time", 0)
|
||||
temp_data.data["delta_time"] = delta_time
|
||||
common_db.save(temp_data) # 更新数据
|
||||
"""
|
||||
该部分将迁移至轻雪生命周期
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@driver.on_shutdown
|
||||
async def on_shutdown():
|
||||
pass
|
||||
|
||||
|
||||
@driver.on_bot_connect
|
||||
async def _(bot: T_Bot):
|
||||
temp_data = common_db.where_one(TempConfig(), default=TempConfig())
|
||||
if isinstance(bot, satori.Bot):
|
||||
await satori_utils.user_infos.load_friends(bot)
|
||||
# 用于重启计时
|
||||
if temp_data.data.get("reload", False):
|
||||
temp_data.data["reload"] = False
|
||||
reload_bot_id = temp_data.data.get("reload_bot_id", 0)
|
||||
if reload_bot_id != bot.self_id:
|
||||
return
|
||||
reload_session_type = temp_data.data.get("reload_session_type", "private")
|
||||
reload_session_id = temp_data.data.get("reload_session_id", 0)
|
||||
delta_time = temp_data.data.get("delta_time", 0)
|
||||
common_db.save(temp_data) # 更新数据
|
||||
|
||||
if delta_time <= 20.0: # 启动时间太长就别发了,丢人
|
||||
if isinstance(bot, satori.Bot):
|
||||
await bot.send_message(
|
||||
channel_id=reload_session_id,
|
||||
message="Liteyuki reloaded in %.2f s" % delta_time
|
||||
)
|
||||
elif isinstance(bot, onebot.v11.Bot):
|
||||
await bot.send_msg(
|
||||
message_type=reload_session_type,
|
||||
user_id=reload_session_id,
|
||||
group_id=reload_session_id,
|
||||
message="Liteyuki reloaded in %.2f s" % delta_time
|
||||
)
|
||||
|
||||
elif isinstance(bot, onebot.v12.Bot):
|
||||
await bot.send_message(
|
||||
message_type=reload_session_type,
|
||||
user_id=reload_session_id,
|
||||
group_id=reload_session_id,
|
||||
message="Liteyuki reloaded in %.2f s" % delta_time,
|
||||
detail_type="group"
|
||||
)
|
||||
|
||||
|
||||
# 每天4点更新
|
||||
@scheduler.scheduled_job("cron", hour=4)
|
||||
async def every_day_update():
|
||||
if get_config("auto_update", default=True):
|
||||
result, logs = update_liteyuki()
|
||||
pip.main(["install", "-r", "requirements.txt"])
|
||||
if result:
|
||||
await broadcast_to_superusers(f"Liteyuki updated: ```\n{logs}\n```")
|
||||
nonebot.logger.info(f"Liteyuki updated: {logs}")
|
||||
reload()
|
||||
else:
|
||||
nonebot.logger.info(logs)
|
||||
|
||||
|
||||
# 需要用户id的api
|
||||
need_user_id = (
|
||||
"send_private_msg",
|
||||
"send_msg",
|
||||
"set_group_card",
|
||||
"set_group_special_title",
|
||||
"get_stranger_info",
|
||||
"get_group_member_info"
|
||||
)
|
||||
|
||||
need_group_id = (
|
||||
"send_group_msg",
|
||||
"send_msg",
|
||||
"set_group_card",
|
||||
"set_group_name",
|
||||
|
||||
"set_group_special_title",
|
||||
"get_group_member_info",
|
||||
"get_group_member_list",
|
||||
"get_group_honor_info"
|
||||
)
|
||||
import time
|
||||
from typing import AnyStr
|
||||
|
||||
import time
|
||||
from typing import AnyStr
|
||||
|
||||
import nonebot
|
||||
import pip
|
||||
from nonebot import get_driver, require
|
||||
from nonebot.adapters import onebot, satori
|
||||
from nonebot.adapters.onebot.v11 import Message, unescape
|
||||
from nonebot.internal.matcher import Matcher
|
||||
from nonebot.permission import SUPERUSER
|
||||
|
||||
# from src.liteyuki.core import Reloader
|
||||
from src.utils import event as event_utils, satori_utils
|
||||
from src.utils.base.config import get_config
|
||||
from src.utils.base.data_manager import TempConfig, common_db
|
||||
from src.utils.base.language import get_user_lang
|
||||
from src.utils.base.ly_typing import T_Bot, T_MessageEvent
|
||||
from src.utils.message.message import MarkdownMessage as md, broadcast_to_superusers
|
||||
from .api import update_liteyuki # type: ignore
|
||||
from src.utils.base import reload # type: ignore
|
||||
from src.utils.base.ly_function import get_function # type: ignore
|
||||
from src.utils.message.html_tool import md_to_pic
|
||||
|
||||
require("nonebot_plugin_alconna")
|
||||
require("nonebot_plugin_apscheduler")
|
||||
from nonebot_plugin_alconna import UniMessage, on_alconna, Alconna, Args, Arparma, MultiVar
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
|
||||
|
||||
driver = get_driver()
|
||||
|
||||
|
||||
@on_alconna(
|
||||
command=Alconna(
|
||||
"liteecho",
|
||||
Args["text", str, ""],
|
||||
),
|
||||
permission=SUPERUSER
|
||||
).handle()
|
||||
# Satori OK
|
||||
async def _(bot: T_Bot, matcher: Matcher, result: Arparma):
|
||||
if text := result.main_args.get("text"):
|
||||
await matcher.finish(Message(unescape(text)))
|
||||
else:
|
||||
await matcher.finish(f"Hello, Liteyuki!\nBot {bot.self_id}")
|
||||
|
||||
|
||||
@on_alconna(
|
||||
aliases={"更新轻雪"},
|
||||
command=Alconna(
|
||||
"update-liteyuki"
|
||||
),
|
||||
permission=SUPERUSER
|
||||
).handle()
|
||||
# Satori OK
|
||||
async def _(bot: T_Bot, event: T_MessageEvent, matcher: Matcher):
|
||||
# 使用git pull更新
|
||||
|
||||
ulang = get_user_lang(str(event.user.id if isinstance(event, satori.event.Event) else event.user_id))
|
||||
success, logs = update_liteyuki()
|
||||
reply = "Liteyuki updated!\n"
|
||||
reply += f"```\n{logs}\n```\n"
|
||||
btn_restart = md.btn_cmd(ulang.get("liteyuki.restart_now"), "reload-liteyuki")
|
||||
pip.main(["install", "-r", "requirements.txt"])
|
||||
reply += f"{ulang.get('liteyuki.update_restart', RESTART=btn_restart)}"
|
||||
# await md.send_md(reply, bot)
|
||||
img_bytes = await md_to_pic(reply)
|
||||
await UniMessage.send(UniMessage.image(raw=img_bytes))
|
||||
|
||||
|
||||
@on_alconna(
|
||||
aliases={"重启轻雪"},
|
||||
command=Alconna(
|
||||
"reload-liteyuki"
|
||||
),
|
||||
permission=SUPERUSER
|
||||
).handle()
|
||||
# Satori OK
|
||||
async def _(matcher: Matcher, bot: T_Bot, event: T_MessageEvent):
|
||||
await matcher.send("Liteyuki reloading")
|
||||
temp_data = common_db.where_one(TempConfig(), default=TempConfig())
|
||||
|
||||
temp_data.data.update(
|
||||
{
|
||||
"reload" : True,
|
||||
"reload_time" : time.time(),
|
||||
"reload_bot_id" : bot.self_id,
|
||||
"reload_session_type": event_utils.get_message_type(event),
|
||||
"reload_session_id" : (event.group_id if event.message_type == "group" else event.user_id)
|
||||
if not isinstance(event, satori.event.Event) else event.chan_active.id,
|
||||
"delta_time" : 0
|
||||
}
|
||||
)
|
||||
|
||||
common_db.save(temp_data)
|
||||
reload()
|
||||
|
||||
|
||||
@on_alconna(
|
||||
command=Alconna(
|
||||
"liteyuki-docs",
|
||||
),
|
||||
aliases={"轻雪文档"},
|
||||
).handle()
|
||||
# Satori OK
|
||||
async def _(matcher: Matcher):
|
||||
await matcher.finish("https://bot.liteyuki.icu/")
|
||||
|
||||
|
||||
@on_alconna(
|
||||
command=Alconna(
|
||||
"/function",
|
||||
Args["function", str]["args", MultiVar(str), ()],
|
||||
),
|
||||
permission=SUPERUSER
|
||||
).handle()
|
||||
async def _(result: Arparma, bot: T_Bot, event: T_MessageEvent, matcher: Matcher):
|
||||
"""
|
||||
调用轻雪函数
|
||||
Args:
|
||||
result:
|
||||
bot:
|
||||
event:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
function_name = result.main_args.get("function")
|
||||
args: tuple[str] = result.main_args.get("args", ())
|
||||
_args = []
|
||||
_kwargs = {
|
||||
"USER_ID" : str(event.user_id),
|
||||
"GROUP_ID": str(event.group_id) if event.message_type == "group" else "0",
|
||||
"BOT_ID" : str(bot.self_id)
|
||||
}
|
||||
|
||||
for arg in args:
|
||||
arg = arg.replace("\\=", "EQUAL_SIGN")
|
||||
if "=" in arg:
|
||||
key, value = arg.split("=", 1)
|
||||
value = unescape(value.replace("EQUAL_SIGN", "="))
|
||||
try:
|
||||
value = eval(value)
|
||||
except:
|
||||
value = value
|
||||
_kwargs[key] = value
|
||||
else:
|
||||
_args.append(arg.replace("EQUAL_SIGN", "="))
|
||||
|
||||
ly_func = get_function(function_name)
|
||||
ly_func.bot = bot if "BOT_ID" not in _kwargs else nonebot.get_bot(_kwargs["BOT_ID"])
|
||||
ly_func.matcher = matcher
|
||||
|
||||
await ly_func(*tuple(_args), **_kwargs)
|
||||
|
||||
|
||||
@on_alconna(
|
||||
command=Alconna(
|
||||
"/api",
|
||||
Args["api", str]["args", MultiVar(AnyStr), ()],
|
||||
),
|
||||
permission=SUPERUSER
|
||||
).handle()
|
||||
async def _(result: Arparma, bot: T_Bot, event: T_MessageEvent, matcher: Matcher):
|
||||
"""
|
||||
调用API
|
||||
Args:
|
||||
result:
|
||||
bot:
|
||||
event:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
api_name = result.main_args.get("api")
|
||||
args: tuple[str] = result.main_args.get("args", ()) # 类似于url参数,但每个参数间用空格分隔,空格是%20
|
||||
args_dict = {}
|
||||
|
||||
for arg in args:
|
||||
key, value = arg.split("=", 1)
|
||||
|
||||
args_dict[key] = unescape(value.replace("%20", " "))
|
||||
|
||||
if api_name in need_user_id and "user_id" not in args_dict:
|
||||
args_dict["user_id"] = str(event.user_id)
|
||||
if api_name in need_group_id and "group_id" not in args_dict and event.message_type == "group":
|
||||
args_dict["group_id"] = str(event.group_id)
|
||||
|
||||
if "message" in args_dict:
|
||||
args_dict["message"] = Message(eval(args_dict["message"]))
|
||||
|
||||
if "messages" in args_dict:
|
||||
args_dict["messages"] = Message(eval(args_dict["messages"]))
|
||||
|
||||
try:
|
||||
result = await bot.call_api(api_name, **args_dict)
|
||||
except Exception as e:
|
||||
result = str(e)
|
||||
|
||||
args_show = "\n".join("- %s: %s" % (k, v) for k, v in args_dict.items())
|
||||
await matcher.finish(f"API: {api_name}\n\nArgs: \n{args_show}\n\nResult: {result}")
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def on_startup():
|
||||
temp_data = common_db.where_one(TempConfig(), default=TempConfig())
|
||||
# 储存重启信息
|
||||
if temp_data.data.get("reload", False):
|
||||
delta_time = time.time() - temp_data.data.get("reload_time", 0)
|
||||
temp_data.data["delta_time"] = delta_time
|
||||
common_db.save(temp_data) # 更新数据
|
||||
"""
|
||||
该部分将迁移至轻雪生命周期
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@driver.on_shutdown
|
||||
async def on_shutdown():
|
||||
pass
|
||||
|
||||
|
||||
@driver.on_bot_connect
|
||||
async def _(bot: T_Bot):
|
||||
temp_data = common_db.where_one(TempConfig(), default=TempConfig())
|
||||
if isinstance(bot, satori.Bot):
|
||||
await satori_utils.user_infos.load_friends(bot)
|
||||
# 用于重启计时
|
||||
if temp_data.data.get("reload", False):
|
||||
temp_data.data["reload"] = False
|
||||
reload_bot_id = temp_data.data.get("reload_bot_id", 0)
|
||||
if reload_bot_id != bot.self_id:
|
||||
return
|
||||
reload_session_type = temp_data.data.get("reload_session_type", "private")
|
||||
reload_session_id = temp_data.data.get("reload_session_id", 0)
|
||||
delta_time = temp_data.data.get("delta_time", 0)
|
||||
common_db.save(temp_data) # 更新数据
|
||||
|
||||
if delta_time <= 20.0: # 启动时间太长就别发了,丢人
|
||||
if isinstance(bot, satori.Bot):
|
||||
await bot.send_message(
|
||||
channel_id=reload_session_id,
|
||||
message="Liteyuki reloaded in %.2f s" % delta_time
|
||||
)
|
||||
elif isinstance(bot, onebot.v11.Bot):
|
||||
await bot.send_msg(
|
||||
message_type=reload_session_type,
|
||||
user_id=reload_session_id,
|
||||
group_id=reload_session_id,
|
||||
message="Liteyuki reloaded in %.2f s" % delta_time
|
||||
)
|
||||
|
||||
elif isinstance(bot, onebot.v12.Bot):
|
||||
await bot.send_message(
|
||||
message_type=reload_session_type,
|
||||
user_id=reload_session_id,
|
||||
group_id=reload_session_id,
|
||||
message="Liteyuki reloaded in %.2f s" % delta_time,
|
||||
detail_type="group"
|
||||
)
|
||||
|
||||
|
||||
# 每天4点更新
|
||||
@scheduler.scheduled_job("cron", hour=4)
|
||||
async def every_day_update():
|
||||
if get_config("auto_update", default=True):
|
||||
result, logs = update_liteyuki()
|
||||
pip.main(["install", "-r", "requirements.txt"])
|
||||
if result:
|
||||
await broadcast_to_superusers(f"Liteyuki updated: ```\n{logs}\n```")
|
||||
nonebot.logger.info(f"Liteyuki updated: {logs}")
|
||||
reload()
|
||||
else:
|
||||
nonebot.logger.info(logs)
|
||||
|
||||
|
||||
# 需要用户id的api
|
||||
need_user_id = (
|
||||
"send_private_msg",
|
||||
"send_msg",
|
||||
"set_group_card",
|
||||
"set_group_special_title",
|
||||
"get_stranger_info",
|
||||
"get_group_member_info"
|
||||
)
|
||||
|
||||
need_group_id = (
|
||||
"send_group_msg",
|
||||
"send_msg",
|
||||
"set_group_card",
|
||||
"set_group_name",
|
||||
|
||||
"set_group_special_title",
|
||||
"get_group_member_info",
|
||||
"get_group_member_list",
|
||||
"get_group_honor_info"
|
||||
)
|
@ -1,33 +1,39 @@
|
||||
import asyncio
|
||||
|
||||
import nonebot.plugin
|
||||
from nonebot import get_driver
|
||||
from src.utils import init_log
|
||||
from src.utils.base.config import get_config
|
||||
from src.utils.base.data_manager import InstalledPlugin, plugin_db
|
||||
from src.utils.base.resource import load_resources
|
||||
from src.utils.message.tools import check_for_package
|
||||
|
||||
load_resources()
|
||||
init_log()
|
||||
|
||||
driver = get_driver()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def load_plugins():
|
||||
nonebot.plugin.load_plugins("src/nonebot_plugins")
|
||||
# 从数据库读取已安装的插件
|
||||
if not get_config("safe_mode", False):
|
||||
# 安全模式下,不加载插件
|
||||
installed_plugins: list[InstalledPlugin] = plugin_db.where_all(InstalledPlugin())
|
||||
if installed_plugins:
|
||||
for installed_plugin in installed_plugins:
|
||||
if not check_for_package(installed_plugin.module_name):
|
||||
nonebot.logger.error(
|
||||
f"{installed_plugin.module_name} not installed, but still in loader index.")
|
||||
else:
|
||||
nonebot.load_plugin(installed_plugin.module_name)
|
||||
nonebot.plugin.load_plugins("plugins")
|
||||
else:
|
||||
nonebot.logger.info("Safe mode is on, no plugin loaded.")
|
||||
import asyncio
|
||||
import os.path
|
||||
from pathlib import Path
|
||||
|
||||
import nonebot.plugin
|
||||
from nonebot import get_driver
|
||||
from src.utils import init_log
|
||||
from src.utils.base.config import get_config
|
||||
from src.utils.base.data_manager import InstalledPlugin, plugin_db
|
||||
from src.utils.base.resource import load_resources
|
||||
from src.utils.message.tools import check_for_package
|
||||
|
||||
load_resources()
|
||||
init_log()
|
||||
|
||||
driver = get_driver()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def load_plugins():
|
||||
print("load from", os.path.join(os.path.dirname(__file__), "../nonebot_plugins"))
|
||||
nonebot.plugin.load_plugins(os.path.abspath(os.path.join(os.path.dirname(__file__), "../nonebot_plugins")))
|
||||
# 从数据库读取已安装的插件
|
||||
if not get_config("safe_mode", False):
|
||||
# 安全模式下,不加载插件
|
||||
installed_plugins: list[InstalledPlugin] = plugin_db.where_all(
|
||||
InstalledPlugin()
|
||||
)
|
||||
if installed_plugins:
|
||||
for installed_plugin in installed_plugins:
|
||||
if not check_for_package(installed_plugin.module_name):
|
||||
nonebot.logger.error(
|
||||
f"{installed_plugin.module_name} not installed, but still in loader index."
|
||||
)
|
||||
else:
|
||||
nonebot.load_plugin(installed_plugin.module_name)
|
||||
nonebot.plugin.load_plugins("plugins")
|
||||
else:
|
||||
nonebot.logger.info("Safe mode is on, no plugin loaded.")
|
@ -1,16 +0,0 @@
|
||||
from nonebot.plugin import PluginMetadata
|
||||
from .auto_update import *
|
||||
|
||||
__author__ = "expliyh"
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="Satori 用户数据自动更新(临时措施)",
|
||||
description="",
|
||||
usage="",
|
||||
type="application",
|
||||
homepage="https://github.com/snowykami/LiteyukiBot",
|
||||
extra={
|
||||
"liteyuki": True,
|
||||
"toggleable" : True,
|
||||
"default_enable" : True,
|
||||
}
|
||||
)
|
@ -1,20 +0,0 @@
|
||||
import nonebot
|
||||
|
||||
from nonebot.message import event_preprocessor
|
||||
from src.utils.base.ly_typing import T_MessageEvent
|
||||
from src.utils import satori_utils
|
||||
from nonebot.adapters import satori
|
||||
from nonebot_plugin_alconna.typings import Event
|
||||
from src.nonebot_plugins.liteyuki_status.counter_for_satori import satori_counter
|
||||
|
||||
|
||||
@event_preprocessor
|
||||
async def pre_handle(event: Event):
|
||||
if isinstance(event, satori.MessageEvent):
|
||||
if event.user.id == event.self_id:
|
||||
satori_counter.msg_sent += 1
|
||||
else:
|
||||
satori_counter.msg_received += 1
|
||||
if event.user.name is not None:
|
||||
if await satori_utils.user_infos.put(event.user):
|
||||
nonebot.logger.info(f"Satori user {event.user.name}<{event.user.id}> updated")
|
@ -1,42 +1,42 @@
|
||||
import sys
|
||||
|
||||
import nonebot
|
||||
|
||||
__NAME__ = "LiteyukiBot"
|
||||
__VERSION__ = "6.3.2" # 60201
|
||||
|
||||
from src.utils.base.config import load_from_yaml, config
|
||||
from src.utils.base.log import init_log
|
||||
from git import Repo
|
||||
|
||||
major, minor, patch = map(int, __VERSION__.split("."))
|
||||
__VERSION_I__ = major * 10000 + minor * 100 + patch
|
||||
|
||||
|
||||
def init():
|
||||
"""
|
||||
初始化
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# 检测python版本是否高于3.10
|
||||
init_log()
|
||||
if sys.version_info < (3, 10):
|
||||
nonebot.logger.error("Requires Python3.10+ to run, please upgrade your Python Environment.")
|
||||
exit(1)
|
||||
|
||||
try:
|
||||
# 检测git仓库
|
||||
repo = Repo(".")
|
||||
except Exception as e:
|
||||
nonebot.logger.error(f"Failed to load git repository: {e}, please clone this project from GitHub instead of downloading the zip file.")
|
||||
|
||||
# temp_data: TempConfig = common_db.where_one(TempConfig(), default=TempConfig())
|
||||
# temp_data.data["start_time"] = time.time()
|
||||
# common_db.save(temp_data)
|
||||
|
||||
nonebot.logger.info(
|
||||
f"Run Liteyuki-NoneBot with Python{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro} "
|
||||
f"at {sys.executable}"
|
||||
)
|
||||
nonebot.logger.info(f"{__NAME__} {__VERSION__}({__VERSION_I__}) is running")
|
||||
import sys
|
||||
|
||||
import nonebot
|
||||
|
||||
__NAME__ = "LiteyukiBot"
|
||||
__VERSION__ = "6.3.2" # 60201
|
||||
|
||||
from src.utils.base.config import load_from_yaml, config
|
||||
from src.utils.base.log import init_log
|
||||
from git import Repo
|
||||
|
||||
major, minor, patch = map(int, __VERSION__.split("."))
|
||||
__VERSION_I__ = major * 10000 + minor * 100 + patch
|
||||
|
||||
|
||||
def init():
|
||||
"""
|
||||
初始化
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# 检测python版本是否高于3.10
|
||||
init_log()
|
||||
if sys.version_info < (3, 10):
|
||||
nonebot.logger.error("Requires Python3.10+ to run, please upgrade your Python Environment.")
|
||||
exit(1)
|
||||
|
||||
try:
|
||||
# 检测git仓库
|
||||
repo = Repo(".")
|
||||
except Exception as e:
|
||||
nonebot.logger.error(f"Failed to load git repository: {e}, please clone this project from GitHub instead of downloading the zip file.")
|
||||
|
||||
# temp_data: TempConfig = common_db.where_one(TempConfig(), default=TempConfig())
|
||||
# temp_data.data["start_time"] = time.time()
|
||||
# common_db.save(temp_data)
|
||||
|
||||
nonebot.logger.info(
|
||||
f"Run Liteyuki-NoneBot with Python{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro} "
|
||||
f"at {sys.executable}"
|
||||
)
|
||||
nonebot.logger.info(f"{__NAME__} {__VERSION__}({__VERSION_I__}) is running")
|
||||
|
@ -1,109 +1,109 @@
|
||||
import os
|
||||
import platform
|
||||
from typing import List
|
||||
|
||||
import nonebot
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..message.tools import random_hex_string
|
||||
|
||||
|
||||
config = {} # 全局配置,确保加载后读取
|
||||
|
||||
|
||||
class SatoriNodeConfig(BaseModel):
|
||||
host: str = ""
|
||||
port: str = "5500"
|
||||
path: str = ""
|
||||
token: str = ""
|
||||
|
||||
|
||||
class SatoriConfig(BaseModel):
|
||||
comment: str = (
|
||||
"These features are still in development. Do not enable in production environment."
|
||||
)
|
||||
enable: bool = False
|
||||
hosts: List[SatoriNodeConfig] = [SatoriNodeConfig()]
|
||||
|
||||
|
||||
class BasicConfig(BaseModel):
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 20216
|
||||
superusers: list[str] = []
|
||||
command_start: list[str] = ["/", ""]
|
||||
nickname: list[str] = [f"LiteyukiBot-{random_hex_string(6)}"]
|
||||
satori: SatoriConfig = SatoriConfig()
|
||||
data_path: str = "data/liteyuki"
|
||||
chromium_path: str = (
|
||||
"/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome" # type: ignore
|
||||
if platform.system() == "Darwin"
|
||||
else (
|
||||
"C:/Program Files (x86)/Microsoft/Edge/Application/msedge.exe"
|
||||
if platform.system() == "Windows"
|
||||
else "/usr/bin/chromium-browser"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def load_from_yaml(file_: str) -> dict:
|
||||
global config
|
||||
nonebot.logger.debug("Loading config from %s" % file_)
|
||||
if not os.path.exists(file_):
|
||||
nonebot.logger.warning(
|
||||
f"Config file {file_} not found, created default config, please modify it and restart"
|
||||
)
|
||||
with open(file_, "w", encoding="utf-8") as f:
|
||||
yaml.dump(BasicConfig().dict(), f, default_flow_style=False)
|
||||
|
||||
with open(file_, "r", encoding="utf-8") as f:
|
||||
conf = init_conf(yaml.load(f, Loader=yaml.FullLoader))
|
||||
config = conf
|
||||
if conf is None:
|
||||
nonebot.logger.warning(
|
||||
f"Config file {file_} is empty, use default config. please modify it and restart"
|
||||
)
|
||||
conf = BasicConfig().dict()
|
||||
return conf
|
||||
|
||||
|
||||
def get_config(key: str, default=None):
|
||||
"""获取配置项,优先级:bot > config > db > yaml"""
|
||||
try:
|
||||
bot = nonebot.get_bot()
|
||||
except:
|
||||
bot = None
|
||||
|
||||
if bot is None:
|
||||
bot_config = {}
|
||||
else:
|
||||
bot_config = bot.config.dict()
|
||||
|
||||
if key in bot_config:
|
||||
return bot_config[key]
|
||||
|
||||
elif key in config:
|
||||
return config[key]
|
||||
|
||||
elif key in load_from_yaml("config.yml"):
|
||||
return load_from_yaml("config.yml")[key]
|
||||
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
def init_conf(conf: dict) -> dict:
|
||||
"""
|
||||
初始化配置文件,确保配置文件中的必要字段存在,且不会冲突
|
||||
Args:
|
||||
conf:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# 若command_start中无"",则添加必要命令头,开启alconna_use_command_start防止冲突
|
||||
# 以下内容由于issue #53 被注释
|
||||
# if "" not in conf.get("command_start", []):
|
||||
# conf["alconna_use_command_start"] = True
|
||||
return conf
|
||||
pass
|
||||
import os
|
||||
import platform
|
||||
from typing import List
|
||||
|
||||
import nonebot
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..message.tools import random_hex_string
|
||||
|
||||
|
||||
config = {} # 全局配置,确保加载后读取
|
||||
|
||||
|
||||
class SatoriNodeConfig(BaseModel):
|
||||
host: str = ""
|
||||
port: str = "5500"
|
||||
path: str = ""
|
||||
token: str = ""
|
||||
|
||||
|
||||
class SatoriConfig(BaseModel):
|
||||
comment: str = (
|
||||
"These features are still in development. Do not enable in production environment."
|
||||
)
|
||||
enable: bool = False
|
||||
hosts: List[SatoriNodeConfig] = [SatoriNodeConfig()]
|
||||
|
||||
|
||||
class BasicConfig(BaseModel):
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 20216
|
||||
superusers: list[str] = []
|
||||
command_start: list[str] = ["/", ""]
|
||||
nickname: list[str] = [f"LiteyukiBot-{random_hex_string(6)}"]
|
||||
satori: SatoriConfig = SatoriConfig()
|
||||
data_path: str = "data/liteyuki"
|
||||
chromium_path: str = (
|
||||
"/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome" # type: ignore
|
||||
if platform.system() == "Darwin"
|
||||
else (
|
||||
"C:/Program Files (x86)/Microsoft/Edge/Application/msedge.exe"
|
||||
if platform.system() == "Windows"
|
||||
else "/usr/bin/chromium-browser"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def load_from_yaml(file_: str) -> dict:
|
||||
global config
|
||||
nonebot.logger.debug("Loading config from %s" % file_)
|
||||
if not os.path.exists(file_):
|
||||
nonebot.logger.warning(
|
||||
f"Config file {file_} not found, created default config, please modify it and restart"
|
||||
)
|
||||
with open(file_, "w", encoding="utf-8") as f:
|
||||
yaml.dump(BasicConfig().dict(), f, default_flow_style=False)
|
||||
|
||||
with open(file_, "r", encoding="utf-8") as f:
|
||||
conf = init_conf(yaml.load(f, Loader=yaml.FullLoader))
|
||||
config = conf
|
||||
if conf is None:
|
||||
nonebot.logger.warning(
|
||||
f"Config file {file_} is empty, use default config. please modify it and restart"
|
||||
)
|
||||
conf = BasicConfig().dict()
|
||||
return conf
|
||||
|
||||
|
||||
def get_config(key: str, default=None):
|
||||
"""获取配置项,优先级:bot > config > db > yaml"""
|
||||
try:
|
||||
bot = nonebot.get_bot()
|
||||
except:
|
||||
bot = None
|
||||
|
||||
if bot is None:
|
||||
bot_config = {}
|
||||
else:
|
||||
bot_config = bot.config.dict()
|
||||
|
||||
if key in bot_config:
|
||||
return bot_config[key]
|
||||
|
||||
elif key in config:
|
||||
return config[key]
|
||||
|
||||
elif key in load_from_yaml("config.yml"):
|
||||
return load_from_yaml("config.yml")[key]
|
||||
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
def init_conf(conf: dict) -> dict:
|
||||
"""
|
||||
初始化配置文件,确保配置文件中的必要字段存在,且不会冲突
|
||||
Args:
|
||||
conf:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# 若command_start中无"",则添加必要命令头,开启alconna_use_command_start防止冲突
|
||||
# 以下内容由于issue #53 被注释
|
||||
# if "" not in conf.get("command_start", []):
|
||||
# conf["alconna_use_command_start"] = True
|
||||
return conf
|
||||
pass
|
||||
|
@ -1,436 +1,436 @@
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import sqlite3
|
||||
from types import NoneType
|
||||
from typing import Any, Callable
|
||||
|
||||
from nonebot import logger
|
||||
from nonebot.compat import PYDANTIC_V2
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LiteModel(BaseModel):
|
||||
TABLE_NAME: str = None
|
||||
id: int = None
|
||||
|
||||
def dump(self, *args, **kwargs):
|
||||
if PYDANTIC_V2:
|
||||
return self.model_dump(*args, **kwargs)
|
||||
else:
|
||||
return self.dict(*args, **kwargs)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, db_name: str):
|
||||
|
||||
if os.path.dirname(db_name) != "" and not os.path.exists(os.path.dirname(db_name)):
|
||||
os.makedirs(os.path.dirname(db_name))
|
||||
|
||||
self.db_name = db_name
|
||||
self.conn = sqlite3.connect(db_name, check_same_thread=False)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
self._on_save_callbacks = []
|
||||
self._is_locked = False
|
||||
|
||||
def lock(self):
|
||||
self.cursor.execute("BEGIN TRANSACTION")
|
||||
self._is_locked = True
|
||||
|
||||
def lock_query(self, query: str, *args):
|
||||
"""锁定查询"""
|
||||
self.cursor.execute(query, args).fetchall()
|
||||
|
||||
def lock_model(self, model: LiteModel) -> LiteModel | Any | None:
|
||||
"""锁定行
|
||||
Args:
|
||||
model: 数据模型实例
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
def unlock(self):
|
||||
self.cursor.execute("COMMIT")
|
||||
self._is_locked = False
|
||||
|
||||
def where_one(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> LiteModel | Any | None:
|
||||
"""查询第一个
|
||||
Args:
|
||||
model: 数据模型实例
|
||||
condition: 查询条件,不给定则查询所有
|
||||
*args: 参数化查询参数
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
all_results = self.where_all(model, condition, *args)
|
||||
return all_results[0] if all_results else default
|
||||
|
||||
def where_all(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> list[LiteModel | Any] | None:
|
||||
"""查询所有
|
||||
Args:
|
||||
model: 数据模型实例
|
||||
condition: 查询条件,不给定则查询所有
|
||||
*args: 参数化查询参数
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
table_name = model.TABLE_NAME
|
||||
model_type = type(model)
|
||||
logger.debug(f"Selecting {model.TABLE_NAME} WHERE {condition.replace('?', '%s') % args}")
|
||||
if not table_name:
|
||||
raise ValueError(f"数据模型{model_type.__name__}未提供表名")
|
||||
|
||||
# condition = f"WHERE {condition}"
|
||||
# print(f"SELECT * FROM {table_name} {condition}", args)
|
||||
# if len(args) == 0:
|
||||
# results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}").fetchall()
|
||||
# else:
|
||||
# results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}", args).fetchall()
|
||||
if condition:
|
||||
results = self.cursor.execute(f"SELECT * FROM {table_name} WHERE {condition}", args).fetchall()
|
||||
else:
|
||||
results = self.cursor.execute(f"SELECT * FROM {table_name}").fetchall()
|
||||
fields = [description[0] for description in self.cursor.description]
|
||||
if not results:
|
||||
return default
|
||||
else:
|
||||
return [model_type(**self._load(dict(zip(fields, result)))) for result in results]
|
||||
|
||||
def save(self, *args: LiteModel):
|
||||
self.returns_ = """增/改操作
|
||||
Args:
|
||||
*args:
|
||||
Returns:
|
||||
"""
|
||||
table_list = [item[0] for item in self.cursor.execute("SELECT name FROM sqlite_master WHERE type ='table'").fetchall()]
|
||||
for model in args:
|
||||
logger.debug(f"Upserting {model}")
|
||||
if not model.TABLE_NAME:
|
||||
raise ValueError(f"数据模型 {model.__class__.__name__} 未提供表名")
|
||||
elif model.TABLE_NAME not in table_list:
|
||||
raise ValueError(f"数据模型 {model.__class__.__name__} 表 {model.TABLE_NAME} 不存在,请先迁移")
|
||||
else:
|
||||
self._save(model.dump(by_alias=True))
|
||||
|
||||
for callback in self._on_save_callbacks:
|
||||
callback(model)
|
||||
|
||||
def _save(self, obj: Any) -> Any:
|
||||
# obj = copy.deepcopy(obj)
|
||||
if isinstance(obj, dict):
|
||||
table_name = obj.get("TABLE_NAME")
|
||||
row_id = obj.get("id")
|
||||
new_obj = {}
|
||||
for field, value in obj.items():
|
||||
if isinstance(value, self.ITERABLE_TYPE):
|
||||
new_obj[self._get_stored_field_prefix(value) + field] = self._save(value) # self._save(value) # -> bytes
|
||||
elif isinstance(value, self.BASIC_TYPE):
|
||||
new_obj[field] = value
|
||||
else:
|
||||
raise ValueError(f"数据模型{table_name}包含不支持的数据类型,字段:{field} 值:{value} 值类型:{type(value)}")
|
||||
if table_name:
|
||||
fields, values = [], []
|
||||
for n_field, n_value in new_obj.items():
|
||||
if n_field not in ["TABLE_NAME", "id"]:
|
||||
fields.append(n_field)
|
||||
values.append(n_value)
|
||||
# 移除TABLE_NAME和id
|
||||
fields = list(fields)
|
||||
values = list(values)
|
||||
if row_id is not None:
|
||||
# 如果 _id 不为空,将 'id' 插入到字段列表的开始
|
||||
fields.insert(0, 'id')
|
||||
# 将 _id 插入到值列表的开始
|
||||
values.insert(0, row_id)
|
||||
fields = ', '.join([f'"{field}"' for field in fields])
|
||||
placeholders = ', '.join('?' for _ in values)
|
||||
self.cursor.execute(f"INSERT OR REPLACE INTO {table_name}({fields}) VALUES ({placeholders})", tuple(values))
|
||||
# self.conn.commit()
|
||||
if self._is_locked:
|
||||
pass
|
||||
else:
|
||||
self.conn.commit()
|
||||
foreign_id = self.cursor.execute("SELECT last_insert_rowid()").fetchone()[0]
|
||||
return f"{self.FOREIGN_KEY_PREFIX}{foreign_id}@{table_name}" # -> FOREIGN_KEY_123456@{table_name} id@{table_name}
|
||||
else:
|
||||
return pickle.dumps(new_obj) # -> bytes
|
||||
elif isinstance(obj, (list, set, tuple)):
|
||||
obj_type = type(obj) # 到时候转回去
|
||||
new_obj = []
|
||||
for item in obj:
|
||||
if isinstance(item, self.ITERABLE_TYPE):
|
||||
new_obj.append(self._save(item))
|
||||
elif isinstance(item, self.BASIC_TYPE):
|
||||
new_obj.append(item)
|
||||
else:
|
||||
raise ValueError(f"数据模型包含不支持的数据类型,值:{item} 值类型:{type(item)}")
|
||||
return pickle.dumps(obj_type(new_obj)) # -> bytes
|
||||
else:
|
||||
raise ValueError(f"数据模型包含不支持的数据类型,值:{obj} 值类型:{type(obj)}")
|
||||
|
||||
def _load(self, obj: Any) -> Any:
|
||||
|
||||
if isinstance(obj, dict):
|
||||
|
||||
new_obj = {}
|
||||
|
||||
for field, value in obj.items():
|
||||
|
||||
field: str
|
||||
|
||||
if field.startswith(self.BYTES_PREFIX):
|
||||
if isinstance(value, bytes):
|
||||
new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value))
|
||||
else: # 从value字段可能为None,fix at 2024/6/13
|
||||
pass
|
||||
# 暂时不作处理,后面再修
|
||||
|
||||
elif field.startswith(self.FOREIGN_KEY_PREFIX):
|
||||
|
||||
new_obj[field.replace(self.FOREIGN_KEY_PREFIX, "")] = self._load(self._get_foreign_data(value))
|
||||
|
||||
else:
|
||||
new_obj[field] = value
|
||||
return new_obj
|
||||
elif isinstance(obj, (list, set, tuple)):
|
||||
|
||||
new_obj = []
|
||||
for item in obj:
|
||||
|
||||
if isinstance(item, bytes):
|
||||
|
||||
# 对bytes进行尝试解析,解析失败则返回原始bytes
|
||||
try:
|
||||
new_obj.append(self._load(pickle.loads(item)))
|
||||
except Exception as e:
|
||||
new_obj.append(self._load(item))
|
||||
|
||||
elif isinstance(item, str) and item.startswith(self.FOREIGN_KEY_PREFIX):
|
||||
new_obj.append(self._load(self._get_foreign_data(item)))
|
||||
else:
|
||||
new_obj.append(self._load(item))
|
||||
return new_obj
|
||||
else:
|
||||
return obj
|
||||
|
||||
def delete(self, model: LiteModel, condition: str, *args: Any, allow_empty: bool = False):
|
||||
"""
|
||||
删除满足条件的数据
|
||||
Args:
|
||||
allow_empty: 允许空条件删除整个表
|
||||
model:
|
||||
condition:
|
||||
*args:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
table_name = model.TABLE_NAME
|
||||
logger.debug(f"Deleting {model} WHERE {condition} {args}")
|
||||
if not table_name:
|
||||
raise ValueError(f"数据模型{model.__class__.__name__}未提供表名")
|
||||
if model.id is not None:
|
||||
condition = f"id = {model.id}"
|
||||
if not condition and not allow_empty:
|
||||
raise ValueError("删除操作必须提供条件")
|
||||
self.cursor.execute(f"DELETE FROM {table_name} WHERE {condition}", args)
|
||||
if self._is_locked:
|
||||
pass
|
||||
else:
|
||||
self.conn.commit()
|
||||
|
||||
def auto_migrate(self, *args: LiteModel):
|
||||
|
||||
"""
|
||||
自动迁移模型
|
||||
Args:
|
||||
*args: 模型类实例化对象,支持空默认值,不支持嵌套迁移
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
for model in args:
|
||||
if not model.TABLE_NAME:
|
||||
raise ValueError(f"数据模型{type(model).__name__}未提供表名")
|
||||
|
||||
# 若无则创建表
|
||||
self.cursor.execute(
|
||||
f'CREATE TABLE IF NOT EXISTS "{model.TABLE_NAME}" (id INTEGER PRIMARY KEY AUTOINCREMENT)'
|
||||
)
|
||||
|
||||
# 获取表结构,field -> SqliteType
|
||||
new_structure = {}
|
||||
for n_field, n_value in model.dump(by_alias=True).items():
|
||||
if n_field not in ["TABLE_NAME", "id"]:
|
||||
new_structure[self._get_stored_field_prefix(n_value) + n_field] = self._get_stored_type(n_value)
|
||||
|
||||
# 原有的字段列表
|
||||
existing_structure = dict([(column[1], column[2]) for column in self.cursor.execute(f'PRAGMA table_info({model.TABLE_NAME})').fetchall()])
|
||||
# 检测缺失字段,由于SQLite是动态类型,所以不需要检测类型
|
||||
for n_field, n_type in new_structure.items():
|
||||
if n_field not in existing_structure.keys() and n_field.lower() not in ["id", "table_name"]:
|
||||
default_value = self.DEFAULT_MAPPING.get(n_type, 'NULL')
|
||||
self.cursor.execute(
|
||||
f"ALTER TABLE '{model.TABLE_NAME}' ADD COLUMN {n_field} {n_type} DEFAULT {self.DEFAULT_MAPPING.get(n_type, default_value)}"
|
||||
)
|
||||
|
||||
# 检测多余字段进行删除
|
||||
for e_field in existing_structure.keys():
|
||||
if e_field not in new_structure.keys() and e_field.lower() not in ['id']:
|
||||
self.cursor.execute(
|
||||
f'ALTER TABLE "{model.TABLE_NAME}" DROP COLUMN "{e_field}"'
|
||||
)
|
||||
self.conn.commit()
|
||||
# 已完成
|
||||
|
||||
def _get_stored_field_prefix(self, value) -> str:
|
||||
"""根据类型获取存储字段前缀,一定在后加上字段名
|
||||
* -> ""
|
||||
Args:
|
||||
value: 储存的值
|
||||
|
||||
Returns:
|
||||
Sqlite3存储字段
|
||||
"""
|
||||
|
||||
if isinstance(value, LiteModel) or isinstance(value, dict) and "TABLE_NAME" in value:
|
||||
return self.FOREIGN_KEY_PREFIX
|
||||
elif type(value) in self.ITERABLE_TYPE:
|
||||
return self.BYTES_PREFIX
|
||||
return ""
|
||||
|
||||
def _get_stored_type(self, value) -> str:
|
||||
"""获取存储类型
|
||||
|
||||
Args:
|
||||
value: 储存的值
|
||||
|
||||
Returns:
|
||||
Sqlite3存储类型
|
||||
"""
|
||||
if isinstance(value, dict) and "TABLE_NAME" in value:
|
||||
# 是一个模型字典,储存外键
|
||||
return "INTEGER"
|
||||
return self.TYPE_MAPPING.get(type(value), "TEXT")
|
||||
|
||||
def _get_foreign_data(self, foreign_value: str) -> dict:
|
||||
"""
|
||||
获取外键数据
|
||||
Args:
|
||||
foreign_value:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
foreign_value = foreign_value.replace(self.FOREIGN_KEY_PREFIX, "")
|
||||
table_name = foreign_value.split("@")[-1]
|
||||
foreign_id = foreign_value.split("@")[0]
|
||||
fields = [description[1] for description in self.cursor.execute(f"PRAGMA table_info({table_name})").fetchall()]
|
||||
result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone()
|
||||
return dict(zip(fields, result))
|
||||
|
||||
def on_save(self, func: Callable[[LiteModel | Any], None]):
|
||||
"""
|
||||
装饰一个可调用对象使其在储存数据模型时被调用
|
||||
Args:
|
||||
func:
|
||||
Returns:
|
||||
"""
|
||||
|
||||
def wrapper(model):
|
||||
# 检查被装饰函数声明的model类型和传入的model类型是否一致
|
||||
sign = inspect.signature(func)
|
||||
if param := sign.parameters.get("model"):
|
||||
if isinstance(model, param.annotation):
|
||||
pass
|
||||
else:
|
||||
return
|
||||
else:
|
||||
return
|
||||
result = func(model)
|
||||
for callback in self._on_save_callbacks:
|
||||
callback(result)
|
||||
return result
|
||||
|
||||
self._on_save_callbacks.append(wrapper)
|
||||
return wrapper
|
||||
|
||||
TYPE_MAPPING = {
|
||||
int : "INTEGER",
|
||||
float : "REAL",
|
||||
str : "TEXT",
|
||||
bool : "INTEGER",
|
||||
bytes : "BLOB",
|
||||
NoneType : "NULL",
|
||||
# dict : "TEXT",
|
||||
# list : "TEXT",
|
||||
# tuple : "TEXT",
|
||||
# set : "TEXT",
|
||||
|
||||
dict : "BLOB", # LITEYUKIDICT{key_name}
|
||||
list : "BLOB", # LITEYUKILIST{key_name}
|
||||
tuple : "BLOB", # LITEYUKITUPLE{key_name}
|
||||
set : "BLOB", # LITEYUKISET{key_name}
|
||||
LiteModel: "TEXT" # FOREIGN_KEY_{table_name}
|
||||
}
|
||||
DEFAULT_MAPPING = {
|
||||
"TEXT" : "''",
|
||||
"INTEGER": 0,
|
||||
"REAL" : 0.0,
|
||||
"BLOB" : None,
|
||||
"NULL" : None
|
||||
}
|
||||
|
||||
# 基础类型
|
||||
BASIC_TYPE = (int, float, str, bool, bytes, NoneType)
|
||||
# 可序列化类型
|
||||
ITERABLE_TYPE = (dict, list, tuple, set, LiteModel)
|
||||
|
||||
# 外键前缀
|
||||
FOREIGN_KEY_PREFIX = "FOREIGN_KEY_"
|
||||
# 转换为的字节前缀
|
||||
BYTES_PREFIX = "PICKLE_BYTES_"
|
||||
|
||||
# transaction tx 事务操作
|
||||
def first(self, model: LiteModel) -> "Database":
|
||||
pass
|
||||
|
||||
def where(self, condition: str, *args) -> "Database":
|
||||
pass
|
||||
|
||||
def limit(self, limit: int) -> "Database":
|
||||
pass
|
||||
|
||||
def order(self, order: str) -> "Database":
|
||||
pass
|
||||
|
||||
|
||||
def check_sqlite_keyword(name):
|
||||
sqlite_keywords = [
|
||||
"ABORT", "ACTION", "ADD", "AFTER", "ALL", "ALTER", "ANALYZE", "AND", "AS", "ASC",
|
||||
"ATTACH", "AUTOINCREMENT", "BEFORE", "BEGIN", "BETWEEN", "BY", "CASCADE", "CASE",
|
||||
"CAST", "CHECK", "COLLATE", "COLUMN", "COMMIT", "CONFLICT", "CONSTRAINT", "CREATE",
|
||||
"CROSS", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP", "DATABASE", "DEFAULT",
|
||||
"DEFERRABLE", "DEFERRED", "DELETE", "DESC", "DETACH", "DISTINCT", "DROP", "EACH",
|
||||
"ELSE", "END", "ESCAPE", "EXCEPT", "EXCLUSIVE", "EXISTS", "EXPLAIN", "FAIL", "FOR",
|
||||
"FOREIGN", "FROM", "FULL", "GLOB", "GROUP", "HAVING", "IF", "IGNORE", "IMMEDIATE",
|
||||
"IN", "INDEX", "INDEXED", "INITIALLY", "INNER", "INSERT", "INSTEAD", "INTERSECT",
|
||||
"INTO", "IS", "ISNULL", "JOIN", "KEY", "LEFT", "LIKE", "LIMIT", "MATCH", "NATURAL",
|
||||
"NO", "NOT", "NOTNULL", "NULL", "OF", "OFFSET", "ON", "OR", "ORDER", "OUTER", "PLAN",
|
||||
"PRAGMA", "PRIMARY", "QUERY", "RAISE", "RECURSIVE", "REFERENCES", "REGEXP", "REINDEX",
|
||||
"RELEASE", "RENAME", "REPLACE", "RESTRICT", "RIGHT", "ROLLBACK", "ROW", "SAVEPOINT",
|
||||
"SELECT", "SET", "TABLE", "TEMP", "TEMPORARY", "THEN", "TO", "TRANSACTION", "TRIGGER",
|
||||
"UNION", "UNIQUE", "UPDATE", "USING", "VACUUM", "VALUES", "VIEW", "VIRTUAL", "WHEN",
|
||||
"WHERE", "WITH", "WITHOUT"
|
||||
]
|
||||
return True
|
||||
# if name.upper() in sqlite_keywords:
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import sqlite3
|
||||
from types import NoneType
|
||||
from typing import Any, Callable
|
||||
|
||||
from nonebot import logger
|
||||
from nonebot.compat import PYDANTIC_V2
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LiteModel(BaseModel):
|
||||
TABLE_NAME: str = None
|
||||
id: int = None
|
||||
|
||||
def dump(self, *args, **kwargs):
|
||||
if PYDANTIC_V2:
|
||||
return self.model_dump(*args, **kwargs)
|
||||
else:
|
||||
return self.dict(*args, **kwargs)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, db_name: str):
|
||||
|
||||
if os.path.dirname(db_name) != "" and not os.path.exists(os.path.dirname(db_name)):
|
||||
os.makedirs(os.path.dirname(db_name))
|
||||
|
||||
self.db_name = db_name
|
||||
self.conn = sqlite3.connect(db_name, check_same_thread=False)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
self._on_save_callbacks = []
|
||||
self._is_locked = False
|
||||
|
||||
def lock(self):
|
||||
self.cursor.execute("BEGIN TRANSACTION")
|
||||
self._is_locked = True
|
||||
|
||||
def lock_query(self, query: str, *args):
|
||||
"""锁定查询"""
|
||||
self.cursor.execute(query, args).fetchall()
|
||||
|
||||
def lock_model(self, model: LiteModel) -> LiteModel | Any | None:
|
||||
"""锁定行
|
||||
Args:
|
||||
model: 数据模型实例
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
def unlock(self):
|
||||
self.cursor.execute("COMMIT")
|
||||
self._is_locked = False
|
||||
|
||||
def where_one(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> LiteModel | Any | None:
|
||||
"""查询第一个
|
||||
Args:
|
||||
model: 数据模型实例
|
||||
condition: 查询条件,不给定则查询所有
|
||||
*args: 参数化查询参数
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
all_results = self.where_all(model, condition, *args)
|
||||
return all_results[0] if all_results else default
|
||||
|
||||
def where_all(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> list[LiteModel | Any] | None:
|
||||
"""查询所有
|
||||
Args:
|
||||
model: 数据模型实例
|
||||
condition: 查询条件,不给定则查询所有
|
||||
*args: 参数化查询参数
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
table_name = model.TABLE_NAME
|
||||
model_type = type(model)
|
||||
logger.debug(f"Selecting {model.TABLE_NAME} WHERE {condition.replace('?', '%s') % args}")
|
||||
if not table_name:
|
||||
raise ValueError(f"数据模型{model_type.__name__}未提供表名")
|
||||
|
||||
# condition = f"WHERE {condition}"
|
||||
# print(f"SELECT * FROM {table_name} {condition}", args)
|
||||
# if len(args) == 0:
|
||||
# results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}").fetchall()
|
||||
# else:
|
||||
# results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}", args).fetchall()
|
||||
if condition:
|
||||
results = self.cursor.execute(f"SELECT * FROM {table_name} WHERE {condition}", args).fetchall()
|
||||
else:
|
||||
results = self.cursor.execute(f"SELECT * FROM {table_name}").fetchall()
|
||||
fields = [description[0] for description in self.cursor.description]
|
||||
if not results:
|
||||
return default
|
||||
else:
|
||||
return [model_type(**self._load(dict(zip(fields, result)))) for result in results]
|
||||
|
||||
def save(self, *args: LiteModel):
|
||||
self.returns_ = """增/改操作
|
||||
Args:
|
||||
*args:
|
||||
Returns:
|
||||
"""
|
||||
table_list = [item[0] for item in self.cursor.execute("SELECT name FROM sqlite_master WHERE type ='table'").fetchall()]
|
||||
for model in args:
|
||||
logger.debug(f"Upserting {model}")
|
||||
if not model.TABLE_NAME:
|
||||
raise ValueError(f"数据模型 {model.__class__.__name__} 未提供表名")
|
||||
elif model.TABLE_NAME not in table_list:
|
||||
raise ValueError(f"数据模型 {model.__class__.__name__} 表 {model.TABLE_NAME} 不存在,请先迁移")
|
||||
else:
|
||||
self._save(model.dump(by_alias=True))
|
||||
|
||||
for callback in self._on_save_callbacks:
|
||||
callback(model)
|
||||
|
||||
def _save(self, obj: Any) -> Any:
|
||||
# obj = copy.deepcopy(obj)
|
||||
if isinstance(obj, dict):
|
||||
table_name = obj.get("TABLE_NAME")
|
||||
row_id = obj.get("id")
|
||||
new_obj = {}
|
||||
for field, value in obj.items():
|
||||
if isinstance(value, self.ITERABLE_TYPE):
|
||||
new_obj[self._get_stored_field_prefix(value) + field] = self._save(value) # self._save(value) # -> bytes
|
||||
elif isinstance(value, self.BASIC_TYPE):
|
||||
new_obj[field] = value
|
||||
else:
|
||||
raise ValueError(f"数据模型{table_name}包含不支持的数据类型,字段:{field} 值:{value} 值类型:{type(value)}")
|
||||
if table_name:
|
||||
fields, values = [], []
|
||||
for n_field, n_value in new_obj.items():
|
||||
if n_field not in ["TABLE_NAME", "id"]:
|
||||
fields.append(n_field)
|
||||
values.append(n_value)
|
||||
# 移除TABLE_NAME和id
|
||||
fields = list(fields)
|
||||
values = list(values)
|
||||
if row_id is not None:
|
||||
# 如果 _id 不为空,将 'id' 插入到字段列表的开始
|
||||
fields.insert(0, 'id')
|
||||
# 将 _id 插入到值列表的开始
|
||||
values.insert(0, row_id)
|
||||
fields = ', '.join([f'"{field}"' for field in fields])
|
||||
placeholders = ', '.join('?' for _ in values)
|
||||
self.cursor.execute(f"INSERT OR REPLACE INTO {table_name}({fields}) VALUES ({placeholders})", tuple(values))
|
||||
# self.conn.commit()
|
||||
if self._is_locked:
|
||||
pass
|
||||
else:
|
||||
self.conn.commit()
|
||||
foreign_id = self.cursor.execute("SELECT last_insert_rowid()").fetchone()[0]
|
||||
return f"{self.FOREIGN_KEY_PREFIX}{foreign_id}@{table_name}" # -> FOREIGN_KEY_123456@{table_name} id@{table_name}
|
||||
else:
|
||||
return pickle.dumps(new_obj) # -> bytes
|
||||
elif isinstance(obj, (list, set, tuple)):
|
||||
obj_type = type(obj) # 到时候转回去
|
||||
new_obj = []
|
||||
for item in obj:
|
||||
if isinstance(item, self.ITERABLE_TYPE):
|
||||
new_obj.append(self._save(item))
|
||||
elif isinstance(item, self.BASIC_TYPE):
|
||||
new_obj.append(item)
|
||||
else:
|
||||
raise ValueError(f"数据模型包含不支持的数据类型,值:{item} 值类型:{type(item)}")
|
||||
return pickle.dumps(obj_type(new_obj)) # -> bytes
|
||||
else:
|
||||
raise ValueError(f"数据模型包含不支持的数据类型,值:{obj} 值类型:{type(obj)}")
|
||||
|
||||
def _load(self, obj: Any) -> Any:
|
||||
|
||||
if isinstance(obj, dict):
|
||||
|
||||
new_obj = {}
|
||||
|
||||
for field, value in obj.items():
|
||||
|
||||
field: str
|
||||
|
||||
if field.startswith(self.BYTES_PREFIX):
|
||||
if isinstance(value, bytes):
|
||||
new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value))
|
||||
else: # 从value字段可能为None,fix at 2024/6/13
|
||||
pass
|
||||
# 暂时不作处理,后面再修
|
||||
|
||||
elif field.startswith(self.FOREIGN_KEY_PREFIX):
|
||||
|
||||
new_obj[field.replace(self.FOREIGN_KEY_PREFIX, "")] = self._load(self._get_foreign_data(value))
|
||||
|
||||
else:
|
||||
new_obj[field] = value
|
||||
return new_obj
|
||||
elif isinstance(obj, (list, set, tuple)):
|
||||
|
||||
new_obj = []
|
||||
for item in obj:
|
||||
|
||||
if isinstance(item, bytes):
|
||||
|
||||
# 对bytes进行尝试解析,解析失败则返回原始bytes
|
||||
try:
|
||||
new_obj.append(self._load(pickle.loads(item)))
|
||||
except Exception as e:
|
||||
new_obj.append(self._load(item))
|
||||
|
||||
elif isinstance(item, str) and item.startswith(self.FOREIGN_KEY_PREFIX):
|
||||
new_obj.append(self._load(self._get_foreign_data(item)))
|
||||
else:
|
||||
new_obj.append(self._load(item))
|
||||
return new_obj
|
||||
else:
|
||||
return obj
|
||||
|
||||
def delete(self, model: LiteModel, condition: str, *args: Any, allow_empty: bool = False):
|
||||
"""
|
||||
删除满足条件的数据
|
||||
Args:
|
||||
allow_empty: 允许空条件删除整个表
|
||||
model:
|
||||
condition:
|
||||
*args:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
table_name = model.TABLE_NAME
|
||||
logger.debug(f"Deleting {model} WHERE {condition} {args}")
|
||||
if not table_name:
|
||||
raise ValueError(f"数据模型{model.__class__.__name__}未提供表名")
|
||||
if model.id is not None:
|
||||
condition = f"id = {model.id}"
|
||||
if not condition and not allow_empty:
|
||||
raise ValueError("删除操作必须提供条件")
|
||||
self.cursor.execute(f"DELETE FROM {table_name} WHERE {condition}", args)
|
||||
if self._is_locked:
|
||||
pass
|
||||
else:
|
||||
self.conn.commit()
|
||||
|
||||
def auto_migrate(self, *args: LiteModel):
|
||||
|
||||
"""
|
||||
自动迁移模型
|
||||
Args:
|
||||
*args: 模型类实例化对象,支持空默认值,不支持嵌套迁移
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
for model in args:
|
||||
if not model.TABLE_NAME:
|
||||
raise ValueError(f"数据模型{type(model).__name__}未提供表名")
|
||||
|
||||
# 若无则创建表
|
||||
self.cursor.execute(
|
||||
f'CREATE TABLE IF NOT EXISTS "{model.TABLE_NAME}" (id INTEGER PRIMARY KEY AUTOINCREMENT)'
|
||||
)
|
||||
|
||||
# 获取表结构,field -> SqliteType
|
||||
new_structure = {}
|
||||
for n_field, n_value in model.dump(by_alias=True).items():
|
||||
if n_field not in ["TABLE_NAME", "id"]:
|
||||
new_structure[self._get_stored_field_prefix(n_value) + n_field] = self._get_stored_type(n_value)
|
||||
|
||||
# 原有的字段列表
|
||||
existing_structure = dict([(column[1], column[2]) for column in self.cursor.execute(f'PRAGMA table_info({model.TABLE_NAME})').fetchall()])
|
||||
# 检测缺失字段,由于SQLite是动态类型,所以不需要检测类型
|
||||
for n_field, n_type in new_structure.items():
|
||||
if n_field not in existing_structure.keys() and n_field.lower() not in ["id", "table_name"]:
|
||||
default_value = self.DEFAULT_MAPPING.get(n_type, 'NULL')
|
||||
self.cursor.execute(
|
||||
f"ALTER TABLE '{model.TABLE_NAME}' ADD COLUMN {n_field} {n_type} DEFAULT {self.DEFAULT_MAPPING.get(n_type, default_value)}"
|
||||
)
|
||||
|
||||
# 检测多余字段进行删除
|
||||
for e_field in existing_structure.keys():
|
||||
if e_field not in new_structure.keys() and e_field.lower() not in ['id']:
|
||||
self.cursor.execute(
|
||||
f'ALTER TABLE "{model.TABLE_NAME}" DROP COLUMN "{e_field}"'
|
||||
)
|
||||
self.conn.commit()
|
||||
# 已完成
|
||||
|
||||
def _get_stored_field_prefix(self, value) -> str:
|
||||
"""根据类型获取存储字段前缀,一定在后加上字段名
|
||||
* -> ""
|
||||
Args:
|
||||
value: 储存的值
|
||||
|
||||
Returns:
|
||||
Sqlite3存储字段
|
||||
"""
|
||||
|
||||
if isinstance(value, LiteModel) or isinstance(value, dict) and "TABLE_NAME" in value:
|
||||
return self.FOREIGN_KEY_PREFIX
|
||||
elif type(value) in self.ITERABLE_TYPE:
|
||||
return self.BYTES_PREFIX
|
||||
return ""
|
||||
|
||||
def _get_stored_type(self, value) -> str:
|
||||
"""获取存储类型
|
||||
|
||||
Args:
|
||||
value: 储存的值
|
||||
|
||||
Returns:
|
||||
Sqlite3存储类型
|
||||
"""
|
||||
if isinstance(value, dict) and "TABLE_NAME" in value:
|
||||
# 是一个模型字典,储存外键
|
||||
return "INTEGER"
|
||||
return self.TYPE_MAPPING.get(type(value), "TEXT")
|
||||
|
||||
def _get_foreign_data(self, foreign_value: str) -> dict:
|
||||
"""
|
||||
获取外键数据
|
||||
Args:
|
||||
foreign_value:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
foreign_value = foreign_value.replace(self.FOREIGN_KEY_PREFIX, "")
|
||||
table_name = foreign_value.split("@")[-1]
|
||||
foreign_id = foreign_value.split("@")[0]
|
||||
fields = [description[1] for description in self.cursor.execute(f"PRAGMA table_info({table_name})").fetchall()]
|
||||
result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone()
|
||||
return dict(zip(fields, result))
|
||||
|
||||
def on_save(self, func: Callable[[LiteModel | Any], None]):
|
||||
"""
|
||||
装饰一个可调用对象使其在储存数据模型时被调用
|
||||
Args:
|
||||
func:
|
||||
Returns:
|
||||
"""
|
||||
|
||||
def wrapper(model):
|
||||
# 检查被装饰函数声明的model类型和传入的model类型是否一致
|
||||
sign = inspect.signature(func)
|
||||
if param := sign.parameters.get("model"):
|
||||
if isinstance(model, param.annotation):
|
||||
pass
|
||||
else:
|
||||
return
|
||||
else:
|
||||
return
|
||||
result = func(model)
|
||||
for callback in self._on_save_callbacks:
|
||||
callback(result)
|
||||
return result
|
||||
|
||||
self._on_save_callbacks.append(wrapper)
|
||||
return wrapper
|
||||
|
||||
TYPE_MAPPING = {
|
||||
int : "INTEGER",
|
||||
float : "REAL",
|
||||
str : "TEXT",
|
||||
bool : "INTEGER",
|
||||
bytes : "BLOB",
|
||||
NoneType : "NULL",
|
||||
# dict : "TEXT",
|
||||
# list : "TEXT",
|
||||
# tuple : "TEXT",
|
||||
# set : "TEXT",
|
||||
|
||||
dict : "BLOB", # LITEYUKIDICT{key_name}
|
||||
list : "BLOB", # LITEYUKILIST{key_name}
|
||||
tuple : "BLOB", # LITEYUKITUPLE{key_name}
|
||||
set : "BLOB", # LITEYUKISET{key_name}
|
||||
LiteModel: "TEXT" # FOREIGN_KEY_{table_name}
|
||||
}
|
||||
DEFAULT_MAPPING = {
|
||||
"TEXT" : "''",
|
||||
"INTEGER": 0,
|
||||
"REAL" : 0.0,
|
||||
"BLOB" : None,
|
||||
"NULL" : None
|
||||
}
|
||||
|
||||
# 基础类型
|
||||
BASIC_TYPE = (int, float, str, bool, bytes, NoneType)
|
||||
# 可序列化类型
|
||||
ITERABLE_TYPE = (dict, list, tuple, set, LiteModel)
|
||||
|
||||
# 外键前缀
|
||||
FOREIGN_KEY_PREFIX = "FOREIGN_KEY_"
|
||||
# 转换为的字节前缀
|
||||
BYTES_PREFIX = "PICKLE_BYTES_"
|
||||
|
||||
# transaction tx 事务操作
|
||||
def first(self, model: LiteModel) -> "Database":
|
||||
pass
|
||||
|
||||
def where(self, condition: str, *args) -> "Database":
|
||||
pass
|
||||
|
||||
def limit(self, limit: int) -> "Database":
|
||||
pass
|
||||
|
||||
def order(self, order: str) -> "Database":
|
||||
pass
|
||||
|
||||
|
||||
def check_sqlite_keyword(name):
|
||||
sqlite_keywords = [
|
||||
"ABORT", "ACTION", "ADD", "AFTER", "ALL", "ALTER", "ANALYZE", "AND", "AS", "ASC",
|
||||
"ATTACH", "AUTOINCREMENT", "BEFORE", "BEGIN", "BETWEEN", "BY", "CASCADE", "CASE",
|
||||
"CAST", "CHECK", "COLLATE", "COLUMN", "COMMIT", "CONFLICT", "CONSTRAINT", "CREATE",
|
||||
"CROSS", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP", "DATABASE", "DEFAULT",
|
||||
"DEFERRABLE", "DEFERRED", "DELETE", "DESC", "DETACH", "DISTINCT", "DROP", "EACH",
|
||||
"ELSE", "END", "ESCAPE", "EXCEPT", "EXCLUSIVE", "EXISTS", "EXPLAIN", "FAIL", "FOR",
|
||||
"FOREIGN", "FROM", "FULL", "GLOB", "GROUP", "HAVING", "IF", "IGNORE", "IMMEDIATE",
|
||||
"IN", "INDEX", "INDEXED", "INITIALLY", "INNER", "INSERT", "INSTEAD", "INTERSECT",
|
||||
"INTO", "IS", "ISNULL", "JOIN", "KEY", "LEFT", "LIKE", "LIMIT", "MATCH", "NATURAL",
|
||||
"NO", "NOT", "NOTNULL", "NULL", "OF", "OFFSET", "ON", "OR", "ORDER", "OUTER", "PLAN",
|
||||
"PRAGMA", "PRIMARY", "QUERY", "RAISE", "RECURSIVE", "REFERENCES", "REGEXP", "REINDEX",
|
||||
"RELEASE", "RENAME", "REPLACE", "RESTRICT", "RIGHT", "ROLLBACK", "ROW", "SAVEPOINT",
|
||||
"SELECT", "SET", "TABLE", "TEMP", "TEMPORARY", "THEN", "TO", "TRANSACTION", "TRIGGER",
|
||||
"UNION", "UNIQUE", "UPDATE", "USING", "VACUUM", "VALUES", "VIEW", "VIRTUAL", "WHEN",
|
||||
"WHERE", "WITH", "WITHOUT"
|
||||
]
|
||||
return True
|
||||
# if name.upper() in sqlite_keywords:
|
||||
# raise ValueError(f"'{name}' 是SQLite保留字,不建议使用,请更换名称")
|
@ -1,99 +1,99 @@
|
||||
import os
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .data import Database, LiteModel
|
||||
|
||||
DATA_PATH = "data/liteyuki"
|
||||
user_db: Database = Database(os.path.join(DATA_PATH, "users.ldb"))
|
||||
group_db: Database = Database(os.path.join(DATA_PATH, "groups.ldb"))
|
||||
plugin_db: Database = Database(os.path.join(DATA_PATH, "plugins.ldb"))
|
||||
common_db: Database = Database(os.path.join(DATA_PATH, "common.ldb"))
|
||||
|
||||
# 内存数据库,临时用于存储数据
|
||||
memory_database = {
|
||||
|
||||
}
|
||||
|
||||
|
||||
class User(LiteModel):
|
||||
TABLE_NAME: str = "user"
|
||||
user_id: str = Field(str(), alias="user_id")
|
||||
username: str = Field(str(), alias="username")
|
||||
profile: dict[str, str] = Field(dict(), alias="profile")
|
||||
enabled_plugins: list[str] = Field(list(), alias="enabled_plugins")
|
||||
disabled_plugins: list[str] = Field(list(), alias="disabled_plugins")
|
||||
|
||||
|
||||
class Group(LiteModel):
|
||||
TABLE_NAME: str = "group_chat"
|
||||
# Group是一个关键字,所以这里用GroupChat
|
||||
group_id: str = Field(str(), alias="group_id")
|
||||
group_name: str = Field(str(), alias="group_name")
|
||||
enabled_plugins: list[str] = Field([], alias="enabled_plugins")
|
||||
disabled_plugins: list[str] = Field([], alias="disabled_plugins")
|
||||
enable: bool = Field(True, alias="enable") # 群聊全局机器人是否启用
|
||||
config: dict = Field({}, alias="config")
|
||||
|
||||
|
||||
class InstalledPlugin(LiteModel):
|
||||
TABLE_NAME: str = "installed_plugin"
|
||||
module_name: str = Field(str(), alias="module_name")
|
||||
version: str = Field(str(), alias="version")
|
||||
|
||||
|
||||
class GlobalPlugin(LiteModel):
|
||||
TABLE_NAME: str = "global_plugin"
|
||||
liteyuki: bool = Field(True, alias="liteyuki") # 是否为LiteYuki插件
|
||||
module_name: str = Field(str(), alias="module_name")
|
||||
enabled: bool = Field(True, alias="enabled")
|
||||
|
||||
|
||||
class StoredConfig(LiteModel):
|
||||
TABLE_NAME: str = "stored_config"
|
||||
config: dict = {}
|
||||
|
||||
|
||||
class TempConfig(LiteModel):
|
||||
"""储存临时键值对的表"""
|
||||
TABLE_NAME: str = "temp_data"
|
||||
data: dict = {}
|
||||
|
||||
|
||||
|
||||
def auto_migrate():
|
||||
user_db.auto_migrate(User())
|
||||
group_db.auto_migrate(Group())
|
||||
plugin_db.auto_migrate(InstalledPlugin(), GlobalPlugin())
|
||||
common_db.auto_migrate(GlobalPlugin(), TempConfig())
|
||||
|
||||
|
||||
auto_migrate()
|
||||
|
||||
|
||||
def set_memory_data(key: str, value) -> None:
|
||||
"""
|
||||
设置内存数据库的数据,类似于redis
|
||||
Args:
|
||||
key:
|
||||
value:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return memory_database.update({
|
||||
key: value
|
||||
})
|
||||
|
||||
|
||||
def get_memory_data(key: str, default=None) -> any:
|
||||
"""
|
||||
获取内存数据库的数据,类似于redis
|
||||
Args:
|
||||
key:
|
||||
default:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return memory_database.get(key, default)
|
||||
import os
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .data import Database, LiteModel
|
||||
|
||||
DATA_PATH = "data/liteyuki"
|
||||
user_db: Database = Database(os.path.join(DATA_PATH, "users.ldb"))
|
||||
group_db: Database = Database(os.path.join(DATA_PATH, "groups.ldb"))
|
||||
plugin_db: Database = Database(os.path.join(DATA_PATH, "plugins.ldb"))
|
||||
common_db: Database = Database(os.path.join(DATA_PATH, "common.ldb"))
|
||||
|
||||
# 内存数据库,临时用于存储数据
|
||||
memory_database = {
|
||||
|
||||
}
|
||||
|
||||
|
||||
class User(LiteModel):
|
||||
TABLE_NAME: str = "user"
|
||||
user_id: str = Field(str(), alias="user_id")
|
||||
username: str = Field(str(), alias="username")
|
||||
profile: dict[str, str] = Field(dict(), alias="profile")
|
||||
enabled_plugins: list[str] = Field(list(), alias="enabled_plugins")
|
||||
disabled_plugins: list[str] = Field(list(), alias="disabled_plugins")
|
||||
|
||||
|
||||
class Group(LiteModel):
|
||||
TABLE_NAME: str = "group_chat"
|
||||
# Group是一个关键字,所以这里用GroupChat
|
||||
group_id: str = Field(str(), alias="group_id")
|
||||
group_name: str = Field(str(), alias="group_name")
|
||||
enabled_plugins: list[str] = Field([], alias="enabled_plugins")
|
||||
disabled_plugins: list[str] = Field([], alias="disabled_plugins")
|
||||
enable: bool = Field(True, alias="enable") # 群聊全局机器人是否启用
|
||||
config: dict = Field({}, alias="config")
|
||||
|
||||
|
||||
class InstalledPlugin(LiteModel):
|
||||
TABLE_NAME: str = "installed_plugin"
|
||||
module_name: str = Field(str(), alias="module_name")
|
||||
version: str = Field(str(), alias="version")
|
||||
|
||||
|
||||
class GlobalPlugin(LiteModel):
|
||||
TABLE_NAME: str = "global_plugin"
|
||||
liteyuki: bool = Field(True, alias="liteyuki") # 是否为LiteYuki插件
|
||||
module_name: str = Field(str(), alias="module_name")
|
||||
enabled: bool = Field(True, alias="enabled")
|
||||
|
||||
|
||||
class StoredConfig(LiteModel):
|
||||
TABLE_NAME: str = "stored_config"
|
||||
config: dict = {}
|
||||
|
||||
|
||||
class TempConfig(LiteModel):
|
||||
"""储存临时键值对的表"""
|
||||
TABLE_NAME: str = "temp_data"
|
||||
data: dict = {}
|
||||
|
||||
|
||||
|
||||
def auto_migrate():
|
||||
user_db.auto_migrate(User())
|
||||
group_db.auto_migrate(Group())
|
||||
plugin_db.auto_migrate(InstalledPlugin(), GlobalPlugin())
|
||||
common_db.auto_migrate(GlobalPlugin(), TempConfig())
|
||||
|
||||
|
||||
auto_migrate()
|
||||
|
||||
|
||||
def set_memory_data(key: str, value) -> None:
|
||||
"""
|
||||
设置内存数据库的数据,类似于redis
|
||||
Args:
|
||||
key:
|
||||
value:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return memory_database.update({
|
||||
key: value
|
||||
})
|
||||
|
||||
|
||||
def get_memory_data(key: str, default=None) -> any:
|
||||
"""
|
||||
获取内存数据库的数据,类似于redis
|
||||
Args:
|
||||
key:
|
||||
default:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return memory_database.get(key, default)
|
||||
|
@ -1,237 +1,237 @@
|
||||
"""
|
||||
语言模块,添加对多语言的支持
|
||||
"""
|
||||
|
||||
import json
|
||||
import locale
|
||||
import os
|
||||
from typing import Any, overload
|
||||
|
||||
import nonebot
|
||||
|
||||
from .config import config, get_config
|
||||
from .data_manager import User, user_db
|
||||
|
||||
_language_data = {
|
||||
"en": {
|
||||
"name": "English",
|
||||
}
|
||||
}
|
||||
|
||||
_user_lang = {"user_id": "zh-CN"}
|
||||
|
||||
|
||||
def load_from_lang(file_path: str, lang_code: str = None):
|
||||
"""
|
||||
从lang文件中加载语言数据,用于简单的文本键值对
|
||||
|
||||
Args:
|
||||
file_path: lang文件路径
|
||||
lang_code: 语言代码,如果为None则从文件名中获取
|
||||
"""
|
||||
try:
|
||||
if lang_code is None:
|
||||
lang_code = os.path.basename(file_path).split(".")[0]
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
data = {}
|
||||
for line in file:
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"): # 空行或注释
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
data[key.strip()] = value.strip()
|
||||
if lang_code not in _language_data:
|
||||
_language_data[lang_code] = {}
|
||||
_language_data[lang_code].update(data)
|
||||
nonebot.logger.debug(f"Loaded language data from {file_path}")
|
||||
except Exception as e:
|
||||
nonebot.logger.error(f"Failed to load language data from {file_path}: {e}")
|
||||
|
||||
|
||||
def load_from_json(file_path: str, lang_code: str = None):
|
||||
"""
|
||||
从json文件中加载语言数据,可以定义一些变量
|
||||
|
||||
Args:
|
||||
lang_code: 语言代码,如果为None则从文件名中获取
|
||||
file_path: json文件路径
|
||||
"""
|
||||
try:
|
||||
if lang_code is None:
|
||||
lang_code = os.path.basename(file_path).split(".")[0]
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
data = json.load(file)
|
||||
if lang_code not in _language_data:
|
||||
_language_data[lang_code] = {}
|
||||
_language_data[lang_code].update(data)
|
||||
nonebot.logger.debug(f"Loaded language data from {file_path}")
|
||||
except Exception as e:
|
||||
nonebot.logger.error(f"Failed to load language data from {file_path}: {e}")
|
||||
|
||||
|
||||
def load_from_dir(dir_path: str):
|
||||
"""
|
||||
从目录中加载语言数据
|
||||
|
||||
Args:
|
||||
dir_path: 目录路径
|
||||
"""
|
||||
for file in os.listdir(dir_path):
|
||||
try:
|
||||
file_path = os.path.join(dir_path, file)
|
||||
if os.path.isfile(file_path):
|
||||
if file.endswith(".lang"):
|
||||
load_from_lang(file_path)
|
||||
elif file.endswith(".json"):
|
||||
load_from_json(file_path)
|
||||
except Exception as e:
|
||||
nonebot.logger.error(f"Failed to load language data from {file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
def load_from_dict(data: dict, lang_code: str):
|
||||
"""
|
||||
从字典中加载语言数据
|
||||
|
||||
Args:
|
||||
lang_code: 语言代码
|
||||
data: 字典数据
|
||||
"""
|
||||
if lang_code not in _language_data:
|
||||
_language_data[lang_code] = {}
|
||||
_language_data[lang_code].update(data)
|
||||
|
||||
|
||||
class Language:
|
||||
# 三重fallback
|
||||
# 用户语言 > 默认语言/系统语言 > zh-CN
|
||||
def __init__(self, lang_code: str = None, fallback_lang_code: str = None):
|
||||
self.lang_code = lang_code
|
||||
|
||||
if self.lang_code is None:
|
||||
self.lang_code = get_default_lang_code()
|
||||
|
||||
self.fallback_lang_code = fallback_lang_code
|
||||
if self.fallback_lang_code is None:
|
||||
self.fallback_lang_code = config.get(
|
||||
"default_language", get_system_lang_code()
|
||||
)
|
||||
|
||||
def _get(self, item: str, *args, **kwargs) -> str | Any:
|
||||
"""
|
||||
获取当前语言文本,kwargs中的default参数为默认文本
|
||||
|
||||
**请不要重写本函数**
|
||||
|
||||
Args:
|
||||
item: 文本键
|
||||
*args: 格式化参数
|
||||
**kwargs: 格式化参数
|
||||
|
||||
Returns:
|
||||
str: 当前语言的文本
|
||||
|
||||
"""
|
||||
default = kwargs.pop("default", None)
|
||||
fallback = (self.lang_code, self.fallback_lang_code, "zh-CN")
|
||||
|
||||
for lang_code in fallback:
|
||||
if lang_code in _language_data and item in _language_data[lang_code]:
|
||||
trans: str = _language_data[lang_code][item]
|
||||
try:
|
||||
return trans.format(*args, **kwargs)
|
||||
except Exception as e:
|
||||
nonebot.logger.warning(f"Failed to format language data: {e}")
|
||||
return trans
|
||||
return default or item
|
||||
|
||||
def get(self, item: str, *args, **kwargs) -> str | Any:
|
||||
"""
|
||||
获取当前语言文本,kwargs中的default参数为默认文本
|
||||
Args:
|
||||
item: 文本键
|
||||
*args: 格式化参数
|
||||
**kwargs: 格式化参数
|
||||
|
||||
Returns:
|
||||
str: 当前语言的文本
|
||||
|
||||
"""
|
||||
return self._get(item, *args, **kwargs)
|
||||
|
||||
def get_many(self, *args: str, **kwargs) -> dict[str, str]:
|
||||
"""
|
||||
获取多个文本
|
||||
Args:
|
||||
*args: 文本键
|
||||
**kwargs: 文本键和默认文本
|
||||
|
||||
Returns:
|
||||
dict: 多个文本
|
||||
"""
|
||||
args_data = {item: self.get(item) for item in args}
|
||||
kwargs_data = {
|
||||
item: self.get(item, default=default) for item, default in kwargs.items()
|
||||
}
|
||||
args_data.update(kwargs_data)
|
||||
return args_data
|
||||
|
||||
|
||||
def change_user_lang(user_id: str, lang_code: str):
|
||||
"""
|
||||
修改用户的语言,同时储存到数据库和内存中
|
||||
"""
|
||||
user = user_db.where_one(
|
||||
User(), "user_id = ?", user_id, default=User(user_id=user_id)
|
||||
)
|
||||
user.profile["lang"] = lang_code
|
||||
user_db.save(user)
|
||||
_user_lang[user_id] = lang_code
|
||||
|
||||
|
||||
def get_user_lang(user_id: str) -> Language:
|
||||
"""
|
||||
获取用户的语言实例,优先从内存中获取
|
||||
"""
|
||||
user_id = str(user_id)
|
||||
|
||||
if user_id not in _user_lang:
|
||||
nonebot.logger.debug(f"Loading user language for {user_id}")
|
||||
user = user_db.where_one(
|
||||
User(),
|
||||
"user_id = ?",
|
||||
user_id,
|
||||
default=User(user_id=user_id, username="Unknown"),
|
||||
)
|
||||
lang_code = user.profile.get("lang", get_default_lang_code())
|
||||
_user_lang[user_id] = lang_code
|
||||
|
||||
return Language(_user_lang[user_id])
|
||||
|
||||
|
||||
def get_system_lang_code() -> str:
|
||||
"""
|
||||
获取系统语言代码
|
||||
"""
|
||||
return locale.getdefaultlocale()[0].replace("_", "-")
|
||||
|
||||
|
||||
def get_default_lang_code() -> str:
|
||||
"""
|
||||
获取默认语言代码,若没有设置则使用系统语言
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return get_config("default_language", default=get_system_lang_code())
|
||||
|
||||
|
||||
def get_all_lang() -> dict[str, str]:
|
||||
"""
|
||||
获取所有语言
|
||||
Returns
|
||||
{'en': 'English'}
|
||||
"""
|
||||
d = {}
|
||||
for key in _language_data:
|
||||
d[key] = _language_data[key].get("language.name", key)
|
||||
return d
|
||||
"""
|
||||
语言模块,添加对多语言的支持
|
||||
"""
|
||||
|
||||
import json
|
||||
import locale
|
||||
import os
|
||||
from typing import Any, overload
|
||||
|
||||
import nonebot
|
||||
|
||||
from .config import config, get_config
|
||||
from .data_manager import User, user_db
|
||||
|
||||
_language_data = {
|
||||
"en": {
|
||||
"name": "English",
|
||||
}
|
||||
}
|
||||
|
||||
_user_lang = {"user_id": "zh-CN"}
|
||||
|
||||
|
||||
def load_from_lang(file_path: str, lang_code: str = None):
|
||||
"""
|
||||
从lang文件中加载语言数据,用于简单的文本键值对
|
||||
|
||||
Args:
|
||||
file_path: lang文件路径
|
||||
lang_code: 语言代码,如果为None则从文件名中获取
|
||||
"""
|
||||
try:
|
||||
if lang_code is None:
|
||||
lang_code = os.path.basename(file_path).split(".")[0]
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
data = {}
|
||||
for line in file:
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"): # 空行或注释
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
data[key.strip()] = value.strip()
|
||||
if lang_code not in _language_data:
|
||||
_language_data[lang_code] = {}
|
||||
_language_data[lang_code].update(data)
|
||||
nonebot.logger.debug(f"Loaded language data from {file_path}")
|
||||
except Exception as e:
|
||||
nonebot.logger.error(f"Failed to load language data from {file_path}: {e}")
|
||||
|
||||
|
||||
def load_from_json(file_path: str, lang_code: str = None):
|
||||
"""
|
||||
从json文件中加载语言数据,可以定义一些变量
|
||||
|
||||
Args:
|
||||
lang_code: 语言代码,如果为None则从文件名中获取
|
||||
file_path: json文件路径
|
||||
"""
|
||||
try:
|
||||
if lang_code is None:
|
||||
lang_code = os.path.basename(file_path).split(".")[0]
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
data = json.load(file)
|
||||
if lang_code not in _language_data:
|
||||
_language_data[lang_code] = {}
|
||||
_language_data[lang_code].update(data)
|
||||
nonebot.logger.debug(f"Loaded language data from {file_path}")
|
||||
except Exception as e:
|
||||
nonebot.logger.error(f"Failed to load language data from {file_path}: {e}")
|
||||
|
||||
|
||||
def load_from_dir(dir_path: str):
|
||||
"""
|
||||
从目录中加载语言数据
|
||||
|
||||
Args:
|
||||
dir_path: 目录路径
|
||||
"""
|
||||
for file in os.listdir(dir_path):
|
||||
try:
|
||||
file_path = os.path.join(dir_path, file)
|
||||
if os.path.isfile(file_path):
|
||||
if file.endswith(".lang"):
|
||||
load_from_lang(file_path)
|
||||
elif file.endswith(".json"):
|
||||
load_from_json(file_path)
|
||||
except Exception as e:
|
||||
nonebot.logger.error(f"Failed to load language data from {file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
def load_from_dict(data: dict, lang_code: str):
|
||||
"""
|
||||
从字典中加载语言数据
|
||||
|
||||
Args:
|
||||
lang_code: 语言代码
|
||||
data: 字典数据
|
||||
"""
|
||||
if lang_code not in _language_data:
|
||||
_language_data[lang_code] = {}
|
||||
_language_data[lang_code].update(data)
|
||||
|
||||
|
||||
class Language:
|
||||
# 三重fallback
|
||||
# 用户语言 > 默认语言/系统语言 > zh-CN
|
||||
def __init__(self, lang_code: str = None, fallback_lang_code: str = None):
|
||||
self.lang_code = lang_code
|
||||
|
||||
if self.lang_code is None:
|
||||
self.lang_code = get_default_lang_code()
|
||||
|
||||
self.fallback_lang_code = fallback_lang_code
|
||||
if self.fallback_lang_code is None:
|
||||
self.fallback_lang_code = config.get(
|
||||
"default_language", get_system_lang_code()
|
||||
)
|
||||
|
||||
def _get(self, item: str, *args, **kwargs) -> str | Any:
|
||||
"""
|
||||
获取当前语言文本,kwargs中的default参数为默认文本
|
||||
|
||||
**请不要重写本函数**
|
||||
|
||||
Args:
|
||||
item: 文本键
|
||||
*args: 格式化参数
|
||||
**kwargs: 格式化参数
|
||||
|
||||
Returns:
|
||||
str: 当前语言的文本
|
||||
|
||||
"""
|
||||
default = kwargs.pop("default", None)
|
||||
fallback = (self.lang_code, self.fallback_lang_code, "zh-CN")
|
||||
|
||||
for lang_code in fallback:
|
||||
if lang_code in _language_data and item in _language_data[lang_code]:
|
||||
trans: str = _language_data[lang_code][item]
|
||||
try:
|
||||
return trans.format(*args, **kwargs)
|
||||
except Exception as e:
|
||||
nonebot.logger.warning(f"Failed to format language data: {e}")
|
||||
return trans
|
||||
return default or item
|
||||
|
||||
def get(self, item: str, *args, **kwargs) -> str | Any:
|
||||
"""
|
||||
获取当前语言文本,kwargs中的default参数为默认文本
|
||||
Args:
|
||||
item: 文本键
|
||||
*args: 格式化参数
|
||||
**kwargs: 格式化参数
|
||||
|
||||
Returns:
|
||||
str: 当前语言的文本
|
||||
|
||||
"""
|
||||
return self._get(item, *args, **kwargs)
|
||||
|
||||
def get_many(self, *args: str, **kwargs) -> dict[str, str]:
|
||||
"""
|
||||
获取多个文本
|
||||
Args:
|
||||
*args: 文本键
|
||||
**kwargs: 文本键和默认文本
|
||||
|
||||
Returns:
|
||||
dict: 多个文本
|
||||
"""
|
||||
args_data = {item: self.get(item) for item in args}
|
||||
kwargs_data = {
|
||||
item: self.get(item, default=default) for item, default in kwargs.items()
|
||||
}
|
||||
args_data.update(kwargs_data)
|
||||
return args_data
|
||||
|
||||
|
||||
def change_user_lang(user_id: str, lang_code: str):
|
||||
"""
|
||||
修改用户的语言,同时储存到数据库和内存中
|
||||
"""
|
||||
user = user_db.where_one(
|
||||
User(), "user_id = ?", user_id, default=User(user_id=user_id)
|
||||
)
|
||||
user.profile["lang"] = lang_code
|
||||
user_db.save(user)
|
||||
_user_lang[user_id] = lang_code
|
||||
|
||||
|
||||
def get_user_lang(user_id: str) -> Language:
|
||||
"""
|
||||
获取用户的语言实例,优先从内存中获取
|
||||
"""
|
||||
user_id = str(user_id)
|
||||
|
||||
if user_id not in _user_lang:
|
||||
nonebot.logger.debug(f"Loading user language for {user_id}")
|
||||
user = user_db.where_one(
|
||||
User(),
|
||||
"user_id = ?",
|
||||
user_id,
|
||||
default=User(user_id=user_id, username="Unknown"),
|
||||
)
|
||||
lang_code = user.profile.get("lang", get_default_lang_code())
|
||||
_user_lang[user_id] = lang_code
|
||||
|
||||
return Language(_user_lang[user_id])
|
||||
|
||||
|
||||
def get_system_lang_code() -> str:
|
||||
"""
|
||||
获取系统语言代码
|
||||
"""
|
||||
return locale.getdefaultlocale()[0].replace("_", "-")
|
||||
|
||||
|
||||
def get_default_lang_code() -> str:
|
||||
"""
|
||||
获取默认语言代码,若没有设置则使用系统语言
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return get_config("default_language", default=get_system_lang_code())
|
||||
|
||||
|
||||
def get_all_lang() -> dict[str, str]:
|
||||
"""
|
||||
获取所有语言
|
||||
Returns
|
||||
{'en': 'English'}
|
||||
"""
|
||||
d = {}
|
||||
for key in _language_data:
|
||||
d[key] = _language_data[key].get("language.name", key)
|
||||
return d
|
||||
|
@ -1,79 +1,79 @@
|
||||
import sys
|
||||
import loguru
|
||||
from typing import TYPE_CHECKING
|
||||
from .config import load_from_yaml
|
||||
from .language import Language, get_default_lang_code
|
||||
|
||||
logger = loguru.logger
|
||||
if TYPE_CHECKING:
|
||||
# avoid sphinx autodoc resolve annotation failed
|
||||
# because loguru module do not have `Logger` class actually
|
||||
from loguru import Record
|
||||
|
||||
|
||||
def default_filter(record: "Record"):
|
||||
"""默认的日志过滤器,根据 `config.log_level` 配置改变日志等级。"""
|
||||
log_level = record["extra"].get("nonebot_log_level", "INFO")
|
||||
levelno = logger.level(log_level).no if isinstance(log_level, str) else log_level
|
||||
return record["level"].no >= levelno
|
||||
|
||||
|
||||
# DEBUG日志格式
|
||||
debug_format: str = (
|
||||
"<c>{time:YYYY-MM-DD HH:mm:ss}</c> "
|
||||
"<lvl>[{level.icon}]</lvl> "
|
||||
"<c><{name}.{module}.{function}:{line}></c> "
|
||||
"{message}"
|
||||
)
|
||||
|
||||
# 默认日志格式
|
||||
default_format: str = (
|
||||
"<c>{time:MM-DD HH:mm:ss}</c> "
|
||||
"<lvl>[{level.icon}]</lvl> "
|
||||
"<c><{name}></c> "
|
||||
"{message}"
|
||||
)
|
||||
|
||||
|
||||
def get_format(level: str) -> str:
|
||||
if level == "DEBUG":
|
||||
return debug_format
|
||||
else:
|
||||
return default_format
|
||||
|
||||
|
||||
logger = loguru.logger.bind()
|
||||
|
||||
|
||||
def init_log():
|
||||
"""
|
||||
在语言加载完成后执行
|
||||
Returns:
|
||||
|
||||
"""
|
||||
global logger
|
||||
|
||||
config = load_from_yaml("config.yml")
|
||||
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=0,
|
||||
diagnose=False,
|
||||
filter=default_filter,
|
||||
format=get_format(config.get("log_level", "INFO")),
|
||||
)
|
||||
show_icon = config.get("log_icon", True)
|
||||
lang = Language(get_default_lang_code())
|
||||
|
||||
debug = lang.get("log.debug", default="==DEBUG")
|
||||
info = lang.get("log.info", default="===INFO")
|
||||
success = lang.get("log.success", default="SUCCESS")
|
||||
warning = lang.get("log.warning", default="WARNING")
|
||||
error = lang.get("log.error", default="==ERROR")
|
||||
|
||||
logger.level("DEBUG", color="<blue>", icon=f"{'🐛' if show_icon else ''}{debug}")
|
||||
logger.level("INFO", color="<normal>", icon=f"{'ℹ️' if show_icon else ''}{info}")
|
||||
logger.level("SUCCESS", color="<green>", icon=f"{'✅' if show_icon else ''}{success}")
|
||||
logger.level("WARNING", color="<yellow>", icon=f"{'⚠️' if show_icon else ''}{warning}")
|
||||
logger.level("ERROR", color="<red>", icon=f"{'⭕' if show_icon else ''}{error}")
|
||||
import sys
|
||||
import loguru
|
||||
from typing import TYPE_CHECKING
|
||||
from .config import load_from_yaml
|
||||
from .language import Language, get_default_lang_code
|
||||
|
||||
logger = loguru.logger
|
||||
if TYPE_CHECKING:
|
||||
# avoid sphinx autodoc resolve annotation failed
|
||||
# because loguru module do not have `Logger` class actually
|
||||
from loguru import Record
|
||||
|
||||
|
||||
def default_filter(record: "Record"):
|
||||
"""默认的日志过滤器,根据 `config.log_level` 配置改变日志等级。"""
|
||||
log_level = record["extra"].get("nonebot_log_level", "INFO")
|
||||
levelno = logger.level(log_level).no if isinstance(log_level, str) else log_level
|
||||
return record["level"].no >= levelno
|
||||
|
||||
|
||||
# DEBUG日志格式
|
||||
debug_format: str = (
|
||||
"<c>{time:YYYY-MM-DD HH:mm:ss}</c> "
|
||||
"<lvl>[{level.icon}]</lvl> "
|
||||
"<c><{name}.{module}.{function}:{line}></c> "
|
||||
"{message}"
|
||||
)
|
||||
|
||||
# 默认日志格式
|
||||
default_format: str = (
|
||||
"<c>{time:MM-DD HH:mm:ss}</c> "
|
||||
"<lvl>[{level.icon}]</lvl> "
|
||||
"<c><{name}></c> "
|
||||
"{message}"
|
||||
)
|
||||
|
||||
|
||||
def get_format(level: str) -> str:
|
||||
if level == "DEBUG":
|
||||
return debug_format
|
||||
else:
|
||||
return default_format
|
||||
|
||||
|
||||
logger = loguru.logger.bind()
|
||||
|
||||
|
||||
def init_log():
|
||||
"""
|
||||
在语言加载完成后执行
|
||||
Returns:
|
||||
|
||||
"""
|
||||
global logger
|
||||
|
||||
config = load_from_yaml("config.yml")
|
||||
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=0,
|
||||
diagnose=False,
|
||||
filter=default_filter,
|
||||
format=get_format(config.get("log_level", "INFO")),
|
||||
)
|
||||
show_icon = config.get("log_icon", True)
|
||||
lang = Language(get_default_lang_code())
|
||||
|
||||
debug = lang.get("log.debug", default="==DEBUG")
|
||||
info = lang.get("log.info", default="===INFO")
|
||||
success = lang.get("log.success", default="SUCCESS")
|
||||
warning = lang.get("log.warning", default="WARNING")
|
||||
error = lang.get("log.error", default="==ERROR")
|
||||
|
||||
logger.level("DEBUG", color="<blue>", icon=f"{'🐛' if show_icon else ''}{debug}")
|
||||
logger.level("INFO", color="<normal>", icon=f"{'ℹ️' if show_icon else ''}{info}")
|
||||
logger.level("SUCCESS", color="<green>", icon=f"{'✅' if show_icon else ''}{success}")
|
||||
logger.level("WARNING", color="<yellow>", icon=f"{'⚠️' if show_icon else ''}{warning}")
|
||||
logger.level("ERROR", color="<red>", icon=f"{'⭕' if show_icon else ''}{error}")
|
||||
|
@ -1,197 +1,197 @@
|
||||
"""
|
||||
liteyuki function是一种类似于mcfunction的函数,用于在liteyuki中实现一些功能,例如自定义指令等,也可与Python函数绑定
|
||||
使用 /function function_name *args **kwargs来调用
|
||||
例如 /function test/hello user_id=123456
|
||||
可以用于一些轻量级插件的编写,无需Python代码
|
||||
SnowyKami
|
||||
"""
|
||||
import asyncio
|
||||
import functools
|
||||
# cmd *args **kwargs
|
||||
# api api_name **kwargs
|
||||
import os
|
||||
from typing import Any, Awaitable, Callable, Coroutine
|
||||
|
||||
import nonebot
|
||||
from nonebot import Bot
|
||||
from nonebot.adapters.satori import bot
|
||||
from nonebot.internal.matcher import Matcher
|
||||
|
||||
ly_function_extensions = (
|
||||
"lyf",
|
||||
"lyfunction",
|
||||
"mcfunction"
|
||||
)
|
||||
|
||||
loaded_functions = dict()
|
||||
|
||||
|
||||
class LiteyukiFunction:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.functions: list[str] = list()
|
||||
self.bot: Bot = None
|
||||
self.kwargs_data = dict()
|
||||
self.args_data = list()
|
||||
self.matcher: Matcher = None
|
||||
self.end = False
|
||||
|
||||
self.sub_tasks: list[asyncio.Task] = list()
|
||||
|
||||
async def __call__(self, *args, **kwargs):
|
||||
self.kwargs_data.update(kwargs)
|
||||
self.args_data = list(set(self.args_data + list(args)))
|
||||
for i, cmd in enumerate(self.functions):
|
||||
r = await self.execute_line(cmd, i, *args, **kwargs)
|
||||
if r == 0:
|
||||
msg = f"End function {self.name} by line {i}"
|
||||
nonebot.logger.debug(msg)
|
||||
for task in self.sub_tasks:
|
||||
task.cancel(msg)
|
||||
return
|
||||
|
||||
def __str__(self):
|
||||
return f"LiteyukiFunction({self.name})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
async def execute_line(self, cmd: str, line: int = 0, *args, **kwargs) -> Any:
|
||||
"""
|
||||
解析一行轻雪函数
|
||||
Args:
|
||||
cmd: 命令
|
||||
line: 行数
|
||||
Returns:
|
||||
"""
|
||||
|
||||
try:
|
||||
if "${" in cmd:
|
||||
# 此种情况下,{}内容不用管,只对${}内的内容进行format
|
||||
for i in range(len(cmd) - 1):
|
||||
if cmd[i] == "$" and cmd[i + 1] == "{":
|
||||
end = cmd.find("}", i)
|
||||
key = cmd[i + 2:end]
|
||||
cmd = cmd.replace(f"${{{key}}}", str(self.kwargs_data.get(key, "")))
|
||||
else:
|
||||
cmd = cmd.format(*self.args_data, **self.kwargs_data)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
no_head = cmd.split(" ", 1)[1] if len(cmd.split(" ")) > 1 else ""
|
||||
try:
|
||||
head, cmd_args, cmd_kwargs = self.get_args(cmd)
|
||||
except Exception as e:
|
||||
error_msg = f"Parsing error in {self.name} at line {line}: {e}"
|
||||
nonebot.logger.error(error_msg)
|
||||
await self.matcher.send(error_msg)
|
||||
return
|
||||
|
||||
if head == "var":
|
||||
# 变量定义
|
||||
self.kwargs_data.update(cmd_kwargs)
|
||||
|
||||
elif head == "cmd":
|
||||
# 在当前计算机上执行命令
|
||||
os.system(no_head)
|
||||
|
||||
elif head == "api":
|
||||
# 调用Bot API 需要Bot实例
|
||||
await self.bot.call_api(cmd_args[1], **cmd_kwargs)
|
||||
|
||||
elif head == "function":
|
||||
# 调用轻雪函数
|
||||
func = get_function(cmd_args[1])
|
||||
func.bot = self.bot
|
||||
func.matcher = self.matcher
|
||||
await func(*cmd_args[2:], **cmd_kwargs)
|
||||
|
||||
elif head == "sleep":
|
||||
# 等待一段时间
|
||||
await asyncio.sleep(float(cmd_args[1]))
|
||||
|
||||
elif head == "nohup":
|
||||
# 挂起运行
|
||||
task = asyncio.create_task(self.execute_line(no_head))
|
||||
self.sub_tasks.append(task)
|
||||
|
||||
elif head == "end":
|
||||
# 结束所有函数
|
||||
self.end = True
|
||||
return 0
|
||||
|
||||
|
||||
elif head == "await":
|
||||
# 等待所有协程执行完毕
|
||||
await asyncio.gather(*self.sub_tasks)
|
||||
|
||||
def get_args(self, line: str) -> tuple[str, tuple[str, ...], dict[str, Any]]:
|
||||
"""
|
||||
获取参数
|
||||
Args:
|
||||
line: 命令
|
||||
Returns:
|
||||
命令头 参数 关键字
|
||||
"""
|
||||
line = line.replace("\\=", "EQUAL_SIGN")
|
||||
head = ""
|
||||
args = list()
|
||||
kwargs = dict()
|
||||
for i, arg in enumerate(line.split(" ")):
|
||||
if "=" in arg:
|
||||
key, value = arg.split("=", 1)
|
||||
value = value.replace("EQUAL_SIGN", "=")
|
||||
try:
|
||||
value = eval(value)
|
||||
except:
|
||||
value = self.kwargs_data.get(value, value)
|
||||
kwargs[key] = value
|
||||
else:
|
||||
if i == 0:
|
||||
head = arg
|
||||
args.append(arg)
|
||||
return head, tuple(args), kwargs
|
||||
|
||||
|
||||
def get_function(name: str) -> LiteyukiFunction | None:
|
||||
"""
|
||||
获取一个轻雪函数
|
||||
Args:
|
||||
name: 函数名
|
||||
Returns:
|
||||
"""
|
||||
return loaded_functions.get(name)
|
||||
|
||||
|
||||
def load_from_dir(path: str):
|
||||
"""
|
||||
从目录及其子目录中递归加载所有轻雪函数,类似mcfunction
|
||||
|
||||
Args:
|
||||
path: 目录路径
|
||||
"""
|
||||
for f in os.listdir(path):
|
||||
f = os.path.join(path, f)
|
||||
if os.path.isfile(f):
|
||||
if f.endswith(ly_function_extensions):
|
||||
load_from_file(f)
|
||||
if os.path.isdir(f):
|
||||
load_from_dir(f)
|
||||
|
||||
|
||||
def load_from_file(path: str):
|
||||
"""
|
||||
从文件中加载轻雪函数
|
||||
Args:
|
||||
path:
|
||||
Returns:
|
||||
"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
name = ".".join(os.path.basename(path).split(".")[:-1])
|
||||
func = LiteyukiFunction(name)
|
||||
for i, line in enumerate(f.read().split("\n")):
|
||||
if line.startswith("#") or line.strip() == "":
|
||||
continue
|
||||
func.functions.append(line)
|
||||
loaded_functions[name] = func
|
||||
nonebot.logger.debug(f"Loaded function {name}")
|
||||
"""
|
||||
liteyuki function是一种类似于mcfunction的函数,用于在liteyuki中实现一些功能,例如自定义指令等,也可与Python函数绑定
|
||||
使用 /function function_name *args **kwargs来调用
|
||||
例如 /function test/hello user_id=123456
|
||||
可以用于一些轻量级插件的编写,无需Python代码
|
||||
SnowyKami
|
||||
"""
|
||||
import asyncio
|
||||
import functools
|
||||
# cmd *args **kwargs
|
||||
# api api_name **kwargs
|
||||
import os
|
||||
from typing import Any, Awaitable, Callable, Coroutine
|
||||
|
||||
import nonebot
|
||||
from nonebot import Bot
|
||||
from nonebot.adapters.satori import bot
|
||||
from nonebot.internal.matcher import Matcher
|
||||
|
||||
ly_function_extensions = (
|
||||
"lyf",
|
||||
"lyfunction",
|
||||
"mcfunction"
|
||||
)
|
||||
|
||||
loaded_functions = dict()
|
||||
|
||||
|
||||
class LiteyukiFunction:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.functions: list[str] = list()
|
||||
self.bot: Bot = None
|
||||
self.kwargs_data = dict()
|
||||
self.args_data = list()
|
||||
self.matcher: Matcher = None
|
||||
self.end = False
|
||||
|
||||
self.sub_tasks: list[asyncio.Task] = list()
|
||||
|
||||
async def __call__(self, *args, **kwargs):
|
||||
self.kwargs_data.update(kwargs)
|
||||
self.args_data = list(set(self.args_data + list(args)))
|
||||
for i, cmd in enumerate(self.functions):
|
||||
r = await self.execute_line(cmd, i, *args, **kwargs)
|
||||
if r == 0:
|
||||
msg = f"End function {self.name} by line {i}"
|
||||
nonebot.logger.debug(msg)
|
||||
for task in self.sub_tasks:
|
||||
task.cancel(msg)
|
||||
return
|
||||
|
||||
def __str__(self):
|
||||
return f"LiteyukiFunction({self.name})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
async def execute_line(self, cmd: str, line: int = 0, *args, **kwargs) -> Any:
|
||||
"""
|
||||
解析一行轻雪函数
|
||||
Args:
|
||||
cmd: 命令
|
||||
line: 行数
|
||||
Returns:
|
||||
"""
|
||||
|
||||
try:
|
||||
if "${" in cmd:
|
||||
# 此种情况下,{}内容不用管,只对${}内的内容进行format
|
||||
for i in range(len(cmd) - 1):
|
||||
if cmd[i] == "$" and cmd[i + 1] == "{":
|
||||
end = cmd.find("}", i)
|
||||
key = cmd[i + 2:end]
|
||||
cmd = cmd.replace(f"${{{key}}}", str(self.kwargs_data.get(key, "")))
|
||||
else:
|
||||
cmd = cmd.format(*self.args_data, **self.kwargs_data)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
no_head = cmd.split(" ", 1)[1] if len(cmd.split(" ")) > 1 else ""
|
||||
try:
|
||||
head, cmd_args, cmd_kwargs = self.get_args(cmd)
|
||||
except Exception as e:
|
||||
error_msg = f"Parsing error in {self.name} at line {line}: {e}"
|
||||
nonebot.logger.error(error_msg)
|
||||
await self.matcher.send(error_msg)
|
||||
return
|
||||
|
||||
if head == "var":
|
||||
# 变量定义
|
||||
self.kwargs_data.update(cmd_kwargs)
|
||||
|
||||
elif head == "cmd":
|
||||
# 在当前计算机上执行命令
|
||||
os.system(no_head)
|
||||
|
||||
elif head == "api":
|
||||
# 调用Bot API 需要Bot实例
|
||||
await self.bot.call_api(cmd_args[1], **cmd_kwargs)
|
||||
|
||||
elif head == "function":
|
||||
# 调用轻雪函数
|
||||
func = get_function(cmd_args[1])
|
||||
func.bot = self.bot
|
||||
func.matcher = self.matcher
|
||||
await func(*cmd_args[2:], **cmd_kwargs)
|
||||
|
||||
elif head == "sleep":
|
||||
# 等待一段时间
|
||||
await asyncio.sleep(float(cmd_args[1]))
|
||||
|
||||
elif head == "nohup":
|
||||
# 挂起运行
|
||||
task = asyncio.create_task(self.execute_line(no_head))
|
||||
self.sub_tasks.append(task)
|
||||
|
||||
elif head == "end":
|
||||
# 结束所有函数
|
||||
self.end = True
|
||||
return 0
|
||||
|
||||
|
||||
elif head == "await":
|
||||
# 等待所有协程执行完毕
|
||||
await asyncio.gather(*self.sub_tasks)
|
||||
|
||||
def get_args(self, line: str) -> tuple[str, tuple[str, ...], dict[str, Any]]:
|
||||
"""
|
||||
获取参数
|
||||
Args:
|
||||
line: 命令
|
||||
Returns:
|
||||
命令头 参数 关键字
|
||||
"""
|
||||
line = line.replace("\\=", "EQUAL_SIGN")
|
||||
head = ""
|
||||
args = list()
|
||||
kwargs = dict()
|
||||
for i, arg in enumerate(line.split(" ")):
|
||||
if "=" in arg:
|
||||
key, value = arg.split("=", 1)
|
||||
value = value.replace("EQUAL_SIGN", "=")
|
||||
try:
|
||||
value = eval(value)
|
||||
except:
|
||||
value = self.kwargs_data.get(value, value)
|
||||
kwargs[key] = value
|
||||
else:
|
||||
if i == 0:
|
||||
head = arg
|
||||
args.append(arg)
|
||||
return head, tuple(args), kwargs
|
||||
|
||||
|
||||
def get_function(name: str) -> LiteyukiFunction | None:
|
||||
"""
|
||||
获取一个轻雪函数
|
||||
Args:
|
||||
name: 函数名
|
||||
Returns:
|
||||
"""
|
||||
return loaded_functions.get(name)
|
||||
|
||||
|
||||
def load_from_dir(path: str):
|
||||
"""
|
||||
从目录及其子目录中递归加载所有轻雪函数,类似mcfunction
|
||||
|
||||
Args:
|
||||
path: 目录路径
|
||||
"""
|
||||
for f in os.listdir(path):
|
||||
f = os.path.join(path, f)
|
||||
if os.path.isfile(f):
|
||||
if f.endswith(ly_function_extensions):
|
||||
load_from_file(f)
|
||||
if os.path.isdir(f):
|
||||
load_from_dir(f)
|
||||
|
||||
|
||||
def load_from_file(path: str):
|
||||
"""
|
||||
从文件中加载轻雪函数
|
||||
Args:
|
||||
path:
|
||||
Returns:
|
||||
"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
name = ".".join(os.path.basename(path).split(".")[:-1])
|
||||
func = LiteyukiFunction(name)
|
||||
for i, line in enumerate(f.read().split("\n")):
|
||||
if line.startswith("#") or line.strip() == "":
|
||||
continue
|
||||
func.functions.append(line)
|
||||
loaded_functions[name] = func
|
||||
nonebot.logger.debug(f"Loaded function {name}")
|
||||
|
@ -1,8 +1,8 @@
|
||||
from nonebot.adapters.onebot import v11, v12
|
||||
from nonebot.adapters import satori
|
||||
|
||||
T_Bot = v11.Bot | v12.Bot | satori.Bot
|
||||
T_GroupMessageEvent = v11.GroupMessageEvent | v12.GroupMessageEvent
|
||||
T_PrivateMessageEvent = v11.PrivateMessageEvent | v12.PrivateMessageEvent
|
||||
T_MessageEvent = v11.MessageEvent | v12.MessageEvent | satori.MessageEvent
|
||||
T_Message = v11.Message | v12.Message | satori.Message
|
||||
from nonebot.adapters.onebot import v11, v12
|
||||
from nonebot.adapters import satori
|
||||
|
||||
T_Bot = v11.Bot | v12.Bot | satori.Bot
|
||||
T_GroupMessageEvent = v11.GroupMessageEvent | v12.GroupMessageEvent
|
||||
T_PrivateMessageEvent = v11.PrivateMessageEvent | v12.PrivateMessageEvent
|
||||
T_MessageEvent = v11.MessageEvent | v12.MessageEvent | satori.MessageEvent
|
||||
T_Message = v11.Message | v12.Message | satori.Message
|
||||
|
@ -1,5 +1,5 @@
|
||||
from nonebot.adapters.onebot import v11
|
||||
|
||||
GROUP_ADMIN = v11.GROUP_ADMIN
|
||||
GROUP_OWNER = v11.GROUP_OWNER
|
||||
|
||||
from nonebot.adapters.onebot import v11
|
||||
|
||||
GROUP_ADMIN = v11.GROUP_ADMIN
|
||||
GROUP_OWNER = v11.GROUP_OWNER
|
||||
|
||||
|
@ -1,355 +1,355 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import nonebot
|
||||
import yaml
|
||||
|
||||
from .data import LiteModel
|
||||
from .language import Language, get_default_lang_code
|
||||
from .ly_function import loaded_functions
|
||||
|
||||
_loaded_resource_packs: list["ResourceMetadata"] = [] # 按照加载顺序排序
|
||||
temp_resource_root = Path("data/liteyuki/resources")
|
||||
temp_extract_root = Path("data/liteyuki/temp")
|
||||
lang = Language(get_default_lang_code())
|
||||
|
||||
|
||||
|
||||
|
||||
class ResourceMetadata(LiteModel):
|
||||
name: str = "Unknown"
|
||||
version: str = "0.0.1"
|
||||
description: str = "Unknown"
|
||||
path: str = ""
|
||||
folder: str = ""
|
||||
|
||||
|
||||
def load_resource_from_dir(path: str):
|
||||
"""
|
||||
把资源包按照文件相对路径复制到运行临时文件夹data/liteyuki/resources
|
||||
Args:
|
||||
path: 资源文件夹
|
||||
Returns:
|
||||
"""
|
||||
if os.path.exists(os.path.join(path, "metadata.yml")):
|
||||
with open(os.path.join(path, "metadata.yml"), "r", encoding="utf-8") as f:
|
||||
metadata = yaml.safe_load(f)
|
||||
elif os.path.isfile(path) and path.endswith(".zip"):
|
||||
# zip文件
|
||||
# 临时解压并读取metadata.yml
|
||||
with zipfile.ZipFile(path, "r") as zip_ref:
|
||||
# 解压至临时目录 data/liteyuki/temp/{pack_name}.zip
|
||||
zip_ref.extractall(os.path.join(temp_extract_root, os.path.basename(path)))
|
||||
with zip_ref.open("metadata.yml") as f:
|
||||
metadata = yaml.safe_load(f)
|
||||
path = os.path.join(temp_extract_root, os.path.basename(path))
|
||||
else:
|
||||
# 没有metadata.yml文件,不是一个资源包
|
||||
return
|
||||
for root, dirs, files in os.walk(path):
|
||||
for file in files:
|
||||
relative_path = os.path.relpath(os.path.join(root, file), path)
|
||||
copy_file(
|
||||
os.path.join(root, file),
|
||||
os.path.join(temp_resource_root, relative_path),
|
||||
)
|
||||
metadata["path"] = path
|
||||
metadata["folder"] = os.path.basename(path)
|
||||
|
||||
if os.path.exists(os.path.join(path, "lang")):
|
||||
# 加载语言
|
||||
from src.utils.base.language import load_from_dir
|
||||
|
||||
load_from_dir(os.path.join(path, "lang"))
|
||||
|
||||
if os.path.exists(os.path.join(path, "functions")):
|
||||
# 加载功能
|
||||
from src.utils.base.ly_function import load_from_dir
|
||||
|
||||
load_from_dir(os.path.join(path, "functions"))
|
||||
|
||||
if os.path.exists(os.path.join(path, "word_bank")):
|
||||
# 加载词库
|
||||
from src.utils.base.word_bank import load_from_dir
|
||||
|
||||
load_from_dir(os.path.join(path, "word_bank"))
|
||||
|
||||
_loaded_resource_packs.insert(0, ResourceMetadata(**metadata))
|
||||
|
||||
|
||||
def get_path(
|
||||
path: os.PathLike[str,] | Path | str,
|
||||
abs_path: bool = True,
|
||||
default: Any = None,
|
||||
debug: bool = False,
|
||||
) -> str | Any:
|
||||
"""
|
||||
获取资源包中的路径,且该路径必须存在
|
||||
Args:
|
||||
path: 相对路径
|
||||
abs_path: 是否返回绝对路径
|
||||
default: 默认解,当该路径不存在时使用
|
||||
debug: 启用调试,每次都会先重载资源
|
||||
Returns: 所需求之路径
|
||||
"""
|
||||
if debug:
|
||||
nonebot.logger.debug("Resource path debug enabled, reloading")
|
||||
load_resources()
|
||||
resource_relative_path = temp_resource_root / path
|
||||
if resource_relative_path.exists():
|
||||
return str(
|
||||
resource_relative_path.resolve() if abs_path else resource_relative_path
|
||||
)
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
def get_resource_path(
|
||||
path: os.PathLike[str,] | Path | str,
|
||||
abs_path: bool = True,
|
||||
only_exist: bool = False,
|
||||
default: Any = None,
|
||||
debug: bool = False,
|
||||
) -> Path:
|
||||
"""
|
||||
获取资源包中的路径
|
||||
Args:
|
||||
path: 相对路径
|
||||
abs_path: 是否返回绝对路径
|
||||
only_exist: 检查该路径是否存在
|
||||
default: [当 `only_exist` 为 **真** 时启用]默认解,当该路径不存在时使用
|
||||
debug: 启用调试,每次都会先重载资源
|
||||
Returns: 所需求之路径
|
||||
"""
|
||||
if debug:
|
||||
nonebot.logger.debug("Resource path debug enabled, reloading")
|
||||
load_resources()
|
||||
resource_relative_path = (
|
||||
(temp_resource_root / path).resolve()
|
||||
if abs_path
|
||||
else (temp_resource_root / path)
|
||||
)
|
||||
if only_exist:
|
||||
if resource_relative_path.exists():
|
||||
return resource_relative_path
|
||||
else:
|
||||
return default
|
||||
else:
|
||||
return resource_relative_path
|
||||
|
||||
|
||||
def get_files(
|
||||
path: os.PathLike[str,] | Path | str, abs_path: bool = False
|
||||
) -> list[str]:
|
||||
"""
|
||||
获取资源包中一个目录的所有内容
|
||||
Args:
|
||||
path: 该目录的相对路径
|
||||
abs_path: 是否返回绝对路径
|
||||
Returns: 目录内容路径所构成之列表
|
||||
"""
|
||||
resource_relative_path = temp_resource_root / path
|
||||
if resource_relative_path.exists():
|
||||
return [
|
||||
(
|
||||
str((resource_relative_path / file_).resolve())
|
||||
if abs_path
|
||||
else str((resource_relative_path / file_))
|
||||
)
|
||||
for file_ in os.listdir(resource_relative_path)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def get_resource_files(
|
||||
path: os.PathLike[str,] | Path | str, abs_path: bool = False
|
||||
) -> list[Path]:
|
||||
"""
|
||||
获取资源包中一个目录的所有内容
|
||||
Args:
|
||||
path: 该目录的相对路径
|
||||
abs_path: 是否返回绝对路径
|
||||
Returns: 目录内容路径所构成之列表
|
||||
"""
|
||||
resource_relative_path = temp_resource_root / path
|
||||
if resource_relative_path.exists():
|
||||
return [
|
||||
(
|
||||
(resource_relative_path / file_).resolve()
|
||||
if abs_path
|
||||
else (resource_relative_path / file_)
|
||||
)
|
||||
for file_ in os.listdir(resource_relative_path)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def get_loaded_resource_packs() -> list[ResourceMetadata]:
|
||||
"""
|
||||
获取已加载的资源包,优先级从前到后
|
||||
Returns: 资源包列表
|
||||
"""
|
||||
return _loaded_resource_packs
|
||||
|
||||
|
||||
def copy_file(src, dst):
|
||||
# 获取目标文件的目录
|
||||
dst_dir = os.path.dirname(dst)
|
||||
# 如果目标目录不存在,创建它
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
# 复制文件
|
||||
shutil.copy(src, dst)
|
||||
|
||||
|
||||
def load_resources():
|
||||
"""用于外部主程序调用的资源加载函数
|
||||
Returns:
|
||||
"""
|
||||
# 加载默认资源和语言
|
||||
# 清空临时资源包路径data/liteyuki/resources
|
||||
_loaded_resource_packs.clear()
|
||||
loaded_functions.clear()
|
||||
if os.path.exists(temp_resource_root):
|
||||
shutil.rmtree(temp_resource_root)
|
||||
os.makedirs(temp_resource_root, exist_ok=True)
|
||||
|
||||
# 加载内置资源
|
||||
standard_resources_path = "src/resources"
|
||||
for resource_dir in os.listdir(standard_resources_path):
|
||||
load_resource_from_dir(os.path.join(standard_resources_path, resource_dir))
|
||||
|
||||
# 加载其他资源包
|
||||
if not os.path.exists("resources"):
|
||||
os.makedirs("resources", exist_ok=True)
|
||||
|
||||
if not os.path.exists("resources/index.json"):
|
||||
json.dump([], open("resources/index.json", "w", encoding="utf-8"))
|
||||
|
||||
resource_index: list[str] = json.load(
|
||||
open("resources/index.json", "r", encoding="utf-8")
|
||||
)
|
||||
resource_index.reverse() # 优先级高的后加载,但是排在前面
|
||||
for resource in resource_index:
|
||||
load_resource_from_dir(os.path.join("resources", resource))
|
||||
|
||||
|
||||
def check_status(name: str) -> bool:
|
||||
"""
|
||||
检查资源包是否已加载
|
||||
Args:
|
||||
name: 资源包名称,文件夹名
|
||||
Returns: 是否已加载
|
||||
"""
|
||||
return name in [rp.folder for rp in get_loaded_resource_packs()]
|
||||
|
||||
|
||||
def check_exist(name: str) -> bool:
|
||||
"""
|
||||
检查资源包文件夹是否存在于resources文件夹
|
||||
Args:
|
||||
name: 资源包名称,文件夹名
|
||||
Returns: 是否存在
|
||||
"""
|
||||
path = os.path.join("resources", name)
|
||||
return os.path.exists(os.path.join(path, "metadata.yml")) or (
|
||||
os.path.isfile(path) and name.endswith(".zip")
|
||||
)
|
||||
|
||||
|
||||
def add_resource_pack(name: str) -> bool:
|
||||
"""
|
||||
添加资源包,该操作仅修改index.json文件,不会加载资源包,要生效请重载资源
|
||||
Args:
|
||||
name: 资源包名称,文件夹名
|
||||
Returns:
|
||||
"""
|
||||
if check_exist(name):
|
||||
old_index: list[str] = json.load(
|
||||
open("resources/index.json", "r", encoding="utf-8")
|
||||
)
|
||||
if name not in old_index:
|
||||
old_index.append(name)
|
||||
json.dump(old_index, open("resources/index.json", "w", encoding="utf-8"))
|
||||
load_resource_from_dir(os.path.join("resources", name))
|
||||
return True
|
||||
else:
|
||||
nonebot.logger.warning(lang.get("liteyuki.resource_loaded", name=name))
|
||||
return False
|
||||
else:
|
||||
nonebot.logger.warning(lang.get("liteyuki.resource_not_exist", name=name))
|
||||
return False
|
||||
|
||||
|
||||
def remove_resource_pack(name: str) -> bool:
|
||||
"""
|
||||
移除资源包,该操作仅修改加载索引,要生效请重载资源
|
||||
Args:
|
||||
name: 资源包名称,文件夹名
|
||||
Returns:
|
||||
"""
|
||||
if check_exist(name):
|
||||
old_index: list[str] = json.load(
|
||||
open("resources/index.json", "r", encoding="utf-8")
|
||||
)
|
||||
if name in old_index:
|
||||
old_index.remove(name)
|
||||
json.dump(old_index, open("resources/index.json", "w", encoding="utf-8"))
|
||||
return True
|
||||
else:
|
||||
nonebot.logger.warning(lang.get("liteyuki.resource_not_loaded", name=name))
|
||||
return False
|
||||
else:
|
||||
nonebot.logger.warning(lang.get("liteyuki.resource_not_exist", name=name))
|
||||
return False
|
||||
|
||||
|
||||
def change_priority(name: str, delta: int) -> bool:
|
||||
"""
|
||||
修改资源包优先级
|
||||
Args:
|
||||
name: 资源包名称,文件夹名
|
||||
delta: 优先级变化,正数表示后移,负数表示前移,0表示移到最前
|
||||
Returns:
|
||||
"""
|
||||
# 正数表示前移,负数表示后移
|
||||
old_resource_list: list[str] = json.load(
|
||||
open("resources/index.json", "r", encoding="utf-8")
|
||||
)
|
||||
new_resource_list = old_resource_list.copy()
|
||||
if name in old_resource_list:
|
||||
index = old_resource_list.index(name)
|
||||
if 0 <= index + delta < len(old_resource_list):
|
||||
new_index = index + delta
|
||||
new_resource_list.remove(name)
|
||||
new_resource_list.insert(new_index, name)
|
||||
json.dump(
|
||||
new_resource_list, open("resources/index.json", "w", encoding="utf-8")
|
||||
)
|
||||
return True
|
||||
else:
|
||||
nonebot.logger.warning("Priority change failed, out of range")
|
||||
return False
|
||||
else:
|
||||
nonebot.logger.debug("Priority change failed, resource not loaded")
|
||||
return False
|
||||
|
||||
|
||||
def get_resource_metadata(name: str) -> ResourceMetadata:
|
||||
"""
|
||||
获取资源包元数据
|
||||
Args:
|
||||
name: 资源包名称,文件夹名
|
||||
Returns:
|
||||
"""
|
||||
for rp in get_loaded_resource_packs():
|
||||
if rp.folder == name:
|
||||
return rp
|
||||
return ResourceMetadata()
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import nonebot
|
||||
import yaml
|
||||
|
||||
from .data import LiteModel
|
||||
from .language import Language, get_default_lang_code
|
||||
from .ly_function import loaded_functions
|
||||
|
||||
_loaded_resource_packs: list["ResourceMetadata"] = [] # 按照加载顺序排序
|
||||
temp_resource_root = Path("data/liteyuki/resources")
|
||||
temp_extract_root = Path("data/liteyuki/temp")
|
||||
lang = Language(get_default_lang_code())
|
||||
|
||||
|
||||
|
||||
|
||||
class ResourceMetadata(LiteModel):
|
||||
name: str = "Unknown"
|
||||
version: str = "0.0.1"
|
||||
description: str = "Unknown"
|
||||
path: str = ""
|
||||
folder: str = ""
|
||||
|
||||
|
||||
def load_resource_from_dir(path: str):
|
||||
"""
|
||||
把资源包按照文件相对路径复制到运行临时文件夹data/liteyuki/resources
|
||||
Args:
|
||||
path: 资源文件夹
|
||||
Returns:
|
||||
"""
|
||||
if os.path.exists(os.path.join(path, "metadata.yml")):
|
||||
with open(os.path.join(path, "metadata.yml"), "r", encoding="utf-8") as f:
|
||||
metadata = yaml.safe_load(f)
|
||||
elif os.path.isfile(path) and path.endswith(".zip"):
|
||||
# zip文件
|
||||
# 临时解压并读取metadata.yml
|
||||
with zipfile.ZipFile(path, "r") as zip_ref:
|
||||
# 解压至临时目录 data/liteyuki/temp/{pack_name}.zip
|
||||
zip_ref.extractall(os.path.join(temp_extract_root, os.path.basename(path)))
|
||||
with zip_ref.open("metadata.yml") as f:
|
||||
metadata = yaml.safe_load(f)
|
||||
path = os.path.join(temp_extract_root, os.path.basename(path))
|
||||
else:
|
||||
# 没有metadata.yml文件,不是一个资源包
|
||||
return
|
||||
for root, dirs, files in os.walk(path):
|
||||
for file in files:
|
||||
relative_path = os.path.relpath(os.path.join(root, file), path)
|
||||
copy_file(
|
||||
os.path.join(root, file),
|
||||
os.path.join(temp_resource_root, relative_path),
|
||||
)
|
||||
metadata["path"] = path
|
||||
metadata["folder"] = os.path.basename(path)
|
||||
|
||||
if os.path.exists(os.path.join(path, "lang")):
|
||||
# 加载语言
|
||||
from src.utils.base.language import load_from_dir
|
||||
|
||||
load_from_dir(os.path.join(path, "lang"))
|
||||
|
||||
if os.path.exists(os.path.join(path, "functions")):
|
||||
# 加载功能
|
||||
from src.utils.base.ly_function import load_from_dir
|
||||
|
||||
load_from_dir(os.path.join(path, "functions"))
|
||||
|
||||
if os.path.exists(os.path.join(path, "word_bank")):
|
||||
# 加载词库
|
||||
from src.utils.base.word_bank import load_from_dir
|
||||
|
||||
load_from_dir(os.path.join(path, "word_bank"))
|
||||
|
||||
_loaded_resource_packs.insert(0, ResourceMetadata(**metadata))
|
||||
|
||||
|
||||
def get_path(
|
||||
path: os.PathLike[str,] | Path | str,
|
||||
abs_path: bool = True,
|
||||
default: Any = None,
|
||||
debug: bool = False,
|
||||
) -> str | Any:
|
||||
"""
|
||||
获取资源包中的路径,且该路径必须存在
|
||||
Args:
|
||||
path: 相对路径
|
||||
abs_path: 是否返回绝对路径
|
||||
default: 默认解,当该路径不存在时使用
|
||||
debug: 启用调试,每次都会先重载资源
|
||||
Returns: 所需求之路径
|
||||
"""
|
||||
if debug:
|
||||
nonebot.logger.debug("Resource path debug enabled, reloading")
|
||||
load_resources()
|
||||
resource_relative_path = temp_resource_root / path
|
||||
if resource_relative_path.exists():
|
||||
return str(
|
||||
resource_relative_path.resolve() if abs_path else resource_relative_path
|
||||
)
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
def get_resource_path(
|
||||
path: os.PathLike[str,] | Path | str,
|
||||
abs_path: bool = True,
|
||||
only_exist: bool = False,
|
||||
default: Any = None,
|
||||
debug: bool = False,
|
||||
) -> Path:
|
||||
"""
|
||||
获取资源包中的路径
|
||||
Args:
|
||||
path: 相对路径
|
||||
abs_path: 是否返回绝对路径
|
||||
only_exist: 检查该路径是否存在
|
||||
default: [当 `only_exist` 为 **真** 时启用]默认解,当该路径不存在时使用
|
||||
debug: 启用调试,每次都会先重载资源
|
||||
Returns: 所需求之路径
|
||||
"""
|
||||
if debug:
|
||||
nonebot.logger.debug("Resource path debug enabled, reloading")
|
||||
load_resources()
|
||||
resource_relative_path = (
|
||||
(temp_resource_root / path).resolve()
|
||||
if abs_path
|
||||
else (temp_resource_root / path)
|
||||
)
|
||||
if only_exist:
|
||||
if resource_relative_path.exists():
|
||||
return resource_relative_path
|
||||
else:
|
||||
return default
|
||||
else:
|
||||
return resource_relative_path
|
||||
|
||||
|
||||
def get_files(
|
||||
path: os.PathLike[str,] | Path | str, abs_path: bool = False
|
||||
) -> list[str]:
|
||||
"""
|
||||
获取资源包中一个目录的所有内容
|
||||
Args:
|
||||
path: 该目录的相对路径
|
||||
abs_path: 是否返回绝对路径
|
||||
Returns: 目录内容路径所构成之列表
|
||||
"""
|
||||
resource_relative_path = temp_resource_root / path
|
||||
if resource_relative_path.exists():
|
||||
return [
|
||||
(
|
||||
str((resource_relative_path / file_).resolve())
|
||||
if abs_path
|
||||
else str((resource_relative_path / file_))
|
||||
)
|
||||
for file_ in os.listdir(resource_relative_path)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def get_resource_files(
|
||||
path: os.PathLike[str,] | Path | str, abs_path: bool = False
|
||||
) -> list[Path]:
|
||||
"""
|
||||
获取资源包中一个目录的所有内容
|
||||
Args:
|
||||
path: 该目录的相对路径
|
||||
abs_path: 是否返回绝对路径
|
||||
Returns: 目录内容路径所构成之列表
|
||||
"""
|
||||
resource_relative_path = temp_resource_root / path
|
||||
if resource_relative_path.exists():
|
||||
return [
|
||||
(
|
||||
(resource_relative_path / file_).resolve()
|
||||
if abs_path
|
||||
else (resource_relative_path / file_)
|
||||
)
|
||||
for file_ in os.listdir(resource_relative_path)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def get_loaded_resource_packs() -> list[ResourceMetadata]:
|
||||
"""
|
||||
获取已加载的资源包,优先级从前到后
|
||||
Returns: 资源包列表
|
||||
"""
|
||||
return _loaded_resource_packs
|
||||
|
||||
|
||||
def copy_file(src, dst):
|
||||
# 获取目标文件的目录
|
||||
dst_dir = os.path.dirname(dst)
|
||||
# 如果目标目录不存在,创建它
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
# 复制文件
|
||||
shutil.copy(src, dst)
|
||||
|
||||
|
||||
def load_resources():
|
||||
"""用于外部主程序调用的资源加载函数
|
||||
Returns:
|
||||
"""
|
||||
# 加载默认资源和语言
|
||||
# 清空临时资源包路径data/liteyuki/resources
|
||||
_loaded_resource_packs.clear()
|
||||
loaded_functions.clear()
|
||||
if os.path.exists(temp_resource_root):
|
||||
shutil.rmtree(temp_resource_root)
|
||||
os.makedirs(temp_resource_root, exist_ok=True)
|
||||
|
||||
# 加载内置资源
|
||||
standard_resources_path = "src/resources"
|
||||
for resource_dir in os.listdir(standard_resources_path):
|
||||
load_resource_from_dir(os.path.join(standard_resources_path, resource_dir))
|
||||
|
||||
# 加载其他资源包
|
||||
if not os.path.exists("resources"):
|
||||
os.makedirs("resources", exist_ok=True)
|
||||
|
||||
if not os.path.exists("resources/index.json"):
|
||||
json.dump([], open("resources/index.json", "w", encoding="utf-8"))
|
||||
|
||||
resource_index: list[str] = json.load(
|
||||
open("resources/index.json", "r", encoding="utf-8")
|
||||
)
|
||||
resource_index.reverse() # 优先级高的后加载,但是排在前面
|
||||
for resource in resource_index:
|
||||
load_resource_from_dir(os.path.join("resources", resource))
|
||||
|
||||
|
||||
def check_status(name: str) -> bool:
|
||||
"""
|
||||
检查资源包是否已加载
|
||||
Args:
|
||||
name: 资源包名称,文件夹名
|
||||
Returns: 是否已加载
|
||||
"""
|
||||
return name in [rp.folder for rp in get_loaded_resource_packs()]
|
||||
|
||||
|
||||
def check_exist(name: str) -> bool:
|
||||
"""
|
||||
检查资源包文件夹是否存在于resources文件夹
|
||||
Args:
|
||||
name: 资源包名称,文件夹名
|
||||
Returns: 是否存在
|
||||
"""
|
||||
path = os.path.join("resources", name)
|
||||
return os.path.exists(os.path.join(path, "metadata.yml")) or (
|
||||
os.path.isfile(path) and name.endswith(".zip")
|
||||
)
|
||||
|
||||
|
||||
def add_resource_pack(name: str) -> bool:
|
||||
"""
|
||||
添加资源包,该操作仅修改index.json文件,不会加载资源包,要生效请重载资源
|
||||
Args:
|
||||
name: 资源包名称,文件夹名
|
||||
Returns:
|
||||
"""
|
||||
if check_exist(name):
|
||||
old_index: list[str] = json.load(
|
||||
open("resources/index.json", "r", encoding="utf-8")
|
||||
)
|
||||
if name not in old_index:
|
||||
old_index.append(name)
|
||||
json.dump(old_index, open("resources/index.json", "w", encoding="utf-8"))
|
||||
load_resource_from_dir(os.path.join("resources", name))
|
||||
return True
|
||||
else:
|
||||
nonebot.logger.warning(lang.get("liteyuki.resource_loaded", name=name))
|
||||
return False
|
||||
else:
|
||||
nonebot.logger.warning(lang.get("liteyuki.resource_not_exist", name=name))
|
||||
return False
|
||||
|
||||
|
||||
def remove_resource_pack(name: str) -> bool:
|
||||
"""
|
||||
移除资源包,该操作仅修改加载索引,要生效请重载资源
|
||||
Args:
|
||||
name: 资源包名称,文件夹名
|
||||
Returns:
|
||||
"""
|
||||
if check_exist(name):
|
||||
old_index: list[str] = json.load(
|
||||
open("resources/index.json", "r", encoding="utf-8")
|
||||
)
|
||||
if name in old_index:
|
||||
old_index.remove(name)
|
||||
json.dump(old_index, open("resources/index.json", "w", encoding="utf-8"))
|
||||
return True
|
||||
else:
|
||||
nonebot.logger.warning(lang.get("liteyuki.resource_not_loaded", name=name))
|
||||
return False
|
||||
else:
|
||||
nonebot.logger.warning(lang.get("liteyuki.resource_not_exist", name=name))
|
||||
return False
|
||||
|
||||
|
||||
def change_priority(name: str, delta: int) -> bool:
|
||||
"""
|
||||
修改资源包优先级
|
||||
Args:
|
||||
name: 资源包名称,文件夹名
|
||||
delta: 优先级变化,正数表示后移,负数表示前移,0表示移到最前
|
||||
Returns:
|
||||
"""
|
||||
# 正数表示前移,负数表示后移
|
||||
old_resource_list: list[str] = json.load(
|
||||
open("resources/index.json", "r", encoding="utf-8")
|
||||
)
|
||||
new_resource_list = old_resource_list.copy()
|
||||
if name in old_resource_list:
|
||||
index = old_resource_list.index(name)
|
||||
if 0 <= index + delta < len(old_resource_list):
|
||||
new_index = index + delta
|
||||
new_resource_list.remove(name)
|
||||
new_resource_list.insert(new_index, name)
|
||||
json.dump(
|
||||
new_resource_list, open("resources/index.json", "w", encoding="utf-8")
|
||||
)
|
||||
return True
|
||||
else:
|
||||
nonebot.logger.warning("Priority change failed, out of range")
|
||||
return False
|
||||
else:
|
||||
nonebot.logger.debug("Priority change failed, resource not loaded")
|
||||
return False
|
||||
|
||||
|
||||
def get_resource_metadata(name: str) -> ResourceMetadata:
|
||||
"""
|
||||
获取资源包元数据
|
||||
Args:
|
||||
name: 资源包名称,文件夹名
|
||||
Returns:
|
||||
"""
|
||||
for rp in get_loaded_resource_packs():
|
||||
if rp.folder == name:
|
||||
return rp
|
||||
return ResourceMetadata()
|
||||
|
@ -1,57 +1,57 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Iterable
|
||||
|
||||
import nonebot
|
||||
|
||||
word_bank: dict[str, set[str]] = {}
|
||||
|
||||
|
||||
def load_from_file(file_path: str):
|
||||
"""
|
||||
从json文件中加载词库
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
data = json.load(file)
|
||||
for key, value_list in data.items():
|
||||
if key not in word_bank:
|
||||
word_bank[key] = set()
|
||||
word_bank[key].update(value_list)
|
||||
|
||||
nonebot.logger.debug(f"Loaded word bank from {file_path}")
|
||||
|
||||
|
||||
def load_from_dir(dir_path: str):
|
||||
"""
|
||||
从目录中加载词库
|
||||
|
||||
Args:
|
||||
dir_path: 目录路径
|
||||
"""
|
||||
for file in os.listdir(dir_path):
|
||||
try:
|
||||
file_path = os.path.join(dir_path, file)
|
||||
if os.path.isfile(file_path):
|
||||
if file.endswith(".json"):
|
||||
load_from_file(file_path)
|
||||
except Exception as e:
|
||||
nonebot.logger.error(f"Failed to load language data from {file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
def get_reply(kws: Iterable[str]) -> str | None:
|
||||
"""
|
||||
获取回复
|
||||
Args:
|
||||
kws: 关键词
|
||||
Returns:
|
||||
"""
|
||||
for kw in kws:
|
||||
if kw in word_bank:
|
||||
return random.choice(list(word_bank[kw]))
|
||||
|
||||
return None
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Iterable
|
||||
|
||||
import nonebot
|
||||
|
||||
word_bank: dict[str, set[str]] = {}
|
||||
|
||||
|
||||
def load_from_file(file_path: str):
|
||||
"""
|
||||
从json文件中加载词库
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
data = json.load(file)
|
||||
for key, value_list in data.items():
|
||||
if key not in word_bank:
|
||||
word_bank[key] = set()
|
||||
word_bank[key].update(value_list)
|
||||
|
||||
nonebot.logger.debug(f"Loaded word bank from {file_path}")
|
||||
|
||||
|
||||
def load_from_dir(dir_path: str):
|
||||
"""
|
||||
从目录中加载词库
|
||||
|
||||
Args:
|
||||
dir_path: 目录路径
|
||||
"""
|
||||
for file in os.listdir(dir_path):
|
||||
try:
|
||||
file_path = os.path.join(dir_path, file)
|
||||
if os.path.isfile(file_path):
|
||||
if file.endswith(".json"):
|
||||
load_from_file(file_path)
|
||||
except Exception as e:
|
||||
nonebot.logger.error(f"Failed to load language data from {file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
def get_reply(kws: Iterable[str]) -> str | None:
|
||||
"""
|
||||
获取回复
|
||||
Args:
|
||||
kws: 关键词
|
||||
Returns:
|
||||
"""
|
||||
for kw in kws:
|
||||
if kw in word_bank:
|
||||
return random.choice(list(word_bank[kw]))
|
||||
|
||||
return None
|
||||
|
@ -1 +1 @@
|
||||
from .get_info import *
|
||||
from .get_info import *
|
||||
|
@ -1,26 +1,26 @@
|
||||
from nonebot.adapters import satori
|
||||
from nonebot.adapters import onebot
|
||||
from src.utils.base.ly_typing import T_MessageEvent, T_GroupMessageEvent
|
||||
|
||||
|
||||
def get_user_id(event: T_MessageEvent):
|
||||
if isinstance(event, satori.event.Event):
|
||||
return event.user.id
|
||||
else:
|
||||
return event.user_id
|
||||
|
||||
|
||||
def get_group_id(event: T_GroupMessageEvent):
|
||||
if isinstance(event, satori.event.Event):
|
||||
return event.guild.id
|
||||
elif isinstance(event, onebot.v11.GroupMessageEvent):
|
||||
return event.group_id
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_message_type(event: T_MessageEvent) -> str:
|
||||
if isinstance(event, satori.event.Event):
|
||||
return "private" if event.guild is None else "group"
|
||||
else:
|
||||
return event.message_type
|
||||
from nonebot.adapters import satori
|
||||
from nonebot.adapters import onebot
|
||||
from src.utils.base.ly_typing import T_MessageEvent, T_GroupMessageEvent
|
||||
|
||||
|
||||
def get_user_id(event: T_MessageEvent):
|
||||
if isinstance(event, satori.event.Event):
|
||||
return event.user.id
|
||||
else:
|
||||
return event.user_id
|
||||
|
||||
|
||||
def get_group_id(event: T_GroupMessageEvent):
|
||||
if isinstance(event, satori.event.Event):
|
||||
return event.guild.id
|
||||
elif isinstance(event, onebot.v11.GroupMessageEvent):
|
||||
return event.group_id
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_message_type(event: T_MessageEvent) -> str:
|
||||
if isinstance(event, satori.event.Event):
|
||||
return "private" if event.guild is None else "group"
|
||||
else:
|
||||
return event.message_type
|
||||
|
80
src/utils/external/logo.py
vendored
80
src/utils/external/logo.py
vendored
@ -1,40 +1,40 @@
|
||||
async def get_user_icon(platform: str, user_id: str) -> str:
|
||||
"""
|
||||
获取用户头像
|
||||
Args:
|
||||
platform: qq, telegram, discord...
|
||||
user_id: 1234567890
|
||||
|
||||
Returns:
|
||||
str: 头像链接
|
||||
"""
|
||||
match platform:
|
||||
case "qq":
|
||||
return f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=640"
|
||||
case "telegram":
|
||||
return f"https://t.me/i/userpic/320/{user_id}.jpg"
|
||||
case "discord":
|
||||
return f"https://cdn.discordapp.com/avatars/{user_id}/"
|
||||
case _:
|
||||
return ""
|
||||
|
||||
|
||||
async def get_group_icon(platform: str, group_id: str) -> str:
|
||||
"""
|
||||
获取群组头像
|
||||
Args:
|
||||
platform: qq, telegram, discord...
|
||||
group_id: 1234567890
|
||||
|
||||
Returns:
|
||||
str: 头像链接
|
||||
"""
|
||||
match platform:
|
||||
case "qq":
|
||||
return f"http://p.qlogo.cn/gh/{group_id}/{group_id}/640"
|
||||
case "telegram":
|
||||
return f"https://t.me/c/{group_id}/"
|
||||
case "discord":
|
||||
return f"https://cdn.discordapp.com/icons/{group_id}/"
|
||||
case _:
|
||||
return ""
|
||||
async def get_user_icon(platform: str, user_id: str) -> str:
|
||||
"""
|
||||
获取用户头像
|
||||
Args:
|
||||
platform: qq, telegram, discord...
|
||||
user_id: 1234567890
|
||||
|
||||
Returns:
|
||||
str: 头像链接
|
||||
"""
|
||||
match platform:
|
||||
case "qq":
|
||||
return f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=640"
|
||||
case "telegram":
|
||||
return f"https://t.me/i/userpic/320/{user_id}.jpg"
|
||||
case "discord":
|
||||
return f"https://cdn.discordapp.com/avatars/{user_id}/"
|
||||
case _:
|
||||
return ""
|
||||
|
||||
|
||||
async def get_group_icon(platform: str, group_id: str) -> str:
|
||||
"""
|
||||
获取群组头像
|
||||
Args:
|
||||
platform: qq, telegram, discord...
|
||||
group_id: 1234567890
|
||||
|
||||
Returns:
|
||||
str: 头像链接
|
||||
"""
|
||||
match platform:
|
||||
case "qq":
|
||||
return f"http://p.qlogo.cn/gh/{group_id}/{group_id}/640"
|
||||
case "telegram":
|
||||
return f"https://t.me/c/{group_id}/"
|
||||
case "discord":
|
||||
return f"https://cdn.discordapp.com/icons/{group_id}/"
|
||||
case _:
|
||||
return ""
|
||||
|
@ -1,89 +1,89 @@
|
||||
import os
|
||||
import aiofiles # type: ignore
|
||||
import nonebot
|
||||
from nonebot import require
|
||||
|
||||
# require("nonebot_plugin_htmlrender")
|
||||
|
||||
from nonebot_plugin_htmlrender import ( # type: ignore
|
||||
template_to_html,
|
||||
template_to_pic,
|
||||
md_to_pic
|
||||
) # type: ignore
|
||||
|
||||
|
||||
async def template2html(
|
||||
template: str,
|
||||
templates: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
template: str: 模板文件
|
||||
**templates: dict: 模板参数
|
||||
Returns:
|
||||
HTML 正文
|
||||
"""
|
||||
template_path = os.path.dirname(template)
|
||||
template_name = os.path.basename(template)
|
||||
return await template_to_html(template_path, template_name, **templates)
|
||||
|
||||
|
||||
async def template2image(
|
||||
template: str,
|
||||
templates: dict,
|
||||
pages=None,
|
||||
wait: int = 0,
|
||||
scale_factor: float = 1,
|
||||
debug: bool = False,
|
||||
) -> bytes:
|
||||
"""
|
||||
template -> html -> image
|
||||
Args:
|
||||
debug: 输入渲染好的 html
|
||||
wait: 等待时间,单位秒
|
||||
pages: 页面参数
|
||||
template: str: 模板文件
|
||||
templates: dict: 模板参数
|
||||
scale_factor: 缩放因子,越高越清晰
|
||||
Returns:
|
||||
图片二进制数据
|
||||
"""
|
||||
|
||||
###
|
||||
if pages is None:
|
||||
pages = {
|
||||
"viewport": {
|
||||
"width" : 1080,
|
||||
"height": 10
|
||||
},
|
||||
}
|
||||
|
||||
template_path = os.path.dirname(template)
|
||||
template_name = os.path.basename(template)
|
||||
|
||||
if debug:
|
||||
# 重载资源
|
||||
raw_html = await template_to_html(
|
||||
template_name=template_name,
|
||||
template_path=template_path,
|
||||
**templates,
|
||||
)
|
||||
random_file_name = f"debug.html"
|
||||
async with aiofiles.open(
|
||||
os.path.join(template_path, random_file_name), "w", encoding="utf-8"
|
||||
) as f:
|
||||
await f.write(raw_html)
|
||||
nonebot.logger.info("Debug HTML: %s" % f"{random_file_name}")
|
||||
return await template_to_pic(
|
||||
template_name=template_name,
|
||||
template_path=template_path,
|
||||
templates=templates,
|
||||
wait=wait,
|
||||
|
||||
###
|
||||
pages=pages,
|
||||
device_scale_factor=scale_factor
|
||||
###
|
||||
)
|
||||
|
||||
|
||||
import os
|
||||
import aiofiles # type: ignore
|
||||
import nonebot
|
||||
from nonebot import require
|
||||
|
||||
# require("nonebot_plugin_htmlrender")
|
||||
|
||||
from nonebot_plugin_htmlrender import ( # type: ignore
|
||||
template_to_html,
|
||||
template_to_pic,
|
||||
md_to_pic
|
||||
) # type: ignore
|
||||
|
||||
|
||||
async def template2html(
|
||||
template: str,
|
||||
templates: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
template: str: 模板文件
|
||||
**templates: dict: 模板参数
|
||||
Returns:
|
||||
HTML 正文
|
||||
"""
|
||||
template_path = os.path.dirname(template)
|
||||
template_name = os.path.basename(template)
|
||||
return await template_to_html(template_path, template_name, **templates)
|
||||
|
||||
|
||||
async def template2image(
|
||||
template: str,
|
||||
templates: dict,
|
||||
pages=None,
|
||||
wait: int = 0,
|
||||
scale_factor: float = 1,
|
||||
debug: bool = False,
|
||||
) -> bytes:
|
||||
"""
|
||||
template -> html -> image
|
||||
Args:
|
||||
debug: 输入渲染好的 html
|
||||
wait: 等待时间,单位秒
|
||||
pages: 页面参数
|
||||
template: str: 模板文件
|
||||
templates: dict: 模板参数
|
||||
scale_factor: 缩放因子,越高越清晰
|
||||
Returns:
|
||||
图片二进制数据
|
||||
"""
|
||||
|
||||
###
|
||||
if pages is None:
|
||||
pages = {
|
||||
"viewport": {
|
||||
"width" : 1080,
|
||||
"height": 10
|
||||
},
|
||||
}
|
||||
|
||||
template_path = os.path.dirname(template)
|
||||
template_name = os.path.basename(template)
|
||||
|
||||
if debug:
|
||||
# 重载资源
|
||||
raw_html = await template_to_html(
|
||||
template_name=template_name,
|
||||
template_path=template_path,
|
||||
**templates,
|
||||
)
|
||||
random_file_name = f"debug.html"
|
||||
async with aiofiles.open(
|
||||
os.path.join(template_path, random_file_name), "w", encoding="utf-8"
|
||||
) as f:
|
||||
await f.write(raw_html)
|
||||
nonebot.logger.info("Debug HTML: %s" % f"{random_file_name}")
|
||||
return await template_to_pic(
|
||||
template_name=template_name,
|
||||
template_path=template_path,
|
||||
templates=templates,
|
||||
wait=wait,
|
||||
|
||||
###
|
||||
pages=pages,
|
||||
device_scale_factor=scale_factor
|
||||
###
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,209 +1,209 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from urllib.parse import quote
|
||||
|
||||
import aiohttp
|
||||
from PIL import Image
|
||||
|
||||
from ..base.config import get_config
|
||||
from ..base.data import LiteModel
|
||||
from ..base.ly_typing import T_Bot
|
||||
|
||||
|
||||
def escape_md(text: str) -> str:
|
||||
"""
|
||||
转义Markdown特殊字符
|
||||
Args:
|
||||
text: str: 文本
|
||||
|
||||
Returns:
|
||||
str: 转义后文本
|
||||
"""
|
||||
spacial_chars = r"\`*_{}[]()#+-.!"
|
||||
for char in spacial_chars:
|
||||
text = text.replace(char, "\\\\" + char)
|
||||
return text.replace("\n", r"\n").replace('"', r'\\\"')
|
||||
|
||||
|
||||
def escape_decorator(func):
|
||||
def wrapper(text: str):
|
||||
return func(escape_md(text))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def compile_md(comps: list[str]) -> str:
|
||||
"""
|
||||
合成Markdown文本
|
||||
Args:
|
||||
comps: list[str]: 组件列表
|
||||
|
||||
Returns:
|
||||
str: 编译后文本
|
||||
"""
|
||||
return "".join(comps)
|
||||
|
||||
|
||||
class MarkdownComponent:
|
||||
@staticmethod
|
||||
def heading(text: str, level: int = 1) -> str:
|
||||
"""标题"""
|
||||
assert 1 <= level <= 6, "标题级别应在 1-6 之间"
|
||||
return f"{'#' * level} {text}\n"
|
||||
|
||||
@staticmethod
|
||||
def bold(text: str) -> str:
|
||||
"""粗体"""
|
||||
return f"**{text}**"
|
||||
|
||||
@staticmethod
|
||||
def italic(text: str) -> str:
|
||||
"""斜体"""
|
||||
return f"*{text}*"
|
||||
|
||||
@staticmethod
|
||||
def strike(text: str) -> str:
|
||||
"""删除线"""
|
||||
return f"~~{text}~~"
|
||||
|
||||
@staticmethod
|
||||
def code(text: str) -> str:
|
||||
"""行内代码"""
|
||||
return f"`{text}`"
|
||||
|
||||
@staticmethod
|
||||
def code_block(text: str, language: str = "") -> str:
|
||||
"""代码块"""
|
||||
return f"```{language}\n{text}\n```\n"
|
||||
|
||||
@staticmethod
|
||||
def quote(text: str) -> str:
|
||||
"""引用"""
|
||||
return f"> {text}\n\n"
|
||||
|
||||
@staticmethod
|
||||
def link(text: str, url: str, symbol: bool = True) -> str:
|
||||
"""
|
||||
链接
|
||||
|
||||
Args:
|
||||
text: 链接文本
|
||||
url: 链接地址
|
||||
symbol: 是否显示链接图标, mqqapi请使用False
|
||||
"""
|
||||
return f"[{'🔗' if symbol else ''}{text}]({url})"
|
||||
|
||||
@staticmethod
|
||||
def image(url: str, *, size: tuple[int, int]) -> str:
|
||||
"""
|
||||
图片,本地图片不建议直接使用
|
||||
Args:
|
||||
url: 图片链接
|
||||
size: 图片大小
|
||||
|
||||
Returns:
|
||||
markdown格式的图片
|
||||
"""
|
||||
return f"![image #{size[0]}px #{size[1]}px]({url})"
|
||||
|
||||
@staticmethod
|
||||
async def auto_image(image: str | bytes, bot: T_Bot) -> str:
|
||||
"""
|
||||
自动获取图片大小
|
||||
Args:
|
||||
image: 本地图片路径 | 图片url http/file | 图片bytes
|
||||
bot: bot对象,用于上传图片到图床
|
||||
|
||||
Returns:
|
||||
markdown格式的图片
|
||||
"""
|
||||
if isinstance(image, bytes):
|
||||
# 传入为二进制图片
|
||||
image_obj = Image.open(BytesIO(image))
|
||||
base64_string = base64.b64encode(image_obj.tobytes()).decode("utf-8")
|
||||
url = await bot.call_api("upload_image", file=f"base64://{base64_string}")
|
||||
size = image_obj.size
|
||||
elif isinstance(image, str):
|
||||
# 传入链接或本地路径
|
||||
if image.startswith("http"):
|
||||
# 网络请求
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image) as resp:
|
||||
image_data = await resp.read()
|
||||
url = image
|
||||
size = Image.open(BytesIO(image_data)).size
|
||||
|
||||
else:
|
||||
# 本地路径/file://
|
||||
image_obj = Image.open(image.replace("file://", ""))
|
||||
base64_string = base64.b64encode(image_obj.tobytes()).decode("utf-8")
|
||||
url = await bot.call_api("upload_image", file=f"base64://{base64_string}")
|
||||
size = image_obj.size
|
||||
else:
|
||||
raise ValueError("图片类型错误")
|
||||
|
||||
return MarkdownComponent.image(url, size=size)
|
||||
|
||||
@staticmethod
|
||||
def table(data: list[list[any]]) -> str:
|
||||
"""
|
||||
表格
|
||||
Args:
|
||||
data: 表格数据,二维列表
|
||||
Returns:
|
||||
markdown格式的表格
|
||||
"""
|
||||
# 表头
|
||||
table = "|".join(map(str, data[0])) + "\n"
|
||||
table += "|".join([":-:" for _ in range(len(data[0]))]) + "\n"
|
||||
# 表内容
|
||||
for row in data[1:]:
|
||||
table += "|".join(map(str, row)) + "\n"
|
||||
return table
|
||||
|
||||
@staticmethod
|
||||
def paragraph(text: str) -> str:
|
||||
"""
|
||||
段落
|
||||
Args:
|
||||
text: 段落内容
|
||||
Returns:
|
||||
markdown格式的段落
|
||||
"""
|
||||
return f"{text}\n"
|
||||
|
||||
|
||||
class Mqqapi:
|
||||
@staticmethod
|
||||
@escape_decorator
|
||||
def cmd(text: str, cmd: str, enter: bool = True, reply: bool = False, use_cmd_start: bool = True) -> str:
|
||||
"""
|
||||
生成点击回调文本
|
||||
Args:
|
||||
text: 显示内容
|
||||
cmd: 命令
|
||||
enter: 是否自动发送
|
||||
reply: 是否回复
|
||||
use_cmd_start: 是否使用配置的命令前缀
|
||||
|
||||
Returns:
|
||||
[text](mqqapi://) markdown格式的可点击回调文本,类似于链接
|
||||
"""
|
||||
|
||||
if use_cmd_start:
|
||||
command_start = get_config("command_start", [])
|
||||
if command_start:
|
||||
# 若命令前缀不为空,则使用配置的第一个命令前缀
|
||||
cmd = f"{command_start[0]}{cmd}"
|
||||
return f"[{text}](mqqapi://aio/inlinecmd?command={quote(cmd)}&reply={str(reply).lower()}&enter={str(enter).lower()})"
|
||||
|
||||
|
||||
class RenderData(LiteModel):
|
||||
label: str
|
||||
visited_label: str
|
||||
style: int
|
||||
|
||||
|
||||
class Button(LiteModel):
|
||||
id: int
|
||||
render_data: RenderData
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from urllib.parse import quote
|
||||
|
||||
import aiohttp
|
||||
from PIL import Image
|
||||
|
||||
from ..base.config import get_config
|
||||
from ..base.data import LiteModel
|
||||
from ..base.ly_typing import T_Bot
|
||||
|
||||
|
||||
def escape_md(text: str) -> str:
|
||||
"""
|
||||
转义Markdown特殊字符
|
||||
Args:
|
||||
text: str: 文本
|
||||
|
||||
Returns:
|
||||
str: 转义后文本
|
||||
"""
|
||||
spacial_chars = r"\`*_{}[]()#+-.!"
|
||||
for char in spacial_chars:
|
||||
text = text.replace(char, "\\\\" + char)
|
||||
return text.replace("\n", r"\n").replace('"', r'\\\"')
|
||||
|
||||
|
||||
def escape_decorator(func):
|
||||
def wrapper(text: str):
|
||||
return func(escape_md(text))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def compile_md(comps: list[str]) -> str:
|
||||
"""
|
||||
合成Markdown文本
|
||||
Args:
|
||||
comps: list[str]: 组件列表
|
||||
|
||||
Returns:
|
||||
str: 编译后文本
|
||||
"""
|
||||
return "".join(comps)
|
||||
|
||||
|
||||
class MarkdownComponent:
|
||||
@staticmethod
|
||||
def heading(text: str, level: int = 1) -> str:
|
||||
"""标题"""
|
||||
assert 1 <= level <= 6, "标题级别应在 1-6 之间"
|
||||
return f"{'#' * level} {text}\n"
|
||||
|
||||
@staticmethod
|
||||
def bold(text: str) -> str:
|
||||
"""粗体"""
|
||||
return f"**{text}**"
|
||||
|
||||
@staticmethod
|
||||
def italic(text: str) -> str:
|
||||
"""斜体"""
|
||||
return f"*{text}*"
|
||||
|
||||
@staticmethod
|
||||
def strike(text: str) -> str:
|
||||
"""删除线"""
|
||||
return f"~~{text}~~"
|
||||
|
||||
@staticmethod
|
||||
def code(text: str) -> str:
|
||||
"""行内代码"""
|
||||
return f"`{text}`"
|
||||
|
||||
@staticmethod
|
||||
def code_block(text: str, language: str = "") -> str:
|
||||
"""代码块"""
|
||||
return f"```{language}\n{text}\n```\n"
|
||||
|
||||
@staticmethod
|
||||
def quote(text: str) -> str:
|
||||
"""引用"""
|
||||
return f"> {text}\n\n"
|
||||
|
||||
@staticmethod
|
||||
def link(text: str, url: str, symbol: bool = True) -> str:
|
||||
"""
|
||||
链接
|
||||
|
||||
Args:
|
||||
text: 链接文本
|
||||
url: 链接地址
|
||||
symbol: 是否显示链接图标, mqqapi请使用False
|
||||
"""
|
||||
return f"[{'🔗' if symbol else ''}{text}]({url})"
|
||||
|
||||
@staticmethod
|
||||
def image(url: str, *, size: tuple[int, int]) -> str:
|
||||
"""
|
||||
图片,本地图片不建议直接使用
|
||||
Args:
|
||||
url: 图片链接
|
||||
size: 图片大小
|
||||
|
||||
Returns:
|
||||
markdown格式的图片
|
||||
"""
|
||||
return f"![image #{size[0]}px #{size[1]}px]({url})"
|
||||
|
||||
@staticmethod
|
||||
async def auto_image(image: str | bytes, bot: T_Bot) -> str:
|
||||
"""
|
||||
自动获取图片大小
|
||||
Args:
|
||||
image: 本地图片路径 | 图片url http/file | 图片bytes
|
||||
bot: bot对象,用于上传图片到图床
|
||||
|
||||
Returns:
|
||||
markdown格式的图片
|
||||
"""
|
||||
if isinstance(image, bytes):
|
||||
# 传入为二进制图片
|
||||
image_obj = Image.open(BytesIO(image))
|
||||
base64_string = base64.b64encode(image_obj.tobytes()).decode("utf-8")
|
||||
url = await bot.call_api("upload_image", file=f"base64://{base64_string}")
|
||||
size = image_obj.size
|
||||
elif isinstance(image, str):
|
||||
# 传入链接或本地路径
|
||||
if image.startswith("http"):
|
||||
# 网络请求
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image) as resp:
|
||||
image_data = await resp.read()
|
||||
url = image
|
||||
size = Image.open(BytesIO(image_data)).size
|
||||
|
||||
else:
|
||||
# 本地路径/file://
|
||||
image_obj = Image.open(image.replace("file://", ""))
|
||||
base64_string = base64.b64encode(image_obj.tobytes()).decode("utf-8")
|
||||
url = await bot.call_api("upload_image", file=f"base64://{base64_string}")
|
||||
size = image_obj.size
|
||||
else:
|
||||
raise ValueError("图片类型错误")
|
||||
|
||||
return MarkdownComponent.image(url, size=size)
|
||||
|
||||
@staticmethod
|
||||
def table(data: list[list[any]]) -> str:
|
||||
"""
|
||||
表格
|
||||
Args:
|
||||
data: 表格数据,二维列表
|
||||
Returns:
|
||||
markdown格式的表格
|
||||
"""
|
||||
# 表头
|
||||
table = "|".join(map(str, data[0])) + "\n"
|
||||
table += "|".join([":-:" for _ in range(len(data[0]))]) + "\n"
|
||||
# 表内容
|
||||
for row in data[1:]:
|
||||
table += "|".join(map(str, row)) + "\n"
|
||||
return table
|
||||
|
||||
@staticmethod
|
||||
def paragraph(text: str) -> str:
|
||||
"""
|
||||
段落
|
||||
Args:
|
||||
text: 段落内容
|
||||
Returns:
|
||||
markdown格式的段落
|
||||
"""
|
||||
return f"{text}\n"
|
||||
|
||||
|
||||
class Mqqapi:
|
||||
@staticmethod
|
||||
@escape_decorator
|
||||
def cmd(text: str, cmd: str, enter: bool = True, reply: bool = False, use_cmd_start: bool = True) -> str:
|
||||
"""
|
||||
生成点击回调文本
|
||||
Args:
|
||||
text: 显示内容
|
||||
cmd: 命令
|
||||
enter: 是否自动发送
|
||||
reply: 是否回复
|
||||
use_cmd_start: 是否使用配置的命令前缀
|
||||
|
||||
Returns:
|
||||
[text](mqqapi://) markdown格式的可点击回调文本,类似于链接
|
||||
"""
|
||||
|
||||
if use_cmd_start:
|
||||
command_start = get_config("command_start", [])
|
||||
if command_start:
|
||||
# 若命令前缀不为空,则使用配置的第一个命令前缀
|
||||
cmd = f"{command_start[0]}{cmd}"
|
||||
return f"[{text}](mqqapi://aio/inlinecmd?command={quote(cmd)}&reply={str(reply).lower()}&enter={str(enter).lower()})"
|
||||
|
||||
|
||||
class RenderData(LiteModel):
|
||||
label: str
|
||||
visited_label: str
|
||||
style: int
|
||||
|
||||
|
||||
class Button(LiteModel):
|
||||
id: int
|
||||
render_data: RenderData
|
||||
|
@ -1,202 +1,202 @@
|
||||
import base64
|
||||
import io
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import nonebot
|
||||
from PIL import Image
|
||||
from nonebot.adapters.onebot import v11
|
||||
|
||||
from .html_tool import md_to_pic
|
||||
from .. import load_from_yaml
|
||||
from ..base.ly_typing import T_Bot, T_Message, T_MessageEvent
|
||||
|
||||
config = load_from_yaml("config.yml")
|
||||
|
||||
|
||||
async def broadcast_to_superusers(message: str | T_Message, markdown: bool = False):
|
||||
"""广播消息给超级用户"""
|
||||
for bot in nonebot.get_bots().values():
|
||||
for user_id in config.get("superusers", []):
|
||||
if markdown:
|
||||
await MarkdownMessage.send_md(message, bot, message_type="private", session_id=user_id)
|
||||
else:
|
||||
await bot.send_private_msg(user_id=user_id, message=message)
|
||||
|
||||
|
||||
class MarkdownMessage:
|
||||
@staticmethod
|
||||
async def send_md(
|
||||
markdown: str,
|
||||
bot: T_Bot, *,
|
||||
message_type: str = None,
|
||||
session_id: str | int = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
发送Markdown消息,支持自动转为图片发送
|
||||
Args:
|
||||
markdown:
|
||||
bot:
|
||||
message_type:
|
||||
session_id:
|
||||
Returns:
|
||||
|
||||
"""
|
||||
plain_markdown = markdown.replace("[🔗", "[")
|
||||
md_image_bytes = await md_to_pic(
|
||||
md=plain_markdown,
|
||||
width=540,
|
||||
device_scale_factor=4
|
||||
)
|
||||
print(md_image_bytes)
|
||||
data = await bot.send_msg(
|
||||
message_type=message_type,
|
||||
group_id=session_id,
|
||||
user_id=session_id,
|
||||
message=v11.MessageSegment.image(md_image_bytes),
|
||||
)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def send_image(
|
||||
image: bytes | str,
|
||||
bot: T_Bot, *,
|
||||
message_type: str = None,
|
||||
session_id: str | int = None,
|
||||
event: T_MessageEvent = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
发送单张装逼大图
|
||||
Args:
|
||||
image: 图片字节流或图片本地路径,链接请使用Markdown.image_async方法获取后通过send_md发送
|
||||
bot: bot instance
|
||||
message_type: message message_type
|
||||
session_id: session id
|
||||
event: event
|
||||
kwargs: other arguments
|
||||
Returns:
|
||||
dict: response data
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
async with aiofiles.open(image, "rb") as f:
|
||||
image = await f.read()
|
||||
method = 2
|
||||
if method == 2:
|
||||
base64_string = base64.b64encode(image).decode("utf-8")
|
||||
data = await bot.call_api("upload_image", file=f"base64://{base64_string}")
|
||||
await MarkdownMessage.send_md(MarkdownMessage.image(data, Image.open(io.BytesIO(image)).size), bot,
|
||||
message_type=message_type,
|
||||
session_id=session_id)
|
||||
|
||||
# 其他实现端方案
|
||||
else:
|
||||
image_message_id = (await bot.send_private_msg(
|
||||
user_id=bot.self_id,
|
||||
message=[
|
||||
v11.MessageSegment.image(file=image)
|
||||
]
|
||||
))["message_id"]
|
||||
image_url = (await bot.get_msg(message_id=image_message_id))["message"][0]["data"]["url"]
|
||||
image_size = Image.open(io.BytesIO(image)).size
|
||||
image_md = MarkdownMessage.image(image_url, image_size)
|
||||
return await MarkdownMessage.send_md(image_md, bot, message_type=message_type, session_id=session_id)
|
||||
|
||||
if data is None:
|
||||
data = await bot.send_msg(
|
||||
message_type=message_type,
|
||||
group_id=session_id,
|
||||
user_id=session_id,
|
||||
message=v11.MessageSegment.image(image),
|
||||
**kwargs
|
||||
)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def get_image_url(image: bytes | str, bot: T_Bot) -> str:
|
||||
"""把图片上传到图床,返回链接
|
||||
Args:
|
||||
bot: 发送的bot
|
||||
image: 图片字节流或图片本地路径
|
||||
Returns:
|
||||
"""
|
||||
# 等林文轩修好Lagrange.OneBot再说
|
||||
|
||||
@staticmethod
|
||||
def btn_cmd(name: str, cmd: str, reply: bool = False, enter: bool = True) -> str:
|
||||
"""生成点击回调按钮
|
||||
Args:
|
||||
name: 按钮显示内容
|
||||
cmd: 发送的命令,已在函数内url编码,不需要再次编码
|
||||
reply: 是否以回复的方式发送消息
|
||||
enter: 自动发送消息则为True,否则填充到输入框
|
||||
|
||||
Returns:
|
||||
markdown格式的可点击回调按钮
|
||||
|
||||
"""
|
||||
if "" not in config.get("command_start", ["/"]) and config.get("alconna_use_command_start", False):
|
||||
cmd = f"{config['command_start'][0]}{cmd}"
|
||||
return f"[{name}](mqqapi://aio/inlinecmd?command={quote(cmd)}&reply={str(reply).lower()}&enter={str(enter).lower()})"
|
||||
|
||||
@staticmethod
|
||||
def btn_link(name: str, url: str) -> str:
|
||||
"""生成点击链接按钮
|
||||
Args:
|
||||
name: 链接显示内容
|
||||
url: 链接地址
|
||||
|
||||
Returns:
|
||||
markdown格式的链接
|
||||
|
||||
"""
|
||||
return f"[🔗{name}]({url})"
|
||||
|
||||
@staticmethod
|
||||
def image(url: str, size: tuple[int, int]) -> str:
|
||||
"""构建图片链接
|
||||
Args:
|
||||
size:
|
||||
url: 图片链接
|
||||
|
||||
Returns:
|
||||
markdown格式的图片
|
||||
|
||||
"""
|
||||
return f"![image #{size[0]}px #{size[1]}px]({url})"
|
||||
|
||||
@staticmethod
|
||||
async def image_async(url: str) -> str:
|
||||
"""获取图片,自动请求获取大小,异步
|
||||
Args:
|
||||
url: 图片链接
|
||||
|
||||
Returns:
|
||||
图片Markdown语法: 
|
||||
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as resp:
|
||||
image = Image.open(io.BytesIO(await resp.read()))
|
||||
return MarkdownMessage.image(url, image.size)
|
||||
except Exception as e:
|
||||
nonebot.logger.error(f"get image error: {e}")
|
||||
return "[Image Error]"
|
||||
|
||||
@staticmethod
|
||||
def escape(text: str) -> str:
|
||||
"""转义特殊字符
|
||||
Args:
|
||||
text: 需要转义的文本,请勿直接把整个markdown文本传入,否则会转义掉所有字符
|
||||
|
||||
Returns:
|
||||
转义后的文本
|
||||
|
||||
"""
|
||||
chars = "*[]()~_`>#+=|{}.!"
|
||||
for char in chars:
|
||||
text = text.replace(char, f"\\\\{char}")
|
||||
return text
|
||||
import base64
|
||||
import io
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import nonebot
|
||||
from PIL import Image
|
||||
from nonebot.adapters.onebot import v11
|
||||
|
||||
from .html_tool import md_to_pic
|
||||
from .. import load_from_yaml
|
||||
from ..base.ly_typing import T_Bot, T_Message, T_MessageEvent
|
||||
|
||||
config = load_from_yaml("config.yml")
|
||||
|
||||
|
||||
async def broadcast_to_superusers(message: str | T_Message, markdown: bool = False):
|
||||
"""广播消息给超级用户"""
|
||||
for bot in nonebot.get_bots().values():
|
||||
for user_id in config.get("superusers", []):
|
||||
if markdown:
|
||||
await MarkdownMessage.send_md(message, bot, message_type="private", session_id=user_id)
|
||||
else:
|
||||
await bot.send_private_msg(user_id=user_id, message=message)
|
||||
|
||||
|
||||
class MarkdownMessage:
|
||||
@staticmethod
|
||||
async def send_md(
|
||||
markdown: str,
|
||||
bot: T_Bot, *,
|
||||
message_type: str = None,
|
||||
session_id: str | int = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
发送Markdown消息,支持自动转为图片发送
|
||||
Args:
|
||||
markdown:
|
||||
bot:
|
||||
message_type:
|
||||
session_id:
|
||||
Returns:
|
||||
|
||||
"""
|
||||
plain_markdown = markdown.replace("[🔗", "[")
|
||||
md_image_bytes = await md_to_pic(
|
||||
md=plain_markdown,
|
||||
width=540,
|
||||
device_scale_factor=4
|
||||
)
|
||||
print(md_image_bytes)
|
||||
data = await bot.send_msg(
|
||||
message_type=message_type,
|
||||
group_id=session_id,
|
||||
user_id=session_id,
|
||||
message=v11.MessageSegment.image(md_image_bytes),
|
||||
)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def send_image(
|
||||
image: bytes | str,
|
||||
bot: T_Bot, *,
|
||||
message_type: str = None,
|
||||
session_id: str | int = None,
|
||||
event: T_MessageEvent = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
发送单张装逼大图
|
||||
Args:
|
||||
image: 图片字节流或图片本地路径,链接请使用Markdown.image_async方法获取后通过send_md发送
|
||||
bot: bot instance
|
||||
message_type: message message_type
|
||||
session_id: session id
|
||||
event: event
|
||||
kwargs: other arguments
|
||||
Returns:
|
||||
dict: response data
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
async with aiofiles.open(image, "rb") as f:
|
||||
image = await f.read()
|
||||
method = 2
|
||||
if method == 2:
|
||||
base64_string = base64.b64encode(image).decode("utf-8")
|
||||
data = await bot.call_api("upload_image", file=f"base64://{base64_string}")
|
||||
await MarkdownMessage.send_md(MarkdownMessage.image(data, Image.open(io.BytesIO(image)).size), bot,
|
||||
message_type=message_type,
|
||||
session_id=session_id)
|
||||
|
||||
# 其他实现端方案
|
||||
else:
|
||||
image_message_id = (await bot.send_private_msg(
|
||||
user_id=bot.self_id,
|
||||
message=[
|
||||
v11.MessageSegment.image(file=image)
|
||||
]
|
||||
))["message_id"]
|
||||
image_url = (await bot.get_msg(message_id=image_message_id))["message"][0]["data"]["url"]
|
||||
image_size = Image.open(io.BytesIO(image)).size
|
||||
image_md = MarkdownMessage.image(image_url, image_size)
|
||||
return await MarkdownMessage.send_md(image_md, bot, message_type=message_type, session_id=session_id)
|
||||
|
||||
if data is None:
|
||||
data = await bot.send_msg(
|
||||
message_type=message_type,
|
||||
group_id=session_id,
|
||||
user_id=session_id,
|
||||
message=v11.MessageSegment.image(image),
|
||||
**kwargs
|
||||
)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def get_image_url(image: bytes | str, bot: T_Bot) -> str:
|
||||
"""把图片上传到图床,返回链接
|
||||
Args:
|
||||
bot: 发送的bot
|
||||
image: 图片字节流或图片本地路径
|
||||
Returns:
|
||||
"""
|
||||
# 等林文轩修好Lagrange.OneBot再说
|
||||
|
||||
@staticmethod
|
||||
def btn_cmd(name: str, cmd: str, reply: bool = False, enter: bool = True) -> str:
|
||||
"""生成点击回调按钮
|
||||
Args:
|
||||
name: 按钮显示内容
|
||||
cmd: 发送的命令,已在函数内url编码,不需要再次编码
|
||||
reply: 是否以回复的方式发送消息
|
||||
enter: 自动发送消息则为True,否则填充到输入框
|
||||
|
||||
Returns:
|
||||
markdown格式的可点击回调按钮
|
||||
|
||||
"""
|
||||
if "" not in config.get("command_start", ["/"]) and config.get("alconna_use_command_start", False):
|
||||
cmd = f"{config['command_start'][0]}{cmd}"
|
||||
return f"[{name}](mqqapi://aio/inlinecmd?command={quote(cmd)}&reply={str(reply).lower()}&enter={str(enter).lower()})"
|
||||
|
||||
@staticmethod
|
||||
def btn_link(name: str, url: str) -> str:
|
||||
"""生成点击链接按钮
|
||||
Args:
|
||||
name: 链接显示内容
|
||||
url: 链接地址
|
||||
|
||||
Returns:
|
||||
markdown格式的链接
|
||||
|
||||
"""
|
||||
return f"[🔗{name}]({url})"
|
||||
|
||||
@staticmethod
|
||||
def image(url: str, size: tuple[int, int]) -> str:
|
||||
"""构建图片链接
|
||||
Args:
|
||||
size:
|
||||
url: 图片链接
|
||||
|
||||
Returns:
|
||||
markdown格式的图片
|
||||
|
||||
"""
|
||||
return f"![image #{size[0]}px #{size[1]}px]({url})"
|
||||
|
||||
@staticmethod
|
||||
async def image_async(url: str) -> str:
|
||||
"""获取图片,自动请求获取大小,异步
|
||||
Args:
|
||||
url: 图片链接
|
||||
|
||||
Returns:
|
||||
图片Markdown语法: 
|
||||
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as resp:
|
||||
image = Image.open(io.BytesIO(await resp.read()))
|
||||
return MarkdownMessage.image(url, image.size)
|
||||
except Exception as e:
|
||||
nonebot.logger.error(f"get image error: {e}")
|
||||
return "[Image Error]"
|
||||
|
||||
@staticmethod
|
||||
def escape(text: str) -> str:
|
||||
"""转义特殊字符
|
||||
Args:
|
||||
text: 需要转义的文本,请勿直接把整个markdown文本传入,否则会转义掉所有字符
|
||||
|
||||
Returns:
|
||||
转义后的文本
|
||||
|
||||
"""
|
||||
chars = "*[]()~_`>#+=|{}.!"
|
||||
for char in chars:
|
||||
text = text.replace(char, f"\\\\{char}")
|
||||
return text
|
||||
|
@ -1,101 +1,101 @@
|
||||
import nonebot
|
||||
|
||||
|
||||
def convert_duration(text: str, default) -> float:
|
||||
"""
|
||||
转换自然语言时间为秒数
|
||||
Args:
|
||||
text: 1d2h3m
|
||||
default: 出错时返回
|
||||
|
||||
Returns:
|
||||
float: 总秒数
|
||||
"""
|
||||
units = {
|
||||
"d" : 86400,
|
||||
"h" : 3600,
|
||||
"m" : 60,
|
||||
"s" : 1,
|
||||
"ms": 0.001
|
||||
}
|
||||
|
||||
duration = 0
|
||||
current_number = ''
|
||||
current_unit = ''
|
||||
try:
|
||||
for char in text:
|
||||
if char.isdigit():
|
||||
current_number += char
|
||||
else:
|
||||
if current_number:
|
||||
duration += int(current_number) * units[current_unit]
|
||||
current_number = ''
|
||||
if char in units:
|
||||
current_unit = char
|
||||
else:
|
||||
current_unit = ''
|
||||
|
||||
if current_number:
|
||||
duration += int(current_number) * units[current_unit]
|
||||
|
||||
return duration
|
||||
|
||||
except BaseException as e:
|
||||
nonebot.logger.info(f"convert_duration error: {e}")
|
||||
return default
|
||||
|
||||
|
||||
def convert_time_to_seconds(time_str):
|
||||
"""转换自然语言时长为秒数
|
||||
Args:
|
||||
time_str: 1d2m3s
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
seconds = 0
|
||||
current_number = ''
|
||||
|
||||
for char in time_str:
|
||||
if char.isdigit() or char == '.':
|
||||
current_number += char
|
||||
elif char == 'd':
|
||||
seconds += float(current_number) * 24 * 60 * 60
|
||||
current_number = ''
|
||||
elif char == 'h':
|
||||
seconds += float(current_number) * 60 * 60
|
||||
current_number = ''
|
||||
elif char == 'm':
|
||||
seconds += float(current_number) * 60
|
||||
current_number = ''
|
||||
elif char == 's':
|
||||
seconds += float(current_number)
|
||||
current_number = ''
|
||||
|
||||
return int(seconds)
|
||||
|
||||
|
||||
def convert_seconds_to_time(seconds):
|
||||
"""转换秒数为自然语言时长
|
||||
Args:
|
||||
seconds: 10000
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
d = seconds // (24 * 60 * 60)
|
||||
h = (seconds % (24 * 60 * 60)) // (60 * 60)
|
||||
m = (seconds % (60 * 60)) // 60
|
||||
s = seconds % 60
|
||||
|
||||
# 若值为0则不显示
|
||||
time_str = ''
|
||||
if d:
|
||||
time_str += f"{d}d"
|
||||
if h:
|
||||
time_str += f"{h}h"
|
||||
if m:
|
||||
time_str += f"{m}m"
|
||||
if not time_str:
|
||||
time_str = f"{s}s"
|
||||
return time_str
|
||||
import nonebot
|
||||
|
||||
|
||||
def convert_duration(text: str, default) -> float:
|
||||
"""
|
||||
转换自然语言时间为秒数
|
||||
Args:
|
||||
text: 1d2h3m
|
||||
default: 出错时返回
|
||||
|
||||
Returns:
|
||||
float: 总秒数
|
||||
"""
|
||||
units = {
|
||||
"d" : 86400,
|
||||
"h" : 3600,
|
||||
"m" : 60,
|
||||
"s" : 1,
|
||||
"ms": 0.001
|
||||
}
|
||||
|
||||
duration = 0
|
||||
current_number = ''
|
||||
current_unit = ''
|
||||
try:
|
||||
for char in text:
|
||||
if char.isdigit():
|
||||
current_number += char
|
||||
else:
|
||||
if current_number:
|
||||
duration += int(current_number) * units[current_unit]
|
||||
current_number = ''
|
||||
if char in units:
|
||||
current_unit = char
|
||||
else:
|
||||
current_unit = ''
|
||||
|
||||
if current_number:
|
||||
duration += int(current_number) * units[current_unit]
|
||||
|
||||
return duration
|
||||
|
||||
except BaseException as e:
|
||||
nonebot.logger.info(f"convert_duration error: {e}")
|
||||
return default
|
||||
|
||||
|
||||
def convert_time_to_seconds(time_str):
|
||||
"""转换自然语言时长为秒数
|
||||
Args:
|
||||
time_str: 1d2m3s
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
seconds = 0
|
||||
current_number = ''
|
||||
|
||||
for char in time_str:
|
||||
if char.isdigit() or char == '.':
|
||||
current_number += char
|
||||
elif char == 'd':
|
||||
seconds += float(current_number) * 24 * 60 * 60
|
||||
current_number = ''
|
||||
elif char == 'h':
|
||||
seconds += float(current_number) * 60 * 60
|
||||
current_number = ''
|
||||
elif char == 'm':
|
||||
seconds += float(current_number) * 60
|
||||
current_number = ''
|
||||
elif char == 's':
|
||||
seconds += float(current_number)
|
||||
current_number = ''
|
||||
|
||||
return int(seconds)
|
||||
|
||||
|
||||
def convert_seconds_to_time(seconds):
|
||||
"""转换秒数为自然语言时长
|
||||
Args:
|
||||
seconds: 10000
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
d = seconds // (24 * 60 * 60)
|
||||
h = (seconds % (24 * 60 * 60)) // (60 * 60)
|
||||
m = (seconds % (60 * 60)) // 60
|
||||
s = seconds % 60
|
||||
|
||||
# 若值为0则不显示
|
||||
time_str = ''
|
||||
if d:
|
||||
time_str += f"{d}d"
|
||||
if h:
|
||||
time_str += f"{h}h"
|
||||
if m:
|
||||
time_str += f"{m}m"
|
||||
if not time_str:
|
||||
time_str = f"{s}s"
|
||||
return time_str
|
||||
|
@ -1,99 +1,99 @@
|
||||
import random
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
|
||||
def clamp(value: float, min_value: float, max_value: float) -> float | int:
|
||||
"""将值限制在最小值和最大值之间
|
||||
|
||||
Args:
|
||||
value (float): 要限制的值
|
||||
min_value (float): 最小值
|
||||
max_value (float): 最大值
|
||||
|
||||
Returns:
|
||||
float: 限制后的值
|
||||
"""
|
||||
return max(min(value, max_value), min_value)
|
||||
|
||||
|
||||
def convert_size(size: int, precision: int = 2, add_unit: bool = True, suffix: str = " XiB") -> str | float:
|
||||
"""把字节数转换为人类可读的字符串,计算正负
|
||||
|
||||
Args:
|
||||
|
||||
add_unit: 是否添加单位,False后则suffix无效
|
||||
suffix: XiB或XB
|
||||
precision: 浮点数的小数点位数
|
||||
size (int): 字节数
|
||||
|
||||
Returns:
|
||||
|
||||
str: The human-readable string, e.g. "1.23 GB".
|
||||
"""
|
||||
is_negative = size < 0
|
||||
size = abs(size)
|
||||
for unit in ("", "K", "M", "G", "T", "P", "E", "Z"):
|
||||
if size < 1024:
|
||||
break
|
||||
size /= 1024
|
||||
if is_negative:
|
||||
size = -size
|
||||
if add_unit:
|
||||
return f"{size:.{precision}f}{suffix.replace('X', unit)}"
|
||||
else:
|
||||
return size
|
||||
|
||||
|
||||
def keywords_in_text(keywords: list[str], text: str, all_matched: bool) -> bool:
|
||||
"""
|
||||
检查关键词是否在文本中
|
||||
Args:
|
||||
keywords: 关键词列表
|
||||
text: 文本
|
||||
all_matched: 是否需要全部匹配
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if all_matched:
|
||||
for keyword in keywords:
|
||||
if keyword not in text:
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
for keyword in keywords:
|
||||
if keyword in text:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_for_package(package_name: str) -> bool:
|
||||
try:
|
||||
version(package_name)
|
||||
return True
|
||||
except PackageNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
def random_ascii_string(length: int) -> str:
|
||||
"""
|
||||
生成随机ASCII字符串
|
||||
Args:
|
||||
length:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return "".join([chr(random.randint(33, 126)) for _ in range(length)])
|
||||
|
||||
|
||||
def random_hex_string(length: int) -> str:
|
||||
"""
|
||||
生成随机十六进制字符串
|
||||
Args:
|
||||
length:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return "".join([random.choice("0123456789abcdef") for _ in range(length)])
|
||||
import random
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
|
||||
def clamp(value: float, min_value: float, max_value: float) -> float | int:
|
||||
"""将值限制在最小值和最大值之间
|
||||
|
||||
Args:
|
||||
value (float): 要限制的值
|
||||
min_value (float): 最小值
|
||||
max_value (float): 最大值
|
||||
|
||||
Returns:
|
||||
float: 限制后的值
|
||||
"""
|
||||
return max(min(value, max_value), min_value)
|
||||
|
||||
|
||||
def convert_size(size: int, precision: int = 2, add_unit: bool = True, suffix: str = " XiB") -> str | float:
|
||||
"""把字节数转换为人类可读的字符串,计算正负
|
||||
|
||||
Args:
|
||||
|
||||
add_unit: 是否添加单位,False后则suffix无效
|
||||
suffix: XiB或XB
|
||||
precision: 浮点数的小数点位数
|
||||
size (int): 字节数
|
||||
|
||||
Returns:
|
||||
|
||||
str: The human-readable string, e.g. "1.23 GB".
|
||||
"""
|
||||
is_negative = size < 0
|
||||
size = abs(size)
|
||||
for unit in ("", "K", "M", "G", "T", "P", "E", "Z"):
|
||||
if size < 1024:
|
||||
break
|
||||
size /= 1024
|
||||
if is_negative:
|
||||
size = -size
|
||||
if add_unit:
|
||||
return f"{size:.{precision}f}{suffix.replace('X', unit)}"
|
||||
else:
|
||||
return size
|
||||
|
||||
|
||||
def keywords_in_text(keywords: list[str], text: str, all_matched: bool) -> bool:
|
||||
"""
|
||||
检查关键词是否在文本中
|
||||
Args:
|
||||
keywords: 关键词列表
|
||||
text: 文本
|
||||
all_matched: 是否需要全部匹配
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if all_matched:
|
||||
for keyword in keywords:
|
||||
if keyword not in text:
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
for keyword in keywords:
|
||||
if keyword in text:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_for_package(package_name: str) -> bool:
|
||||
try:
|
||||
version(package_name)
|
||||
return True
|
||||
except PackageNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
def random_ascii_string(length: int) -> str:
|
||||
"""
|
||||
生成随机ASCII字符串
|
||||
Args:
|
||||
length:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return "".join([chr(random.randint(33, 126)) for _ in range(length)])
|
||||
|
||||
|
||||
def random_hex_string(length: int) -> str:
|
||||
"""
|
||||
生成随机十六进制字符串
|
||||
Args:
|
||||
length:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return "".join([random.choice("0123456789abcdef") for _ in range(length)])
|
||||
|
@ -1,3 +1,3 @@
|
||||
from .user_info import user_infos
|
||||
from .count_friends import count_friends
|
||||
from .count_groups import count_groups
|
||||
from .user_info import user_infos
|
||||
from .count_friends import count_friends
|
||||
from .count_groups import count_groups
|
||||
|
@ -1,13 +1,13 @@
|
||||
from nonebot.adapters import satori
|
||||
|
||||
|
||||
async def count_friends(bot: satori.Bot) -> int:
|
||||
cnt: int = 0
|
||||
|
||||
friend_response = await bot.friend_list()
|
||||
while friend_response.next is not None:
|
||||
cnt += len(friend_response.data)
|
||||
friend_response = await bot.friend_list(next_token=friend_response.next)
|
||||
|
||||
cnt += len(friend_response.data)
|
||||
return cnt - 1
|
||||
from nonebot.adapters import satori
|
||||
|
||||
|
||||
async def count_friends(bot: satori.Bot) -> int:
|
||||
cnt: int = 0
|
||||
|
||||
friend_response = await bot.friend_list()
|
||||
while friend_response.next is not None:
|
||||
cnt += len(friend_response.data)
|
||||
friend_response = await bot.friend_list(next_token=friend_response.next)
|
||||
|
||||
cnt += len(friend_response.data)
|
||||
return cnt - 1
|
||||
|
@ -1,13 +1,13 @@
|
||||
from nonebot.adapters import satori
|
||||
|
||||
|
||||
async def count_groups(bot: satori.Bot) -> int:
|
||||
cnt: int = 0
|
||||
|
||||
group_response = await bot.guild_list()
|
||||
while group_response.next is not None:
|
||||
cnt += len(group_response.data)
|
||||
group_response = await bot.friend_list(next_token=group_response.next)
|
||||
|
||||
cnt += len(group_response.data)
|
||||
return cnt - 1
|
||||
from nonebot.adapters import satori
|
||||
|
||||
|
||||
async def count_groups(bot: satori.Bot) -> int:
|
||||
cnt: int = 0
|
||||
|
||||
group_response = await bot.guild_list()
|
||||
while group_response.next is not None:
|
||||
cnt += len(group_response.data)
|
||||
group_response = await bot.friend_list(next_token=group_response.next)
|
||||
|
||||
cnt += len(group_response.data)
|
||||
return cnt - 1
|
||||
|
@ -1,64 +1,64 @@
|
||||
import nonebot
|
||||
|
||||
from nonebot.adapters import satori
|
||||
from nonebot.adapters.satori.models import User
|
||||
|
||||
|
||||
class UserInfo:
|
||||
user_infos: dict = {}
|
||||
|
||||
async def load_friends(self, bot: satori.Bot):
|
||||
nonebot.logger.info("Update user info from friends")
|
||||
friend_response = await bot.friend_list()
|
||||
while friend_response.next is not None:
|
||||
for i in friend_response.data:
|
||||
i: User = i
|
||||
self.user_infos[str(i.id)] = i
|
||||
friend_response = await bot.friend_list(next_token=friend_response.next)
|
||||
|
||||
for i in friend_response.data:
|
||||
i: User = i
|
||||
self.user_infos[str(i.id)] = i
|
||||
|
||||
nonebot.logger.info("Finish update user info")
|
||||
|
||||
async def get(self, uid: int | str) -> User | None:
|
||||
try:
|
||||
return self.user_infos[str(uid)]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
async def put(self, user: User) -> bool:
|
||||
"""
|
||||
向用户信息数据库中添加/修改一项,返回值仅代表数据是否变更,不代表操作是否成功
|
||||
Args:
|
||||
user: 要加入数据库的用户
|
||||
|
||||
Returns: 当数据库中用户信息发生变化时返回 True, 否则返回 False
|
||||
|
||||
"""
|
||||
try:
|
||||
old_user: User = self.user_infos[str(user.id)]
|
||||
attr_edited = False
|
||||
if user.name is not None:
|
||||
if old_user.name != user.name:
|
||||
attr_edited = True
|
||||
self.user_infos[str(user.id)].name = user.name
|
||||
if user.nick is not None:
|
||||
if old_user.nick != user.nick:
|
||||
attr_edited = True
|
||||
self.user_infos[str(user.id)].nick = user.nick
|
||||
if user.avatar is not None:
|
||||
if old_user.avatar != user.avatar:
|
||||
attr_edited = True
|
||||
self.user_infos[str(user.id)].avatar = user.avatar
|
||||
return attr_edited
|
||||
except KeyError:
|
||||
self.user_infos[str(user.id)] = user
|
||||
return True
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
user_infos = UserInfo()
|
||||
import nonebot
|
||||
|
||||
from nonebot.adapters import satori
|
||||
from nonebot.adapters.satori.models import User
|
||||
|
||||
|
||||
class UserInfo:
|
||||
user_infos: dict = {}
|
||||
|
||||
async def load_friends(self, bot: satori.Bot):
|
||||
nonebot.logger.info("Update user info from friends")
|
||||
friend_response = await bot.friend_list()
|
||||
while friend_response.next is not None:
|
||||
for i in friend_response.data:
|
||||
i: User = i
|
||||
self.user_infos[str(i.id)] = i
|
||||
friend_response = await bot.friend_list(next_token=friend_response.next)
|
||||
|
||||
for i in friend_response.data:
|
||||
i: User = i
|
||||
self.user_infos[str(i.id)] = i
|
||||
|
||||
nonebot.logger.info("Finish update user info")
|
||||
|
||||
async def get(self, uid: int | str) -> User | None:
|
||||
try:
|
||||
return self.user_infos[str(uid)]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
async def put(self, user: User) -> bool:
|
||||
"""
|
||||
向用户信息数据库中添加/修改一项,返回值仅代表数据是否变更,不代表操作是否成功
|
||||
Args:
|
||||
user: 要加入数据库的用户
|
||||
|
||||
Returns: 当数据库中用户信息发生变化时返回 True, 否则返回 False
|
||||
|
||||
"""
|
||||
try:
|
||||
old_user: User = self.user_infos[str(user.id)]
|
||||
attr_edited = False
|
||||
if user.name is not None:
|
||||
if old_user.name != user.name:
|
||||
attr_edited = True
|
||||
self.user_infos[str(user.id)].name = user.name
|
||||
if user.nick is not None:
|
||||
if old_user.nick != user.nick:
|
||||
attr_edited = True
|
||||
self.user_infos[str(user.id)].nick = user.nick
|
||||
if user.avatar is not None:
|
||||
if old_user.avatar != user.avatar:
|
||||
attr_edited = True
|
||||
self.user_infos[str(user.id)].avatar = user.avatar
|
||||
return attr_edited
|
||||
except KeyError:
|
||||
self.user_infos[str(user.id)] = user
|
||||
return True
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
user_infos = UserInfo()
|
||||
|
Reference in New Issue
Block a user