Coverage for mcpgateway/wrapper.py: 86%

217 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-09 11:03 +0100

1# -*- coding: utf-8 -*- 

2""" 

3MCP Gateway Wrapper server. 

4 

5Copyright 2025 

6SPDX-License-Identifier: Apache-2.0 

7Authors: Keval Mahajan, Mihai Criveti, Madhav Kandukuri 

8 

9This module implements a wrapper bridge that facilitates 

10interaction between the MCP client and the MCP gateway. 

11It provides several functionalities, including listing tools, 

12invoking tools, managing resources, retrieving prompts, 

13and handling tool calls via the MCP gateway. 

14 

15A **stdio** bridge that exposes a remote MCP Gateway 

16(HTTP-/JSON-RPC APIs) as a local MCP server. All JSON-RPC 

17traffic is written to **stdout**; every log or trace message 

18is emitted on **stderr** so that protocol messages and 

19diagnostics never mix. 

20 

21Environment variables: 

22- MCP_SERVER_CATALOG_URLS: Comma-separated list of gateway catalog URLs (required) 

23- MCP_AUTH_TOKEN: Bearer token for the gateway (optional) 

24- MCP_TOOL_CALL_TIMEOUT: Seconds to wait for a gateway RPC call (default 90) 

25- MCP_WRAPPER_LOG_LEVEL: Python log level name or OFF/NONE to disable logging (default INFO) 

26 

27Example: 

28 $ export MCPGATEWAY_BEARER_TOKEN=$(python3 -m mcpgateway.utils.create_jwt_token --username admin --exp 10080 --secret my-test-key) 

29 $ export MCP_AUTH_TOKEN=${MCPGATEWAY_BEARER_TOKEN} 

30 $ export MCP_SERVER_CATALOG_URLS='http://localhost:4444/servers/UUID_OF_SERVER_1' 

31 $ export MCP_TOOL_CALL_TIMEOUT=120 

32 $ export MCP_WRAPPER_LOG_LEVEL=DEBUG # OFF to disable logging 

33 $ python3 -m mcpgateway.wrapper 

34""" 

35 

36# Standard 

37import asyncio 

38import logging 

39import os 

40import sys 

41from typing import Any, Dict, List, Optional, Union 

42from urllib.parse import urlparse 

43 

44# Third-Party 

45import httpx 

46from mcp import types 

47from mcp.server import NotificationOptions, Server 

48from mcp.server.models import InitializationOptions 

49import mcp.server.stdio 

50from pydantic import AnyUrl 

51 

52# First-Party 

53from mcpgateway import __version__ 

54 

55# ----------------------------------------------------------------------------- 

56# Configuration 

57# ----------------------------------------------------------------------------- 

58ENV_SERVER_CATALOGS = "MCP_SERVER_CATALOG_URLS" 

59ENV_AUTH_TOKEN = "MCP_AUTH_TOKEN" # nosec B105 - this is an *environment variable name*, not a secret 

60ENV_TIMEOUT = "MCP_TOOL_CALL_TIMEOUT" 

61ENV_LOG_LEVEL = "MCP_WRAPPER_LOG_LEVEL" 

62 

63RAW_CATALOGS: str = os.getenv(ENV_SERVER_CATALOGS, "") 

64SERVER_CATALOG_URLS: List[str] = [u.strip() for u in RAW_CATALOGS.split(",") if u.strip()] 

65 

66AUTH_TOKEN: str = os.getenv(ENV_AUTH_TOKEN, "") 

67TOOL_CALL_TIMEOUT: int = int(os.getenv(ENV_TIMEOUT, "90")) 

68 

69# Validate required configuration 

70if not SERVER_CATALOG_URLS: 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true

71 print(f"Error: {ENV_SERVER_CATALOGS} environment variable is required", file=sys.stderr) 

72 sys.exit(1) 

73 

74 

75# ----------------------------------------------------------------------------- 

