mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-08-01 16:39:52 +00:00
✨ 添加依赖注入支持,重构函数调用上下文,优化插件加载机制
This commit is contained in:
@ -18,17 +18,18 @@ from azure.ai.inference.models import (
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from nonebot import get_driver, logger, on_command, on_message
|
||||
from nonebot.adapters import Bot, Event, Message
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.params import CommandArg
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot.rule import Rule, to_me
|
||||
from nonebot.typing import T_State
|
||||
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
|
||||
|
||||
from nonebot_plugin_marshoai.plugin.func_call.caller import get_function_calls
|
||||
|
||||
from .metadata import metadata
|
||||
from .models import MarshoContext, MarshoTools
|
||||
from .plugin import _plugins, load_plugins
|
||||
from .plugin import _plugins, load_plugin, load_plugins
|
||||
from .plugin.func_call.caller import get_function_calls
|
||||
from .plugin.func_call.models import SessionContext
|
||||
from .util import *
|
||||
|
||||
|
||||
@ -115,10 +116,15 @@ async def _preload_plugins():
|
||||
"""启动钩子加载插件"""
|
||||
if config.marshoai_enable_plugins:
|
||||
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
|
||||
"""加载内置插件"""
|
||||
marshoai_plugin_dirs.insert(
|
||||
0, Path(__file__).parent / "plugins"
|
||||
) # 预置插件目录
|
||||
"""加载指定目录插件"""
|
||||
load_plugins(*marshoai_plugin_dirs)
|
||||
"""加载sys.path下的包"""
|
||||
for package_name in config.marshoai_plugins:
|
||||
load_plugin(package_name)
|
||||
logger.info(
|
||||
"如果启用小棉插件后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_PLUGINS 设为 false。"
|
||||
)
|
||||
@ -227,8 +233,10 @@ async def marsho(
|
||||
event: Event,
|
||||
bot: Bot,
|
||||
state: T_State,
|
||||
matcher: Matcher,
|
||||
text: Optional[UniMsg] = None,
|
||||
):
|
||||
|
||||
global target_list
|
||||
if event.get_message().extract_plain_text() and (
|
||||
not text
|
||||
@ -324,7 +332,7 @@ async def marsho(
|
||||
)
|
||||
return
|
||||
elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS:
|
||||
|
||||
# function call
|
||||
# 需要获取额外信息,调用函数工具
|
||||
tool_msg = []
|
||||
while choice.message.tool_calls != None:
|
||||
@ -360,12 +368,14 @@ async def marsho(
|
||||
logger.debug(f"调用插件函数 {tool_call.function.name}")
|
||||
# 权限检查,规则检查 TODO
|
||||
# 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入
|
||||
caller.event, caller.bot, caller.state = (
|
||||
event,
|
||||
bot,
|
||||
state,
|
||||
)
|
||||
func_return = await caller.call(**function_args)
|
||||
func_return = await caller.with_ctx(
|
||||
SessionContext(
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
matcher=matcher,
|
||||
)
|
||||
).call(**function_args)
|
||||
else:
|
||||
logger.error(f"未找到函数 {tool_call.function.name}")
|
||||
func_return = f"未找到函数 {tool_call.function.name}"
|
||||
|
Reference in New Issue
Block a user