Coverage for agentos/mcp/__init__.py: 25%

312 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1"""MCP (Model Context Protocol) client implementation for AgentOS. 

2 

3Full MCP client with JSON-RPC 2.0, initialize handshake, tool/resource/prompt 

4discovery, dual transport (stdio + SSE), Sampling, Logging, Roots. 

5Designed to be used as async context manager. 

6 

7v1.14.0: Added Sampling, Resource Templates, Logging, Roots support. 

8""" 

9 

10from __future__ import annotations 

11 

12import asyncio 

13import json 

14import logging 

15import subprocess 

16import uuid 

17from abc import ABC, abstractmethod 

18from dataclasses import dataclass, field 

19from typing import Any, AsyncIterator, Dict, List, Optional 

20 

21import httpx 

22 

23logger = logging.getLogger(__name__) 

24 

25# ── Data Models ──────────────────────────── 

26 

27 

28@dataclass 

29class MCPServerConfig: 

30 """Configuration for connecting to an MCP server.""" 

31 name: str 

32 transport: str = "stdio" # stdio | sse 

33 command: Optional[str] = None 

34 args: List[str] = field(default_factory=list) 

35 url: Optional[str] = None 

36 env: Dict[str, str] = field(default_factory=dict) 

37 timeout: int = 30 

38 capabilities: Dict[str, Any] = field(default_factory=dict) 

39 

40 

41@dataclass 

42class MCPToolInfo: 

43 """Metadata for a discovered MCP tool.""" 

44 name: str 

45 description: str = "" 

46 input_schema: Dict[str, Any] = field(default_factory=dict) 

47 server_name: str = "" 

48 

49 

50@dataclass 

51class MCPResourceInfo: 

52 """Metadata for a discovered MCP resource.""" 

53 uri: str 

54 name: str = "" 

55 description: str = "" 

56 mime_type: str = "" 

57 server_name: str = "" 

58 

59 

60@dataclass 

61class MCPPromptInfo: 

62 """Metadata for a discovered MCP prompt.""" 

63 name: str 

64 description: str = "" 

65 arguments: List[Dict[str, Any]] = field(default_factory=list) 

66 server_name: str = "" 

67 

68 

69# ── JSON-RPC 2.0 Transport ────────────────── 

70 

71 

72class MCPTransport(ABC): 

73 """Abstract transport layer for MCP JSON-RPC 2.0 communication.""" 

74 

75 @abstractmethod 

76 async def connect(self, config: MCPServerConfig) -> None: 

77 ... 

78 

79 @abstractmethod 

80 async def send_request(self, method: str, params: Optional[Dict] = None) -> Dict[str, Any]: 

81 ... 

82 

83 @abstractmethod 

84 async def send_notification(self, method: str, params: Optional[Dict] = None) -> None: 

85 ... 

86 

87 @abstractmethod 

88 async def close(self) -> None: 

89 ... 

90 

91 

92class StdioTransport(MCPTransport): 

93 """MCP transport over subprocess stdio. 

94 

95 Communicates with an MCP server launched as a child process 

96 using newline-delimited JSON-RPC 2.0 messages. 

97 """ 

98 

99 def __init__(self): 

100 self._proc: Optional[subprocess.Popen] = None 

101 self._lock = asyncio.Lock() 

102 self._request_id = 0 

103 self._pending: Dict[int, asyncio.Future] = {} 

104 self._reader_task: Optional[asyncio.Task] = None 

105 

106 async def connect(self, config: MCPServerConfig) -> None: 

107 cmd = config.command or "npx" 

108 full_args = [cmd] + list(config.args) 

109 env = {**__import__("os").environ, **config.env} 

110 

111 self._proc = subprocess.Popen( 

112 full_args, 

113 stdin=subprocess.PIPE, 

114 stdout=subprocess.PIPE, 

115 stderr=subprocess.PIPE, 

116 env=env, 

117 text=False, 

118 ) 