76# Base URL Extraction 

77# ----------------------------------------------------------------------------- 

78def _extract_base_url(url: str) -> str: 

79 """Return the gateway-level base URL. 

80 

81 The function keeps any application root path (`APP_ROOT_PATH`) that the 

82 remote gateway is mounted under (for example `/gateway`) while removing 

83 the `/servers/<id>` suffix that appears in catalog endpoints. It also 

84 discards any query string or fragment. 

85 

86 Args: 

87 url (str): Full catalog URL, e.g. 

88 `https://host.com/gateway/servers/UUID_OF_SERVER_1`. 

89 

90 Returns: 

91 str: Clean base URL suitable for building `/tools/`, `/prompts/`, 

92 or `/resources/` endpoints-for example 

93 `https://host.com/gateway`. 

94 

95 Raises: 

96 ValueError: If *url* lacks a scheme or network location. 

97 

98 Examples: 

99 >>> _extract_base_url("https://host.com/servers/UUID_OF_SERVER_2") 

100 'https://host.com' 

101 >>> _extract_base_url("https://host.com/gateway/servers/UUID_OF_SERVER_2") 

102 'https://host.com/gateway' 

103 >>> _extract_base_url("https://host.com/gateway/servers") 

104 'https://host.com/gateway' 

105 >>> _extract_base_url("https://host.com/gateway") 

106 'https://host.com/gateway' 

107 

108 Note: 

109 If the target server was started with `APP_ROOT_PATH=/gateway`, the 

110 resulting catalog URLs include that prefix. This helper preserves the 

111 prefix so the wrapper's follow-up calls remain correctly scoped. 

112 """ 

113 parsed = urlparse(url) 

114 if not parsed.scheme or not parsed.netloc: 

115 raise ValueError(f"Invalid URL provided: {url}") 

116 

117 path = parsed.path or "" 

118 if "/servers/" in path: 

119 path = path.split("/servers")[0] # ".../servers/UUID_OF_SERVER_123" -> "..." 

120 elif path.endswith("/servers"): 

121 path = path[: -len("/servers")] # ".../servers" -> "..." 

122 # otherwise keep the existing path (supports APP_ROOT_PATH) 

123 

124 return f"{parsed.scheme}://{parsed.netloc}{path}" 

125 

126 

127BASE_URL: str = _extract_base_url(SERVER_CATALOG_URLS[0]) if SERVER_CATALOG_URLS else "" 

128 

129# ----------------------------------------------------------------------------- 

130# Logging Setup 

131# ----------------------------------------------------------------------------- 

132_log_level = os.getenv(ENV_LOG_LEVEL, "INFO").upper() 

133if _log_level in {"OFF", "NONE", "DISABLE", "FALSE", "0"}: 133 ↛ 134line 133 didn't jump to line 134 because the condition on line 133 was never true

134 logging.disable(logging.CRITICAL) 

135else: 

136 logging.basicConfig( 

137 level=getattr(logging, _log_level, logging.INFO), 

138 format="%(asctime)s %(levelname)-8s %(name)s: %(message)s", 

139 stream=sys.stderr, 

140 ) 

141 

142logger = logging.getLogger("mcpgateway.wrapper") 

143logger.info(f"Starting MCP wrapper {__version__}: base_url={BASE_URL}, timeout={TOOL_CALL_TIMEOUT}") 

144 

145 

146# ----------------------------------------------------------------------------- 

147# HTTP Helpers 

148# ----------------------------------------------------------------------------- 

149async def fetch_url(url: str) -> httpx.Response: 

150 """ 

151 Perform an asynchronous HTTP GET request and return the response. 

152 

153 Args: 

154 url: The target URL to fetch. 

155 

156 Returns: 

157 The successful ``httpx.Response`` object. 

158 

159 Raises: 

160 httpx.RequestError: If a network problem occurs while making the request. 

161 httpx.HTTPStatusError: If the server returns a 4xx or 5xx response. 

162 """ 

