mirror of
https://github.com/TriM-Organization/LiteyukiBot-TriM.git
synced 2025-09-07 04:36:23 +00:00
🔀手动Merge轻雪主仓库a77f97f
This commit is contained in:
@ -1,11 +1,12 @@
|
||||
from liteyuki.bot import (
|
||||
LiteyukiBot,
|
||||
get_bot
|
||||
get_bot,
|
||||
get_config,
|
||||
get_config_with_compat
|
||||
)
|
||||
|
||||
from liteyuki.comm import (
|
||||
Channel,
|
||||
chan,
|
||||
Event
|
||||
)
|
||||
|
||||
@ -15,7 +16,27 @@ from liteyuki.plugin import (
|
||||
)
|
||||
|
||||
from liteyuki.log import (
|
||||
logger,
|
||||
init_log
|
||||
|
||||
init_log,
|
||||
logger
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LiteyukiBot",
|
||||
"get_bot",
|
||||
"get_config",
|
||||
"get_config_with_compat",
|
||||
"Channel",
|
||||
"Event",
|
||||
"load_plugin",
|
||||
"load_plugins",
|
||||
"init_log",
|
||||
"logger",
|
||||
]
|
||||
|
||||
__version__ = "6.3.9" # 测试版本号
|
||||
# 6.3.9
|
||||
# 更改了on语法
|
||||
|
||||
# 6.3.8
|
||||
# 1. 初步添加对聊天的支持
|
||||
# 2. 优化了通道的性能
|
||||
|
@ -1,103 +1,102 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from watchdog.observers import Observer
|
||||
|
||||
from liteyuki.bot.lifespan import LIFESPAN_FUNC, Lifespan
|
||||
from liteyuki.core import IS_MAIN_PROCESS
|
||||
from liteyuki.bot.lifespan import LIFESPAN_FUNC, Lifespan, PROCESS_LIFESPAN_FUNC
|
||||
from liteyuki.comm.channel import get_channel
|
||||
from liteyuki.core.manager import ProcessManager
|
||||
from liteyuki.core.spawn_process import mb_run, nb_run
|
||||
from liteyuki.log import init_log, logger
|
||||
from liteyuki.plugin import load_plugins
|
||||
from liteyuki.plugin import load_plugin
|
||||
from liteyuki.utils import IS_MAIN_PROCESS
|
||||
|
||||
__all__ = ["LiteyukiBot", "get_bot"]
|
||||
__all__ = [
|
||||
"LiteyukiBot",
|
||||
"get_bot",
|
||||
"get_config",
|
||||
"get_config_with_compat",
|
||||
]
|
||||
|
||||
|
||||
class LiteyukiBot:
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""
|
||||
初始化轻雪实例
|
||||
Args:
|
||||
**kwargs: 配置
|
||||
"""
|
||||
"""常规操作"""
|
||||
print_logo()
|
||||
global _BOT_INSTANCE
|
||||
_BOT_INSTANCE = self # 引用
|
||||
|
||||
"""配置"""
|
||||
self.config: dict[str, Any] = kwargs
|
||||
|
||||
"""初始化"""
|
||||
self.init(**self.config) # 初始化
|
||||
logger.info("尹灵温 正在初始化…")
|
||||
|
||||
self.lifespan: Lifespan = Lifespan()
|
||||
"""生命周期管理"""
|
||||
self.lifespan = Lifespan()
|
||||
self.process_manager: ProcessManager = ProcessManager(lifespan=self.lifespan)
|
||||
|
||||
self.process_manager: ProcessManager = ProcessManager(bot=self)
|
||||
"""事件循环"""
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
self.loop_thread = threading.Thread(target=self.loop.run_forever, daemon=True)
|
||||
self.stop_event = threading.Event()
|
||||
self.call_restart_count = 0
|
||||
|
||||
"""加载插件加载器"""
|
||||
load_plugin("liteyuki.plugins.plugin_loader") # 加载轻雪插件
|
||||
|
||||
async def _run(self):
|
||||
"""
|
||||
启动逻辑
|
||||
"""
|
||||
await self.lifespan.before_start() # 启动前钩子
|
||||
await self.lifespan.after_start() # 启动后钩子
|
||||
await self.keep_alive()
|
||||
|
||||
def run(self):
|
||||
load_plugins("liteyuki/plugins") # 加载轻雪插件
|
||||
"""
|
||||
外部启动接口
|
||||
"""
|
||||
self.process_manager.start_all()
|
||||
try:
|
||||
asyncio.run(self._run())
|
||||
except KeyboardInterrupt:
|
||||
logger.opt(colors=True).info("<y>尹灵温 关闭中…</y>")
|
||||
self.stop()
|
||||
logger.opt(colors=True).info("<y>尹灵温 已关停</y>")
|
||||
|
||||
self.loop_thread.start() # 启动事件循环
|
||||
asyncio.run(self.lifespan.before_start()) # 启动前钩子
|
||||
|
||||
self.process_manager.add_target("nonebot", nb_run, **self.config)
|
||||
self.process_manager.start("nonebot")
|
||||
|
||||
self.process_manager.add_target("melobot", mb_run, **self.config)
|
||||
self.process_manager.start("melobot")
|
||||
|
||||
asyncio.run(self.lifespan.after_start()) # 启动后钩子
|
||||
|
||||
self.start_watcher() # 启动文件监视器
|
||||
|
||||
def start_watcher(self):
|
||||
if self.config.get("debug", False):
|
||||
|
||||
src_directories = (
|
||||
"liteyuki",
|
||||
"src/liteyuki_main",
|
||||
"src/liteyuki_plugins",
|
||||
"src/nonebot_plugins",
|
||||
"src/utils",
|
||||
)
|
||||
src_excludes_extensions = ("pyc",)
|
||||
|
||||
logger.debug("轻雪重载 已启用,正在加载文件修改监测……")
|
||||
restart = self.restart_process
|
||||
|
||||
class CodeModifiedHandler(FileSystemEventHandler):
|
||||
"""
|
||||
Handler for code file changes
|
||||
"""
|
||||
|
||||
def on_modified(self, event):
|
||||
if (
|
||||
event.src_path.endswith(src_excludes_extensions)
|
||||
or event.is_directory
|
||||
or "__pycache__" in event.src_path
|
||||
):
|
||||
return
|
||||
logger.info(f"文件 {event.src_path} 已修改,机器人自动重启……")
|
||||
restart()
|
||||
|
||||
code_modified_handler = CodeModifiedHandler()
|
||||
|
||||
observer = Observer()
|
||||
for directory in src_directories:
|
||||
observer.schedule(code_modified_handler, directory, recursive=True)
|
||||
observer.start()
|
||||
async def keep_alive(self):
|
||||
"""
|
||||
保持轻雪运行
|
||||
"""
|
||||
logger.info("尹灵温 持续运行中…")
|
||||
try:
|
||||
while not self.stop_event.is_set():
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception:
|
||||
logger.info("尹灵温 现退停…")
|
||||
self.stop()
|
||||
|
||||
def restart(self, delay: int = 0):
|
||||
"""
|
||||
重启轻雪本体
|
||||
Returns:
|
||||
|
||||
Args:
|
||||
delay ([`int`](https%3A//docs.python.org/3/library/functions.html#int), optional): 延迟重启时间. Defaults to 0.
|
||||
"""
|
||||
|
||||
if self.call_restart_count < 1:
|
||||
executable = sys.executable
|
||||
args = sys.argv
|
||||
logger.info("正在重启 尹灵温...")
|
||||
logger.info("正在重启 尹灵温机器人框架")
|
||||
time.sleep(delay)
|
||||
if platform.system() == "Windows":
|
||||
cmd = "start"
|
||||
@ -110,7 +109,9 @@ class LiteyukiBot:
|
||||
self.process_manager.terminate_all()
|
||||
# 进程退出后重启
|
||||
threading.Thread(
|
||||
target=os.system, args=(f"{cmd} {executable} {' '.join(args)}",)
|
||||
target=os.system,
|
||||
args=(f"{cmd} {executable} {' '.join(args)}",),
|
||||
daemon=True,
|
||||
).start()
|
||||
sys.exit(0)
|
||||
self.call_restart_count += 1
|
||||
@ -119,44 +120,46 @@ class LiteyukiBot:
|
||||
"""
|
||||
停止轻雪
|
||||
Args:
|
||||
name: 进程名称, 默认为None, 所有进程
|
||||
name ([`Optional`](https%3A//docs.python.org/3/library/typing.html#typing.Optional)[[`str`](https%3A//docs.python.org/3/library/stdtypes.html#str)]): 进程名. Defaults to None.
|
||||
Returns:
|
||||
|
||||
"""
|
||||
logger.info("Stopping LiteyukiBot...")
|
||||
|
||||
self.loop.create_task(self.lifespan.before_shutdown()) # 重启前钩子
|
||||
self.loop.create_task(self.lifespan.before_shutdown()) # 停止前钩子
|
||||
|
||||
if name:
|
||||
self.process_manager.terminate(name)
|
||||
if name is not None:
|
||||
chan_active = get_channel(f"{name}-active")
|
||||
chan_active.send(1)
|
||||
else:
|
||||
self.process_manager.terminate_all()
|
||||
for process_name in self.process_manager.processes:
|
||||
chan_active = get_channel(f"{process_name}-active")
|
||||
chan_active.send(1)
|
||||
|
||||
def init(self, *args, **kwargs):
|
||||
"""
|
||||
初始化轻雪, 自动调用
|
||||
Returns:
|
||||
|
||||
Args:
|
||||
*args: 参数
|
||||
**kwargs: 关键字参数
|
||||
"""
|
||||
self.init_config()
|
||||
self.init_logger()
|
||||
|
||||
def init_logger(self):
|
||||
# 修改nonebot的日志配置
|
||||
"""
|
||||
初始化日志
|
||||
"""
|
||||
init_log(config=self.config)
|
||||
|
||||
def init_config(self):
|
||||
pass
|
||||
def stop(self):
|
||||
"""
|
||||
停止轻雪
|
||||
"""
|
||||
self.process_manager.terminate_all()
|
||||
self.stop_event.set()
|
||||
|
||||
def on_before_start(self, func: LIFESPAN_FUNC):
|
||||
def on_before_start(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
"""
|
||||
注册启动前的函数
|
||||
Args:
|
||||
func:
|
||||
|
||||
func ([`LIFESPAN_FUNC`](./lifespan#var-lifespan-func)): 生命周期函数
|
||||
Returns:
|
||||
|
||||
[`LIFESPAN_FUNC`](./lifespan#var-lifespan-func): 生命周期函数
|
||||
"""
|
||||
return self.lifespan.on_before_start(func)
|
||||
|
||||
@ -164,81 +167,128 @@ class LiteyukiBot:
|
||||
"""
|
||||
注册启动后的函数
|
||||
Args:
|
||||
func:
|
||||
|
||||
func ([`LIFESPAN_FUNC`](./lifespan#var-lifespan-func)): 生命周期函数
|
||||
Returns:
|
||||
|
||||
[`LIFESPAN_FUNC`](./lifespan#var-lifespan-func): 生命周期函数
|
||||
"""
|
||||
return self.lifespan.on_after_start(func)
|
||||
|
||||
def on_before_shutdown(self, func: LIFESPAN_FUNC):
|
||||
"""
|
||||
注册停止前的函数,为子进程停止时调用
|
||||
Args:
|
||||
func:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return self.lifespan.on_before_shutdown(func)
|
||||
|
||||
def on_after_shutdown(self, func: LIFESPAN_FUNC):
|
||||
"""
|
||||
注册停止后的函数:未实现
|
||||
Args:
|
||||
func:
|
||||
|
||||
func ([`LIFESPAN_FUNC`](./lifespan#var-lifespan-func)): 生命周期函数
|
||||
Returns:
|
||||
|
||||
[`LIFESPAN_FUNC`](./lifespan#var-lifespan-func): 生命周期函数
|
||||
"""
|
||||
return self.lifespan.on_after_shutdown(func)
|
||||
|
||||
def on_before_restart(self, func: LIFESPAN_FUNC):
|
||||
def on_before_process_shutdown(self, func: PROCESS_LIFESPAN_FUNC):
|
||||
"""
|
||||
注册重启前的函数,为子进程重启时调用
|
||||
注册进程停止前的函数,为子进程停止时调用
|
||||
Args:
|
||||
func:
|
||||
|
||||
func ([`PROCESS_LIFESPAN_FUNC`](./lifespan#var-process-lifespan-func)): 生命周期函数
|
||||
Returns:
|
||||
[`PROCESS_LIFESPAN_FUNC`](./lifespan#var-process-lifespan-func): 生命周期函数
|
||||
"""
|
||||
return self.lifespan.on_before_process_shutdown(func)
|
||||
|
||||
def on_before_process_restart(
|
||||
self, func: PROCESS_LIFESPAN_FUNC
|
||||
) -> PROCESS_LIFESPAN_FUNC:
|
||||
"""
|
||||
注册进程重启前的函数,为子进程重启时调用
|
||||
Args:
|
||||
func ([`PROCESS_LIFESPAN_FUNC`](./lifespan#var-process-lifespan-func)): 生命周期函数
|
||||
Returns:
|
||||
[`PROCESS_LIFESPAN_FUNC`](./lifespan#var-process-lifespan-func): 生命周期函数
|
||||
"""
|
||||
|
||||
return self.lifespan.on_before_restart(func)
|
||||
return self.lifespan.on_before_process_restart(func)
|
||||
|
||||
def on_after_restart(self, func: LIFESPAN_FUNC):
|
||||
"""
|
||||
注册重启后的函数:未实现
|
||||
Args:
|
||||
func:
|
||||
|
||||
func ([`LIFESPAN_FUNC`](./lifespan#var-lifespan-func)): 生命周期函数
|
||||
Returns:
|
||||
|
||||
[`LIFESPAN_FUNC`](./lifespan#var-lifespan-func): 生命周期函数
|
||||
"""
|
||||
return self.lifespan.on_after_restart(func)
|
||||
|
||||
def on_after_nonebot_init(self, func: LIFESPAN_FUNC):
|
||||
"""
|
||||
注册nonebot初始化后的函数
|
||||
Args:
|
||||
func:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return self.lifespan.on_after_nonebot_init(func)
|
||||
_BOT_INSTANCE: LiteyukiBot
|
||||
|
||||
|
||||
_BOT_INSTANCE: Optional[LiteyukiBot] = None
|
||||
|
||||
|
||||
def get_bot() -> Optional[LiteyukiBot]:
|
||||
def get_bot() -> LiteyukiBot:
|
||||
"""
|
||||
获取轻雪实例
|
||||
Returns:
|
||||
LiteyukiBot: 当前的轻雪实例
|
||||
[`LiteyukiBot`](#class-liteyukibot): 轻雪实例
|
||||
"""
|
||||
|
||||
if IS_MAIN_PROCESS:
|
||||
if _BOT_INSTANCE is None:
|
||||
raise RuntimeError("尹灵温 实例未初始化")
|
||||
return _BOT_INSTANCE
|
||||
else:
|
||||
# 从多进程上下文中获取
|
||||
pass
|
||||
raise RuntimeError("无法在子进程中获取机器人实例")
|
||||
|
||||
|
||||
def get_config(key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
获取配置
|
||||
Args:
|
||||
key ([`str`](https%3A//docs.python.org/3/library/stdtypes.html#str)): 配置键
|
||||
default ([`Any`](https%3A//docs.python.org/3/library/functions.html#any), optional): 默认值. Defaults to None.
|
||||
Returns:
|
||||
[`Any`](https%3A//docs.python.org/3/library/functions.html#any): 配置值
|
||||
"""
|
||||
return get_bot().config.get(key, default)
|
||||
|
||||
|
||||
def get_config_with_compat(
|
||||
key: str, compat_keys: tuple[str], default: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
获取配置,兼容旧版本
|
||||
Args:
|
||||
key ([`str`](https%3A//docs.python.org/3/library/stdtypes.html#str)): 配置键
|
||||
compat_keys ([`tuple`](https%3A//docs.python.org/3/library/stdtypes.html#tuple)[`str`](https%3A//docs.python.org/3/library/stdtypes.html#str)): 兼容键
|
||||
default ([`Any`](https%3A//docs.python.org/3/library/functions.html#any), optional): 默认值. Defaults to None.
|
||||
|
||||
Returns:
|
||||
[`Any`](https%3A//docs.python.org/3/library/functions.html#any): 配置值
|
||||
"""
|
||||
if key in get_bot().config:
|
||||
return get_bot().config[key]
|
||||
for compat_key in compat_keys:
|
||||
if compat_key in get_bot().config:
|
||||
logger.warning(f'配置键 "{compat_key}" 即将被 "{key}" 取代,请及时更新')
|
||||
return get_bot().config[compat_key]
|
||||
return default
|
||||
|
||||
|
||||
def print_logo():
|
||||
"""@litedoc-hide"""
|
||||
print(
|
||||
"\033[34m"
|
||||
+ r"""
|
||||
▅▅▅▅▅▅▅▅▅▅▅▅▅▅██ ▅▅▅▅▅▅▅▅▅▅▅▅▅▅██ ██ ▅▅▅▅▅▅▅▅▅▅█™
|
||||
▛ ██ ██ ▛ ██ ███ ██ ██
|
||||
██ ██ ███████████████ ██ ████████▅ ██
|
||||
███████████████ ██ ███ ██ ██
|
||||
██ ██ ▅██████████████▛ ██ ████████████
|
||||
██ ██ ███ ███
|
||||
████████████████ ██▅ ███ ██ ▅▅▅▅▅▅▅▅▅▅▅██
|
||||
███ █ ▜███████ ██ ███ ██ ██ ██ ██
|
||||
███ ███ █████▛ ██ ██ ██ ██ ██
|
||||
███ ██ ███ █ ██ ██ ██ ██ ██
|
||||
███ █████ ██████ ███ ██████████████
|
||||
商业标记 版权所有 © 2024 金羿Eilles
|
||||
机器软件 版权所有 © 2020-2024 神羽SnowyKami & 金羿Eilles\\
|
||||
会同 LiteyukiStudio & 睿乐组织
|
||||
保留所有权利
|
||||
"""
|
||||
+ "\033[0m"
|
||||
)
|
||||
|
@ -8,14 +8,19 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
@File : lifespan.py
|
||||
@Software: PyCharm
|
||||
"""
|
||||
from typing import Any, Awaitable, Callable, TypeAlias
|
||||
import asyncio
|
||||
from typing import Any, Awaitable, Callable, TypeAlias, Sequence
|
||||
|
||||
from liteyuki.log import logger
|
||||
from liteyuki.utils import is_coroutine_callable
|
||||
from liteyuki.utils import is_coroutine_callable, async_wrapper
|
||||
|
||||
SYNC_LIFESPAN_FUNC: TypeAlias = Callable[[], Any]
|
||||
ASYNC_LIFESPAN_FUNC: TypeAlias = Callable[[], Awaitable[Any]]
|
||||
LIFESPAN_FUNC: TypeAlias = SYNC_LIFESPAN_FUNC | ASYNC_LIFESPAN_FUNC
|
||||
SYNC_LIFESPAN_FUNC: TypeAlias = Callable[[], Any] # 同步生命周期函数
|
||||
ASYNC_LIFESPAN_FUNC: TypeAlias = Callable[[], Awaitable[Any]] # 异步生命周期函数
|
||||
LIFESPAN_FUNC: TypeAlias = SYNC_LIFESPAN_FUNC | ASYNC_LIFESPAN_FUNC # 生命周期函数
|
||||
|
||||
SYNC_PROCESS_LIFESPAN_FUNC: TypeAlias = Callable[[str], Any] # 同步进程生命周期函数
|
||||
ASYNC_PROCESS_LIFESPAN_FUNC: TypeAlias = Callable[[str], Awaitable[Any]] # 异步进程生命周期函数
|
||||
PROCESS_LIFESPAN_FUNC: TypeAlias = SYNC_PROCESS_LIFESPAN_FUNC | ASYNC_PROCESS_LIFESPAN_FUNC # 进程函数
|
||||
|
||||
|
||||
class Lifespan:
|
||||
@ -23,41 +28,35 @@ class Lifespan:
|
||||
"""
|
||||
轻雪生命周期管理,启动、停止、重启
|
||||
"""
|
||||
|
||||
self.life_flag: int = 0 # 0: 启动前,1: 启动后,2: 停止前,3: 停止后
|
||||
self.life_flag: int = 0
|
||||
|
||||
self._before_start_funcs: list[LIFESPAN_FUNC] = []
|
||||
self._after_start_funcs: list[LIFESPAN_FUNC] = []
|
||||
|
||||
self._before_shutdown_funcs: list[LIFESPAN_FUNC] = []
|
||||
self._before_process_shutdown_funcs: list[PROCESS_LIFESPAN_FUNC] = []
|
||||
self._after_shutdown_funcs: list[LIFESPAN_FUNC] = []
|
||||
|
||||
self._before_restart_funcs: list[LIFESPAN_FUNC] = []
|
||||
self._before_process_restart_funcs: list[PROCESS_LIFESPAN_FUNC] = []
|
||||
self._after_restart_funcs: list[LIFESPAN_FUNC] = []
|
||||
|
||||
self._after_nonebot_init_funcs: list[LIFESPAN_FUNC] = []
|
||||
|
||||
@staticmethod
|
||||
async def _run_funcs(funcs: list[LIFESPAN_FUNC]) -> None:
|
||||
async def run_funcs(funcs: Sequence[LIFESPAN_FUNC | PROCESS_LIFESPAN_FUNC], *args, **kwargs) -> None:
|
||||
"""
|
||||
运行函数
|
||||
并发运行异步函数
|
||||
Args:
|
||||
funcs:
|
||||
funcs ([`Sequence`](https%3A//docs.python.org/3/library/typing.html#typing.Sequence)[[`ASYNC_LIFESPAN_FUNC`](#var-lifespan-func) | [`PROCESS_LIFESPAN_FUNC`](#var-process-lifespan-func)]): 函数列表
|
||||
Returns:
|
||||
"""
|
||||
for func in funcs:
|
||||
if is_coroutine_callable(func):
|
||||
await func()
|
||||
else:
|
||||
func()
|
||||
tasks = [func(*args, **kwargs) if is_coroutine_callable(func) else async_wrapper(func)(*args, **kwargs) for func in funcs]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def on_before_start(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
"""
|
||||
注册启动时的函数
|
||||
Args:
|
||||
func:
|
||||
func ([`LIFESPAN_FUNC`](#var-lifespan-func)): 生命周期函数
|
||||
Returns:
|
||||
LIFESPAN_FUNC:
|
||||
[`LIFESPAN_FUNC`](#var-lifespan-func): 生命周期函数
|
||||
"""
|
||||
self._before_start_funcs.append(func)
|
||||
return func
|
||||
@ -66,124 +65,95 @@ class Lifespan:
|
||||
"""
|
||||
注册启动时的函数
|
||||
Args:
|
||||
func:
|
||||
func ([`LIFESPAN_FUNC`](#var-lifespan-func)): 生命周期函数
|
||||
Returns:
|
||||
LIFESPAN_FUNC:
|
||||
[`LIFESPAN_FUNC`](#var-lifespan-func): 生命周期函数
|
||||
"""
|
||||
self._after_start_funcs.append(func)
|
||||
return func
|
||||
|
||||
def on_before_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
def on_before_process_shutdown(self, func: PROCESS_LIFESPAN_FUNC) -> PROCESS_LIFESPAN_FUNC:
|
||||
"""
|
||||
注册停止前的函数
|
||||
注册进程停止前的函数
|
||||
Args:
|
||||
func:
|
||||
func ([`PROCESS_LIFESPAN_FUNC`](#var-process-lifespan-func)): 进程生命周期函数
|
||||
Returns:
|
||||
LIFESPAN_FUNC:
|
||||
[`PROCESS_LIFESPAN_FUNC`](#var-process-lifespan-func): 进程生命周期函数
|
||||
"""
|
||||
self._before_shutdown_funcs.append(func)
|
||||
self._before_process_shutdown_funcs.append(func)
|
||||
return func
|
||||
|
||||
def on_after_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
"""
|
||||
注册停止后的函数
|
||||
Args:
|
||||
func:
|
||||
|
||||
func ([`LIFESPAN_FUNC`](#var-lifespan-func)): 生命周期函数
|
||||
Returns:
|
||||
LIFESPAN_FUNC:
|
||||
|
||||
[`LIFESPAN_FUNC`](#var-lifespan-func): 生命周期函数
|
||||
"""
|
||||
self._after_shutdown_funcs.append(func)
|
||||
return func
|
||||
|
||||
def on_before_restart(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
def on_before_process_restart(self, func: PROCESS_LIFESPAN_FUNC) -> PROCESS_LIFESPAN_FUNC:
|
||||
"""
|
||||
注册重启时的函数
|
||||
注册进程重启前的函数
|
||||
Args:
|
||||
func:
|
||||
func ([`PROCESS_LIFESPAN_FUNC`](#var-process-lifespan-func)): 进程生命周期函数
|
||||
Returns:
|
||||
LIFESPAN_FUNC:
|
||||
[`PROCESS_LIFESPAN_FUNC`](#var-process-lifespan-func): 进程生命周期函数
|
||||
"""
|
||||
self._before_restart_funcs.append(func)
|
||||
self._before_process_restart_funcs.append(func)
|
||||
return func
|
||||
|
||||
def on_after_restart(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
"""
|
||||
注册重启后的函数
|
||||
Args:
|
||||
func:
|
||||
func ([`LIFESPAN_FUNC`](#var-lifespan-func)): 生命周期函数
|
||||
Returns:
|
||||
LIFESPAN_FUNC:
|
||||
[`LIFESPAN_FUNC`](#var-lifespan-func): 生命周期函数
|
||||
"""
|
||||
self._after_restart_funcs.append(func)
|
||||
return func
|
||||
|
||||
def on_after_nonebot_init(self, func):
|
||||
"""
|
||||
注册 NoneBot 初始化后的函数
|
||||
Args:
|
||||
func:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
self._after_nonebot_init_funcs.append(func)
|
||||
return func
|
||||
|
||||
async def before_start(self) -> None:
|
||||
"""
|
||||
启动前
|
||||
Returns:
|
||||
启动前钩子
|
||||
"""
|
||||
logger.debug("正在运行 before_start 之函数")
|
||||
await self._run_funcs(self._before_start_funcs)
|
||||
logger.debug("运行 before_start 函数")
|
||||
await self.run_funcs(self._before_start_funcs)
|
||||
|
||||
async def after_start(self) -> None:
|
||||
"""
|
||||
启动后
|
||||
Returns:
|
||||
启动后钩子
|
||||
"""
|
||||
logger.debug("正在运行 after_start 之函数")
|
||||
await self._run_funcs(self._after_start_funcs)
|
||||
logger.debug("运行 after_start 函数")
|
||||
await self.run_funcs(self._after_start_funcs)
|
||||
|
||||
async def before_shutdown(self) -> None:
|
||||
async def before_process_shutdown(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
停止前
|
||||
Returns:
|
||||
停止前钩子
|
||||
"""
|
||||
logger.debug("正在运行 before_shutdown 之函数")
|
||||
await self._run_funcs(self._before_shutdown_funcs)
|
||||
logger.debug("运行 before_shutdown 函数")
|
||||
await self.run_funcs(self._before_process_shutdown_funcs, *args, **kwargs)
|
||||
|
||||
async def after_shutdown(self) -> None:
|
||||
"""
|
||||
停止后
|
||||
Returns:
|
||||
停止后钩子 未实现
|
||||
"""
|
||||
logger.debug("正在运行 after_shutdown 之函数")
|
||||
await self._run_funcs(self._after_shutdown_funcs)
|
||||
logger.debug("运行 after_shutdown 函数")
|
||||
await self.run_funcs(self._after_shutdown_funcs)
|
||||
|
||||
async def before_restart(self) -> None:
|
||||
async def before_process_restart(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
重启前
|
||||
Returns:
|
||||
重启前钩子
|
||||
"""
|
||||
logger.debug("正在运行 before_restart 之函数")
|
||||
await self._run_funcs(self._before_restart_funcs)
|
||||
logger.debug("运行 before_restart 函数")
|
||||
await self.run_funcs(self._before_process_restart_funcs, *args, **kwargs)
|
||||
|
||||
async def after_restart(self) -> None:
|
||||
"""
|
||||
重启后
|
||||
Returns:
|
||||
|
||||
重启后钩子 未实现
|
||||
"""
|
||||
logger.debug("正在运行 after_restart 之函数")
|
||||
await self._run_funcs(self._after_restart_funcs)
|
||||
|
||||
async def after_nonebot_init(self) -> None:
|
||||
"""
|
||||
NoneBot 初始化后
|
||||
Returns:
|
||||
"""
|
||||
logger.debug("正在运行 after_nonebot_init 之函数")
|
||||
await self._run_funcs(self._after_nonebot_init_funcs)
|
||||
logger.debug("运行 after_restart 函数")
|
||||
await self.run_funcs(self._after_restart_funcs)
|
||||
|
@ -1,30 +1,38 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/7/26 下午10:36
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : __init__.py
|
||||
@Software: PyCharm
|
||||
该模块用于轻雪主进程和Nonebot子进程之间的通信
|
||||
依赖关系
|
||||
event -> _
|
||||
storage -> channel_
|
||||
rpc -> channel_, storage
|
||||
"""
|
||||
from liteyuki.comm.channel import (
|
||||
Channel,
|
||||
chan,
|
||||
get_channel,
|
||||
set_channel,
|
||||
set_channels,
|
||||
get_channels
|
||||
get_channels,
|
||||
active_channel,
|
||||
passive_channel
|
||||
)
|
||||
from liteyuki.comm.event import Event
|
||||
|
||||
__all__ = [
|
||||
"Channel",
|
||||
"chan",
|
||||
"Event",
|
||||
"get_channel",
|
||||
"set_channel",
|
||||
"set_channels",
|
||||
"get_channels"
|
||||
"get_channels",
|
||||
"active_channel",
|
||||
"passive_channel"
|
||||
]
|
||||
|
||||
from liteyuki.utils import IS_MAIN_PROCESS
|
||||
|
||||
# 第一次引用必定为赋值
|
||||
_ref_count = 0
|
||||
if not IS_MAIN_PROCESS:
|
||||
if (active_channel is None or passive_channel is None) and _ref_count > 0:
|
||||
raise RuntimeError("无法在子进程中初始化 Channel")
|
||||
_ref_count += 1
|
||||
|
@ -1,219 +1,305 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/7/26 下午11:21
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : channel.py
|
||||
@Software: PyCharm
|
||||
|
||||
本模块定义了一个通用的通道类,用于进程间通信
|
||||
"""
|
||||
import functools
|
||||
import multiprocessing
|
||||
import threading
|
||||
import asyncio
|
||||
from multiprocessing import Pipe
|
||||
from typing import Any, Optional, Callable, Awaitable, List, TypeAlias
|
||||
from uuid import uuid4
|
||||
from typing import Any, Callable, Coroutine, Generic, Optional, TypeAlias, TypeVar, get_args
|
||||
|
||||
from liteyuki.utils import is_coroutine_callable, run_coroutine
|
||||
from liteyuki.log import logger
|
||||
from liteyuki.utils import IS_MAIN_PROCESS, is_coroutine_callable
|
||||
|
||||
SYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[Any], Any]
|
||||
ASYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[Any], Awaitable[Any]]
|
||||
ON_RECEIVE_FUNC: TypeAlias = SYNC_ON_RECEIVE_FUNC | ASYNC_ON_RECEIVE_FUNC
|
||||
T = TypeVar("T")
|
||||
|
||||
SYNC_FILTER_FUNC: TypeAlias = Callable[[Any], bool]
|
||||
ASYNC_FILTER_FUNC: TypeAlias = Callable[[Any], Awaitable[bool]]
|
||||
FILTER_FUNC: TypeAlias = SYNC_FILTER_FUNC | ASYNC_FILTER_FUNC
|
||||
SYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[T], Any] # 同步接收函数
|
||||
ASYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[T], Coroutine[Any, Any, Any]] # 异步接收函数
|
||||
ON_RECEIVE_FUNC: TypeAlias = SYNC_ON_RECEIVE_FUNC | ASYNC_ON_RECEIVE_FUNC # 接收函数
|
||||
|
||||
IS_MAIN_PROCESS = multiprocessing.current_process().name == "MainProcess"
|
||||
SYNC_FILTER_FUNC: TypeAlias = Callable[[T], bool] # 同步过滤函数
|
||||
ASYNC_FILTER_FUNC: TypeAlias = Callable[[T], Coroutine[Any, Any, bool]] # 异步过滤函数
|
||||
FILTER_FUNC: TypeAlias = SYNC_FILTER_FUNC | ASYNC_FILTER_FUNC # 过滤函数
|
||||
|
||||
_func_id: int = 0
|
||||
_channel: dict[str, "Channel"] = {}
|
||||
_callback_funcs: dict[str, ON_RECEIVE_FUNC] = {}
|
||||
_callback_funcs: dict[int, ON_RECEIVE_FUNC] = {}
|
||||
|
||||
|
||||
class Channel:
|
||||
class Channel(Generic[T]):
|
||||
"""
|
||||
通道类,用于进程间通信,进程内不可用,仅限主进程和子进程之间通信
|
||||
通道类,可以在进程间和进程内通信,双向但同时只能有一个发送者和一个接收者
|
||||
有两种接收工作方式,但是只能选择一种,主动接收和被动接收,主动接收使用 `receive` 方法,被动接收使用 `on_receive` 装饰器
|
||||
"""
|
||||
|
||||
def __init__(self, _id: str):
|
||||
self.main_send_conn, self.sub_receive_conn = Pipe()
|
||||
self.sub_send_conn, self.main_receive_conn = Pipe()
|
||||
self._closed = False
|
||||
self._on_main_receive_funcs: list[str] = []
|
||||
self._on_sub_receive_funcs: list[str] = []
|
||||
self.name: str = _id
|
||||
def __init__(self, name: str, type_check: Optional[bool] = None):
|
||||
"""
|
||||
初始化通道
|
||||
Args:
|
||||
name: 通道ID
|
||||
type_check: 是否开启类型检查, 若为空,则传入泛型默认开启,否则默认关闭
|
||||
"""
|
||||
|
||||
self.is_main_receive_loop_running = False
|
||||
self.is_sub_receive_loop_running = False
|
||||
self.conn_send, self.conn_recv = Pipe()
|
||||
self._conn_send_inner, self._conn_recv_inner = Pipe() # 内部通道,用于子进程通信
|
||||
self._closed = False
|
||||
self._on_main_receive_func_ids: list[int] = []
|
||||
self._on_sub_receive_func_ids: list[int] = []
|
||||
self.name: str = name
|
||||
|
||||
self.is_receive_loop_running = False
|
||||
|
||||
if type_check is None:
|
||||
# 若传入泛型则默认开启类型检查
|
||||
type_check = self._get_generic_type() is not None
|
||||
|
||||
elif type_check:
|
||||
if self._get_generic_type() is None:
|
||||
raise TypeError("Type hint 是强制类型检查之所必须")
|
||||
self.type_check = type_check
|
||||
if name in _channel:
|
||||
raise ValueError(f"Channel {name} 已存在")
|
||||
|
||||
if IS_MAIN_PROCESS:
|
||||
if name in _channel:
|
||||
raise ValueError(f"Channel {name} 已存在")
|
||||
_channel[name] = self
|
||||
logger.debug(f"Channel {name} 已在主进程中初始化")
|
||||
else:
|
||||
logger.debug(f"Channel {name} 已初始化于子进程中,之后应于主进程中手动设置为妙")
|
||||
|
||||
def _get_generic_type(self) -> Optional[type]:
|
||||
"""
|
||||
获取通道传递泛型类型
|
||||
Returns:
|
||||
Optional[type]: 泛型类型
|
||||
"""
|
||||
if hasattr(self, '__orig_class__'):
|
||||
return get_args(self.__orig_class__)[0]
|
||||
return None
|
||||
|
||||
def _validate_structure(self, data: Any, structure: type) -> bool:
|
||||
"""
|
||||
验证数据结构
|
||||
Args:
|
||||
data: 数据
|
||||
structure: 结构
|
||||
Returns:
|
||||
bool: 是否通过验证
|
||||
"""
|
||||
if isinstance(structure, type):
|
||||
return isinstance(data, structure)
|
||||
elif isinstance(structure, tuple):
|
||||
if not isinstance(data, tuple) or len(data) != len(structure):
|
||||
return False
|
||||
return all(self._validate_structure(d, s) for d, s in zip(data, structure))
|
||||
elif isinstance(structure, list):
|
||||
if not isinstance(data, list):
|
||||
return False
|
||||
return all(self._validate_structure(d, structure[0]) for d in data)
|
||||
elif isinstance(structure, dict):
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
return all(k in data and self._validate_structure(data[k], structure[k]) for k in structure)
|
||||
return False
|
||||
|
||||
def __str__(self):
|
||||
return f"Channel({self.name})"
|
||||
|
||||
def send(self, data: Any):
|
||||
def send(self, data: T):
|
||||
"""
|
||||
发送数据
|
||||
发送数据,发送函数为同步函数,没有异步的必要
|
||||
Args:
|
||||
data: 数据
|
||||
data (T): 数据
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError("无法发送至已关闭的通道中")
|
||||
if IS_MAIN_PROCESS:
|
||||
print("主进程发送数据:", data)
|
||||
self.main_send_conn.send(data)
|
||||
else:
|
||||
print("子进程发送数据:", data)
|
||||
self.sub_send_conn.send(data)
|
||||
if self.type_check:
|
||||
_type = self._get_generic_type()
|
||||
if _type is not None and not self._validate_structure(data, _type):
|
||||
raise TypeError(f"该数据必须为 {_type} 实例,而非 {type(data)}")
|
||||
|
||||
def receive(self) -> Any:
|
||||
if self._closed:
|
||||
raise RuntimeError("数据无法向已关闭的 Channel 中发送")
|
||||
self.conn_send.send(data)
|
||||
|
||||
def receive(self) -> T:
|
||||
"""
|
||||
接收数据
|
||||
Args:
|
||||
同步接收数据,会阻塞线程
|
||||
Returns:
|
||||
T: 数据
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError("无法从已关闭的通道中接收")
|
||||
raise RuntimeError("无法在已关闭的 Channel 中接取数据")
|
||||
|
||||
while True:
|
||||
# 判断receiver是否为None或者receiver是否等于接收者,是则接收数据,否则不动数据
|
||||
if IS_MAIN_PROCESS:
|
||||
data = self.main_receive_conn.recv()
|
||||
print("主进程接收数据:", data)
|
||||
else:
|
||||
data = self.sub_receive_conn.recv()
|
||||
print("子进程接收数据:", data)
|
||||
|
||||
data = self.conn_recv.recv()
|
||||
return data
|
||||
|
||||
def close(self):
|
||||
async def async_receive(self) -> T:
|
||||
"""
|
||||
关闭通道
|
||||
异步接收数据,会挂起等待
|
||||
Returns:
|
||||
T: 数据
|
||||
"""
|
||||
self._closed = True
|
||||
self.sub_receive_conn.close()
|
||||
self.main_send_conn.close()
|
||||
self.sub_send_conn.close()
|
||||
self.main_receive_conn.close()
|
||||
loop = asyncio.get_running_loop()
|
||||
data = await loop.run_in_executor(None, self.receive)
|
||||
return data
|
||||
|
||||
def on_receive(self, filter_func: Optional[FILTER_FUNC] = None) -> Callable[[ON_RECEIVE_FUNC], ON_RECEIVE_FUNC]:
|
||||
def on_receive(self, filter_func: Optional[FILTER_FUNC] = None) -> Callable[[Callable[[T], Any]], Callable[[T], Any]]:
|
||||
"""
|
||||
接收数据并执行函数
|
||||
Args:
|
||||
filter_func: 过滤函数,为None则不过滤
|
||||
filter_func ([`Optional`](https%3A//docs.python.org/3/library/typing.html#typing.Optional)[[`FILTER_FUNC`](#var-FILTER_FUNC)], optional): 过滤函数. Defaults to None.
|
||||
Returns:
|
||||
装饰器,装饰一个函数在接收到数据后执行
|
||||
Callable[[Callable[[T], Any]], Callable[[T], Any]]: 装饰器
|
||||
"""
|
||||
if (not self.is_sub_receive_loop_running) and not IS_MAIN_PROCESS:
|
||||
threading.Thread(target=self._start_sub_receive_loop).start()
|
||||
if not IS_MAIN_PROCESS:
|
||||
raise RuntimeError("on_receive 仅可用于主进程内")
|
||||
|
||||
if (not self.is_main_receive_loop_running) and IS_MAIN_PROCESS:
|
||||
threading.Thread(target=self._start_main_receive_loop).start()
|
||||
def decorator(func: Callable[[T], Any]) -> Callable[[T], Any]:
|
||||
global _func_id
|
||||
|
||||
def decorator(func: ON_RECEIVE_FUNC) -> ON_RECEIVE_FUNC:
|
||||
async def wrapper(data: Any) -> Any:
|
||||
async def wrapper(data: T) -> Any:
|
||||
if filter_func is not None:
|
||||
if is_coroutine_callable(filter_func):
|
||||
if not await filter_func(data):
|
||||
if not (await filter_func(data)): # type: ignore
|
||||
return
|
||||
else:
|
||||
if not filter_func(data):
|
||||
return
|
||||
return await func(data)
|
||||
|
||||
function_id = str(uuid4())
|
||||
_callback_funcs[function_id] = wrapper
|
||||
if is_coroutine_callable(func):
|
||||
return await func(data)
|
||||
else:
|
||||
return func(data)
|
||||
|
||||
_callback_funcs[_func_id] = wrapper
|
||||
if IS_MAIN_PROCESS:
|
||||
self._on_main_receive_funcs.append(function_id)
|
||||
self._on_main_receive_func_ids.append(_func_id)
|
||||
else:
|
||||
self._on_sub_receive_funcs.append(function_id)
|
||||
self._on_sub_receive_func_ids.append(_func_id)
|
||||
_func_id += 1
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def _run_on_main_receive_funcs(self, data: Any):
|
||||
async def _run_on_receive_funcs(self, data: Any):
|
||||
"""
|
||||
运行接收函数
|
||||
Args:
|
||||
data: 数据
|
||||
"""
|
||||
for func_id in self._on_main_receive_funcs:
|
||||
func = _callback_funcs[func_id]
|
||||
run_coroutine(func(data))
|
||||
|
||||
def _run_on_sub_receive_funcs(self, data: Any):
|
||||
"""
|
||||
运行接收函数
|
||||
Args:
|
||||
data: 数据
|
||||
"""
|
||||
for func_id in self._on_sub_receive_funcs:
|
||||
func = _callback_funcs[func_id]
|
||||
run_coroutine(func(data))
|
||||
|
||||
def _start_main_receive_loop(self):
|
||||
"""
|
||||
开始接收数据
|
||||
"""
|
||||
self.is_main_receive_loop_running = True
|
||||
while not self._closed:
|
||||
data = self.main_receive_conn.recv()
|
||||
self._run_on_main_receive_funcs(data)
|
||||
|
||||
def _start_sub_receive_loop(self):
|
||||
"""
|
||||
开始接收数据
|
||||
"""
|
||||
self.is_sub_receive_loop_running = True
|
||||
while not self._closed:
|
||||
data = self.sub_receive_conn.recv()
|
||||
self._run_on_sub_receive_funcs(data)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> Any:
|
||||
return self.receive()
|
||||
if IS_MAIN_PROCESS:
|
||||
[asyncio.create_task(_callback_funcs[func_id](data)) for func_id in self._on_main_receive_func_ids]
|
||||
else:
|
||||
[asyncio.create_task(_callback_funcs[func_id](data)) for func_id in self._on_sub_receive_func_ids]
|
||||
|
||||
|
||||
"""默认通道实例,可直接从模块导入使用"""
|
||||
chan = Channel("default")
|
||||
"""子进程可用的主动和被动通道"""
|
||||
active_channel: Channel = Channel(name="active_channel") # 主动通道
|
||||
passive_channel: Channel = Channel(name="passive_channel") # 被动通道
|
||||
publish_channel: Channel[tuple[str, dict[str, Any]]] = Channel(name="publish_channel") # 发布通道
|
||||
"""通道传递通道,主进程创建单例,子进程初始化时实例化"""
|
||||
channel_deliver_active_channel: Channel[Channel[Any]] # 主动通道传递通道
|
||||
channel_deliver_passive_channel: Channel[tuple[str, dict[str, Any]]] # 被动通道传递通道
|
||||
|
||||
if IS_MAIN_PROCESS:
|
||||
channel_deliver_active_channel = Channel(name="channel_deliver_active_channel") # 主动通道传递通道
|
||||
channel_deliver_passive_channel = Channel(name="channel_deliver_passive_channel") # 被动通道传递通道
|
||||
|
||||
|
||||
def set_channel(name: str, channel: Channel):
|
||||
@channel_deliver_passive_channel.on_receive(filter_func=lambda data: data[0] == "set_channel")
|
||||
def on_set_channel(data: tuple[str, dict[str, Any]]):
|
||||
name, channel = data[1]["name"], data[1]["channel_"]
|
||||
set_channel(name, channel)
|
||||
|
||||
|
||||
@channel_deliver_passive_channel.on_receive(filter_func=lambda data: data[0] == "get_channel")
|
||||
def on_get_channel(data: tuple[str, dict[str, Any]]):
|
||||
name, recv_chan = data[1]["name"], data[1]["recv_chan"]
|
||||
recv_chan.send(get_channel(name))
|
||||
|
||||
|
||||
@channel_deliver_passive_channel.on_receive(filter_func=lambda data: data[0] == "get_channels")
|
||||
def on_get_channels(data: tuple[str, dict[str, Any]]):
|
||||
recv_chan = data[1]["recv_chan"]
|
||||
recv_chan.send(get_channels())
|
||||
|
||||
|
||||
def set_channel(name: str, channel: "Channel"):
|
||||
"""
|
||||
设置通道实例
|
||||
Args:
|
||||
name: 通道名称
|
||||
channel: 通道实例
|
||||
name ([`str`](https%3A//docs.python.org/3/library/stdtypes.html#str)): 通道名称
|
||||
channel ([`Channel`](#class-channel-generic-t)): 通道实例
|
||||
"""
|
||||
_channel[name] = channel
|
||||
if not isinstance(channel, Channel):
|
||||
raise TypeError(f"channel_ 必须为 Channel 实例,而非 {type(channel)}")
|
||||
|
||||
if IS_MAIN_PROCESS:
|
||||
if name in _channel:
|
||||
raise ValueError(f"Channel {name} 已存在")
|
||||
_channel[name] = channel
|
||||
else:
|
||||
# 请求主进程设置通道
|
||||
channel_deliver_passive_channel.send(
|
||||
(
|
||||
"set_channel", {
|
||||
"name" : name,
|
||||
"channel_": channel,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_channels(channels: dict[str, Channel]):
|
||||
def set_channels(channels: dict[str, "Channel"]):
|
||||
"""
|
||||
设置通道实例
|
||||
Args:
|
||||
channels: 通道名称
|
||||
channels ([`dict`](https%3A//docs.python.org/3/library/stdtypes.html#dict)[[`str`](https%3A//docs.python.org/3/library/stdtypes.html#str), [`Channel`](#class-channel-generic-t)]): 通道实例
|
||||
"""
|
||||
for name, channel in channels.items():
|
||||
_channel[name] = channel
|
||||
set_channel(name, channel)
|
||||
|
||||
|
||||
def get_channel(name: str) -> Optional[Channel]:
|
||||
def get_channel(name: str) -> "Channel":
|
||||
"""
|
||||
获取通道实例
|
||||
Args:
|
||||
name: 通道名称
|
||||
name ([`str`](https%3A//docs.python.org/3/library/stdtypes.html#str)): 通道名称
|
||||
Returns:
|
||||
[`Channel`](#class-channel-generic-t): 通道实例
|
||||
"""
|
||||
return _channel.get(name, None)
|
||||
if IS_MAIN_PROCESS:
|
||||
return _channel[name]
|
||||
|
||||
else:
|
||||
recv_chan = Channel[Channel[Any]]("recv_chan")
|
||||
channel_deliver_passive_channel.send(
|
||||
(
|
||||
"get_channel",
|
||||
{
|
||||
"name" : name,
|
||||
"recv_chan": recv_chan
|
||||
}
|
||||
)
|
||||
)
|
||||
return recv_chan.receive()
|
||||
|
||||
|
||||
def get_channels() -> dict[str, Channel]:
|
||||
def get_channels() -> dict[str, "Channel"]:
|
||||
"""
|
||||
获取通道实例
|
||||
获取通道实例们
|
||||
Returns:
|
||||
[`dict`](https%3A//docs.python.org/3/library/stdtypes.html#dict)[[`str`](https%3A//docs.python.org/3/library/stdtypes.html#str), [`Channel`](#class-channel-generic-t)]: 通道实例
|
||||
"""
|
||||
return _channel
|
||||
if IS_MAIN_PROCESS:
|
||||
return _channel
|
||||
else:
|
||||
recv_chan = Channel[dict[str, Channel[Any]]]("recv_chan")
|
||||
channel_deliver_passive_channel.send(
|
||||
(
|
||||
"get_channels",
|
||||
{
|
||||
"recv_chan": recv_chan
|
||||
}
|
||||
)
|
||||
)
|
||||
return recv_chan.receive()
|
||||
|
@ -1,12 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/7/26 下午10:47
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : event.py
|
||||
@Software: PyCharm
|
||||
本模块用于轻雪主进程和子进程之间的通信的事件类
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
|
26
liteyuki/comm/rpc.py
Normal file
26
liteyuki/comm/rpc.py
Normal file
@ -0,0 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
本模块用于实现RPC(基于IPC)通信
|
||||
"""
|
||||
|
||||
from typing import TypeAlias, Callable, Any
|
||||
|
||||
from liteyuki.comm.channel import Channel
|
||||
|
||||
ON_CALLING_FUNC: TypeAlias = Callable[[tuple, dict], Any]
|
||||
|
||||
|
||||
class RPC:
|
||||
"""
|
||||
RPC类
|
||||
"""
|
||||
|
||||
def __init__(self, on_calling: ON_CALLING_FUNC) -> None:
|
||||
self.on_calling = on_calling
|
||||
|
||||
def call(self, args: tuple, kwargs: dict) -> Any:
|
||||
"""
|
||||
调用
|
||||
"""
|
||||
# 获取self.calling函数名
|
||||
return self.on_calling(args, kwargs)
|
48
liteyuki/comm/socks_channel.py
Normal file
48
liteyuki/comm/socks_channel.py
Normal file
@ -0,0 +1,48 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
基于socket的通道
|
||||
"""
|
||||
|
||||
|
||||
class SocksChannel:
|
||||
"""
|
||||
通道类,可以在进程间和进程内通信,双向但同时只能有一个发送者和一个接收者
|
||||
有两种接收工作方式,但是只能选择一种,主动接收和被动接收,主动接收使用 `receive` 方法,被动接收使用 `on_receive` 装饰器
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
"""
|
||||
初始化通道
|
||||
Args:
|
||||
name: 通道ID
|
||||
"""
|
||||
|
||||
self._name = name
|
||||
self._conn_send = None
|
||||
self._conn_recv = None
|
||||
self._closed = False
|
||||
|
||||
def send(self, data):
|
||||
"""
|
||||
发送数据
|
||||
Args:
|
||||
data: 数据
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def receive(self):
|
||||
"""
|
||||
接收数据
|
||||
Returns:
|
||||
data: 数据
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
关闭通道
|
||||
"""
|
||||
|
||||
pass
|
247
liteyuki/comm/storage.py
Normal file
247
liteyuki/comm/storage.py
Normal file
@ -0,0 +1,247 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
共享内存模块。类似于redis,但是更加轻量级并且线程安全
|
||||
"""
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from liteyuki.comm import channel
|
||||
from liteyuki.comm.channel import ASYNC_ON_RECEIVE_FUNC, Channel, ON_RECEIVE_FUNC
|
||||
from liteyuki.utils import (
|
||||
IS_MAIN_PROCESS,
|
||||
is_coroutine_callable,
|
||||
run_coroutine_in_thread,
|
||||
)
|
||||
|
||||
if IS_MAIN_PROCESS:
|
||||
_locks = {}
|
||||
|
||||
_on_main_subscriber_receive_funcs: dict[str, list[ASYNC_ON_RECEIVE_FUNC]] = {} # type: ignore
|
||||
"""主进程订阅者接收函数"""
|
||||
_on_sub_subscriber_receive_funcs: dict[str, list[ASYNC_ON_RECEIVE_FUNC]] = {} # type: ignore
|
||||
"""子进程订阅者接收函数"""
|
||||
|
||||
|
||||
def _get_lock(key) -> threading.Lock:
|
||||
"""
|
||||
获取锁
|
||||
"""
|
||||
if IS_MAIN_PROCESS:
|
||||
if key not in _locks:
|
||||
_locks[key] = threading.Lock()
|
||||
return _locks[key]
|
||||
else:
|
||||
raise RuntimeError("无法在子进程中获取线程锁")
|
||||
|
||||
|
||||
class KeyValueStore:
|
||||
def __init__(self):
|
||||
self._store = {}
|
||||
self.active_chan = Channel[tuple[str, Optional[dict[str, Any]]]](
|
||||
name="shared_memory-active"
|
||||
)
|
||||
self.passive_chan = Channel[tuple[str, Optional[dict[str, Any]]]](
|
||||
name="shared_memory-passive"
|
||||
)
|
||||
|
||||
self.publish_channel = Channel[tuple[str, Any]](name="shared_memory-publish")
|
||||
|
||||
self.is_main_receive_loop_running = False
|
||||
self.is_sub_receive_loop_running = False
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""
|
||||
设置键值对
|
||||
Args:
|
||||
key: 键
|
||||
value: 值
|
||||
|
||||
"""
|
||||
if IS_MAIN_PROCESS:
|
||||
lock = _get_lock(key)
|
||||
with lock:
|
||||
self._store[key] = value
|
||||
else:
|
||||
# 向主进程发送请求拿取
|
||||
self.passive_chan.send(("set", {"key": key, "value": value}))
|
||||
|
||||
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
|
||||
"""
|
||||
获取键值对
|
||||
Args:
|
||||
key: 键
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 值
|
||||
"""
|
||||
if IS_MAIN_PROCESS:
|
||||
lock = _get_lock(key)
|
||||
with lock:
|
||||
return self._store.get(key, default)
|
||||
else:
|
||||
recv_chan = Channel[Optional[Any]]("recv_chan")
|
||||
self.passive_chan.send(
|
||||
("get", {"key": key, "default": default, "recv_chan": recv_chan})
|
||||
)
|
||||
return recv_chan.receive()
|
||||
|
||||
def delete(self, key: str, ignore_key_error: bool = True) -> None:
|
||||
"""
|
||||
删除键值对
|
||||
Args:
|
||||
key: 键
|
||||
ignore_key_error: 是否忽略键不存在的错误
|
||||
|
||||
Returns:
|
||||
"""
|
||||
if IS_MAIN_PROCESS:
|
||||
lock = _get_lock(key)
|
||||
with lock:
|
||||
if key in self._store:
|
||||
try:
|
||||
del self._store[key]
|
||||
del _locks[key]
|
||||
except KeyError as e:
|
||||
if not ignore_key_error:
|
||||
raise e
|
||||
else:
|
||||
# 向主进程发送请求删除
|
||||
self.passive_chan.send(("delete", {"key": key}))
|
||||
|
||||
def get_all(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取所有键值对
|
||||
Returns:
|
||||
dict[str, Any]: 键值对
|
||||
"""
|
||||
if IS_MAIN_PROCESS:
|
||||
return self._store
|
||||
else:
|
||||
recv_chan = Channel[dict[str, Any]]("recv_chan")
|
||||
self.passive_chan.send(("get_all", {"recv_chan": recv_chan}))
|
||||
return recv_chan.receive()
|
||||
|
||||
def publish(self, channel_: str, data: Any) -> None:
|
||||
"""
|
||||
发布消息
|
||||
Args:
|
||||
channel_: 频道
|
||||
data: 数据
|
||||
|
||||
Returns:
|
||||
"""
|
||||
self.active_chan.send(("publish", {"channel": channel_, "data": data}))
|
||||
|
||||
def on_subscriber_receive(
|
||||
self, channel_: str
|
||||
) -> Callable[[ON_RECEIVE_FUNC], ON_RECEIVE_FUNC]:
|
||||
"""
|
||||
订阅者接收消息时的回调
|
||||
Args:
|
||||
channel_: 频道
|
||||
|
||||
Returns:
|
||||
装饰器
|
||||
"""
|
||||
if not IS_MAIN_PROCESS:
|
||||
raise RuntimeError("无法订阅一子线程消息")
|
||||
|
||||
def decorator(func: ON_RECEIVE_FUNC) -> ON_RECEIVE_FUNC:
|
||||
async def wrapper(data: Any):
|
||||
if is_coroutine_callable(func):
|
||||
await func(data)
|
||||
else:
|
||||
func(data)
|
||||
|
||||
if IS_MAIN_PROCESS:
|
||||
if channel_ not in _on_main_subscriber_receive_funcs:
|
||||
_on_main_subscriber_receive_funcs[channel_] = []
|
||||
_on_main_subscriber_receive_funcs[channel_].append(wrapper)
|
||||
else:
|
||||
if channel_ not in _on_sub_subscriber_receive_funcs:
|
||||
_on_sub_subscriber_receive_funcs[channel_] = []
|
||||
_on_sub_subscriber_receive_funcs[channel_].append(wrapper)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
async def run_subscriber_receive_funcs(channel_: str, data: Any):
|
||||
"""
|
||||
运行订阅者接收函数
|
||||
Args:
|
||||
channel_: 频道
|
||||
data: 数据
|
||||
"""
|
||||
[
|
||||
asyncio.create_task(func(data))
|
||||
for func in _on_main_subscriber_receive_funcs[channel_]
|
||||
]
|
||||
|
||||
async def start_receive_loop(self):
|
||||
"""
|
||||
启动发布订阅接收器循环,在主进程中运行,若有子进程订阅则推送给子进程
|
||||
"""
|
||||
|
||||
if not IS_MAIN_PROCESS:
|
||||
raise RuntimeError("无法在子进程中启用订阅接收器循环")
|
||||
while True:
|
||||
data = await self.active_chan.async_receive()
|
||||
if data[0] == "publish":
|
||||
# 运行主进程订阅函数
|
||||
await self.run_subscriber_receive_funcs(
|
||||
data[1]["channel"], data[1]["data"]
|
||||
)
|
||||
# 推送给子进程
|
||||
self.publish_channel.send(data)
|
||||
|
||||
|
||||
class GlobalKeyValueStore:
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = KeyValueStore()
|
||||
return cls._instance
|
||||
|
||||
|
||||
shared_memory: KeyValueStore = GlobalKeyValueStore.get_instance() # 共享内存对象
|
||||
|
||||
# 全局单例访问点
|
||||
if IS_MAIN_PROCESS:
|
||||
|
||||
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "get")
|
||||
def on_get(data: tuple[str, dict[str, Any]]):
|
||||
key = data[1]["key"]
|
||||
default = data[1]["default"]
|
||||
recv_chan = data[1]["recv_chan"]
|
||||
recv_chan.send(shared_memory.get(key, default))
|
||||
|
||||
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "set")
|
||||
def on_set(data: tuple[str, dict[str, Any]]):
|
||||
key = data[1]["key"]
|
||||
value = data[1]["value"]
|
||||
shared_memory.set(key, value)
|
||||
|
||||
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "delete")
|
||||
def on_delete(data: tuple[str, dict[str, Any]]):
|
||||
key = data[1]["key"]
|
||||
shared_memory.delete(key)
|
||||
|
||||
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "get_all")
|
||||
def on_get_all(data: tuple[str, dict[str, Any]]):
|
||||
recv_chan = data[1]["recv_chan"]
|
||||
recv_chan.send(shared_memory.get_all())
|
||||
|
||||
|
||||
_ref_count = 0 # import 引用计数, 防止获取空指针
|
||||
if not IS_MAIN_PROCESS:
|
||||
if (shared_memory is None) and _ref_count > 1:
|
||||
raise RuntimeError("共享内存未初始化")
|
||||
_ref_count += 1
|
@ -1,49 +1,132 @@
|
||||
"""
|
||||
该模块用于常用配置文件的加载
|
||||
多配置文件编写原则:
|
||||
1. 尽量不要冲突: 一个键不要多次出现
|
||||
2. 分工明确: 每个配置文件给一个或一类服务提供配置
|
||||
3. 扁平化编写: 配置文件尽量扁平化,不要出现过多的嵌套
|
||||
4. 注意冲突时的优先级: 项目目录下的配置文件优先级高于config目录下的配置文件
|
||||
5. 请不要将需要动态加载的内容写入配置文件,你应该使用其他储存方式
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import nonebot
|
||||
import json
|
||||
import copy
|
||||
import toml
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from typing import Any
|
||||
|
||||
from liteyuki.log import logger
|
||||
|
||||
_SUPPORTED_CONFIG_FORMATS = (".yaml", ".yml", ".json", ".toml")
|
||||
|
||||
|
||||
config = {} # 主进程全局配置,确保加载后读取
|
||||
def flat_config(config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
扁平化配置文件
|
||||
|
||||
{a:{b:{c:1}}} -> {"a.b.c": 1}
|
||||
Args:
|
||||
config: 配置项目
|
||||
|
||||
Returns:
|
||||
扁平化后的配置文件,但也包含原有的键值对
|
||||
"""
|
||||
new_config = copy.deepcopy(config)
|
||||
for key, value in config.items():
|
||||
if isinstance(value, dict):
|
||||
for k, v in flat_config(value).items():
|
||||
new_config[f"{key}.{k}"] = v
|
||||
return new_config
|
||||
|
||||
|
||||
class SatoriNodeConfig(BaseModel):
|
||||
host: str = ""
|
||||
port: str = "5500"
|
||||
path: str = ""
|
||||
token: str = ""
|
||||
def load_from_yaml(file_: str) -> dict[str, Any]:
|
||||
"""
|
||||
Load config from yaml file
|
||||
|
||||
"""
|
||||
logger.debug("正在从 {} 中加载 YAML 配置".format(file_))
|
||||
config = yaml.safe_load(open(file_, "r", encoding="utf-8"))
|
||||
return flat_config(config if config is not None else {})
|
||||
|
||||
|
||||
class SatoriConfig(BaseModel):
|
||||
comment: str = "此皆正处于开发之中,切勿在生产环境中启用。"
|
||||
enable: bool = False
|
||||
hosts: List[SatoriNodeConfig] = [SatoriNodeConfig()]
|
||||
def load_from_json(file_: str) -> dict[str, Any]:
|
||||
"""
|
||||
Load config from json file
|
||||
"""
|
||||
logger.debug("正在从 {} 中加载 JSON 配置".format(file_))
|
||||
config = json.load(open(file_, "r", encoding="utf-8"))
|
||||
return flat_config(config if config is not None else {})
|
||||
|
||||
|
||||
class BasicConfig(BaseModel):
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 20247
|
||||
superusers: list[str] = []
|
||||
command_start: list[str] = ["/", ""]
|
||||
nickname: list[str] = [f"灵温"]
|
||||
satori: SatoriConfig = SatoriConfig()
|
||||
data_path: str = "data/liteyuki"
|
||||
def load_from_toml(file_: str) -> dict[str, Any]:
|
||||
"""
|
||||
Load config from toml file
|
||||
"""
|
||||
logger.debug("正在从 {} 中加载 TOML 配置".format(file_))
|
||||
config = toml.load(open(file_, "r", encoding="utf-8"))
|
||||
return flat_config(config if config is not None else {})
|
||||
|
||||
|
||||
def load_from_yaml(file: str) -> dict:
|
||||
global config
|
||||
nonebot.logger.debug("Loading config from %s" % file)
|
||||
if not os.path.exists(file):
|
||||
nonebot.logger.warning(f"未找到配置文件 {file} ,已创建默认配置,请修改后重启。")
|
||||
with open(file, "w", encoding="utf-8") as f:
|
||||
yaml.dump(BasicConfig().dict(), f, default_flow_style=False)
|
||||
def load_from_files(*files: str, no_warning: bool = False) -> dict[str, Any]:
|
||||
"""
|
||||
从指定文件加载配置项,会自动识别文件格式
|
||||
默认执行扁平化选项
|
||||
"""
|
||||
config = {}
|
||||
for file in files:
|
||||
if os.path.exists(file):
|
||||
if file.endswith((".yaml", "yml")):
|
||||
config.update(load_from_yaml(file))
|
||||
elif file.endswith(".json"):
|
||||
config.update(load_from_json(file))
|
||||
elif file.endswith(".toml"):
|
||||
config.update(load_from_toml(file))
|
||||
else:
|
||||
if not no_warning:
|
||||
logger.warning(f"不支持配置文件 {file} 的类型")
|
||||
else:
|
||||
if not no_warning:
|
||||
logger.warning(f"配置文件 {file} 未寻得")
|
||||
return config
|
||||
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
conf = yaml.load(f, Loader=yaml.FullLoader)
|
||||
config = conf
|
||||
if conf is None:
|
||||
nonebot.logger.warning(f"配置文件 {file} 为空,已创建默认配置,请修改后重启。")
|
||||
conf = BasicConfig().dict()
|
||||
return conf
|
||||
|
||||
def load_configs_from_dirs(
|
||||
*directories: str, no_waring: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
从目录下加载配置文件,不递归
|
||||
按照读取文件的优先级反向覆盖
|
||||
默认执行扁平化选项
|
||||
"""
|
||||
config = {}
|
||||
for directory in directories:
|
||||
if not os.path.exists(directory):
|
||||
if not no_waring:
|
||||
logger.warning(f"目录 {directory} 未寻得")
|
||||
continue
|
||||
for file in os.listdir(directory):
|
||||
if file.endswith(_SUPPORTED_CONFIG_FORMATS):
|
||||
config.update(
|
||||
load_from_files(os.path.join(directory, file), no_warning=no_waring)
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def load_config_in_default(no_waring: bool = False) -> dict[str, Any]:
|
||||
"""
|
||||
从一个标准的轻雪项目加载配置文件
|
||||
项目目录下的config.*和config目录下的所有配置文件
|
||||
项目目录下的配置文件优先
|
||||
"""
|
||||
config = load_configs_from_dirs("config", no_waring=no_waring)
|
||||
config.update(
|
||||
load_from_files(
|
||||
"config.yaml",
|
||||
"config.toml",
|
||||
"config.json",
|
||||
"config.yml",
|
||||
no_warning=no_waring,
|
||||
)
|
||||
)
|
||||
return config
|
||||
|
@ -1,11 +1,2 @@
|
||||
import multiprocessing
|
||||
|
||||
from .spawn_process import *
|
||||
from .manager import *
|
||||
|
||||
__all__ = [
|
||||
"IS_MAIN_PROCESS"
|
||||
]
|
||||
|
||||
IS_MAIN_PROCESS = multiprocessing.current_process().name == "MainProcess"
|
||||
|
||||
|
@ -9,92 +9,175 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
@Software: PyCharm
|
||||
"""
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import threading
|
||||
from multiprocessing import Process
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypeAlias
|
||||
|
||||
from liteyuki.comm import Channel, get_channel, set_channels
|
||||
from liteyuki.log import logger
|
||||
from liteyuki.utils import IS_MAIN_PROCESS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from liteyuki.bot import LiteyukiBot
|
||||
from liteyuki.bot.lifespan import Lifespan
|
||||
from liteyuki.comm.storage import KeyValueStore
|
||||
|
||||
from liteyuki.comm import Channel
|
||||
|
||||
if IS_MAIN_PROCESS:
|
||||
from liteyuki.comm.channel import get_channel, publish_channel, get_channels
|
||||
from liteyuki.comm.storage import shared_memory
|
||||
from liteyuki.comm.channel import (
|
||||
channel_deliver_active_channel,
|
||||
channel_deliver_passive_channel,
|
||||
)
|
||||
else:
|
||||
from liteyuki.comm import channel
|
||||
from liteyuki.comm import storage
|
||||
|
||||
TARGET_FUNC: TypeAlias = Callable[..., Any]
|
||||
TIMEOUT = 10
|
||||
|
||||
__all__ = ["ProcessManager"]
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
class ChannelDeliver:
|
||||
def __init__(
|
||||
self,
|
||||
active: Channel[Any],
|
||||
passive: Channel[Any],
|
||||
channel_deliver_active: Channel[Channel[Any]],
|
||||
channel_deliver_passive: Channel[tuple[str, dict]],
|
||||
publish: Channel[tuple[str, Any]],
|
||||
):
|
||||
self.active = active
|
||||
self.passive = passive
|
||||
self.channel_deliver_active = channel_deliver_active
|
||||
self.channel_deliver_passive = channel_deliver_passive
|
||||
self.publish = publish
|
||||
|
||||
|
||||
# 函数处理一些跨进程通道的
|
||||
def _delivery_channel_wrapper(
|
||||
func: TARGET_FUNC, cd: ChannelDeliver, sm: "KeyValueStore", *args, **kwargs
|
||||
):
|
||||
"""
|
||||
子进程入口函数
|
||||
处理一些操作
|
||||
"""
|
||||
# 给子进程设置通道
|
||||
if IS_MAIN_PROCESS:
|
||||
raise RuntimeError("函数仅可在子进程中被调用")
|
||||
|
||||
channel.active_channel = cd.active # 子进程主动通道
|
||||
channel.passive_channel = cd.passive # 子进程被动通道
|
||||
channel.channel_deliver_active_channel = (
|
||||
cd.channel_deliver_active
|
||||
) # 子进程通道传递主动通道
|
||||
channel.channel_deliver_passive_channel = (
|
||||
cd.channel_deliver_passive
|
||||
) # 子进程通道传递被动通道
|
||||
channel.publish_channel = cd.publish # 子进程发布通道
|
||||
|
||||
# 给子进程创建共享内存实例
|
||||
|
||||
storage.shared_memory = sm
|
||||
|
||||
func(*args, **kwargs)
|
||||
|
||||
|
||||
class ProcessManager:
|
||||
"""
|
||||
在主进程中被调用
|
||||
进程管理器
|
||||
"""
|
||||
|
||||
def __init__(self, bot: "LiteyukiBot"):
|
||||
self.bot = bot
|
||||
self.targets: dict[str, tuple[callable, tuple, dict]] = {}
|
||||
def __init__(self, lifespan: "Lifespan"):
|
||||
self.lifespan = lifespan
|
||||
self.targets: dict[str, tuple[Callable, tuple, dict]] = {}
|
||||
self.processes: dict[str, Process] = {}
|
||||
|
||||
set_channels(
|
||||
{
|
||||
"nonebot-active": Channel(_id="nonebot-active"),
|
||||
"melobot-active": Channel(_id="melobot-active"),
|
||||
"nonebot-passive": Channel(_id="nonebot-passive"),
|
||||
"melobot-passive": Channel(_id="melobot-passive"),
|
||||
}
|
||||
)
|
||||
|
||||
def start(self, name: str, delay: int = 0):
|
||||
def _run_process(self, name: str):
|
||||
"""
|
||||
开启后自动监控进程,并添加到进程字典中
|
||||
开启后自动监控进程,并添加到进程字典中,会阻塞,请创建task
|
||||
Args:
|
||||
name:
|
||||
delay:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if name not in self.targets:
|
||||
raise KeyError(f"未有 Process {name} 之存在")
|
||||
raise KeyError(f"Process {name} 未寻得")
|
||||
|
||||
def _start():
|
||||
should_exit = False
|
||||
while not should_exit:
|
||||
chan_active = get_channel(f"{name}-active")
|
||||
chan_passive = get_channel(f"{name}-passive")
|
||||
process = Process(
|
||||
target=self.targets[name][0],
|
||||
args=(chan_active, chan_passive, *self.targets[name][1]),
|
||||
kwargs=self.targets[name][2],
|
||||
)
|
||||
self.processes[name] = process
|
||||
process.start()
|
||||
while not should_exit:
|
||||
# 0退出 1重启
|
||||
data = chan_active.receive()
|
||||
if data == 1:
|
||||
logger.info(f"重启 {name} 进程")
|
||||
asyncio.run(self.bot.lifespan.before_shutdown())
|
||||
asyncio.run(self.bot.lifespan.before_restart())
|
||||
self.terminate(name)
|
||||
break
|
||||
chan_active = get_channel(f"{name}-active")
|
||||
|
||||
elif data == 0:
|
||||
logger.info(f"关停 {name} 进程")
|
||||
asyncio.run(self.bot.lifespan.before_shutdown())
|
||||
should_exit = True
|
||||
self.terminate(name)
|
||||
else:
|
||||
logger.warning("数据未知,省略:{}".format(data))
|
||||
def _start_process():
|
||||
process = Process(
|
||||
target=self.targets[name][0],
|
||||
args=self.targets[name][1],
|
||||
kwargs=self.targets[name][2],
|
||||
daemon=True,
|
||||
)
|
||||
self.processes[name] = process
|
||||
process.start()
|
||||
|
||||
if delay:
|
||||
threading.Timer(delay, _start).start()
|
||||
else:
|
||||
threading.Thread(target=_start).start()
|
||||
# 启动进程并监听信号
|
||||
_start_process()
|
||||
while True:
|
||||
data = chan_active.receive()
|
||||
if data == 0:
|
||||
# 停止
|
||||
logger.info(f"正在关停 Process {name}")
|
||||
self.terminate(name)
|
||||
break
|
||||
elif data == 1:
|
||||
# 重启
|
||||
logger.info(f"正在重启 Process {name}")
|
||||
self.terminate(name)
|
||||
_start_process()
|
||||
continue
|
||||
else:
|
||||
logger.warning("接收到未知信号数据 {} ,已忽略".format(data))
|
||||
|
||||
def add_target(self, name: str, target, *args, **kwargs):
|
||||
self.targets[name] = (target, args, kwargs)
|
||||
def start_all(self):
|
||||
"""
|
||||
对外启动方法,启动所有进程,创建asyncio task
|
||||
"""
|
||||
# [asyncio.create_task(self._run_process(name)) for name in self.targets]
|
||||
|
||||
def join(self):
|
||||
for name in self.targets:
|
||||
logger.debug(f"正在启动 Process {name}")
|
||||
threading.Thread(
|
||||
target=self._run_process, args=(name,), daemon=True
|
||||
).start()
|
||||
|
||||
def add_target(self, name: str, target: TARGET_FUNC, args: tuple = (), kwargs=None):
|
||||
"""
|
||||
添加进程
|
||||
Args:
|
||||
name: 进程名,用于获取和唯一标识
|
||||
target: 进程函数
|
||||
args: 进程函数参数
|
||||
kwargs: 进程函数关键字参数,通常会默认传入chan_active和chan_passive
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
chan_active: Channel = Channel(name=f"{name}-active")
|
||||
chan_passive: Channel = Channel(name=f"{name}-passive")
|
||||
|
||||
channel_deliver = ChannelDeliver(
|
||||
active=chan_active,
|
||||
passive=chan_passive,
|
||||
channel_deliver_active=channel_deliver_active_channel,
|
||||
channel_deliver_passive=channel_deliver_passive_channel,
|
||||
publish=publish_channel,
|
||||
)
|
||||
|
||||
self.targets[name] = (
|
||||
_delivery_channel_wrapper,
|
||||
(target, channel_deliver, shared_memory, *args),
|
||||
kwargs,
|
||||
)
|
||||
# 主进程通道
|
||||
|
||||
def join_all(self):
|
||||
for name, process in self.targets:
|
||||
process.join()
|
||||
|
||||
@ -107,14 +190,29 @@ class ProcessManager:
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if name not in self.targets:
|
||||
raise logger.warning(f"未有 Process {name} 之存在")
|
||||
if name not in self.processes:
|
||||
logger.warning(f"Process {name} 未寻得")
|
||||
return
|
||||
process = self.processes[name]
|
||||
process.terminate()
|
||||
process.join(TIMEOUT)
|
||||
if process.is_alive():
|
||||
process.kill()
|
||||
logger.success(f"Process {name} 已迫令终止")
|
||||
|
||||
def terminate_all(self):
|
||||
for name in self.targets:
|
||||
self.terminate(name)
|
||||
|
||||
def is_process_alive(self, name: str) -> bool:
|
||||
"""
|
||||
检查进程是否存活
|
||||
Args:
|
||||
name:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if name not in self.targets:
|
||||
logger.warning(f"Process {name} 未寻得")
|
||||
return self.processes[name].is_alive()
|
||||
|
@ -1,14 +0,0 @@
|
||||
from . import (
|
||||
satori,
|
||||
onebot
|
||||
)
|
||||
|
||||
|
||||
def init(config: dict):
|
||||
onebot.init()
|
||||
satori.init(config)
|
||||
|
||||
|
||||
def register():
|
||||
onebot.register()
|
||||
satori.register()
|
@ -1,12 +0,0 @@
|
||||
import nonebot
|
||||
from nonebot.adapters.onebot import v11, v12
|
||||
|
||||
|
||||
def init():
|
||||
pass
|
||||
|
||||
|
||||
def register():
|
||||
driver = nonebot.get_driver()
|
||||
driver.register_adapter(v11.Adapter)
|
||||
driver.register_adapter(v12.Adapter)
|
@ -1,26 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import nonebot
|
||||
from nonebot.adapters import satori
|
||||
|
||||
|
||||
def init(config: dict):
|
||||
if config.get("satori", None) is None:
|
||||
nonebot.logger.info("未查见 Satori 的配置文档,将跳过 Satori 初始化")
|
||||
return None
|
||||
satori_config = config.get("satori")
|
||||
if not satori_config.get("enable", False):
|
||||
nonebot.logger.info("未启用 Satori ,将跳过 Satori 初始化")
|
||||
return None
|
||||
if os.getenv("SATORI_CLIENTS", None) is not None:
|
||||
nonebot.logger.info("Satori 客户端已设入环境变量,跳过此步。")
|
||||
os.environ["SATORI_CLIENTS"] = json.dumps(satori_config.get("hosts", []), ensure_ascii=False)
|
||||
config['satori_clients'] = satori_config.get("hosts", [])
|
||||
return
|
||||
|
||||
|
||||
def register():
|
||||
if os.getenv("SATORI_CLIENTS", None) is not None:
|
||||
driver = nonebot.get_driver()
|
||||
driver.register_adapter(satori.Adapter)
|
@ -1,6 +0,0 @@
|
||||
from .auto_set_env import auto_set_env
|
||||
|
||||
|
||||
def init(config: dict):
|
||||
auto_set_env(config)
|
||||
return
|
@ -1,20 +0,0 @@
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
import nonebot
|
||||
|
||||
from .defines import *
|
||||
|
||||
|
||||
def auto_set_env(config: dict):
|
||||
dotenv.load_dotenv(".env")
|
||||
if os.getenv("DRIVER", None) is not None:
|
||||
nonebot.logger.info("Driver 已设入环境变量中,将跳过自动配置环节。")
|
||||
return
|
||||
if config.get("satori", {'enable': False}).get("enable", False):
|
||||
os.environ["DRIVER"] = get_driver_string(ASGI_DRIVER, HTTPX_DRIVER, WEBSOCKETS_DRIVER)
|
||||
nonebot.logger.info("已启用 Satori,将 driver 设为 ASGI+HTTPX+WEBSOCKETS")
|
||||
else:
|
||||
os.environ["DRIVER"] = get_driver_string(ASGI_DRIVER)
|
||||
nonebot.logger.info("已禁用 Satori,将 driver 设为 ASGI")
|
||||
return
|
@ -1,17 +0,0 @@
|
||||
ASGI_DRIVER = "~fastapi"
|
||||
HTTPX_DRIVER = "~httpx"
|
||||
WEBSOCKETS_DRIVER = "~websockets"
|
||||
|
||||
|
||||
def get_driver_string(*argv):
|
||||
output_string = ""
|
||||
if ASGI_DRIVER in argv:
|
||||
output_string += ASGI_DRIVER
|
||||
for arg in argv:
|
||||
if arg != ASGI_DRIVER:
|
||||
output_string = f"{output_string}+{arg}"
|
||||
return output_string
|
||||
|
||||
|
||||
def get_driver_full_string(*argv):
|
||||
return f"DRIVER={get_driver_string(argv)}"
|
@ -1,56 +0,0 @@
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import nonebot
|
||||
|
||||
from liteyuki.core.nb import adapter_manager, driver_manager
|
||||
from liteyuki.comm.channel import set_channel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from liteyuki.comm.channel import Channel
|
||||
|
||||
timeout_limit: int = 20
|
||||
|
||||
"""导出对象,用于主进程与nonebot通信"""
|
||||
_channels = {}
|
||||
|
||||
|
||||
def nb_run(chan_active: "Channel", chan_passive: "Channel", *args, **kwargs):
|
||||
"""
|
||||
初始化NoneBot并运行在子进程
|
||||
Args:
|
||||
|
||||
chan_active:
|
||||
chan_passive:
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
set_channel("nonebot-active", chan_active)
|
||||
set_channel("nonebot-passive", chan_passive)
|
||||
nonebot.init(**kwargs)
|
||||
driver_manager.init(config=kwargs)
|
||||
adapter_manager.init(kwargs)
|
||||
adapter_manager.register()
|
||||
nonebot.load_plugin("src.liteyuki_main")
|
||||
nonebot.run()
|
||||
|
||||
|
||||
def mb_run(chan_active: "Channel", chan_passive: "Channel", *args, **kwargs):
|
||||
"""
|
||||
初始化MeloBot并运行在子进程
|
||||
Args:
|
||||
chan_active
|
||||
chan_passive
|
||||
*args:
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
set_channel("melobot-active", chan_active)
|
||||
set_channel("melobot-passive", chan_passive)
|
||||
|
||||
# bot = MeloBot(__name__)
|
||||
# bot.init(AbstractConnector(cd_time=0))
|
||||
# bot.run()
|
4
liteyuki/dev/__init__.py
Normal file
4
liteyuki/dev/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
该模块用于存放一些开发工具
|
||||
"""
|
90
liteyuki/dev/observer.py
Normal file
90
liteyuki/dev/observer.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""
|
||||
此模块用于注册观察者函数,使用watchdog监控文件变化并重启bot
|
||||
启用该模块需要在配置文件中设置`dev_mode`为True
|
||||
"""
|
||||
import time
|
||||
from typing import Callable, TypeAlias
|
||||
|
||||
from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
||||
from watchdog.observers import Observer
|
||||
|
||||
from liteyuki import get_bot, get_config_with_compat, logger
|
||||
|
||||
liteyuki_bot = get_bot()
|
||||
|
||||
CALLBACK_FUNC: TypeAlias = Callable[[FileSystemEvent], None] # 位置1为FileSystemEvent
|
||||
FILTER_FUNC: TypeAlias = Callable[[FileSystemEvent], bool] # 位置1为FileSystemEvent
|
||||
observer = Observer()
|
||||
|
||||
|
||||
def debounce(wait):
|
||||
"""
|
||||
防抖函数
|
||||
"""
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal last_call_time
|
||||
current_time = time.time()
|
||||
if (current_time - last_call_time) > wait:
|
||||
last_call_time = current_time
|
||||
return func(*args, **kwargs)
|
||||
|
||||
last_call_time = None
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
if get_config_with_compat("liteyuki.reload", ("dev_mode",), False):
|
||||
logger.debug("Liteyuki Reload 已启用,正在监视文件更新")
|
||||
observer.start()
|
||||
|
||||
|
||||
class CodeModifiedHandler(FileSystemEventHandler):
|
||||
"""
|
||||
Handler for code file changes
|
||||
"""
|
||||
|
||||
@debounce(1)
|
||||
def on_modified(self, event):
|
||||
raise NotImplementedError("on_modified 函数在继承后必须实现")
|
||||
|
||||
def on_created(self, event):
|
||||
self.on_modified(event)
|
||||
|
||||
def on_deleted(self, event):
|
||||
self.on_modified(event)
|
||||
|
||||
def on_moved(self, event):
|
||||
self.on_modified(event)
|
||||
|
||||
def on_any_event(self, event):
|
||||
self.on_modified(event)
|
||||
|
||||
|
||||
def on_file_system_event(directories: tuple[str], recursive: bool = True, event_filter: FILTER_FUNC = None) -> Callable[[CALLBACK_FUNC], CALLBACK_FUNC]:
|
||||
"""
|
||||
注册文件系统变化监听器
|
||||
Args:
|
||||
directories: 监听目录们
|
||||
recursive: 是否递归监听子目录
|
||||
event_filter: 事件过滤器, 返回True则执行回调函数
|
||||
Returns:
|
||||
装饰器,装饰一个函数在接收到数据后执行
|
||||
"""
|
||||
|
||||
def decorator(func: CALLBACK_FUNC) -> CALLBACK_FUNC:
|
||||
def wrapper(event: FileSystemEvent):
|
||||
|
||||
if event_filter is not None and not event_filter(event):
|
||||
return
|
||||
func(event)
|
||||
|
||||
code_modified_handler = CodeModifiedHandler()
|
||||
code_modified_handler.on_modified = wrapper
|
||||
for directory in directories:
|
||||
observer.schedule(code_modified_handler, directory, recursive=recursive)
|
||||
|
||||
return func
|
||||
|
||||
return decorator
|
28
liteyuki/dev/plugin.py
Normal file
28
liteyuki/dev/plugin.py
Normal file
@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/8/18 上午5:04
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : plugin.py
|
||||
@Software: PyCharm
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
from liteyuki.bot import LiteyukiBot
|
||||
from liteyuki.config import load_config_in_default
|
||||
|
||||
|
||||
def run_plugins(*module_path: str | Path):
|
||||
"""
|
||||
运行插件,无需手动初始化bot
|
||||
Args:
|
||||
module_path: 插件路径,参考`liteyuki.load_plugin`的函数签名
|
||||
"""
|
||||
cfg = load_config_in_default()
|
||||
plugins = cfg.get("liteyuki.plugins", [])
|
||||
plugins.extend(module_path)
|
||||
cfg["liteyuki.plugins"] = plugins
|
||||
bot = LiteyukiBot(**cfg)
|
||||
bot.run()
|
@ -9,22 +9,10 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
@Software: PyCharm
|
||||
"""
|
||||
import sys
|
||||
|
||||
import loguru
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
logger = loguru.logger
|
||||
if TYPE_CHECKING:
|
||||
# avoid sphinx autodoc resolve annotation failed
|
||||
# because loguru module do not have `Logger` class actually
|
||||
from loguru import Record
|
||||
|
||||
|
||||
def default_filter(record: "Record"):
|
||||
"""默认的日志过滤器,根据 `config.log_level` 配置改变日志等级。"""
|
||||
log_level = record["extra"].get("nonebot_log_level", "INFO")
|
||||
levelno = logger.level(log_level).no if isinstance(log_level, str) else log_level
|
||||
return record["level"].no >= levelno
|
||||
|
||||
|
||||
# DEBUG日志格式
|
||||
debug_format: str = (
|
||||
@ -50,35 +38,26 @@ def get_format(level: str) -> str:
|
||||
return default_format
|
||||
|
||||
|
||||
logger = loguru.logger.bind()
|
||||
|
||||
|
||||
def init_log(config: dict):
|
||||
"""
|
||||
在语言加载完成后执行
|
||||
Returns:
|
||||
|
||||
"""
|
||||
global logger
|
||||
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=0,
|
||||
diagnose=False,
|
||||
filter=default_filter,
|
||||
format=get_format(config.get("log_level", "INFO")),
|
||||
)
|
||||
show_icon = config.get("log_icon", True)
|
||||
logger.level("DEBUG", color="<blue>", icon=f"{'🐛' if show_icon else ''}试")
|
||||
logger.level("INFO", color="<normal>", icon=f"{'ℹ️' if show_icon else ''}讯")
|
||||
logger.level("SUCCESS", color="<green>", icon=f"{'✅' if show_icon else ''}警")
|
||||
logger.level("WARNING", color="<yellow>", icon=f"{'⚠️' if show_icon else ''}误")
|
||||
logger.level("ERROR", color="<red>", icon=f"{'⭕' if show_icon else ''}成")
|
||||
|
||||
# debug = lang.get("log.debug", default="==DEBUG")
|
||||
# info = lang.get("log.info", default="===INFO")
|
||||
# success = lang.get("log.success", default="SUCCESS")
|
||||
# warning = lang.get("log.warning", default="WARNING")
|
||||
# error = lang.get("log.error", default="==ERROR")
|
||||
#
|
||||
# logger.level("DEBUG", color="<blue>", icon=f"{'🐛' if show_icon else ''}{debug}")
|
||||
# logger.level("INFO", color="<normal>", icon=f"{'ℹ️' if show_icon else ''}{info}")
|
||||
# logger.level("SUCCESS", color="<green>", icon=f"{'✅' if show_icon else ''}{success}")
|
||||
# logger.level("WARNING", color="<yellow>", icon=f"{'⚠️' if show_icon else ''}{warning}")
|
||||
# logger.level("ERROR", color="<red>", icon=f"{'⭕' if show_icon else ''}{error}")
|
||||
|
||||
init_log(config={})
|
||||
|
10
liteyuki/message/__init__.py
Normal file
10
liteyuki/message/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/8/19 下午10:44
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : __init__.py.py
|
||||
@Software: PyCharm
|
||||
"""
|
88
liteyuki/message/event.py
Normal file
88
liteyuki/message/event.py
Normal file
@ -0,0 +1,88 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/8/19 下午10:47
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : event.py
|
||||
@Software: PyCharm
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
|
||||
from liteyuki import Channel
|
||||
from liteyuki.comm.storage import shared_memory
|
||||
|
||||
|
||||
class MessageEvent:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
bot_id: str,
|
||||
message: list[dict[str, Any]] | str,
|
||||
message_type: str,
|
||||
raw_message: str,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
session_type: str,
|
||||
receive_channel: Optional[Channel["MessageEvent"]] = None,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
轻雪抽象消息事件
|
||||
Args:
|
||||
|
||||
bot_id: 机器人ID
|
||||
message: 消息,消息段数组[{type: str, data: dict[str, Any]}]
|
||||
raw_message: 原始消息(通常为纯文本的格式)
|
||||
message_type: 消息类型(private, group, other)
|
||||
|
||||
session_id: 会话ID(私聊通常为用户ID,群聊通常为群ID)
|
||||
session_type: 会话类型(private, group)
|
||||
receive_channel: 接收频道(用于回复消息)
|
||||
|
||||
data: 附加数据
|
||||
"""
|
||||
|
||||
if data is None:
|
||||
data = {}
|
||||
self.message_type = message_type
|
||||
self.data = data
|
||||
self.bot_id = bot_id
|
||||
|
||||
self.message = message
|
||||
self.raw_message = raw_message
|
||||
|
||||
self.session_id = session_id
|
||||
self.session_type = session_type
|
||||
self.user_id = user_id
|
||||
|
||||
self.receive_channel = receive_channel
|
||||
|
||||
def __str__(self):
|
||||
return (f"Event(message_type={self.message_type}, data={self.data}, bot_id={self.bot_id}, "
|
||||
f"session_id={self.session_id}, session_type={self.session_type})")
|
||||
|
||||
def reply(self, message: str | dict[str, Any]):
|
||||
"""
|
||||
回复消息
|
||||
Args:
|
||||
message:
|
||||
Returns:
|
||||
"""
|
||||
reply_event = MessageEvent(
|
||||
message_type=self.session_type,
|
||||
message=message,
|
||||
raw_message="",
|
||||
data={
|
||||
"message": message
|
||||
},
|
||||
bot_id=self.bot_id,
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
session_type=self.session_type,
|
||||
receive_channel=None
|
||||
)
|
||||
# shared_memory.publish(self.receive_channel, reply_event)
|
||||
if self.receive_channel:
|
||||
self.receive_channel.send(reply_event)
|
63
liteyuki/message/matcher.py
Normal file
63
liteyuki/message/matcher.py
Normal file
@ -0,0 +1,63 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/8/19 下午10:51
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : matcher.py
|
||||
@Software: PyCharm
|
||||
"""
|
||||
import traceback
|
||||
from typing import Any, TypeAlias, Callable, Coroutine
|
||||
|
||||
from liteyuki.message.event import MessageEvent
|
||||
from liteyuki.message.rule import Rule
|
||||
|
||||
EventHandler: TypeAlias = Callable[[MessageEvent], Coroutine[None, None, Any]]
|
||||
|
||||
|
||||
class Matcher:
|
||||
def __init__(self, rule: Rule, priority: int, block: bool):
|
||||
"""
|
||||
匹配器
|
||||
Args:
|
||||
rule: 规则
|
||||
priority: 优先级 >= 0
|
||||
block: 是否阻断后续优先级更低的匹配器
|
||||
"""
|
||||
self.rule = rule
|
||||
self.priority = priority
|
||||
self.block = block
|
||||
self.handlers: list[EventHandler] = []
|
||||
|
||||
def __str__(self):
|
||||
return f"Matcher(rule={self.rule}, priority={self.priority}, block={self.block})"
|
||||
|
||||
def handle(self) -> Callable[[EventHandler], EventHandler]:
|
||||
"""
|
||||
添加处理函数,装饰器
|
||||
Returns:
|
||||
装饰器 handler
|
||||
"""
|
||||
def decorator(handler: EventHandler) -> EventHandler:
|
||||
self.handlers.append(handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
async def run(self, event: MessageEvent) -> None:
|
||||
"""
|
||||
运行处理函数
|
||||
Args:
|
||||
event:
|
||||
Returns:
|
||||
"""
|
||||
if not await self.rule(event):
|
||||
return
|
||||
|
||||
for handler in self.handlers:
|
||||
try:
|
||||
await handler(event)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
62
liteyuki/message/on.py
Normal file
62
liteyuki/message/on.py
Normal file
@ -0,0 +1,62 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/8/19 下午10:52
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : on.py
|
||||
@Software: PyCharm
|
||||
"""
|
||||
|
||||
from queue import Queue
|
||||
|
||||
from liteyuki.comm.storage import shared_memory
|
||||
from liteyuki.log import logger
|
||||
from liteyuki.message.event import MessageEvent
|
||||
from liteyuki.message.matcher import Matcher
|
||||
from liteyuki.message.rule import Rule, empty_rule
|
||||
|
||||
_matcher_list: list[Matcher] = []
|
||||
_queue: Queue = Queue()
|
||||
|
||||
|
||||
@shared_memory.on_subscriber_receive("event_to_liteyuki")
|
||||
async def _(event: MessageEvent):
|
||||
print("AA")
|
||||
current_priority = -1
|
||||
for i, matcher in enumerate(_matcher_list):
|
||||
logger.info(f"为 Event {event} 运行 Matcher {matcher}")
|
||||
await matcher.run(event)
|
||||
# 同优先级不阻断,不同优先级阻断
|
||||
if current_priority != matcher.priority:
|
||||
current_priority = matcher.priority
|
||||
if matcher.block:
|
||||
break
|
||||
else:
|
||||
logger.info(f"无 Matcher 适配于 Event {event}")
|
||||
print("BB")
|
||||
|
||||
|
||||
def add_matcher(matcher: Matcher):
|
||||
for i, m in enumerate(_matcher_list):
|
||||
if m.priority < matcher.priority:
|
||||
_matcher_list.insert(i, matcher)
|
||||
break
|
||||
else:
|
||||
_matcher_list.append(matcher)
|
||||
|
||||
|
||||
def on_message(rule: Rule = empty_rule, priority: int = 0, block: bool = False) -> Matcher:
|
||||
matcher = Matcher(rule, priority, block)
|
||||
# 按照优先级插入
|
||||
add_matcher(matcher)
|
||||
return matcher
|
||||
|
||||
|
||||
def on_keywords(keywords: list[str], rule=empty_rule, priority: int = 0, block: bool = False) -> Matcher:
|
||||
@Rule
|
||||
async def on_keywords_rule(event: MessageEvent):
|
||||
return any(keyword in event.raw_message for keyword in keywords)
|
||||
|
||||
return on_message(on_keywords_rule & rule, priority, block)
|
51
liteyuki/message/rule.py
Normal file
51
liteyuki/message/rule.py
Normal file
@ -0,0 +1,51 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/8/19 下午10:55
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : rule.py
|
||||
@Software: PyCharm
|
||||
"""
|
||||
import inspect
|
||||
from typing import Optional, TypeAlias, Callable, Coroutine
|
||||
|
||||
from liteyuki.message.event import MessageEvent
|
||||
from liteyuki import get_config
|
||||
|
||||
_superusers: list[str] = get_config("liteyuki.superusers", [])
|
||||
|
||||
RuleHandlerFunc: TypeAlias = Callable[[MessageEvent], Coroutine[None, None, bool]]
|
||||
"""规则函数签名"""
|
||||
|
||||
|
||||
class Rule:
|
||||
def __init__(self, handler: RuleHandlerFunc):
|
||||
self.handler = handler
|
||||
|
||||
def __or__(self, other: "Rule") -> "Rule":
|
||||
async def combined_handler(event: MessageEvent) -> bool:
|
||||
return await self.handler(event) or await other.handler(event)
|
||||
|
||||
return Rule(combined_handler)
|
||||
|
||||
def __and__(self, other: "Rule") -> "Rule":
|
||||
async def combined_handler(event: MessageEvent) -> bool:
|
||||
return await self.handler(event) and await other.handler(event)
|
||||
|
||||
return Rule(combined_handler)
|
||||
|
||||
async def __call__(self, event: MessageEvent) -> bool:
|
||||
if self.handler is None:
|
||||
return True
|
||||
return await self.handler(event)
|
||||
|
||||
|
||||
@Rule
|
||||
async def empty_rule(event: MessageEvent) -> bool:
|
||||
return True
|
||||
|
||||
@Rule
|
||||
async def is_su_rule(event: MessageEvent) -> bool:
|
||||
return str(event.user_id) in _superusers
|
@ -2,9 +2,9 @@
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/8/10 下午5:18
|
||||
@Time : 2024/8/19 下午10:47
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : reloader_monitor.py
|
||||
@File : session.py
|
||||
@Software: PyCharm
|
||||
"""
|
357
liteyuki/mkdoc.py
Normal file
357
liteyuki/mkdoc.py
Normal file
@ -0,0 +1,357 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/8/19 上午6:23
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : mkdoc.py
|
||||
@Software: PyCharm
|
||||
"""
|
||||
|
||||
import ast
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
|
||||
NO_TYPE_ANY = "Any"
|
||||
NO_TYPE_HINT = "NoTypeHint"
|
||||
|
||||
|
||||
class DefType(Enum):
|
||||
FUNCTION = "function"
|
||||
METHOD = "method"
|
||||
STATIC_METHOD = "staticmethod"
|
||||
CLASS_METHOD = "classmethod"
|
||||
PROPERTY = "property"
|
||||
|
||||
|
||||
class FunctionInfo(BaseModel):
|
||||
name: str
|
||||
args: list[tuple[str, str]]
|
||||
return_type: str
|
||||
docstring: str
|
||||
source_code: str = ""
|
||||
|
||||
type: DefType
|
||||
"""若为类中def,则有"""
|
||||
is_async: bool
|
||||
|
||||
|
||||
class AttributeInfo(BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
value: Any = None
|
||||
docstring: str = ""
|
||||
|
||||
|
||||
class ClassInfo(BaseModel):
|
||||
name: str
|
||||
docstring: str
|
||||
methods: list[FunctionInfo]
|
||||
attributes: list[AttributeInfo]
|
||||
inherit: list[str]
|
||||
|
||||
|
||||
class ModuleInfo(BaseModel):
|
||||
module_path: str
|
||||
"""点分割模块路径 例如 liteyuki.bot"""
|
||||
|
||||
functions: list[FunctionInfo]
|
||||
classes: list[ClassInfo]
|
||||
attributes: list[AttributeInfo]
|
||||
docstring: str
|
||||
|
||||
|
||||
def get_relative_path(base_path: str, target_path: str) -> str:
|
||||
"""
|
||||
获取相对路径
|
||||
Args:
|
||||
base_path: 基础路径
|
||||
target_path: 目标路径
|
||||
"""
|
||||
return os.path.relpath(target_path, base_path)
|
||||
|
||||
|
||||
def write_to_files(file_data: dict[str, str]):
|
||||
"""
|
||||
输出文件
|
||||
Args:
|
||||
file_data: 文件数据 相对路径
|
||||
"""
|
||||
|
||||
for rp, data in file_data.items():
|
||||
|
||||
if not os.path.exists(os.path.dirname(rp)):
|
||||
os.makedirs(os.path.dirname(rp))
|
||||
with open(rp, 'w', encoding='utf-8') as f:
|
||||
f.write(data)
|
||||
|
||||
|
||||
def get_file_list(module_folder: str):
|
||||
file_list = []
|
||||
for root, dirs, files in os.walk(module_folder):
|
||||
for file in files:
|
||||
if file.endswith((".py", ".pyi")):
|
||||
file_list.append(os.path.join(root, file))
|
||||
return file_list
|
||||
|
||||
|
||||
def get_module_info_normal(file_path: str, ignore_private: bool = True) -> ModuleInfo:
|
||||
"""
|
||||
获取函数和类
|
||||
Args:
|
||||
file_path: Python 文件路径
|
||||
ignore_private: 忽略私有函数和类
|
||||
Returns:
|
||||
模块信息
|
||||
"""
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
file_content = file.read()
|
||||
tree = ast.parse(file_content)
|
||||
|
||||
dot_sep_module_path = file_path.replace(os.sep, '.').replace(".py", "").replace(".pyi", "")
|
||||
module_docstring = ast.get_docstring(tree)
|
||||
|
||||
module_info = ModuleInfo(
|
||||
module_path=dot_sep_module_path,
|
||||
functions=[],
|
||||
classes=[],
|
||||
attributes=[],
|
||||
docstring=module_docstring if module_docstring else ""
|
||||
)
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
# 模块函数 且不在类中 若ignore_private=True则忽略私有函数
|
||||
if not any(isinstance(parent, ast.ClassDef) for parent in ast.iter_child_nodes(node)) and (not ignore_private or not node.name.startswith('_')):
|
||||
|
||||
# 判断第一个参数是否为self或cls,后期用其他办法优化
|
||||
if node.args.args:
|
||||
first_arg = node.args.args[0]
|
||||
if first_arg.arg in ("self", "cls"):
|
||||
continue
|
||||
|
||||
function_docstring = ast.get_docstring(node)
|
||||
|
||||
func_info = FunctionInfo(
|
||||
name=node.name,
|
||||
args=[(arg.arg, ast.unparse(arg.annotation) if arg.annotation else NO_TYPE_ANY) for arg in node.args.args],
|
||||
return_type=ast.unparse(node.returns) if node.returns else "None",
|
||||
docstring=function_docstring if function_docstring else "",
|
||||
type=DefType.FUNCTION,
|
||||
is_async=isinstance(node, ast.AsyncFunctionDef),
|
||||
source_code=ast.unparse(node)
|
||||
)
|
||||
module_info.functions.append(func_info)
|
||||
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
# 模块类
|
||||
class_docstring = ast.get_docstring(node)
|
||||
|
||||
class_info = ClassInfo(
|
||||
name=node.name,
|
||||
docstring=class_docstring if class_docstring else "",
|
||||
methods=[],
|
||||
attributes=[],
|
||||
inherit=[ast.unparse(base) for base in node.bases]
|
||||
)
|
||||
|
||||
for class_node in node.body:
|
||||
# methods [instance, static, class, property],保留__init__方法
|
||||
if isinstance(class_node, ast.FunctionDef) and (not ignore_private or not class_node.name.startswith('_') or class_node.name == "__init__"):
|
||||
method_docstring = ast.get_docstring(class_node)
|
||||
def_type = DefType.METHOD
|
||||
if class_node.decorator_list:
|
||||
if any(isinstance(decorator, ast.Name) and decorator.id == "staticmethod" for decorator in class_node.decorator_list):
|
||||
def_type = DefType.STATIC_METHOD
|
||||
elif any(isinstance(decorator, ast.Name) and decorator.id == "classmethod" for decorator in class_node.decorator_list):
|
||||
def_type = DefType.CLASS_METHOD
|
||||
elif any(isinstance(decorator, ast.Name) and decorator.id == "property" for decorator in class_node.decorator_list):
|
||||
def_type = DefType.PROPERTY
|
||||
class_info.methods.append(FunctionInfo(
|
||||
name=class_node.name,
|
||||
args=[(arg.arg, ast.unparse(arg.annotation) if arg.annotation else NO_TYPE_ANY) for arg in class_node.args.args],
|
||||
return_type=ast.unparse(class_node.returns) if class_node.returns else "None",
|
||||
docstring=method_docstring if method_docstring else "",
|
||||
type=def_type,
|
||||
is_async=isinstance(class_node, ast.AsyncFunctionDef),
|
||||
source_code=ast.unparse(class_node)
|
||||
))
|
||||
# attributes
|
||||
elif isinstance(class_node, ast.Assign):
|
||||
for target in class_node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
class_info.attributes.append(AttributeInfo(
|
||||
name=target.id,
|
||||
type=ast.unparse(class_node.value)
|
||||
))
|
||||
module_info.classes.append(class_info)
|
||||
|
||||
elif isinstance(node, ast.Assign):
|
||||
# 检查是否在类或函数中
|
||||
if not any(isinstance(parent, (ast.ClassDef, ast.FunctionDef)) for parent in ast.iter_child_nodes(node)):
|
||||
# 模块属性变量
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and (not ignore_private or not target.id.startswith('_')):
|
||||
attr_type = NO_TYPE_HINT
|
||||
if isinstance(node.value, ast.AnnAssign) and node.value.annotation:
|
||||
attr_type = ast.unparse(node.value.annotation)
|
||||
module_info.attributes.append(AttributeInfo(
|
||||
name=target.id,
|
||||
type=attr_type,
|
||||
value=ast.unparse(node.value) if node.value else None
|
||||
))
|
||||
|
||||
return module_info
|
||||
|
||||
|
||||
def generate_markdown(module_info: ModuleInfo, front_matter=None, lang: str = "zh-CN") -> str:
|
||||
"""
|
||||
生成模块的Markdown
|
||||
你可在此自定义生成的Markdown格式
|
||||
Args:
|
||||
module_info: 模块信息
|
||||
front_matter: 自定义选项title, index, icon, category
|
||||
lang: 语言
|
||||
Returns:
|
||||
Markdown 字符串
|
||||
"""
|
||||
|
||||
content = ""
|
||||
|
||||
front_matter = "---\n" + "\n".join([f"{k}: {v}" for k, v in front_matter.items()]) + "\n---\n\n"
|
||||
|
||||
content += front_matter
|
||||
|
||||
# 模块函数
|
||||
for func in module_info.functions:
|
||||
args_with_type = [f"{arg[0]}: {arg[1]}" if arg[1] else arg[0] for arg in func.args]
|
||||
content += f"### ***{'async ' if func.is_async else ''}def*** `{func.name}({', '.join(args_with_type)}) -> {func.return_type}`\n\n"
|
||||
|
||||
func.docstring = func.docstring.replace("\n", "\n\n")
|
||||
content += f"{func.docstring}\n\n"
|
||||
|
||||
# 函数源代码可展开区域
|
||||
content += f"<details>\n<summary>源代码</summary>\n\n```python\n{func.source_code}\n```\n</details>\n\n"
|
||||
|
||||
# 类
|
||||
for cls in module_info.classes:
|
||||
if cls.inherit:
|
||||
inherit = f"({', '.join(cls.inherit)})" if cls.inherit else ""
|
||||
content += f"### ***class*** `{cls.name}{inherit}`\n\n"
|
||||
else:
|
||||
content += f"### ***class*** `{cls.name}`\n\n"
|
||||
|
||||
cls.docstring = cls.docstring.replace("\n", "\n\n")
|
||||
content += f"{cls.docstring}\n\n"
|
||||
for method in cls.methods:
|
||||
# 类函数
|
||||
|
||||
if method.type != DefType.METHOD:
|
||||
args_with_type = [f"{arg[0]}: {arg[1]}" if arg[1] else arg[0] for arg in method.args]
|
||||
content += f"###   ***@{method.type.value}***\n"
|
||||
else:
|
||||
# self不加类型提示
|
||||
args_with_type = [f"{arg[0]}: {arg[1]}" if arg[1] and arg[0] != "self" else arg[0] for arg in method.args]
|
||||
content += f"###   ***{'async ' if method.is_async else ''}def*** `{method.name}({', '.join(args_with_type)}) -> {method.return_type}`\n\n"
|
||||
|
||||
method.docstring = method.docstring.replace("\n", "\n\n")
|
||||
content += f" {method.docstring}\n\n"
|
||||
# 函数源代码可展开区域
|
||||
|
||||
if lang == "zh-CN":
|
||||
TEXT_SOURCE_CODE = "源代码"
|
||||
else:
|
||||
TEXT_SOURCE_CODE = "Source Code"
|
||||
|
||||
content += f"<details>\n<summary>{TEXT_SOURCE_CODE}</summary>\n\n```python\n{method.source_code}\n```\n</details>\n\n"
|
||||
for attr in cls.attributes:
|
||||
content += f"###   ***attr*** `{attr.name}: {attr.type}`\n\n"
|
||||
|
||||
# 模块属性
|
||||
for attr in module_info.attributes:
|
||||
if attr.type == NO_TYPE_HINT:
|
||||
content += f"### ***var*** `{attr.name} = {attr.value}`\n\n"
|
||||
else:
|
||||
content += f"### ***var*** `{attr.name}: {attr.type} = {attr.value}`\n\n"
|
||||
|
||||
attr.docstring = attr.docstring.replace("\n", "\n\n")
|
||||
content += f"{attr.docstring}\n\n"
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def generate_docs(module_folder: str, output_dir: str, with_top: bool = False, lang: str = "zh-CN", ignored_paths=None):
|
||||
"""
|
||||
生成文档
|
||||
Args:
|
||||
module_folder: 模块文件夹
|
||||
output_dir: 输出文件夹
|
||||
with_top: 是否包含顶层文件夹 False时例如docs/api/module_a, docs/api/module_b, True时例如docs/api/module/module_a.md, docs/api/module/module_b.md
|
||||
ignored_paths: 忽略的路径
|
||||
lang: 语言
|
||||
"""
|
||||
if ignored_paths is None:
|
||||
ignored_paths = []
|
||||
file_data: dict[str, str] = {} # 路径 -> 字串
|
||||
|
||||
file_list = get_file_list(module_folder)
|
||||
|
||||
# 清理输出目录
|
||||
shutil.rmtree(output_dir, ignore_errors=True)
|
||||
os.mkdir(output_dir)
|
||||
|
||||
replace_data = {
|
||||
"__init__": "README",
|
||||
".py" : ".md",
|
||||
}
|
||||
|
||||
for pyfile_path in file_list:
|
||||
if any(ignored_path.replace("\\", "/") in pyfile_path.replace("\\", "/") for ignored_path in ignored_paths):
|
||||
continue
|
||||
|
||||
no_module_name_pyfile_path = get_relative_path(module_folder, pyfile_path) # 去头路径
|
||||
|
||||
# markdown相对路径
|
||||
rel_md_path = pyfile_path if with_top else no_module_name_pyfile_path
|
||||
for rk, rv in replace_data.items():
|
||||
rel_md_path = rel_md_path.replace(rk, rv)
|
||||
|
||||
abs_md_path = os.path.join(output_dir, rel_md_path)
|
||||
|
||||
# 获取模块信息
|
||||
module_info = get_module_info_normal(pyfile_path)
|
||||
|
||||
# 生成markdown
|
||||
|
||||
if "README" in abs_md_path:
|
||||
front_matter = {
|
||||
"title" : module_info.module_path.replace(".__init__", "").replace("_", "\\n"),
|
||||
"index" : "true",
|
||||
"icon" : "laptop-code",
|
||||
"category": "API"
|
||||
}
|
||||
else:
|
||||
front_matter = {
|
||||
"title" : module_info.module_path.replace("_", "\\n"),
|
||||
"order" : "1",
|
||||
"icon" : "laptop-code",
|
||||
"category": "API"
|
||||
}
|
||||
|
||||
md_content = generate_markdown(module_info, front_matter)
|
||||
print(f"Generate {pyfile_path} -> {abs_md_path}")
|
||||
file_data[abs_md_path] = md_content
|
||||
|
||||
write_to_files(file_data)
|
||||
|
||||
|
||||
# 入口脚本
|
||||
if __name__ == '__main__':
|
||||
# 这里填入你的模块路径
|
||||
generate_docs('liteyuki', 'docs/dev/api', with_top=False, ignored_paths=["liteyuki/plugins"], lang="zh-CN")
|
||||
generate_docs('liteyuki', 'docs/en/dev/api', with_top=False, ignored_paths=["liteyuki/plugins"], lang="en")
|
@ -1,11 +1,12 @@
|
||||
from liteyuki.plugin.model import Plugin, PluginMetadata
|
||||
from liteyuki.plugin.model import Plugin, PluginMetadata, PluginType
|
||||
from liteyuki.plugin.load import load_plugin, load_plugins, _plugins
|
||||
|
||||
__all__ = [
|
||||
"PluginMetadata",
|
||||
"Plugin",
|
||||
"load_plugin",
|
||||
"load_plugins",
|
||||
"PluginMetadata",
|
||||
"Plugin",
|
||||
"PluginType",
|
||||
"load_plugin",
|
||||
"load_plugins",
|
||||
]
|
||||
|
||||
|
||||
|
@ -10,14 +10,12 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
"""
|
||||
import os
|
||||
import traceback
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from nonebot import logger
|
||||
|
||||
from liteyuki.plugin.model import Plugin, PluginMetadata
|
||||
from importlib import import_module
|
||||
|
||||
from liteyuki.log import logger
|
||||
from liteyuki.plugin.model import Plugin, PluginMetadata, PluginType
|
||||
from liteyuki.utils import path_to_module_name
|
||||
|
||||
_plugins: dict[str, Plugin] = {}
|
||||
@ -25,6 +23,7 @@ _plugins: dict[str, Plugin] = {}
|
||||
__all__ = [
|
||||
"load_plugin",
|
||||
"load_plugins",
|
||||
"_plugins",
|
||||
]
|
||||
|
||||
|
||||
@ -35,45 +34,92 @@ def load_plugin(module_path: str | Path) -> Optional[Plugin]:
|
||||
module_path: 插件名称 `path.to.your.plugin`
|
||||
或插件路径 `pathlib.Path(path/to/your/plugin)`
|
||||
"""
|
||||
module_path = path_to_module_name(Path(module_path)) if isinstance(module_path, Path) else module_path
|
||||
module_path = (
|
||||
path_to_module_name(Path(module_path))
|
||||
if isinstance(module_path, Path)
|
||||
else module_path
|
||||
)
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
_plugins[module.__name__] = Plugin(
|
||||
name=module.__name__,
|
||||
module=module,
|
||||
module_name=module_path,
|
||||
metadata=module.__dict__.get("__plugin_metadata__", None)
|
||||
)
|
||||
if module.__dict__.get("__plugin_metadata__", None):
|
||||
metadata: "PluginMetadata" = module.__dict__["__plugin_metadata__"]
|
||||
display_name = module.__name__.split(".")[-1]
|
||||
elif module.__dict__.get("__liteyuki_plugin_meta__", None):
|
||||
metadata: "PluginMetadata" = module.__dict__["__liteyuki_plugin_meta__"]
|
||||
display_name = format_display_name(
|
||||
f"{metadata.name}({module.__name__.split('.')[-1]})", metadata.type
|
||||
)
|
||||
elif module.__dict__.get("__plugin_meta__", None):
|
||||
metadata: "PluginMetadata" = module.__dict__["__plugin_meta__"]
|
||||
display_name = format_display_name(
|
||||
f"{metadata.name}({module.__name__.split('.')[-1]})", metadata.type
|
||||
)
|
||||
else:
|
||||
|
||||
logger.opt(colors=True).warning(
|
||||
f'轻雪插件 "{module.__name__}" 的元信息未指定,将采用空的元信息'
|
||||
)
|
||||
|
||||
metadata = PluginMetadata(
|
||||
name=module.__name__,
|
||||
)
|
||||
display_name = module.__name__.split(".")[-1]
|
||||
|
||||
_plugins[module.__name__].metadata = metadata
|
||||
|
||||
logger.opt(colors=True).success(
|
||||
f'成功加载 轻雪插件 "<y>{module.__name__.split(".")[-1]}</y>"'
|
||||
f'成功加载轻雪插件 "{display_name}"'
|
||||
)
|
||||
return _plugins[module.__name__]
|
||||
|
||||
except Exception as e:
|
||||
logger.opt(colors=True).success(
|
||||
f'未能加载 轻雪插件 "<r>{module_path}</r>"'
|
||||
f'无法加载轻雪插件 "<r>{module_path}</r>"'
|
||||
)
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
def load_plugins(*plugin_dir: str) -> set[Plugin]:
|
||||
def load_plugins(*plugin_dir: str, ignore_warning: bool = True) -> set[Plugin]:
|
||||
"""导入文件夹下多个插件
|
||||
|
||||
参数:
|
||||
plugin_dir: 文件夹路径
|
||||
ignore_warning: 是否忽略警告,通常是目录不存在或目录为空
|
||||
"""
|
||||
plugins = set()
|
||||
for dir_path in plugin_dir:
|
||||
# 遍历每一个文件夹下的py文件和包含__init__.py的文件夹,不递归
|
||||
if not os.path.exists(dir_path):
|
||||
if not ignore_warning:
|
||||
logger.warning(f"插件目录 '{dir_path}' 不存在")
|
||||
continue
|
||||
|
||||
if not os.listdir(dir_path):
|
||||
if not ignore_warning:
|
||||
logger.warning(f"插件目录 '{dir_path}' 为空")
|
||||
continue
|
||||
|
||||
if not os.path.isdir(dir_path):
|
||||
if not ignore_warning:
|
||||
logger.warning(f"本应是插件目录的路径 '{dir_path}' 并非如此")
|
||||
continue
|
||||
|
||||
for f in os.listdir(dir_path):
|
||||
path = Path(os.path.join(dir_path, f))
|
||||
|
||||
module_name = None
|
||||
if os.path.isfile(path) and f.endswith('.py') and f != '__init__.py':
|
||||
if os.path.isfile(path) and f.endswith(".py") and f != "__init__.py":
|
||||
module_name = f"{path_to_module_name(Path(dir_path))}.{f[:-3]}"
|
||||
|
||||
elif os.path.isdir(path) and os.path.exists(os.path.join(path, '__init__.py')):
|
||||
elif os.path.isdir(path) and os.path.exists(
|
||||
os.path.join(path, "__init__.py")
|
||||
):
|
||||
module_name = path_to_module_name(path)
|
||||
|
||||
if module_name:
|
||||
@ -81,3 +127,27 @@ def load_plugins(*plugin_dir: str) -> set[Plugin]:
|
||||
if _plugins.get(module_name):
|
||||
plugins.add(_plugins[module_name])
|
||||
return plugins
|
||||
|
||||
|
||||
def format_display_name(display_name: str, plugin_type: PluginType) -> str:
|
||||
"""
|
||||
设置插件名称颜色,根据不同类型插件设置颜色
|
||||
Args:
|
||||
display_name: 插件名称
|
||||
plugin_type: 插件类型
|
||||
|
||||
Returns:
|
||||
str: 设置后的插件名称 <y>name</y>
|
||||
"""
|
||||
color = "y"
|
||||
match plugin_type:
|
||||
case PluginType.APPLICATION:
|
||||
color = "m"
|
||||
case PluginType.TEST:
|
||||
color = "g"
|
||||
case PluginType.MODULE:
|
||||
color = "e"
|
||||
case PluginType.SERVICE:
|
||||
color = "c"
|
||||
|
||||
return f"<{color}>{display_name} [{plugin_type.name}]</{color}>"
|
||||
|
@ -1,9 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All rights reserved
|
||||
|
||||
版权所有 © 2020-2024 神羽SnowyKami & 金羿Eilles with LiteyukiStudio & TriM Org.
|
||||
保留所有权利
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/7/23 下午11:59
|
||||
@Author : snowykami
|
||||
|
@ -8,22 +8,61 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
@File : model.py
|
||||
@Software: PyCharm
|
||||
"""
|
||||
from enum import Enum
|
||||
from types import ModuleType
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PluginType(Enum):
|
||||
"""
|
||||
插件类型枚举值
|
||||
"""
|
||||
APPLICATION = "application"
|
||||
"""应用端:例如NoneBot"""
|
||||
|
||||
SERVICE = "service"
|
||||
"""服务端:例如AI绘画后端"""
|
||||
|
||||
MODULE = "module"
|
||||
"""模块:导出对象给其他插件使用"""
|
||||
|
||||
UNCLASSIFIED = "unclassified"
|
||||
"""未分类:默认值"""
|
||||
|
||||
TEST = "test"
|
||||
"""测试:测试插件"""
|
||||
|
||||
|
||||
class PluginMetadata(BaseModel):
|
||||
"""
|
||||
轻雪插件元数据,由插件编写者提供,name为必填项
|
||||
Attributes:
|
||||
----------
|
||||
|
||||
name: str
|
||||
插件名称
|
||||
description: str
|
||||
插件描述
|
||||
usage: str
|
||||
插件使用方法
|
||||
type: str
|
||||
插件类型
|
||||
author: str
|
||||
插件作者
|
||||
homepage: str
|
||||
插件主页
|
||||
extra: dict[str, Any]
|
||||
额外信息
|
||||
"""
|
||||
name: str
|
||||
description: str = ""
|
||||
usage: str = ""
|
||||
type: str = ""
|
||||
type: PluginType = PluginType.UNCLASSIFIED
|
||||
author: str = ""
|
||||
homepage: str = ""
|
||||
running_in_main: bool = True # 是否在主进程运行
|
||||
extra: dict[str, Any] = {}
|
||||
|
||||
|
||||
class Plugin(BaseModel):
|
||||
|
5
liteyuki/plugins/__init__.py
Normal file
5
liteyuki/plugins/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
此模块为内置插件文件夹,用于存放内置插件。
|
||||
This module is the built-in plugin folder, used to store built-in plugins.
|
||||
"""
|
@ -1,55 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
#
|
||||
# @Time : 2024/7/22 上午11:25
|
||||
# @Author : snowykami
|
||||
# @Email : snowykami@outlook.com
|
||||
# @File : asa.py
|
||||
# @Software: PyCharm
|
||||
import asyncio
|
||||
|
||||
from liteyuki.plugin import PluginMetadata
|
||||
from liteyuki import get_bot, logger
|
||||
from liteyuki.comm.channel import get_channel
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="lifespan_monitor",
|
||||
)
|
||||
|
||||
bot = get_bot()
|
||||
nbp_chan = get_channel("nonebot-passive")
|
||||
mbp_chan = get_channel("melobot-passive")
|
||||
|
||||
|
||||
@bot.on_before_start
|
||||
def _():
|
||||
logger.info("生命周期监控器:启动前")
|
||||
|
||||
|
||||
@bot.on_before_shutdown
|
||||
def _():
|
||||
print(get_channel("main"))
|
||||
logger.info("生命周期监控器:停止前")
|
||||
|
||||
|
||||
@bot.on_before_restart
|
||||
def _():
|
||||
logger.info("生命周期监控器:重启前")
|
||||
|
||||
|
||||
@bot.on_after_start
|
||||
def _():
|
||||
logger.info("生命周期监控器:启动后")
|
||||
|
||||
|
||||
@bot.on_after_start
|
||||
async def _():
|
||||
logger.info("生命周期监控器:启动后")
|
||||
|
||||
|
||||
|
||||
# @mbp_chan.on_receive()
|
||||
# @nbp_chan.on_receive()
|
||||
# async def _(data):
|
||||
# print("主进程收到数据", data)
|
19
liteyuki/plugins/liteecho.py
Normal file
19
liteyuki/plugins/liteecho.py
Normal file
@ -0,0 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/8/22 下午12:31
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : liteecho.py
|
||||
@Software: PyCharm
|
||||
"""
|
||||
|
||||
from liteyuki.message.on import on_startswith
|
||||
from liteyuki.message.event import MessageEvent
|
||||
from liteyuki.message.rule import is_su_rule
|
||||
|
||||
|
||||
@on_startswith(["ryounecho", "ryeco"], rule=is_su_rule).handle()
|
||||
async def liteecho(event: MessageEvent):
|
||||
event.reply(event.raw_message.strip()[8:].strip())
|
32
liteyuki/plugins/plugin_loader/__init__.py
Normal file
32
liteyuki/plugins/plugin_loader/__init__.py
Normal file
@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/8/11 下午10:02
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : __init__.py.py
|
||||
@Software: PyCharm
|
||||
"""
|
||||
from liteyuki import get_config, load_plugin
|
||||
from liteyuki.plugin import PluginMetadata, load_plugins, PluginType
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="外部轻雪插件加载器",
|
||||
description="插件加载器,用于加载轻雪原生插件",
|
||||
type=PluginType.SERVICE
|
||||
)
|
||||
|
||||
|
||||
def default_plugins_loader():
|
||||
"""
|
||||
默认插件加载器,应在初始化时调用
|
||||
"""
|
||||
for plugin in get_config("liteyuki.plugins", []):
|
||||
load_plugin(plugin)
|
||||
|
||||
for plugin_dir in get_config("liteyuki.plugin_dirs", ["src/liteyuki_plugins"]):
|
||||
load_plugins(plugin_dir)
|
||||
|
||||
|
||||
default_plugins_loader()
|
@ -1,57 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||
|
||||
@Time : 2024/8/10 下午11:25
|
||||
@Author : snowykami
|
||||
@Email : snowykami@outlook.com
|
||||
@File : register_service.py
|
||||
@Software: PyCharm
|
||||
"""
|
||||
import json
|
||||
import os.path
|
||||
import platform
|
||||
|
||||
import requests
|
||||
from git import Repo
|
||||
from liteyuki.plugin import PluginMetadata
|
||||
from liteyuki import get_bot, logger
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="注册服务",
|
||||
)
|
||||
|
||||
liteyuki = get_bot()
|
||||
commit_hash = Repo(".").head.commit.hexsha
|
||||
|
||||
|
||||
def register_bot():
|
||||
url = "https://api.liteyuki.icu/register"
|
||||
data = {
|
||||
"name" : "尹灵温|轻雪-睿乐",
|
||||
"version" : "即时更新",
|
||||
"hash" : commit_hash,
|
||||
"version_i": 99,
|
||||
"python" : f"{platform.python_implementation()} {platform.python_version()}",
|
||||
"os" : f"{platform.system()} {platform.version()} {platform.machine()}"
|
||||
}
|
||||
try:
|
||||
logger.info("正在等待 Liteyuki 注册服务器……")
|
||||
resp = requests.post(url, json=data, timeout=(10, 15))
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
if liteyuki_id := data.get("liteyuki_id"):
|
||||
with open("data/liteyuki/liteyuki.json", "wb") as f:
|
||||
f.write(json.dumps(data).encode("utf-8"))
|
||||
logger.success("成功将 {} 注册到 Liteyuki 服务器".format(liteyuki_id))
|
||||
else:
|
||||
raise ValueError(f"无法向 Liteyuki 服务器注册:{data}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"虽然向 Liteyuki 服务器注册失败,但无关紧要:{e}")
|
||||
|
||||
|
||||
@liteyuki.on_before_start
|
||||
async def _():
|
||||
if not os.path.exists("data/liteyuki/liteyuki.json"):
|
||||
register_bot()
|
@ -4,11 +4,15 @@
|
||||
"""
|
||||
import asyncio
|
||||
import inspect
|
||||
import multiprocessing
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
from liteyuki.log import logger
|
||||
|
||||
IS_MAIN_PROCESS = multiprocessing.current_process().name == "MainProcess"
|
||||
|
||||
|
||||
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
||||
"""
|
||||
@ -39,7 +43,7 @@ def run_coroutine(*coro: Coroutine):
|
||||
# 检测是否有现有的事件循环
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_running():
|
||||
# 如果事件循环正在运行,创建任务
|
||||
for c in coro:
|
||||
@ -59,6 +63,18 @@ def run_coroutine(*coro: Coroutine):
|
||||
logger.error(f"协程异常:{e}")
|
||||
|
||||
|
||||
def run_coroutine_in_thread(*coro: Coroutine):
|
||||
"""
|
||||
在新线程中运行协程
|
||||
Args:
|
||||
coro:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
threading.Thread(target=run_coroutine, args=coro, daemon=True).start()
|
||||
|
||||
|
||||
def path_to_module_name(path: Path) -> str:
|
||||
"""
|
||||
转换路径为模块名
|
||||
@ -72,3 +88,19 @@ def path_to_module_name(path: Path) -> str:
|
||||
return ".".join(rel_path.parts[:-1])
|
||||
else:
|
||||
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
|
||||
|
||||
|
||||
def async_wrapper(func: Callable[..., Any]) -> Callable[..., Coroutine]:
|
||||
"""
|
||||
异步包装器
|
||||
Args:
|
||||
func: Sync Callable
|
||||
Returns:
|
||||
Coroutine: Asynchronous Callable
|
||||
"""
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
wrapper.__signature__ = inspect.signature(func)
|
||||
return wrapper
|
||||
|
Reference in New Issue
Block a user