Coverage for agentos/protocols/mcp.py: 47%

81 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 16:36 +0800

1""" 

2AgentOS v0.20 MCP (Model Context Protocol) 客户端。 

3支持 stdio / SSE / WebSocket 三种传输方式。 

4""" 

5 

6from __future__ import annotations 

7 

8import asyncio 

9import json 

10import subprocess 

11from abc import ABC, abstractmethod 

12from dataclasses import dataclass, field 

13from typing import Any 

14 

15 

16@dataclass 

17class MCPServerConfig: 

18 """MCP 服务端配置。""" 

19 name: str 

20 transport: str = "stdio" # stdio | sse | ws 

21 command: str | None = None 

22 args: list[str] = field(default_factory=list) 

23 url: str | None = None 

24 env: dict[str, str] = field(default_factory=dict) 

25 

26 

27@dataclass 

28class MCPToolSchema: 

29 """MCP 工具 Schema。""" 

30 name: str 

31 description: str 

32 input_schema: dict 

33 

34 

35# ── 传输层 ────────────────────────────────────── 

36 

37class MCPTransport(ABC): 

38 

39 """MCP 传输协议。""" 

40 

41 @abstractmethod 

42 async def connect(self, config: MCPServerConfig): ... 

43 

44 @abstractmethod 

45 async def send(self, method: str, params: dict | None = None) -> dict: ... 

46 

47 @abstractmethod 

48 async def close(self): ... 

49 

50 

51class StdioTransport(MCPTransport): 

52 """通过 subprocess 与 MCP Server 通信。""" 

53 

54 def __init__(self): 

55 self._proc: subprocess.Popen | None = None 

56 self._lock = asyncio.Lock() 

57 

58 async def connect(self, config: MCPServerConfig): 

59 self._proc = subprocess.Popen( 

60 [config.command or "npx"] + config.args, 

61 stdin=subprocess.PIPE, 

62 stdout=subprocess.PIPE, 

63 stderr=subprocess.PIPE, 

64 env={**__import__("os").environ, **config.env}, 

65 ) 

66 

67 async def send(self, method: str, params: dict | None = None) -> dict: 

68 async with self._lock: 

69 msg = json.dumps({"jsonrpc": "2.0", "method": method, "params": params or {}, "id": 1}) 

70 self._proc.stdin.write((msg + "\n").encode()) 

71 self._proc.stdin.flush() 

72 line = self._proc.stdout.readline() 

73 return json.loads(line) 

74 

75 async def close(self): 

76 if self._proc: 

77 self._proc.terminate() 

78 

79 

80class SSETransport(MCPTransport): 

81 """通过 HTTP SSE 与远程 MCP Server 通信。""" 

82 

83 async def connect(self, config: MCPServerConfig): 

84 import httpx 

85 self._client = httpx.AsyncClient(base_url=config.url, timeout=30) 

86 

87 async def send(self, method: str, params: dict | None = None) -> dict: 

88 resp = await self._client.post("/message", json={"jsonrpc": "2.0", "method": method, "params": params or {}, "id": 1}) 

89 return resp.json() 

90 

91 async def close(self): 

92 if hasattr(self, "_client"): 

93 await self._client.aclose() 

94 

95 

96# ── MCP 客户端 ────────────────────────────────── 

97 

98class MCPClient: 

99 """MCP 协议客户端,管理多个 MCP Server 连接。""" 

100 

101 TRANSPORTS = {"stdio": StdioTransport, "sse": SSETransport} 

102 

103 def __init__(self): 

104 self._servers: dict[str, MCPTransport] = {} 

105 self._tools: dict[str, MCPToolSchema] = {} 

106 

107 async def connect_server(self, config: MCPServerConfig): 

108 transport_cls = self.TRANSPORTS.get(config.transport, StdioTransport) 

109 transport = transport_cls() 

110 await transport.connect(config) 

111 self._servers[config.name] = transport 

112 # 拉取工具列表 

113 result = await self._list_tools(config.name) 

114 for tool in result.get("tools", []): 

115 schema = MCPToolSchema( 

116 name=f"mcp_{config.name}_{tool['name']}", 

117 description=tool.get("description", ""), 

118 input_schema=tool.get("inputSchema", {}), 

119 ) 

120 self._tools[schema.name] = schema 

121 

122 async def _list_tools(self, server_name: str) -> dict: 

123 transport = self._servers[server_name] 

124 return await transport.send("tools/list") 

125 

126 async def call_tool(self, full_name: str, arguments: dict) -> Any: 

127 server_name = full_name.replace("mcp_", "", 1) 

128 # 找到所属server 

129 for name in self._servers: 

130 if full_name.startswith(f"mcp_{name}_"): 

131 tool_name = full_name[len(f"mcp_{name}_") + 1:] 

132 transport = self._servers[name] 

133 result = await transport.send("tools/call", {"name": tool_name, "arguments": arguments}) 

134 return result.get("content", [{}])[0].get("text", "") 

135 raise ValueError(f"Unknown MCP tool: {full_name}") 

136 

137 def get_mcp_tool_schemas(self) -> list[dict]: 

138 """转为 OpenAI function 格式。""" 

139 schemas = [] 

140 for tool in self._tools.values(): 

141 schemas.append({ 

142 "type": "function", 

143 "function": { 

144 "name": tool.name, 

145 "description": tool.description, 

146 "parameters": tool.input_schema, 

147 }, 

148 }) 

149 return schemas 

150 

151 async def close_all(self): 

152 for transport in self._servers.values(): 

153 await transport.close()