163 headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} if AUTH_TOKEN else {} 

164 async with httpx.AsyncClient(timeout=TOOL_CALL_TIMEOUT) as client: 

165 try: 

166 response = await client.get(url, headers=headers) 

167 response.raise_for_status() 

168 return response 

169 except httpx.RequestError as err: 

170 logger.error(f"Network error while fetching {url}: {err}") 

171 raise 

172 except httpx.HTTPStatusError as err: 

173 logger.error(f"HTTP {err.response.status_code} returned for {url}: {err}") 

174 raise 

175 

176 

177# ----------------------------------------------------------------------------- 

178# Metadata Helpers 

179# ----------------------------------------------------------------------------- 

180async def get_tools_from_mcp_server(catalog_urls: List[str]) -> List[str]: 

181 """ 

182 Retrieve associated tool IDs from the MCP gateway server catalogs. 

183 

184 Args: 

185 catalog_urls (List[str]): List of catalog endpoint URLs. 

186 

187 Returns: 

188 List[str]: A list of tool ID strings extracted from the server catalog. 

189 """ 

190 server_ids = [url.split("/")[-1] for url in catalog_urls] 

191 url = f"{BASE_URL}/servers/" 

192 response = await fetch_url(url) 

193 catalog = response.json() 

194 tool_ids: List[str] = [] 

195 for entry in catalog: 

196 if str(entry.get("id")) in server_ids: 

197 tool_ids.extend(entry.get("associatedTools", [])) 

198 return tool_ids 

199 

200 

201async def tools_metadata(tool_ids: List[str]) -> List[Dict[str, Any]]: 

202 """ 

203 Fetch metadata for a list of MCP tools by their IDs. 

204 

205 Args: 

206 tool_ids (List[str]): List of tool ID strings. 

207 

208 Returns: 

209 List[Dict[str, Any]]: A list of metadata dictionaries for each tool. 

210 """ 

211 if not tool_ids: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true

212 return [] 

213 url = f"{BASE_URL}/tools/" 

214 response = await fetch_url(url) 

215 data: List[Dict[str, Any]] = response.json() 

216 if tool_ids == ["0"]: 

217 return data 

218 

219 return [tool for tool in data if tool["name"] in tool_ids] 

220 

221 

222async def get_prompts_from_mcp_server(catalog_urls: List[str]) -> List[str]: 

223 """ 

224 Retrieve associated prompt IDs from the MCP gateway server catalogs. 

225 

226 Args: 

227 catalog_urls (List[str]): List of catalog endpoint URLs. 

228 

229 Returns: 

230 List[str]: A list of prompt ID strings. 

231 """ 

232 server_ids = [url.split("/")[-1] for url in catalog_urls] 

233 url = f"{BASE_URL}/servers/" 

234 response = await fetch_url(url) 

235 catalog = response.json() 

236 prompt_ids: List[str] = [] 

237 for entry in catalog: 

238 if str(entry.get("id")) in server_ids: 

239 prompt_ids.extend(entry.get("associatedPrompts", [])) 

240 return prompt_ids 

241 

242 

243async def prompts_metadata(prompt_ids: List[str]) -> List[Dict[str, Any]]: 

244 """ 

245 Fetch metadata for a list of MCP prompts by their IDs. 

246 

247 Args: 

248 prompt_ids (List[str]): List of prompt ID strings. 

249 

250 Returns: 

251 List[Dict[str, Any]]: A list of metadata dictionaries for each prompt. 

252 """ 

253 if not prompt_ids: 253 ↛ 254line 253 didn't jump to line 254 because the condition on line 253 was never true

254 return [] 

255 url = f"{BASE_URL}/prompts/" 

256 response = await fetch_url(url) 

257 data: List[Dict[str, Any]] = response.json() 

258 if prompt_ids == ["0"]: 

259 return data 

260 return [pr for pr in data if str(pr.get("id")) in prompt_ids] 