119 self._reader_task = asyncio.create_task(self._read_loop()) 

120 

121 async def _read_loop(self) -> None: 

122 """Continuously read JSON-RPC responses from stdout.""" 

123 loop = asyncio.get_event_loop() 

124 while self._proc and self._proc.poll() is None: 

125 try: 

126 line = await loop.run_in_executor(None, self._proc.stdout.readline) 

127 if not line: 

128 break 

129 data = json.loads(line.decode("utf-8")) 

130 req_id = data.get("id") 

131 if req_id is not None and req_id in self._pending: 

132 future = self._pending.pop(req_id) 

133 if "error" in data: 

134 future.set_exception(MCPError( 

135 data["error"].get("code", -1), 

136 data["error"].get("message", "Unknown error"), 

137 )) 

138 else: 

139 future.set_result(data.get("result", {})) 

140 except Exception: 

141 continue 

142 

143 async def send_request(self, method: str, params: Optional[Dict] = None) -> Dict[str, Any]: 

144 """Send a JSON-RPC 2.0 request and await the response.""" 

145 if not self._proc or self._proc.poll() is not None: 

146 raise MCPError(-32000, "MCP server process is not running") 

147 

148 async with self._lock: 

149 self._request_id += 1 

150 req_id = self._request_id 

151 request = { 

152 "jsonrpc": "2.0", 

153 "id": req_id, 

154 "method": method, 

155 "params": params or {}, 

156 } 

157 

158 future: asyncio.Future = asyncio.get_event_loop().create_future() 

159 self._pending[req_id] = future 

160 

161 payload = json.dumps(request).encode("utf-8") + b"\n" 

162 loop = asyncio.get_event_loop() 

163 await loop.run_in_executor(None, self._proc.stdin.write, payload) 

164 await loop.run_in_executor(None, self._proc.stdin.flush) 

165 

166 try: 

167 return await asyncio.wait_for(future, timeout=30) 

168 except asyncio.TimeoutError: 

169 self._pending.pop(req_id, None) 

170 raise MCPError(-32001, "Request timed out") 

171 

172 async def send_notification(self, method: str, params: Optional[Dict] = None) -> None: 

173 """Send a JSON-RPC 2.0 notification (no response expected).""" 

174 if not self._proc or self._proc.poll() is not None: 

175 return 

176 

177 async with self._lock: 

178 notification = { 

179 "jsonrpc": "2.0", 

180 "method": method, 

181 "params": params or {}, 

182 } 

183 payload = json.dumps(notification).encode("utf-8") + b"\n" 

184 loop = asyncio.get_event_loop() 

185 await loop.run_in_executor(None, self._proc.stdin.write, payload) 

186 await loop.run_in_executor(None, self._proc.stdin.flush) 

187 

188 async def close(self) -> None: 

189 if self._reader_task: 

190 self._reader_task.cancel() 

191 try: 

192 await self._reader_task 

193 except asyncio.CancelledError: 

194 pass 

195 if self._proc: 

196 self._proc.terminate() 

197 try: 

198 self._proc.wait(timeout=5) 

199 except subprocess.TimeoutExpired: 

200 self._proc.kill() 

201 self._proc = None 

202 

203 

204class SSETransport(MCPTransport): 

205 """MCP transport over HTTP SSE (Server-Sent Events). 

206 

207 Connects to a remote MCP server via HTTP POST for requests 

208 and SSE stream for responses. 

209 """ 

210 

211 def __init__(self): 

212 self._client: Optional[httpx.AsyncClient] = None 

213 self._request_id = 0 

214 self._pending: Dict[int, asyncio.Future] = {} 

215 self._sse_task: Optional[asyncio.Task] = None 

216 self._response_queue: asyncio.Queue = asyncio.Queue() 

217 self._message_endpoint: str = "" 

218 self._sse_endpoint: str = "" 

219 

