MCP 客户端功能

This commit is contained in:
2025-09-05 20:37:15 +08:00
parent 7eb22743d8
commit b2914be3c1
18 changed files with 978 additions and 123 deletions

View File

@ -0,0 +1,31 @@
"""
Modified by Asankilp from: https://github.com/Moemu/MuiceBot with ❤
Modified from: https://github.com/modelcontextprotocol/python-sdk/tree/main/examples/clients/simple-chatbot
MIT License
Copyright (c) 2024 Anthropic, PBC
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from .client import cleanup_servers, get_mcp_list, handle_mcp_tool, initialize_servers
__all__ = ["handle_mcp_tool", "cleanup_servers", "initialize_servers", "get_mcp_list"]

View File

@ -0,0 +1,120 @@
import asyncio
from typing import Any, Optional
from nonebot import logger
from .config import get_mcp_server_config
from .server import Server, Tool
_servers: list[Server] = list()
async def initialize_servers() -> None:
"""
初始化全部 MCP 实例
"""
server_config = get_mcp_server_config()
_servers.extend(
[Server(name, srv_config) for name, srv_config in server_config.items()]
)
for server in _servers:
logger.info(f"正在初始化 MCP 服务器: {server.name}...")
try:
await server.initialize()
except Exception as e:
logger.error(f"初始化 MCP 服务器实例时出现问题: {e}")
await cleanup_servers()
raise
async def handle_mcp_tool(
tool: str, arguments: Optional[dict[str, Any]] = None
) -> Optional[str]:
"""
处理 MCP Tool 调用
"""
logger.info(f"执行 MCP 工具: {tool} (参数: {arguments})")
for server in _servers:
server_tools = await server.list_tools()
if not any(server_tool.name == tool for server_tool in server_tools):
continue
try:
result = await server.execute_tool(tool, arguments)
if isinstance(result, dict) and "progress" in result:
progress = result["progress"]
total = result["total"]
percentage = (progress / total) * 100
logger.info(
f"工具 {tool} 执行进度: {progress}/{total} ({percentage:.1f}%)"
)
return f"Tool execution result: {result}"
except Exception as e:
error_msg = f"Error executing tool: {str(e)}"
logger.error(error_msg)
return error_msg
return None # Not found.
async def cleanup_servers() -> None:
"""
清理 MCP 实例
"""
cleanup_tasks = [asyncio.create_task(server.cleanup()) for server in _servers]
if cleanup_tasks:
try:
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
except Exception as e:
logger.warning(f"清理 MCP 实例时出现错误: {e}")
async def transform_json(tool: Tool) -> dict[str, Any]:
"""
将 MCP Tool 转换为 OpenAI 所需的 parameters 格式,并删除多余字段
"""
func_desc = {
"name": tool.name,
"description": tool.description,
"parameters": {},
"required": [],
}
if tool.input_schema:
parameters = {
"type": tool.input_schema.get("type", "object"),
"properties": tool.input_schema.get("properties", {}),
"required": tool.input_schema.get("required", []),
}
func_desc["parameters"] = parameters
output = {"type": "function", "function": func_desc}
return output
async def get_mcp_list() -> list[dict[str, dict]]:
"""
获得适用于 OpenAI Tool Call 输入格式的 MCP 工具列表
"""
all_tools: list[dict[str, dict]] = []
for server in _servers:
tools = await server.list_tools()
all_tools.extend([await transform_json(tool) for tool in tools])
return all_tools
async def is_mcp_tool(tool_name: str) -> bool:
"""
检查工具是否为 MCP 工具
"""
mcp_list = await get_mcp_list()
for tool in mcp_list:
if tool["function"]["name"] == tool_name:
return True
return False

View File

@ -0,0 +1,74 @@
import json
import shutil
from pathlib import Path
from typing import Any, Literal
from nonebot import logger
from pydantic import BaseModel, Field, ValidationError, model_validator
from typing_extensions import Self
mcp_config_file_path = Path("config/marshoai/mcp.json").resolve()
class mcpConfig(BaseModel):
command: str = Field(default="")
"""执行指令"""
args: list[str] = Field(default_factory=list)
"""命令参数"""
env: dict[str, Any] = Field(default_factory=dict)
"""环境配置"""
headers: dict[str, Any] = Field(default_factory=dict)
"""HTTP请求头(用于 `sse` 和 `streamable_http` 传输方式)"""
type: Literal["stdio", "sse", "streamable_http"] = Field(default="stdio")
"""传输方式: `stdio`, `sse`, `streamable_http`"""
url: str = Field(default="")
"""服务器 URL (用于 `sse` 和 `streamable_http` 传输方式)"""
@model_validator(mode="after")
def validate_config(self) -> Self:
srv_type = self.type
command = self.command
url = self.url
if srv_type == "stdio":
if not command:
raise ValueError("当 type 为 'stdio'command 字段必须存在")
# 检查 command 是否为可执行的命令
elif not shutil.which(command):
raise ValueError(f"命令 '{command}' 不存在或不可执行。")
elif srv_type in ["sse", "streamable_http"] and not url:
raise ValueError(f"当 type 为 '{srv_type}'url 字段必须存在")
return self
def get_mcp_server_config() -> dict[str, mcpConfig]:
"""
从 MCP 配置文件 `config/mcp.json` 中获取 MCP Server 配置
"""
if not mcp_config_file_path.exists():
return {}
try:
with open(mcp_config_file_path, "r", encoding="utf-8") as f:
configs = json.load(f) or {}
except (json.JSONDecodeError, IOError, OSError) as e:
raise RuntimeError(f"读取 MCP 配置文件时发生错误: {e}")
if not isinstance(configs, dict):
raise TypeError("非预期的 MCP 配置文件格式")
mcp_servers = configs.get("mcpServers", {})
if not isinstance(mcp_servers, dict):
raise TypeError("非预期的 MCP 配置文件格式")
mcp_config: dict[str, mcpConfig] = {}
for name, srv_config in mcp_servers.items():
try:
mcp_config[name] = mcpConfig(**srv_config)
except (ValidationError, TypeError) as e:
logger.warning(f"无效的MCP服务器配置 '{name}': {e}")
continue
return mcp_config

View File

@ -0,0 +1,190 @@
import asyncio
import logging
import os
from contextlib import AsyncExitStack
from typing import Any, Optional
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from .config import mcpConfig
class Tool:
"""
MCP Tool
"""
def __init__(
self, name: str, description: str, input_schema: dict[str, Any]
) -> None:
self.name: str = name
self.description: str = description
self.input_schema: dict[str, Any] = input_schema
def format_for_llm(self) -> str:
"""
为 llm 生成工具描述
:return: 工具描述
"""
args_desc = []
if "properties" in self.input_schema:
for param_name, param_info in self.input_schema["properties"].items():
arg_desc = (
f"- {param_name}: {param_info.get('description', 'No description')}"
)
if param_name in self.input_schema.get("required", []):
arg_desc += " (required)"
args_desc.append(arg_desc)
return (
f"Tool: {self.name}\n"
f"Description: {self.description}\n"
f"Arguments:{chr(10).join(args_desc)}"
""
)
class Server:
"""
管理 MCP 服务器连接和工具执行的 Server 实例
"""
def __init__(self, name: str, config: mcpConfig) -> None:
self.name: str = name
self.config: mcpConfig = config
self.session: ClientSession | None = None
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.exit_stack: AsyncExitStack = AsyncExitStack()
self._transport_initializers = {
"stdio": self._initialize_stdio,
"sse": self._initialize_sse,
"streamable_http": self._initialize_streamable_http,
}
async def _initialize_stdio(self) -> tuple[Any, Any]:
"""
初始化 stdio 传输方式
:return: (read, write) 元组
"""
server_params = StdioServerParameters(
command=self.config.command,
args=self.config.args,
env={**os.environ, **self.config.env} if self.config.env else None,
)
transport_context = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
return transport_context
async def _initialize_sse(self) -> tuple[Any, Any]:
"""
初始化 sse 传输方式
:return: (read, write) 元组
"""
transport_context = await self.exit_stack.enter_async_context(
sse_client(self.config.url, headers=self.config.headers)
)
return transport_context
async def _initialize_streamable_http(self) -> tuple[Any, Any]:
"""
初始化 streamable_http 传输方式
:return: (read, write) 元组
"""
read, write, *_ = await self.exit_stack.enter_async_context(
streamablehttp_client(self.config.url, headers=self.config.headers)
)
return read, write
async def initialize(self) -> None:
"""
初始化实例
"""
transport = self.config.type
initializer = self._transport_initializers[transport]
read, write = await initializer()
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
await session.initialize()
self.session = session
async def list_tools(self) -> list[Tool]:
"""
从 MCP 服务器获得可用工具列表
:return: 工具列表
:raises RuntimeError: 如果服务器未启动
"""
if not self.session:
raise RuntimeError(f"Server {self.name} not initialized")
tools_response = await self.session.list_tools()
tools: list[Tool] = []
for item in tools_response:
if isinstance(item, tuple) and item[0] == "tools":
tools.extend(
Tool(tool.name, tool.description, tool.inputSchema)
for tool in item[1]
)
return tools
async def execute_tool(
self,
tool_name: str,
arguments: Optional[dict[str, Any]] = None,
retries: int = 2,
delay: float = 1.0,
) -> Any:
"""
执行一个 MCP 工具
:param tool_name: 工具名称
:param arguments: 工具参数
:param retries: 重试次数
:param delay: 重试间隔
:return: 工具执行结果
:raises RuntimeError: 如果服务器未初始化
:raises Exception: 工具在所有重试中均失败
"""
if not self.session:
raise RuntimeError(f"Server {self.name} not initialized")
attempt = 0
while attempt < retries:
try:
logging.info(f"Executing {tool_name}...")
result = await self.session.call_tool(tool_name, arguments)
return result
except Exception as e:
attempt += 1
logging.warning(
f"Error executing tool: {e}. Attempt {attempt} of {retries}."
)
if attempt < retries:
logging.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay)
else:
logging.error("Max retries reached. Failing.")
raise
async def cleanup(self) -> None:
"""Clean up server resources."""
async with self._cleanup_lock:
try:
await self.exit_stack.aclose()
self.session = None
except Exception as e:
logging.error(f"Error during cleanup of server {self.name}: {e}")