261 

262 

263async def get_resources_from_mcp_server(catalog_urls: List[str]) -> List[str]: 

264 """ 

265 Retrieve associated resource IDs from the MCP gateway server catalogs. 

266 

267 Args: 

268 catalog_urls (List[str]): List of catalog endpoint URLs. 

269 

270 Returns: 

271 List[str]: A list of resource ID strings. 

272 """ 

273 server_ids = [url.split("/")[-1] for url in catalog_urls] 

274 url = f"{BASE_URL}/servers/" 

275 response = await fetch_url(url) 

276 catalog = response.json() 

277 resource_ids: List[str] = [] 

278 for entry in catalog: 

279 if str(entry.get("id")) in server_ids: 

280 resource_ids.extend(entry.get("associatedResources", [])) 

281 return resource_ids 

282 

283 

284async def resources_metadata(resource_ids: List[str]) -> List[Dict[str, Any]]: 

285 """ 

286 Fetch metadata for a list of MCP resources by their IDs. 

287 

288 Args: 

289 resource_ids (List[str]): List of resource ID strings. 

290 

291 Returns: 

292 List[Dict[str, Any]]: A list of metadata dictionaries for each resource. 

293 """ 

294 if not resource_ids: 294 ↛ 295line 294 didn't jump to line 295 because the condition on line 294 was never true

295 return [] 

296 url = f"{BASE_URL}/resources/" 

297 response = await fetch_url(url) 

298 data: List[Dict[str, Any]] = response.json() 

299 if resource_ids == ["0"]: 

300 return data 

301 return [res for res in data if str(res.get("id")) in resource_ids] 

302 

303 

304# ----------------------------------------------------------------------------- 

305# Server Handlers 

306# ----------------------------------------------------------------------------- 

307server: Server = Server("mcpgateway-wrapper") 

308 

309 

310@server.list_tools() 

311async def handle_list_tools() -> List[types.Tool]: 

312 """ 

313 List all available MCP tools exposed by the gateway. 

314 

315 Queries the configured server catalogs to retrieve tool IDs and then 

316 fetches metadata for each tool to construct a list of Tool objects. 

317 

318 Returns: 

319 List[types.Tool]: A list of Tool instances including name, description, and input schema. 

320 

321 Raises: 

322 RuntimeError: If an error occurs during fetching or processing. 

323 """ 

324 try: 

325 tool_ids = ["0"] if SERVER_CATALOG_URLS[0] == BASE_URL else await get_tools_from_mcp_server(SERVER_CATALOG_URLS) 

326 metadata = await tools_metadata(tool_ids) 

327 tools = [] 

328 for tool in metadata: 

329 tool_name = tool.get("name") 

330 if tool_name: # Only include tools with valid names 330 ↛ 328line 330 didn't jump to line 328 because the condition on line 330 was always true

331 tools.append( 

332 types.Tool( 

333 name=str(tool_name), 

334 description=tool.get("description", ""), 

335 inputSchema=tool.get("inputSchema", {}), 

336 annotations=tool.get("annotations", {}), 

337 ) 

338 ) 

339 return tools 

340 except Exception as exc: 

341 logger.exception("Error listing tools") 

342 raise RuntimeError(f"Error listing tools: {exc}") 

343 

344 

345@server.call_tool() 

346async def handle_call_tool(name: str, arguments: Optional[Dict[str, Any]] = None) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 

347 """ 

348 Invoke a named MCP tool via the gateway's RPC endpoint. 

349 

350 Args: 

351 name (str): The name of the tool to invoke. 

352 arguments (Optional[Dict[str, Any]]): The arguments to pass to the tool method. 

353 

354 Returns: 

355 List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 

356 A list of content objects returned by the tool. 

357 

358 Raises: 

359 ValueError: If tool call fails. 

360 RuntimeError: If the HTTP request fails or returns an error. 

361 """ 