220 async def connect(self, config: MCPServerConfig) -> None: 

221 if not config.url: 

222 raise MCPError(-32602, "URL required for SSE transport") 

223 

224 self._message_endpoint = config.url.rstrip("/") + "/message" 

225 self._sse_endpoint = config.url.rstrip("/") + "/sse" 

226 self._client = httpx.AsyncClient(timeout=config.timeout) 

227 self._sse_task = asyncio.create_task(self._sse_loop()) 

228 

229 async def _sse_loop(self) -> None: 

230 """Read SSE events and route to pending futures.""" 

231 while self._client: 

232 try: 

233 async with self._client.stream("GET", self._sse_endpoint) as response: 

234 async for line in response.aiter_lines(): 

235 if line.startswith("data: "): 

236 data_str = line[6:].strip() 

237 try: 

238 data = json.loads(data_str) 

239 req_id = data.get("id") 

240 if req_id is not None and req_id in self._pending: 

241 future = self._pending.pop(req_id) 

242 if "error" in data: 

243 future.set_exception(MCPError( 

244 data["error"].get("code", -1), 

245 data["error"].get("message", ""), 

246 )) 

247 else: 

248 future.set_result(data.get("result", {})) 

249 except json.JSONDecodeError: 

250 continue 

251 except Exception: 

252 await asyncio.sleep(1) 

253 

254 async def send_request(self, method: str, params: Optional[Dict] = None) -> Dict[str, Any]: 

255 if not self._client: 

256 raise MCPError(-32000, "SSE transport not connected") 

257 

258 self._request_id += 1 

259 req_id = self._request_id 

260 request = { 

261 "jsonrpc": "2.0", 

262 "id": req_id, 

263 "method": method, 

264 "params": params or {}, 

265 } 

266 

267 future: asyncio.Future = asyncio.get_event_loop().create_future() 

268 self._pending[req_id] = future 

269 

270 resp = await self._client.post(self._message_endpoint, json=request) 

271 resp.raise_for_status() 

272 

273 try: 

274 return await asyncio.wait_for(future, timeout=30) 

275 except asyncio.TimeoutError: 

276 self._pending.pop(req_id, None) 

277 raise MCPError(-32001, "SSE request timed out") 

278 

279 async def send_notification(self, method: str, params: Optional[Dict] = None) -> None: 

280 if not self._client: 

281 return 

282 notification = { 

283 "jsonrpc": "2.0", 

284 "method": method, 

285 "params": params or {}, 

286 } 

287 await self._client.post(self._message_endpoint, json=notification) 

288 

289 async def close(self) -> None: 

290 if self._sse_task: 

291 self._sse_task.cancel() 

292 try: 

293 await self._sse_task 

294 except asyncio.CancelledError: 

295 pass 

296 if self._client: 

297 await self._client.aclose() 

298 self._client = None 

299 

300 

301# ── Error ──────────────────────────────────── 

302 

303 

304class MCPError(Exception): 

305 """MCP protocol error.""" 

306 def __init__(self, code: int, message: str, data: Any = None): 

307 self.code = code 

308 self.message = message 

309 self.data = data 

310 super().__init__(f"MCP Error [{code}]: {message}") 

311 

312 

313# ── Full Client ───────────────────────────── 

314 

315 

316class MCPClient: 

317 """Full MCP client for connecting to and using MCP servers. 

318 

319 Supports stdio (local process) and SSE (remote HTTP) transports. 

320 

321 Usage: 

322 async with MCPClient() as client: 

323 await client.connect_server(MCPServerConfig( 

324 name="filesystem", 

325 command="npx", 

326 args=["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], 

327 )) 

328 tools = client.list_tools() 

329 result = await client.call_tool("filesystem", "read_file", {"path": "/tmp/test.txt"}) 

330 """ 

331 

332 TRANSPORTS = { 

333 "stdio": StdioTransport, 

334 "sse": SSETransport, 

335 } 

