重构Caller类,移除泛型参数;添加函数签名复制装饰器

This commit is contained in:
2024-12-15 17:08:02 +08:00
parent af9a5e3c96
commit 0379789bec
8 changed files with 208 additions and 72 deletions

View File

@ -23,6 +23,8 @@ from nonebot.permission import SUPERUSER
from nonebot.rule import Rule, to_me
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
from nonebot_plugin_marshoai.plugin.func_call.caller import get_function_calls
from .metadata import metadata
from .models import MarshoContext, MarshoTools
from .plugin import _plugins, load_plugins
@ -103,10 +105,9 @@ async def _preload_tools():
@driver.on_startup
async def _preload_plugins():
"""启动钩子加载插件"""
marshoai_plugin_dirs = config.marshoai_plugin_dirs
marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins")
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins") # 预置插件目录
load_plugins(*marshoai_plugin_dirs)
logger.opt(colors=True).info(f"已加载 <c>{len(_plugins)}</c> 个小棉插件")
@add_usermsg_cmd.handle()
@ -266,7 +267,10 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
client=client,
model_name=model_name,
msg=context_msg + [UserMessage(content=usermsg)], # type: ignore
tools=tools.get_tools_list(),
tools=tools.get_tools_list()
+ list(
map(lambda v: v.data(), get_function_calls().values())
), # TODO 临时追加函数,后期优化
)
# await UniMessage(str(response)).send()
choice = response.choices[0]
@ -315,9 +319,23 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
await UniMessage(
f"调用函数 {tool_call.function.name} ,参数为 {function_args}"
).send()
func_return = await tools.call(
tool_call.function.name, function_args
) # 获取返回值
# TODO 临时追加插件函数,若工具中没有则调用插件函数
if tools.has_function(tool_call.function.name):
logger.debug(f"调用工具函数 {tool_call.function.name}")
func_return = await tools.call(
tool_call.function.name, function_args
) # 获取返回值
else:
if caller := get_function_calls().get(
tool_call.function.name
):
logger.debug(f"调用插件函数 {tool_call.function.name}")
# 实现依赖注入检查函数参数及参数注解类型对Event类型的参数进行注入
caller.event = event
func_return = await caller.call(**function_args)
else:
logger.error(f"未找到函数 {tool_call.function.name}")
func_return = f"未找到函数 {tool_call.function.name}"
tool_msg.append(
ToolMessage(tool_call_id=tool_call.id, content=func_return) # type: ignore
)