362 if arguments is None: 362 ↛ 363line 362 didn't jump to line 363 because the condition on line 362 was never true

363 arguments = {} 

364 

365 logger.info(f"Calling tool {name} with args {arguments}") 

366 payload = {"jsonrpc": "2.0", "id": 2, "method": name, "params": arguments} 

367 headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} if AUTH_TOKEN else {} 

368 

369 try: 

370 async with httpx.AsyncClient(timeout=TOOL_CALL_TIMEOUT) as client: 

371 resp = await client.post(f"{BASE_URL}/rpc/", json=payload, headers=headers) 

372 resp.raise_for_status() 

373 result = resp.json() 

374 

375 if "error" in result: 

376 error_msg = result["error"].get("message", "Unknown error") 

377 raise ValueError(f"Tool call failed: {error_msg}") 

378 

379 tool_result = result.get("result", result) 

380 return [types.TextContent(type="text", text=str(tool_result))] 

381 

382 except httpx.TimeoutException as exc: 

383 logger.error(f"Timeout calling tool {name}: {exc}") 

384 raise RuntimeError(f"Tool call timeout: {exc}") 

385 except Exception as exc: 

386 logger.exception(f"Error calling tool {name}") 

387 raise RuntimeError(f"Error calling tool: {exc}") 

388 

389 

390@server.list_resources() 

391async def handle_list_resources() -> List[types.Resource]: 

392 """ 

393 List all available MCP resources exposed by the gateway. 

394 

395 Fetches resource IDs from the configured catalogs and retrieves 

396 metadata to construct Resource instances. 

397 

398 Returns: 

399 List[types.Resource]: A list of Resource objects including URI, name, description, and MIME type. 

400 

401 Raises: 

402 RuntimeError: If an error occurs during fetching or processing. 

403 """ 

404 try: 

405 ids = ["0"] if SERVER_CATALOG_URLS[0] == BASE_URL else await get_resources_from_mcp_server(SERVER_CATALOG_URLS) 

406 meta = await resources_metadata(ids) 

407 resources = [] 

408 for r in meta: 

409 uri = r.get("uri") 

410 if not uri: 410 ↛ 411line 410 didn't jump to line 411 because the condition on line 410 was never true

411 logger.warning(f"Resource missing URI, skipping: {r}") 

412 continue 

413 try: 

414 resources.append( 

415 types.Resource( 

416 uri=AnyUrl(uri), 

417 name=r.get("name", ""), 

418 description=r.get("description", ""), 

419 mimeType=r.get("mimeType", "text/plain"), 

420 ) 

421 ) 

422 except Exception as e: 

423 logger.warning(f"Invalid resource URI {uri}: {e}") 

424 continue 

425 return resources 

426 except Exception as exc: 

427 logger.exception("Error listing resources") 

428 raise RuntimeError(f"Error listing resources: {exc}") 

429 

430 

431@server.read_resource() 

432async def handle_read_resource(uri: AnyUrl) -> str: 

433 """ 

434 Read and return the content of a resource by its URI. 

435 

436 Args: 

437 uri (AnyUrl): The URI of the resource to read. 

438 

439 Returns: 

440 str: The body text of the fetched resource. 

441 

442 Raises: 

443 ValueError: If the resource cannot be fetched. 

444 """ 

445 try: 

446 response = await fetch_url(str(uri)) 

447 return response.text 

448 except Exception as exc: 

449 logger.exception(f"Error reading resource {uri}") 

450 raise ValueError(f"Failed to read resource at {uri}: {exc}") 

451 

452 

453@server.list_prompts() 

454async def handle_list_prompts() -> List[types.Prompt]: 

455 """ 

456 List all available MCP prompts exposed by the gateway. 

457 

458 Retrieves prompt IDs from the catalogs and fetches metadata 

459 to create Prompt instances. 

460 

461 Returns: 

462 List[types.Prompt]: A list of Prompt objects including name, description, and arguments. 

463 

464 Raises: 

465 RuntimeError: If an error occurs during fetching or processing. 

466 """ 