336 

337 def __init__(self): 

338 self._servers: Dict[str, MCPTransport] = {} 

339 self._server_configs: Dict[str, MCPServerConfig] = {} 

340 self._tools: Dict[str, MCPToolInfo] = {} 

341 self._resources: Dict[str, MCPResourceInfo] = {} 

342 self._prompts: Dict[str, MCPPromptInfo] = {} 

343 self._server_capabilities: Dict[str, Dict[str, Any]] = {} 

344 

345 async def __aenter__(self) -> "MCPClient": 

346 return self 

347 

348 async def __aexit__(self, *args) -> None: 

349 await self.close_all() 

350 

351 async def connect_server(self, config: MCPServerConfig) -> Dict[str, Any]: 

352 """Connect to an MCP server and perform initialization handshake. 

353 

354 Returns the server's capabilities dict. 

355 """ 

356 transport_cls = self.TRANSPORTS.get(config.transport) 

357 if not transport_cls: 

358 raise MCPError(-32601, f"Unknown transport: {config.transport}") 

359 

360 transport = transport_cls() 

361 await transport.connect(config) 

362 

363 # MCP Initialize handshake 

364 init_result = await transport.send_request("initialize", { 

365 "protocolVersion": "2024-11-05", 

366 "capabilities": config.capabilities or {}, 

367 "clientInfo": { 

368 "name": "agentos-mcp-client", 

369 "version": "1.0.0", 

370 }, 

371 }) 

372 

373 # Send initialized notification 

374 await transport.send_notification("notifications/initialized") 

375 

376 self._servers[config.name] = transport 

377 self._server_configs[config.name] = config 

378 self._server_capabilities[config.name] = init_result.get("capabilities", {}) 

379 

380 # Discover tools, resources, prompts 

381 await self._discover_server(config.name) 

382 

383 return init_result 

384 

385 async def _discover_server(self, server_name: str) -> None: 

386 """Discover all capabilities of a connected server.""" 

387 transport = self._servers[server_name] 

388 caps = self._server_capabilities.get(server_name, {}) 

389 

390 # Discover tools 

391 if caps.get("tools"): 

392 try: 

393 result = await transport.send_request("tools/list") 

394 for tool in result.get("tools", []): 

395 full_name = f"mcp__{server_name}__{tool['name']}" 

396 self._tools[full_name] = MCPToolInfo( 

397 name=tool["name"], 

398 description=tool.get("description", ""), 

399 input_schema=tool.get("inputSchema", {}), 

400 server_name=server_name, 

401 ) 

402 except MCPError: 

403 logger.debug(f"Server '{server_name}' tools/list not supported") 

404 

405 # Discover resources 

406 if caps.get("resources"): 

407 try: 

408 result = await transport.send_request("resources/list") 

409 for res in result.get("resources", []): 

410 self._resources[res["uri"]] = MCPResourceInfo( 

411 uri=res["uri"], 

412 name=res.get("name", ""), 

413 description=res.get("description", ""), 

414 mime_type=res.get("mimeType", ""), 

415 server_name=server_name, 

416 ) 

417 except MCPError: 

418 logger.debug(f"Server '{server_name}' resources/list not supported") 

419 

420 # Discover prompts 

421 if caps.get("prompts"): 

422 try: 

423 result = await transport.send_request("prompts/list") 

424 for prompt in result.get("prompts", []): 

425 key = f"{server_name}__{prompt['name']}" 

426 self._prompts[key] = MCPPromptInfo( 

427 name=prompt["name"], 

428 description=prompt.get("description", ""), 

429 arguments=prompt.get("arguments", []), 

430 server_name=server_name, 

431 ) 

432 except MCPError: 

433 logger.debug(f"Server '{server_name}' prompts/list not supported") 

434 

435 async def call_tool( 

436 self, 

437 server_name: str, 

438 tool_name: str, 

439 arguments: Optional[Dict[str, Any]] = None, 

440 ) -> Any: 

