添加依赖注入支持,重构函数调用上下文,优化插件加载机制

This commit is contained in:
2024-12-17 13:25:30 +08:00
parent a0f657b239
commit a2c4fb220e
4 changed files with 132 additions and 42 deletions

View File

@ -18,17 +18,18 @@ from azure.ai.inference.models import (
from azure.core.credentials import AzureKeyCredential
from nonebot import get_driver, logger, on_command, on_message
from nonebot.adapters import Bot, Event, Message
from nonebot.matcher import Matcher
from nonebot.params import CommandArg
from nonebot.permission import SUPERUSER
from nonebot.rule import Rule, to_me
from nonebot.typing import T_State
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
from nonebot_plugin_marshoai.plugin.func_call.caller import get_function_calls
from .metadata import metadata
from .models import MarshoContext, MarshoTools
from .plugin import _plugins, load_plugins
from .plugin import _plugins, load_plugin, load_plugins
from .plugin.func_call.caller import get_function_calls
from .plugin.func_call.models import SessionContext
from .util import *
@ -115,10 +116,15 @@ async def _preload_plugins():
"""启动钩子加载插件"""
if config.marshoai_enable_plugins:
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
"""加载内置插件"""
marshoai_plugin_dirs.insert(
0, Path(__file__).parent / "plugins"
) # 预置插件目录
"""加载指定目录插件"""
load_plugins(*marshoai_plugin_dirs)
"""加载sys.path下的包"""
for package_name in config.marshoai_plugins:
load_plugin(package_name)
logger.info(
"如果启用小棉插件后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_PLUGINS 设为 false。"
)
@ -227,8 +233,10 @@ async def marsho(
event: Event,
bot: Bot,
state: T_State,
matcher: Matcher,
text: Optional[UniMsg] = None,
):
global target_list
if event.get_message().extract_plain_text() and (
not text
@ -324,7 +332,7 @@ async def marsho(
)
return
elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS:
# function call
# 需要获取额外信息,调用函数工具
tool_msg = []
while choice.message.tool_calls != None:
@ -360,12 +368,14 @@ async def marsho(
logger.debug(f"调用插件函数 {tool_call.function.name}")
# 权限检查,规则检查 TODO
# 实现依赖注入检查函数参数及参数注解类型对Event类型的参数进行注入
caller.event, caller.bot, caller.state = (
event,
bot,
state,
)
func_return = await caller.call(**function_args)
func_return = await caller.with_ctx(
SessionContext(
bot=bot,
event=event,
state=state,
matcher=matcher,
)
).call(**function_args)
else:
logger.error(f"未找到函数 {tool_call.function.name}")
func_return = f"未找到函数 {tool_call.function.name}"