Coverage for agentos/protocols/mcp.py: 47%
81 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
1"""
2AgentOS v0.20 MCP (Model Context Protocol) 客户端。
3支持 stdio / SSE / WebSocket 三种传输方式。
4"""
6from __future__ import annotations
8import asyncio
9import json
10import subprocess
11from abc import ABC, abstractmethod
12from dataclasses import dataclass, field
13from typing import Any
16@dataclass
17class MCPServerConfig:
18 """MCP 服务端配置。"""
19 name: str
20 transport: str = "stdio" # stdio | sse | ws
21 command: str | None = None
22 args: list[str] = field(default_factory=list)
23 url: str | None = None
24 env: dict[str, str] = field(default_factory=dict)
27@dataclass
28class MCPToolSchema:
29 """MCP 工具 Schema。"""
30 name: str
31 description: str
32 input_schema: dict
35# ── 传输层 ──────────────────────────────────────
37class MCPTransport(ABC):
39 """MCP 传输协议。"""
41 @abstractmethod
42 async def connect(self, config: MCPServerConfig): ...
44 @abstractmethod
45 async def send(self, method: str, params: dict | None = None) -> dict: ...
47 @abstractmethod
48 async def close(self): ...
51class StdioTransport(MCPTransport):
52 """通过 subprocess 与 MCP Server 通信。"""
54 def __init__(self):
55 self._proc: subprocess.Popen | None = None
56 self._lock = asyncio.Lock()
58 async def connect(self, config: MCPServerConfig):
59 self._proc = subprocess.Popen(
60 [config.command or "npx"] + config.args,
61 stdin=subprocess.PIPE,
62 stdout=subprocess.PIPE,
63 stderr=subprocess.PIPE,
64 env={**__import__("os").environ, **config.env},
65 )
67 async def send(self, method: str, params: dict | None = None) -> dict:
68 async with self._lock:
69 msg = json.dumps({"jsonrpc": "2.0", "method": method, "params": params or {}, "id": 1})
70 self._proc.stdin.write((msg + "\n").encode())
71 self._proc.stdin.flush()
72 line = self._proc.stdout.readline()
73 return json.loads(line)
75 async def close(self):
76 if self._proc:
77 self._proc.terminate()
80class SSETransport(MCPTransport):
81 """通过 HTTP SSE 与远程 MCP Server 通信。"""
83 async def connect(self, config: MCPServerConfig):
84 import httpx
85 self._client = httpx.AsyncClient(base_url=config.url, timeout=30)
87 async def send(self, method: str, params: dict | None = None) -> dict:
88 resp = await self._client.post("/message", json={"jsonrpc": "2.0", "method": method, "params": params or {}, "id": 1})
89 return resp.json()
91 async def close(self):
92 if hasattr(self, "_client"):
93 await self._client.aclose()
96# ── MCP 客户端 ──────────────────────────────────
98class MCPClient:
99 """MCP 协议客户端,管理多个 MCP Server 连接。"""
101 TRANSPORTS = {"stdio": StdioTransport, "sse": SSETransport}
103 def __init__(self):
104 self._servers: dict[str, MCPTransport] = {}
105 self._tools: dict[str, MCPToolSchema] = {}
107 async def connect_server(self, config: MCPServerConfig):
108 transport_cls = self.TRANSPORTS.get(config.transport, StdioTransport)
109 transport = transport_cls()
110 await transport.connect(config)
111 self._servers[config.name] = transport
112 # 拉取工具列表
113 result = await self._list_tools(config.name)
114 for tool in result.get("tools", []):
115 schema = MCPToolSchema(
116 name=f"mcp_{config.name}_{tool['name']}",
117 description=tool.get("description", ""),
118 input_schema=tool.get("inputSchema", {}),
119 )
120 self._tools[schema.name] = schema
122 async def _list_tools(self, server_name: str) -> dict:
123 transport = self._servers[server_name]
124 return await transport.send("tools/list")
126 async def call_tool(self, full_name: str, arguments: dict) -> Any:
127 server_name = full_name.replace("mcp_", "", 1)
128 # 找到所属server
129 for name in self._servers:
130 if full_name.startswith(f"mcp_{name}_"):
131 tool_name = full_name[len(f"mcp_{name}_") + 1:]
132 transport = self._servers[name]
133 result = await transport.send("tools/call", {"name": tool_name, "arguments": arguments})
134 return result.get("content", [{}])[0].get("text", "")
135 raise ValueError(f"Unknown MCP tool: {full_name}")
137 def get_mcp_tool_schemas(self) -> list[dict]:
138 """转为 OpenAI function 格式。"""
139 schemas = []
140 for tool in self._tools.values():
141 schemas.append({
142 "type": "function",
143 "function": {
144 "name": tool.name,
145 "description": tool.description,
146 "parameters": tool.input_schema,
147 },
148 })
149 return schemas
151 async def close_all(self):
152 for transport in self._servers.values():
153 await transport.close()