441 """Call a tool on a connected MCP server. 

442 

443 Args: 

444 server_name: Name of the MCP server. 

445 tool_name: Name of the tool to call. 

446 arguments: Tool arguments dict. 

447 

448 Returns: 

449 Tool result content (text or structured data). 

450 """ 

451 if server_name not in self._servers: 

452 raise MCPError(-32602, f"Server '{server_name}' not connected") 

453 

454 transport = self._servers[server_name] 

455 result = await transport.send_request("tools/call", { 

456 "name": tool_name, 

457 "arguments": arguments or {}, 

458 }) 

459 

460 content = result.get("content", []) 

461 if not content: 

462 return "" 

463 

464 # Extract text from content blocks 

465 texts = [] 

466 for block in content: 

467 if block.get("type") == "text": 

468 texts.append(block.get("text", "")) 

469 elif block.get("type") == "resource": 

470 texts.append(f"[Resource: {block.get('resource', {}).get('uri', '')}]") 

471 elif block.get("type") == "image": 

472 texts.append(f"[Image: {block.get('data', '')[:50]}...]") 

473 

474 return "\n".join(texts) if texts else content 

475 

476 async def read_resource(self, server_name: str, uri: str) -> Dict[str, Any]: 

477 """Read a resource from a connected MCP server.""" 

478 if server_name not in self._servers: 

479 raise MCPError(-32602, f"Server '{server_name}' not connected") 

480 

481 transport = self._servers[server_name] 

482 result = await transport.send_request("resources/read", {"uri": uri}) 

483 return result.get("contents", [{}])[0] if result.get("contents") else {} 

484 

485 async def get_prompt( 

486 self, 

487 server_name: str, 

488 prompt_name: str, 

489 arguments: Optional[Dict[str, str]] = None, 

490 ) -> Dict[str, Any]: 

491 """Get a prompt template from a connected MCP server.""" 

492 if server_name not in self._servers: 

493 raise MCPError(-32602, f"Server '{server_name}' not connected") 

494 

495 transport = self._servers[server_name] 

496 result = await transport.send_request("prompts/get", { 

497 "name": prompt_name, 

498 "arguments": arguments or {}, 

499 }) 

500 return result 

501 

502 def list_tools(self, server_name: Optional[str] = None) -> List[MCPToolInfo]: 

503 """List discovered tools, optionally filtered by server.""" 

504 tools = list(self._tools.values()) 

505 if server_name: 

506 tools = [t for t in tools if t.server_name == server_name] 

507 return tools 

508 

509 def list_resources(self, server_name: Optional[str] = None) -> List[MCPResourceInfo]: 

510 """List discovered resources, optionally filtered by server.""" 

511 resources = list(self._resources.values()) 

512 if server_name: 

513 resources = [r for r in resources if r.server_name == server_name] 

514 return resources 

515 

516 def list_prompts(self, server_name: Optional[str] = None) -> List[MCPPromptInfo]: 

517 """List discovered prompts, optionally filtered by server.""" 

518 prompts = list(self._prompts.values()) 

519 if server_name: 

520 prompts = [p for p in prompts if p.server_name == server_name] 

521 return prompts 

522 

523 def get_server_capabilities(self, server_name: str) -> Dict[str, Any]: 

524 """Get the capabilities reported by a server.""" 

525 return self._server_capabilities.get(server_name, {}) 

526 

527 def get_tool_schemas( 

528 self, 

529 server_name: Optional[str] = None, 

530 format: str = "openai", 

531 ) -> List[Dict[str, Any]]: 

532 """Export tool schemas in OpenAI or Anthropic function format. 

533 

534 Args: 

535 server_name: Optional filter by server. 

536 format: 'openai' or 'anthropic'. 

537 

538 Returns: 

539 List of function/tool schema dicts. 

540 """ 

541 tools = self.list_tools(server_name) 

