mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-09-24 15:36:23 +00:00
✨ MCP 客户端功能
This commit is contained in:
@ -33,6 +33,7 @@ from nonebot import get_driver, logger # type: ignore
|
||||
|
||||
from .config import config
|
||||
from .dev import * # noqa: F403
|
||||
from .extensions.mcp_extension.client import get_mcp_list, initialize_servers
|
||||
from .marsho import * # noqa: F403
|
||||
from .metadata import metadata
|
||||
|
||||
@ -47,6 +48,9 @@ driver = get_driver()
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
if config.marshoai_enable_mcp:
|
||||
await initialize_servers()
|
||||
print(await get_mcp_list())
|
||||
logger.info("MarshoAI 已经加载~🐾")
|
||||
logger.info(f"Marsho 的插件数据存储于 : {str(store.get_plugin_data_dir())} 哦~🐾")
|
||||
if config.marshoai_token == "":
|
||||
|
@ -71,11 +71,13 @@ class ConfigModel(BaseModel):
|
||||
"""开发者模式,启用本地插件插件重载"""
|
||||
marshoai_plugins: list[str] = []
|
||||
"""marsho插件的名称列表,从pip安装的使用包名,从本地导入的使用路径"""
|
||||
marshoai_enable_mcp: bool = False
|
||||
|
||||
|
||||
yaml = YAML()
|
||||
|
||||
config_file_path = Path("config/marshoai/config.yaml").resolve()
|
||||
marsho_config_file_path = Path("config/marshoai/config.yaml").resolve()
|
||||
mcp_config_file_path = Path("config/marshoai/mcp.json").resolve()
|
||||
|
||||
destination_folder = Path("config/marshoai/")
|
||||
destination_file = destination_folder / "config.yaml"
|
||||
@ -98,7 +100,7 @@ def check_yaml_is_changed():
|
||||
"""
|
||||
检查配置文件是否需要更新
|
||||
"""
|
||||
with open(config_file_path, "r", encoding="utf-8") as f:
|
||||
with open(marsho_config_file_path, "r", encoding="utf-8") as f:
|
||||
old = yaml.load(f)
|
||||
with StringIO(dump_config_to_yaml(ConfigModel())) as f2:
|
||||
example_ = yaml.load(f2)
|
||||
@ -125,9 +127,9 @@ def merge_configs(existing_cfg, new_cfg):
|
||||
|
||||
config: ConfigModel = get_plugin_config(ConfigModel)
|
||||
if config.marshoai_use_yaml_config:
|
||||
if not config_file_path.exists():
|
||||
if not marsho_config_file_path.exists():
|
||||
logger.info("配置文件不存在,正在创建")
|
||||
config_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
marsho_config_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
write_default_config(destination_file)
|
||||
else:
|
||||
logger.info("配置文件存在,正在读取")
|
||||
@ -136,7 +138,7 @@ if config.marshoai_use_yaml_config:
|
||||
yaml_2 = YAML()
|
||||
logger.info("插件新的配置已更新, 正在更新")
|
||||
|
||||
with open(config_file_path, "r", encoding="utf-8") as f:
|
||||
with open(marsho_config_file_path, "r", encoding="utf-8") as f:
|
||||
old_config = yaml_2.load(f)
|
||||
|
||||
with StringIO(dump_config_to_yaml(ConfigModel())) as f2:
|
||||
@ -147,7 +149,7 @@ if config.marshoai_use_yaml_config:
|
||||
with open(destination_file, "w", encoding="utf-8") as f:
|
||||
yaml_2.dump(merged_config, f)
|
||||
|
||||
with open(config_file_path, "r", encoding="utf-8") as f:
|
||||
with open(marsho_config_file_path, "r", encoding="utf-8") as f:
|
||||
yaml_config = yaml_.load(f, Loader=yaml_.FullLoader)
|
||||
|
||||
config = ConfigModel(**yaml_config)
|
||||
@ -156,3 +158,10 @@ else:
|
||||
# "MarshoAI 支持新的 YAML 配置系统,若要使用,请将 MARSHOAI_USE_YAML_CONFIG 配置项设置为 true。"
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
if config.marshoai_enable_mcp:
|
||||
if not mcp_config_file_path.exists():
|
||||
mcp_config_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(mcp_config_file_path, "w", encoding="utf-8") as f:
|
||||
f.write("{}")
|
||||
|
@ -41,7 +41,8 @@ SUPPORT_IMAGE_MODELS: list = [
|
||||
"mistral-ai/mistral-small-2503",
|
||||
]
|
||||
OPENAI_NEW_MODELS: list = [
|
||||
"openai/o4" "openai/o4-mini",
|
||||
"openai/o4",
|
||||
"openai/o4-mini",
|
||||
"openai/o3",
|
||||
"openai/o3-mini",
|
||||
"openai/o1",
|
||||
|
31
nonebot_plugin_marshoai/extensions/mcp_extension/__init__.py
Normal file
31
nonebot_plugin_marshoai/extensions/mcp_extension/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
Modified by Asankilp from: https://github.com/Moemu/MuiceBot with ❤
|
||||
|
||||
Modified from: https://github.com/modelcontextprotocol/python-sdk/tree/main/examples/clients/simple-chatbot
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Anthropic, PBC
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
from .client import cleanup_servers, get_mcp_list, handle_mcp_tool, initialize_servers
|
||||
|
||||
__all__ = ["handle_mcp_tool", "cleanup_servers", "initialize_servers", "get_mcp_list"]
|
120
nonebot_plugin_marshoai/extensions/mcp_extension/client.py
Normal file
120
nonebot_plugin_marshoai/extensions/mcp_extension/client.py
Normal file
@ -0,0 +1,120 @@
|
||||
import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
from nonebot import logger
|
||||
|
||||
from .config import get_mcp_server_config
|
||||
from .server import Server, Tool
|
||||
|
||||
_servers: list[Server] = list()
|
||||
|
||||
|
||||
async def initialize_servers() -> None:
|
||||
"""
|
||||
初始化全部 MCP 实例
|
||||
"""
|
||||
server_config = get_mcp_server_config()
|
||||
_servers.extend(
|
||||
[Server(name, srv_config) for name, srv_config in server_config.items()]
|
||||
)
|
||||
for server in _servers:
|
||||
logger.info(f"正在初始化 MCP 服务器: {server.name}...")
|
||||
try:
|
||||
await server.initialize()
|
||||
except Exception as e:
|
||||
logger.error(f"初始化 MCP 服务器实例时出现问题: {e}")
|
||||
await cleanup_servers()
|
||||
raise
|
||||
|
||||
|
||||
async def handle_mcp_tool(
|
||||
tool: str, arguments: Optional[dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
处理 MCP Tool 调用
|
||||
"""
|
||||
logger.info(f"执行 MCP 工具: {tool} (参数: {arguments})")
|
||||
|
||||
for server in _servers:
|
||||
server_tools = await server.list_tools()
|
||||
if not any(server_tool.name == tool for server_tool in server_tools):
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await server.execute_tool(tool, arguments)
|
||||
|
||||
if isinstance(result, dict) and "progress" in result:
|
||||
progress = result["progress"]
|
||||
total = result["total"]
|
||||
percentage = (progress / total) * 100
|
||||
logger.info(
|
||||
f"工具 {tool} 执行进度: {progress}/{total} ({percentage:.1f}%)"
|
||||
)
|
||||
|
||||
return f"Tool execution result: {result}"
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing tool: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
return None # Not found.
|
||||
|
||||
|
||||
async def cleanup_servers() -> None:
|
||||
"""
|
||||
清理 MCP 实例
|
||||
"""
|
||||
cleanup_tasks = [asyncio.create_task(server.cleanup()) for server in _servers]
|
||||
if cleanup_tasks:
|
||||
try:
|
||||
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理 MCP 实例时出现错误: {e}")
|
||||
|
||||
|
||||
async def transform_json(tool: Tool) -> dict[str, Any]:
|
||||
"""
|
||||
将 MCP Tool 转换为 OpenAI 所需的 parameters 格式,并删除多余字段
|
||||
"""
|
||||
func_desc = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
if tool.input_schema:
|
||||
parameters = {
|
||||
"type": tool.input_schema.get("type", "object"),
|
||||
"properties": tool.input_schema.get("properties", {}),
|
||||
"required": tool.input_schema.get("required", []),
|
||||
}
|
||||
func_desc["parameters"] = parameters
|
||||
|
||||
output = {"type": "function", "function": func_desc}
|
||||
|
||||
return output
|
||||
|
||||
|
||||
async def get_mcp_list() -> list[dict[str, dict]]:
|
||||
"""
|
||||
获得适用于 OpenAI Tool Call 输入格式的 MCP 工具列表
|
||||
"""
|
||||
all_tools: list[dict[str, dict]] = []
|
||||
|
||||
for server in _servers:
|
||||
tools = await server.list_tools()
|
||||
all_tools.extend([await transform_json(tool) for tool in tools])
|
||||
|
||||
return all_tools
|
||||
|
||||
|
||||
async def is_mcp_tool(tool_name: str) -> bool:
|
||||
"""
|
||||
检查工具是否为 MCP 工具
|
||||
"""
|
||||
mcp_list = await get_mcp_list()
|
||||
for tool in mcp_list:
|
||||
if tool["function"]["name"] == tool_name:
|
||||
return True
|
||||
return False
|
74
nonebot_plugin_marshoai/extensions/mcp_extension/config.py
Normal file
74
nonebot_plugin_marshoai/extensions/mcp_extension/config.py
Normal file
@ -0,0 +1,74 @@
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from nonebot import logger
|
||||
from pydantic import BaseModel, Field, ValidationError, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
mcp_config_file_path = Path("config/marshoai/mcp.json").resolve()
|
||||
|
||||
|
||||
class mcpConfig(BaseModel):
|
||||
command: str = Field(default="")
|
||||
"""执行指令"""
|
||||
args: list[str] = Field(default_factory=list)
|
||||
"""命令参数"""
|
||||
env: dict[str, Any] = Field(default_factory=dict)
|
||||
"""环境配置"""
|
||||
headers: dict[str, Any] = Field(default_factory=dict)
|
||||
"""HTTP请求头(用于 `sse` 和 `streamable_http` 传输方式)"""
|
||||
type: Literal["stdio", "sse", "streamable_http"] = Field(default="stdio")
|
||||
"""传输方式: `stdio`, `sse`, `streamable_http`"""
|
||||
url: str = Field(default="")
|
||||
"""服务器 URL (用于 `sse` 和 `streamable_http` 传输方式)"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_config(self) -> Self:
|
||||
srv_type = self.type
|
||||
command = self.command
|
||||
url = self.url
|
||||
|
||||
if srv_type == "stdio":
|
||||
if not command:
|
||||
raise ValueError("当 type 为 'stdio' 时,command 字段必须存在")
|
||||
# 检查 command 是否为可执行的命令
|
||||
elif not shutil.which(command):
|
||||
raise ValueError(f"命令 '{command}' 不存在或不可执行。")
|
||||
|
||||
elif srv_type in ["sse", "streamable_http"] and not url:
|
||||
raise ValueError(f"当 type 为 '{srv_type}' 时,url 字段必须存在")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def get_mcp_server_config() -> dict[str, mcpConfig]:
|
||||
"""
|
||||
从 MCP 配置文件 `config/mcp.json` 中获取 MCP Server 配置
|
||||
"""
|
||||
if not mcp_config_file_path.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(mcp_config_file_path, "r", encoding="utf-8") as f:
|
||||
configs = json.load(f) or {}
|
||||
except (json.JSONDecodeError, IOError, OSError) as e:
|
||||
raise RuntimeError(f"读取 MCP 配置文件时发生错误: {e}")
|
||||
|
||||
if not isinstance(configs, dict):
|
||||
raise TypeError("非预期的 MCP 配置文件格式")
|
||||
|
||||
mcp_servers = configs.get("mcpServers", {})
|
||||
if not isinstance(mcp_servers, dict):
|
||||
raise TypeError("非预期的 MCP 配置文件格式")
|
||||
|
||||
mcp_config: dict[str, mcpConfig] = {}
|
||||
for name, srv_config in mcp_servers.items():
|
||||
try:
|
||||
mcp_config[name] = mcpConfig(**srv_config)
|
||||
except (ValidationError, TypeError) as e:
|
||||
logger.warning(f"无效的MCP服务器配置 '{name}': {e}")
|
||||
continue
|
||||
|
||||
return mcp_config
|
190
nonebot_plugin_marshoai/extensions/mcp_extension/server.py
Normal file
190
nonebot_plugin_marshoai/extensions/mcp_extension/server.py
Normal file
@ -0,0 +1,190 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any, Optional
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
from .config import mcpConfig
|
||||
|
||||
|
||||
class Tool:
|
||||
"""
|
||||
MCP Tool
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name: str, description: str, input_schema: dict[str, Any]
|
||||
) -> None:
|
||||
self.name: str = name
|
||||
self.description: str = description
|
||||
self.input_schema: dict[str, Any] = input_schema
|
||||
|
||||
def format_for_llm(self) -> str:
|
||||
"""
|
||||
为 llm 生成工具描述
|
||||
|
||||
:return: 工具描述
|
||||
"""
|
||||
args_desc = []
|
||||
if "properties" in self.input_schema:
|
||||
for param_name, param_info in self.input_schema["properties"].items():
|
||||
arg_desc = (
|
||||
f"- {param_name}: {param_info.get('description', 'No description')}"
|
||||
)
|
||||
if param_name in self.input_schema.get("required", []):
|
||||
arg_desc += " (required)"
|
||||
args_desc.append(arg_desc)
|
||||
|
||||
return (
|
||||
f"Tool: {self.name}\n"
|
||||
f"Description: {self.description}\n"
|
||||
f"Arguments:{chr(10).join(args_desc)}"
|
||||
""
|
||||
)
|
||||
|
||||
|
||||
class Server:
|
||||
"""
|
||||
管理 MCP 服务器连接和工具执行的 Server 实例
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, config: mcpConfig) -> None:
|
||||
self.name: str = name
|
||||
self.config: mcpConfig = config
|
||||
self.session: ClientSession | None = None
|
||||
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
||||
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
||||
self._transport_initializers = {
|
||||
"stdio": self._initialize_stdio,
|
||||
"sse": self._initialize_sse,
|
||||
"streamable_http": self._initialize_streamable_http,
|
||||
}
|
||||
|
||||
async def _initialize_stdio(self) -> tuple[Any, Any]:
|
||||
"""
|
||||
初始化 stdio 传输方式
|
||||
|
||||
:return: (read, write) 元组
|
||||
"""
|
||||
server_params = StdioServerParameters(
|
||||
command=self.config.command,
|
||||
args=self.config.args,
|
||||
env={**os.environ, **self.config.env} if self.config.env else None,
|
||||
)
|
||||
transport_context = await self.exit_stack.enter_async_context(
|
||||
stdio_client(server_params)
|
||||
)
|
||||
return transport_context
|
||||
|
||||
async def _initialize_sse(self) -> tuple[Any, Any]:
|
||||
"""
|
||||
初始化 sse 传输方式
|
||||
|
||||
:return: (read, write) 元组
|
||||
"""
|
||||
transport_context = await self.exit_stack.enter_async_context(
|
||||
sse_client(self.config.url, headers=self.config.headers)
|
||||
)
|
||||
return transport_context
|
||||
|
||||
async def _initialize_streamable_http(self) -> tuple[Any, Any]:
|
||||
"""
|
||||
初始化 streamable_http 传输方式
|
||||
|
||||
:return: (read, write) 元组
|
||||
"""
|
||||
read, write, *_ = await self.exit_stack.enter_async_context(
|
||||
streamablehttp_client(self.config.url, headers=self.config.headers)
|
||||
)
|
||||
return read, write
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
初始化实例
|
||||
"""
|
||||
transport = self.config.type
|
||||
initializer = self._transport_initializers[transport]
|
||||
read, write = await initializer()
|
||||
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
|
||||
async def list_tools(self) -> list[Tool]:
|
||||
"""
|
||||
从 MCP 服务器获得可用工具列表
|
||||
|
||||
:return: 工具列表
|
||||
|
||||
:raises RuntimeError: 如果服务器未启动
|
||||
"""
|
||||
if not self.session:
|
||||
raise RuntimeError(f"Server {self.name} not initialized")
|
||||
|
||||
tools_response = await self.session.list_tools()
|
||||
tools: list[Tool] = []
|
||||
|
||||
for item in tools_response:
|
||||
if isinstance(item, tuple) and item[0] == "tools":
|
||||
tools.extend(
|
||||
Tool(tool.name, tool.description, tool.inputSchema)
|
||||
for tool in item[1]
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
async def execute_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Optional[dict[str, Any]] = None,
|
||||
retries: int = 2,
|
||||
delay: float = 1.0,
|
||||
) -> Any:
|
||||
"""
|
||||
执行一个 MCP 工具
|
||||
|
||||
:param tool_name: 工具名称
|
||||
:param arguments: 工具参数
|
||||
:param retries: 重试次数
|
||||
:param delay: 重试间隔
|
||||
|
||||
:return: 工具执行结果
|
||||
|
||||
:raises RuntimeError: 如果服务器未初始化
|
||||
:raises Exception: 工具在所有重试中均失败
|
||||
"""
|
||||
if not self.session:
|
||||
raise RuntimeError(f"Server {self.name} not initialized")
|
||||
|
||||
attempt = 0
|
||||
while attempt < retries:
|
||||
try:
|
||||
logging.info(f"Executing {tool_name}...")
|
||||
result = await self.session.call_tool(tool_name, arguments)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
logging.warning(
|
||||
f"Error executing tool: {e}. Attempt {attempt} of {retries}."
|
||||
)
|
||||
if attempt < retries:
|
||||
logging.info(f"Retrying in {delay} seconds...")
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logging.error("Max retries reached. Failing.")
|
||||
raise
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up server resources."""
|
||||
async with self._cleanup_lock:
|
||||
try:
|
||||
await self.exit_stack.aclose()
|
||||
self.session = None
|
||||
except Exception as e:
|
||||
logging.error(f"Error during cleanup of server {self.name}: {e}")
|
@ -31,6 +31,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio
|
||||
|
||||
from .config import config
|
||||
from .constants import SUPPORT_IMAGE_MODELS
|
||||
from .extensions.mcp_extension.client import handle_mcp_tool, is_mcp_tool
|
||||
from .instances import target_list
|
||||
from .models import MarshoContext
|
||||
from .plugin.func_call.caller import get_function_calls
|
||||
@ -148,43 +149,56 @@ class MarshoHandler:
|
||||
# pass
|
||||
tool_msg.append(choice.message)
|
||||
for tool_call in tool_calls: # type: ignore
|
||||
tool_name = tool_call.function.name
|
||||
tool_clean_name = tool_name.replace("-", ".")
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
function_args = json.loads(
|
||||
tool_call.function.arguments.replace("'", '"')
|
||||
)
|
||||
if await is_mcp_tool(tool_name):
|
||||
tool_clean_name = tool_name # MCP 工具不需要替换
|
||||
# 删除args的placeholder参数
|
||||
if "placeholder" in function_args:
|
||||
del function_args["placeholder"]
|
||||
logger.info(
|
||||
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
|
||||
f"调用工具 {tool_clean_name},参数:"
|
||||
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
|
||||
)
|
||||
await UniMessage(
|
||||
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
|
||||
f"调用工具 {tool_clean_name}\n参数:"
|
||||
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
|
||||
).send()
|
||||
if caller := get_function_calls().get(tool_call.function.name):
|
||||
logger.debug(f"调用插件函数 {caller.full_name}")
|
||||
# 权限检查,规则检查 TODO
|
||||
# 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入
|
||||
func_return = await caller.with_ctx(
|
||||
SessionContext(
|
||||
bot=self.bot,
|
||||
event=self.event,
|
||||
matcher=self.matcher,
|
||||
state=None,
|
||||
if not await is_mcp_tool(tool_name):
|
||||
if caller := get_function_calls().get(tool_call.function.name):
|
||||
logger.debug(f"调用插件函数 {caller.full_name}")
|
||||
# 权限检查,规则检查 TODO
|
||||
# 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入
|
||||
func_return = await caller.with_ctx(
|
||||
SessionContext(
|
||||
bot=self.bot,
|
||||
event=self.event,
|
||||
matcher=self.matcher,
|
||||
state=None,
|
||||
)
|
||||
).call(**function_args)
|
||||
else:
|
||||
logger.error(
|
||||
f"未找到函数 {tool_call.function.name.replace('-', '.')}"
|
||||
)
|
||||
).call(**function_args)
|
||||
func_return = (
|
||||
f"未找到函数 {tool_call.function.name.replace('-', '.')}"
|
||||
)
|
||||
tool_msg.append(
|
||||
ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore
|
||||
)
|
||||
else:
|
||||
logger.error(f"未找到函数 {tool_call.function.name.replace('-', '.')}")
|
||||
func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}"
|
||||
tool_msg.append(
|
||||
ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore
|
||||
)
|
||||
# tool_msg[0]["tool_calls"][0]["type"] = "builtin_function"
|
||||
# await UniMessage(str(tool_msg)).send()
|
||||
func_return = await handle_mcp_tool(tool_name, function_args)
|
||||
tool_msg.append(
|
||||
ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore
|
||||
)
|
||||
|
||||
return await self.handle_common_chat(
|
||||
user_message=user_message,
|
||||
model_name=model_name,
|
||||
|
@ -27,6 +27,7 @@ from nonebot_plugin_argot.extension import ArgotExtension # type: ignore
|
||||
|
||||
from .config import config
|
||||
from .constants import INTRODUCTION, SUPPORT_IMAGE_MODELS
|
||||
from .extensions.mcp_extension.client import get_mcp_list
|
||||
from .handler import MarshoHandler
|
||||
from .hooks import * # noqa: F403
|
||||
from .instances import client, context, model_name, target_list, tools
|
||||
@ -263,8 +264,10 @@ async def marsho(
|
||||
|
||||
usermsg = await handler.process_user_input(text, model_name)
|
||||
|
||||
tools_lists = tools.tools_list + list(
|
||||
map(lambda v: v.data(), get_function_calls().values())
|
||||
tools_lists = (
|
||||
tools.tools_list
|
||||
+ list(map(lambda v: v.data(), get_function_calls().values()))
|
||||
+ await get_mcp_list()
|
||||
)
|
||||
logger.info(f"正在获取回答,模型:{model_name}")
|
||||
await message_reaction(Emoji("66"))
|
||||
|
@ -6,7 +6,6 @@ from nonebot.adapters import Bot, Event
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.permission import Permission
|
||||
from nonebot.rule import Rule
|
||||
from nonebot.typing import T_State
|
||||
|
||||
from ..models import Plugin
|
||||
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
|
||||
@ -73,7 +72,7 @@ class Caller:
|
||||
# if self.ctx.state is None:
|
||||
# return False, "State is None"
|
||||
if self._rule and not await self._rule(
|
||||
self.ctx.bot, self.ctx.event, self.ctx.state
|
||||
self.ctx.bot, self.ctx.event, self.ctx.state or {}
|
||||
):
|
||||
return False, "告诉用户 Rule Denied 规则不匹配"
|
||||
|
||||
|
Reference in New Issue
Block a user