mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-08-01 16:39:52 +00:00
🛠️添加小棉工具功能,移除MARSHOAI_ENABLE_TIME_PROMPT配置项
This commit is contained in:
@ -1,15 +1,18 @@
|
||||
import contextlib
|
||||
import traceback
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from arclet.alconna import Alconna, Args, AllParam
|
||||
from azure.ai.inference.models import (
|
||||
UserMessage,
|
||||
AssistantMessage,
|
||||
ToolMessage,
|
||||
TextContentItem,
|
||||
ImageContentItem,
|
||||
ImageUrl,
|
||||
CompletionsFinishReason,
|
||||
ChatCompletionsToolCall,
|
||||
)
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from nonebot import on_command, logger
|
||||
@ -18,11 +21,12 @@ from nonebot.params import CommandArg
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot_plugin_alconna import on_alconna, MsgTarget
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg
|
||||
import nonebot_plugin_localstore as store
|
||||
from nonebot import get_driver
|
||||
|
||||
from .constants import *
|
||||
from .metadata import metadata
|
||||
from .models import MarshoContext
|
||||
from .models import MarshoContext, MarshoTools
|
||||
from .util import *
|
||||
|
||||
driver = get_driver()
|
||||
@ -53,11 +57,18 @@ refresh_data_cmd = on_command("refresh_data", permission=SUPERUSER)
|
||||
|
||||
model_name = config.marshoai_default_model
|
||||
context = MarshoContext()
|
||||
tools = MarshoTools()
|
||||
token = config.marshoai_token
|
||||
endpoint = config.marshoai_azure_endpoint
|
||||
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
|
||||
target_list = [] # 记录需保存历史上下文的列表
|
||||
|
||||
@driver.on_startup
|
||||
async def _preload_tools():
|
||||
tools_dir = store.get_plugin_data_dir() / "tools"
|
||||
os.makedirs(tools_dir, exist_ok=True)
|
||||
tools.load_tools(Path(__file__).parent / "tools")
|
||||
tools.load_tools(store.get_plugin_data_dir() / "tools")
|
||||
|
||||
@add_usermsg_cmd.handle()
|
||||
async def add_usermsg(target: MsgTarget, arg: Message = CommandArg()):
|
||||
@ -77,6 +88,7 @@ async def add_assistantmsg(target: MsgTarget, arg: Message = CommandArg()):
|
||||
|
||||
@praises_cmd.handle()
|
||||
async def praises():
|
||||
#await UniMessage(await tools.call("marshoai-weather.get_weather", {"location":"杭州"})).send()
|
||||
await praises_cmd.finish(build_praises())
|
||||
|
||||
|
||||
@ -200,24 +212,45 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
|
||||
client=client,
|
||||
model_name=model_name,
|
||||
msg=context_msg + [UserMessage(content=usermsg)],
|
||||
tools=tools.get_tools_list()
|
||||
)
|
||||
# await UniMessage(str(response)).send()
|
||||
choice = response.choices[0]
|
||||
if (
|
||||
choice["finish_reason"] == CompletionsFinishReason.STOPPED
|
||||
): # 当对话成功时,将dict的上下文添加到上下文类中
|
||||
if (choice["finish_reason"] == CompletionsFinishReason.STOPPED): # 当对话成功时,将dict的上下文添加到上下文类中
|
||||
context.append(
|
||||
UserMessage(content=usermsg).as_dict(), target.id, target.private
|
||||
)
|
||||
context.append(choice.message.as_dict(), target.id, target.private)
|
||||
if [target.id, target.private] not in target_list:
|
||||
target_list.append([target.id, target.private])
|
||||
await UniMessage(str(choice.message.content)).send(reply_to=True)
|
||||
elif choice["finish_reason"] == CompletionsFinishReason.CONTENT_FILTERED:
|
||||
await UniMessage("*已被内容过滤器过滤。请调整聊天内容后重试。").send(
|
||||
reply_to=True
|
||||
)
|
||||
await UniMessage("*已被内容过滤器过滤。请调整聊天内容后重试。").send(reply_to=True)
|
||||
return
|
||||
await UniMessage(str(choice.message.content)).send(reply_to=True)
|
||||
elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS:
|
||||
tool_msg = []
|
||||
while choice.message.tool_calls != None:
|
||||
tool_msg.append(AssistantMessage(tool_calls=response.choices[0].message.tool_calls))
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if isinstance(tool_call, ChatCompletionsToolCall):
|
||||
function_args = json.loads(tool_call.function.arguments.replace("'", '"'))
|
||||
logger.info(f"调用函数 {tool_call.function.name} ,参数为 {function_args}")
|
||||
await UniMessage(f"调用函数 {tool_call.function.name} ,参数为 {function_args}").send()
|
||||
func_return = await tools.call(tool_call.function.name, function_args)
|
||||
tool_msg.append(ToolMessage(tool_call_id=tool_call.id, content=func_return))
|
||||
response = await make_chat(
|
||||
client=client,
|
||||
model_name=model_name,
|
||||
msg = context_msg + [UserMessage(content=usermsg)] + tool_msg,
|
||||
tools=tools.get_tools_list()
|
||||
)
|
||||
choice = response.choices[0]
|
||||
context.append(
|
||||
UserMessage(content=usermsg).as_dict(), target.id, target.private
|
||||
)
|
||||
#context.append(tool_msg, target.id, target.private)
|
||||
context.append(choice.message.as_dict(), target.id, target.private)
|
||||
await UniMessage(str(choice.message.content)).send(reply_to=True)
|
||||
except Exception as e:
|
||||
await UniMessage(str(e) + suggest_solution(str(e))).send()
|
||||
traceback.print_exc()
|
||||
|
Reference in New Issue
Block a user