542 schemas = [] 

543 

544 for tool in tools: 

545 params = tool.input_schema 

546 if format == "openai": 

547 schemas.append({ 

548 "type": "function", 

549 "function": { 

550 "name": f"mcp__{tool.server_name}__{tool.name}", 

551 "description": tool.description, 

552 "parameters": params, 

553 }, 

554 }) 

555 elif format == "anthropic": 

556 schemas.append({ 

557 "name": f"mcp__{tool.server_name}__{tool.name}", 

558 "description": tool.description, 

559 "input_schema": params, 

560 }) 

561 

562 return schemas 

563 

564 @property 

565 def connected_servers(self) -> List[str]: 

566 """List names of connected servers.""" 

567 return list(self._servers.keys()) 

568 

569 async def disconnect_server(self, server_name: str) -> None: 

570 """Disconnect from a specific MCP server.""" 

571 if server_name in self._servers: 

572 await self._servers[server_name].close() 

573 del self._servers[server_name] 

574 self._server_configs.pop(server_name, None) 

575 self._server_capabilities.pop(server_name, None) 

576 # Remove associated tools/resources/prompts 

577 self._tools = { 

578 k: v for k, v in self._tools.items() 

579 if v.server_name != server_name 

580 } 

581 self._resources = { 

582 k: v for k, v in self._resources.items() 

583 if v.server_name != server_name 

584 } 

585 self._prompts = { 

586 k: v for k, v in self._prompts.items() 

587 if v.server_name != server_name 

588 } 

589 

590 async def close_all(self) -> None: 

591 """Disconnect from all MCP servers.""" 

592 for name in list(self._servers.keys()): 

593 await self.disconnect_server(name) 

594 

595 

596# ── Convenience Functions ──────────────────── 

597 

598 

599async def connect_mcp_servers( 

600 configs: List[MCPServerConfig], 

601) -> MCPClient: 

602 """Connect to multiple MCP servers at once. 

603 

604 Usage: 

605 client = await connect_mcp_servers([ 

606 MCPServerConfig(name="filesystem", command="npx", 

607 args=["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]), 

608 MCPServerConfig(name="github", command="npx", 

609 args=["-y", "@modelcontextprotocol/server-github"], 

610 env={"GITHUB_PERSONAL_ACCESS_TOKEN": os.environ["GITHUB_TOKEN"]}), 

611 ]) 

612 """ 

613 client = MCPClient() 

614 for config in configs: 

615 await client.connect_server(config) 

616 return client 

617 

618 

619# ── MCP Server (v1.5.2) ───────────────────── 

620 

621from agentos.mcp.server import ( 

622 MCPServer, 

623 MCPToolDef, 

624 MCPResource, 

625 MCPPromptDef, 

626 create_default_server, 

627 start_mcp_server, 

628) 

629 

630__all__ = [ 

631 "MCPServerConfig", 

632 "MCPToolInfo", 

633 "MCPResourceInfo", 

634 "MCPPromptInfo", 

635 "MCPError", 

636 "MCPTransport", 

637 "StdioTransport", 

638 "SSETransport", 

639 "MCPClient", 

640 "connect_mcp_servers", 

641 # MCP Server (v1.5.2) 

642 "MCPServer", 

643 "MCPToolDef", 

644 "MCPResource", 

645 "MCPPromptDef", 

646 "create_default_server", 

647 "start_mcp_server", 

648 # MCP Sampling, Resource Templates, Logging, Roots (v1.14.0) 

649 "MCPClientSampling", 

650 "SamplingRequest", 

651 "SamplingResponse", 

652 "SamplingMessage", 

653 "SamplingContentBlock", 

654 "SamplingRole", 

655 "SamplingError", 

656 "mock_llm_call", 

657 "MCPResourceTemplate", 

658 "MCPLogLevel", 

659 "MCPLoggingHandler", 

660 "MCPRoot", 

661]