467 try: 

468 ids = ["0"] if SERVER_CATALOG_URLS[0] == BASE_URL else await get_prompts_from_mcp_server(SERVER_CATALOG_URLS) 

469 meta = await prompts_metadata(ids) 

470 prompts = [] 

471 for p in meta: 

472 prompt_name = p.get("name") 

473 if prompt_name: # Only include prompts with valid names 473 ↛ 471line 473 didn't jump to line 471 because the condition on line 473 was always true

474 prompts.append( 

475 types.Prompt( 

476 name=str(prompt_name), 

477 description=p.get("description", ""), 

478 arguments=p.get("arguments", []), 

479 ) 

480 ) 

481 return prompts 

482 except Exception as exc: 

483 logger.exception("Error listing prompts") 

484 raise RuntimeError(f"Error listing prompts: {exc}") 

485 

486 

487@server.get_prompt() 

488async def handle_get_prompt(name: str, arguments: Optional[Dict[str, str]] = None) -> types.GetPromptResult: 

489 """ 

490 Retrieve and format a single prompt template with provided arguments. 

491 

492 Args: 

493 name (str): The unique name of the prompt to fetch. 

494 arguments (Optional[Dict[str, str]]): A mapping of placeholder names to replacement values. 

495 

496 Returns: 

497 types.GetPromptResult: Contains description and list of formatted PromptMessage instances. 

498 

499 Raises: 

500 ValueError: If fetching or formatting fails. 

501 

502 Example: 

503 >>> await handle_get_prompt("greet", {"username": "Alice"}) 

504 """ 

505 try: 

506 url = f"{BASE_URL}/prompts/{name}" 

507 response = await fetch_url(url) 

508 prompt_data = response.json() 

509 

510 template = prompt_data.get("template", "") 

511 try: 

512 formatted = template.format(**(arguments or {})) 

513 except KeyError as exc: 

514 raise ValueError(f"Missing placeholder in arguments: {exc}") 

515 except Exception as exc: 

516 raise ValueError(f"Error formatting prompt: {exc}") 

517 

518 return types.GetPromptResult( 

519 description=prompt_data.get("description", ""), 

520 messages=[ 

521 types.PromptMessage( 

522 role="user", 

523 content=types.TextContent(type="text", text=formatted), 

524 ) 

525 ], 

526 ) 

527 except ValueError: 

528 raise 

529 except Exception as exc: 

530 logger.exception(f"Error getting prompt {name}") 

531 raise ValueError(f"Failed to fetch prompt '{name}': {exc}") 

532 

533 

534async def main() -> None: 

535 """ 

536 Main entry point to start the MCP stdio server. 

537 

538 Initializes the server over standard IO, registers capabilities, 

539 and begins listening for JSON-RPC messages. 

540 

541 This function should only be called in a script context. 

542 

543 Raises: 

544 RuntimeError: If the server fails to start. 

545 

546 Example: 

547 if __name__ == "__main__": 

548 asyncio.run(main()) 

549 """ 

550 try: 

551 async with mcp.server.stdio.stdio_server() as (reader, writer): 

552 await server.run( 

553 reader, 

554 writer, 

555 InitializationOptions( 

556 server_name="mcpgateway-wrapper", 

557 server_version=__version__, 

558 capabilities=server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}), 

559 ), 

560 ) 

561 except Exception as exc: 

562 logger.exception("Server failed to start") 

563 raise RuntimeError(f"Server startup failed: {exc}") 

564 

565 

566if __name__ == "__main__": 

567 try: 

568 asyncio.run(main()) 

569 except KeyboardInterrupt: 

570 logger.info("Server interrupted by user") 

571 except Exception: 

572 logger.exception("Server failed") 

573 sys.exit(1) 

574 finally: 

575 logger.info("Wrapper shutdown complete")