Coverage for mcpgateway/wrapper.py: 86%
217 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
1# -*- coding: utf-8 -*-
2"""
3MCP Gateway Wrapper server.
5Copyright 2025
6SPDX-License-Identifier: Apache-2.0
7Authors: Keval Mahajan, Mihai Criveti, Madhav Kandukuri
9This module implements a wrapper bridge that facilitates
10interaction between the MCP client and the MCP gateway.
11It provides several functionalities, including listing tools,
12invoking tools, managing resources, retrieving prompts,
13and handling tool calls via the MCP gateway.
15A **stdio** bridge that exposes a remote MCP Gateway
16(HTTP-/JSON-RPC APIs) as a local MCP server. All JSON-RPC
17traffic is written to **stdout**; every log or trace message
18is emitted on **stderr** so that protocol messages and
19diagnostics never mix.
21Environment variables:
22- MCP_SERVER_CATALOG_URLS: Comma-separated list of gateway catalog URLs (required)
23- MCP_AUTH_TOKEN: Bearer token for the gateway (optional)
24- MCP_TOOL_CALL_TIMEOUT: Seconds to wait for a gateway RPC call (default 90)
25- MCP_WRAPPER_LOG_LEVEL: Python log level name or OFF/NONE to disable logging (default INFO)
27Example:
28 $ export MCPGATEWAY_BEARER_TOKEN=$(python3 -m mcpgateway.utils.create_jwt_token --username admin --exp 10080 --secret my-test-key)
29 $ export MCP_AUTH_TOKEN=${MCPGATEWAY_BEARER_TOKEN}
30 $ export MCP_SERVER_CATALOG_URLS='http://localhost:4444/servers/UUID_OF_SERVER_1'
31 $ export MCP_TOOL_CALL_TIMEOUT=120
32 $ export MCP_WRAPPER_LOG_LEVEL=DEBUG # OFF to disable logging
33 $ python3 -m mcpgateway.wrapper
34"""
36# Standard
37import asyncio
38import logging
39import os
40import sys
41from typing import Any, Dict, List, Optional, Union
42from urllib.parse import urlparse
44# Third-Party
45import httpx
46from mcp import types
47from mcp.server import NotificationOptions, Server
48from mcp.server.models import InitializationOptions
49import mcp.server.stdio
50from pydantic import AnyUrl
52# First-Party
53from mcpgateway import __version__
55# -----------------------------------------------------------------------------
56# Configuration
57# -----------------------------------------------------------------------------
58ENV_SERVER_CATALOGS = "MCP_SERVER_CATALOG_URLS"
59ENV_AUTH_TOKEN = "MCP_AUTH_TOKEN" # nosec B105 - this is an *environment variable name*, not a secret
60ENV_TIMEOUT = "MCP_TOOL_CALL_TIMEOUT"
61ENV_LOG_LEVEL = "MCP_WRAPPER_LOG_LEVEL"
63RAW_CATALOGS: str = os.getenv(ENV_SERVER_CATALOGS, "")
64SERVER_CATALOG_URLS: List[str] = [u.strip() for u in RAW_CATALOGS.split(",") if u.strip()]
66AUTH_TOKEN: str = os.getenv(ENV_AUTH_TOKEN, "")
67TOOL_CALL_TIMEOUT: int = int(os.getenv(ENV_TIMEOUT, "90"))
69# Validate required configuration
70if not SERVER_CATALOG_URLS: 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true
71 print(f"Error: {ENV_SERVER_CATALOGS} environment variable is required", file=sys.stderr)
72 sys.exit(1)
75# -----------------------------------------------------------------------------
76# Base URL Extraction
77# -----------------------------------------------------------------------------
78def _extract_base_url(url: str) -> str:
79 """Return the gateway-level base URL.
81 The function keeps any application root path (`APP_ROOT_PATH`) that the
82 remote gateway is mounted under (for example `/gateway`) while removing
83 the `/servers/<id>` suffix that appears in catalog endpoints. It also
84 discards any query string or fragment.
86 Args:
87 url (str): Full catalog URL, e.g.
88 `https://host.com/gateway/servers/UUID_OF_SERVER_1`.
90 Returns:
91 str: Clean base URL suitable for building `/tools/`, `/prompts/`,
92 or `/resources/` endpoints-for example
93 `https://host.com/gateway`.
95 Raises:
96 ValueError: If *url* lacks a scheme or network location.
98 Examples:
99 >>> _extract_base_url("https://host.com/servers/UUID_OF_SERVER_2")
100 'https://host.com'
101 >>> _extract_base_url("https://host.com/gateway/servers/UUID_OF_SERVER_2")
102 'https://host.com/gateway'
103 >>> _extract_base_url("https://host.com/gateway/servers")
104 'https://host.com/gateway'
105 >>> _extract_base_url("https://host.com/gateway")
106 'https://host.com/gateway'
108 Note:
109 If the target server was started with `APP_ROOT_PATH=/gateway`, the
110 resulting catalog URLs include that prefix. This helper preserves the
111 prefix so the wrapper's follow-up calls remain correctly scoped.
112 """
113 parsed = urlparse(url)
114 if not parsed.scheme or not parsed.netloc:
115 raise ValueError(f"Invalid URL provided: {url}")
117 path = parsed.path or ""
118 if "/servers/" in path:
119 path = path.split("/servers")[0] # ".../servers/UUID_OF_SERVER_123" -> "..."
120 elif path.endswith("/servers"):
121 path = path[: -len("/servers")] # ".../servers" -> "..."
122 # otherwise keep the existing path (supports APP_ROOT_PATH)
124 return f"{parsed.scheme}://{parsed.netloc}{path}"
127BASE_URL: str = _extract_base_url(SERVER_CATALOG_URLS[0]) if SERVER_CATALOG_URLS else ""
129# -----------------------------------------------------------------------------
130# Logging Setup
131# -----------------------------------------------------------------------------
132_log_level = os.getenv(ENV_LOG_LEVEL, "INFO").upper()
133if _log_level in {"OFF", "NONE", "DISABLE", "FALSE", "0"}: 133 ↛ 134line 133 didn't jump to line 134 because the condition on line 133 was never true
134 logging.disable(logging.CRITICAL)
135else:
136 logging.basicConfig(
137 level=getattr(logging, _log_level, logging.INFO),
138 format="%(asctime)s %(levelname)-8s %(name)s: %(message)s",
139 stream=sys.stderr,
140 )
142logger = logging.getLogger("mcpgateway.wrapper")
143logger.info(f"Starting MCP wrapper {__version__}: base_url={BASE_URL}, timeout={TOOL_CALL_TIMEOUT}")
146# -----------------------------------------------------------------------------
147# HTTP Helpers
148# -----------------------------------------------------------------------------
149async def fetch_url(url: str) -> httpx.Response:
150 """
151 Perform an asynchronous HTTP GET request and return the response.
153 Args:
154 url: The target URL to fetch.
156 Returns:
157 The successful ``httpx.Response`` object.
159 Raises:
160 httpx.RequestError: If a network problem occurs while making the request.
161 httpx.HTTPStatusError: If the server returns a 4xx or 5xx response.
162 """
163 headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} if AUTH_TOKEN else {}
164 async with httpx.AsyncClient(timeout=TOOL_CALL_TIMEOUT) as client:
165 try:
166 response = await client.get(url, headers=headers)
167 response.raise_for_status()
168 return response
169 except httpx.RequestError as err:
170 logger.error(f"Network error while fetching {url}: {err}")
171 raise
172 except httpx.HTTPStatusError as err:
173 logger.error(f"HTTP {err.response.status_code} returned for {url}: {err}")
174 raise
177# -----------------------------------------------------------------------------
178# Metadata Helpers
179# -----------------------------------------------------------------------------
180async def get_tools_from_mcp_server(catalog_urls: List[str]) -> List[str]:
181 """
182 Retrieve associated tool IDs from the MCP gateway server catalogs.
184 Args:
185 catalog_urls (List[str]): List of catalog endpoint URLs.
187 Returns:
188 List[str]: A list of tool ID strings extracted from the server catalog.
189 """
190 server_ids = [url.split("/")[-1] for url in catalog_urls]
191 url = f"{BASE_URL}/servers/"
192 response = await fetch_url(url)
193 catalog = response.json()
194 tool_ids: List[str] = []
195 for entry in catalog:
196 if str(entry.get("id")) in server_ids:
197 tool_ids.extend(entry.get("associatedTools", []))
198 return tool_ids
201async def tools_metadata(tool_ids: List[str]) -> List[Dict[str, Any]]:
202 """
203 Fetch metadata for a list of MCP tools by their IDs.
205 Args:
206 tool_ids (List[str]): List of tool ID strings.
208 Returns:
209 List[Dict[str, Any]]: A list of metadata dictionaries for each tool.
210 """
211 if not tool_ids: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true
212 return []
213 url = f"{BASE_URL}/tools/"
214 response = await fetch_url(url)
215 data: List[Dict[str, Any]] = response.json()
216 if tool_ids == ["0"]:
217 return data
219 return [tool for tool in data if tool["name"] in tool_ids]
222async def get_prompts_from_mcp_server(catalog_urls: List[str]) -> List[str]:
223 """
224 Retrieve associated prompt IDs from the MCP gateway server catalogs.
226 Args:
227 catalog_urls (List[str]): List of catalog endpoint URLs.
229 Returns:
230 List[str]: A list of prompt ID strings.
231 """
232 server_ids = [url.split("/")[-1] for url in catalog_urls]
233 url = f"{BASE_URL}/servers/"
234 response = await fetch_url(url)
235 catalog = response.json()
236 prompt_ids: List[str] = []
237 for entry in catalog:
238 if str(entry.get("id")) in server_ids:
239 prompt_ids.extend(entry.get("associatedPrompts", []))
240 return prompt_ids
243async def prompts_metadata(prompt_ids: List[str]) -> List[Dict[str, Any]]:
244 """
245 Fetch metadata for a list of MCP prompts by their IDs.
247 Args:
248 prompt_ids (List[str]): List of prompt ID strings.
250 Returns:
251 List[Dict[str, Any]]: A list of metadata dictionaries for each prompt.
252 """
253 if not prompt_ids: 253 ↛ 254line 253 didn't jump to line 254 because the condition on line 253 was never true
254 return []
255 url = f"{BASE_URL}/prompts/"
256 response = await fetch_url(url)
257 data: List[Dict[str, Any]] = response.json()
258 if prompt_ids == ["0"]:
259 return data
260 return [pr for pr in data if str(pr.get("id")) in prompt_ids]
263async def get_resources_from_mcp_server(catalog_urls: List[str]) -> List[str]:
264 """
265 Retrieve associated resource IDs from the MCP gateway server catalogs.
267 Args:
268 catalog_urls (List[str]): List of catalog endpoint URLs.
270 Returns:
271 List[str]: A list of resource ID strings.
272 """
273 server_ids = [url.split("/")[-1] for url in catalog_urls]
274 url = f"{BASE_URL}/servers/"
275 response = await fetch_url(url)
276 catalog = response.json()
277 resource_ids: List[str] = []
278 for entry in catalog:
279 if str(entry.get("id")) in server_ids:
280 resource_ids.extend(entry.get("associatedResources", []))
281 return resource_ids
284async def resources_metadata(resource_ids: List[str]) -> List[Dict[str, Any]]:
285 """
286 Fetch metadata for a list of MCP resources by their IDs.
288 Args:
289 resource_ids (List[str]): List of resource ID strings.
291 Returns:
292 List[Dict[str, Any]]: A list of metadata dictionaries for each resource.
293 """
294 if not resource_ids: 294 ↛ 295line 294 didn't jump to line 295 because the condition on line 294 was never true
295 return []
296 url = f"{BASE_URL}/resources/"
297 response = await fetch_url(url)
298 data: List[Dict[str, Any]] = response.json()
299 if resource_ids == ["0"]:
300 return data
301 return [res for res in data if str(res.get("id")) in resource_ids]
304# -----------------------------------------------------------------------------
305# Server Handlers
306# -----------------------------------------------------------------------------
307server: Server = Server("mcpgateway-wrapper")
310@server.list_tools()
311async def handle_list_tools() -> List[types.Tool]:
312 """
313 List all available MCP tools exposed by the gateway.
315 Queries the configured server catalogs to retrieve tool IDs and then
316 fetches metadata for each tool to construct a list of Tool objects.
318 Returns:
319 List[types.Tool]: A list of Tool instances including name, description, and input schema.
321 Raises:
322 RuntimeError: If an error occurs during fetching or processing.
323 """
324 try:
325 tool_ids = ["0"] if SERVER_CATALOG_URLS[0] == BASE_URL else await get_tools_from_mcp_server(SERVER_CATALOG_URLS)
326 metadata = await tools_metadata(tool_ids)
327 tools = []
328 for tool in metadata:
329 tool_name = tool.get("name")
330 if tool_name: # Only include tools with valid names 330 ↛ 328line 330 didn't jump to line 328 because the condition on line 330 was always true
331 tools.append(
332 types.Tool(
333 name=str(tool_name),
334 description=tool.get("description", ""),
335 inputSchema=tool.get("inputSchema", {}),
336 annotations=tool.get("annotations", {}),
337 )
338 )
339 return tools
340 except Exception as exc:
341 logger.exception("Error listing tools")
342 raise RuntimeError(f"Error listing tools: {exc}")
345@server.call_tool()
346async def handle_call_tool(name: str, arguments: Optional[Dict[str, Any]] = None) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
347 """
348 Invoke a named MCP tool via the gateway's RPC endpoint.
350 Args:
351 name (str): The name of the tool to invoke.
352 arguments (Optional[Dict[str, Any]]): The arguments to pass to the tool method.
354 Returns:
355 List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
356 A list of content objects returned by the tool.
358 Raises:
359 ValueError: If tool call fails.
360 RuntimeError: If the HTTP request fails or returns an error.
361 """
362 if arguments is None: 362 ↛ 363line 362 didn't jump to line 363 because the condition on line 362 was never true
363 arguments = {}
365 logger.info(f"Calling tool {name} with args {arguments}")
366 payload = {"jsonrpc": "2.0", "id": 2, "method": name, "params": arguments}
367 headers = {"Authorization": f"Bearer {AUTH_TOKEN}"} if AUTH_TOKEN else {}
369 try:
370 async with httpx.AsyncClient(timeout=TOOL_CALL_TIMEOUT) as client:
371 resp = await client.post(f"{BASE_URL}/rpc/", json=payload, headers=headers)
372 resp.raise_for_status()
373 result = resp.json()
375 if "error" in result:
376 error_msg = result["error"].get("message", "Unknown error")
377 raise ValueError(f"Tool call failed: {error_msg}")
379 tool_result = result.get("result", result)
380 return [types.TextContent(type="text", text=str(tool_result))]
382 except httpx.TimeoutException as exc:
383 logger.error(f"Timeout calling tool {name}: {exc}")
384 raise RuntimeError(f"Tool call timeout: {exc}")
385 except Exception as exc:
386 logger.exception(f"Error calling tool {name}")
387 raise RuntimeError(f"Error calling tool: {exc}")
390@server.list_resources()
391async def handle_list_resources() -> List[types.Resource]:
392 """
393 List all available MCP resources exposed by the gateway.
395 Fetches resource IDs from the configured catalogs and retrieves
396 metadata to construct Resource instances.
398 Returns:
399 List[types.Resource]: A list of Resource objects including URI, name, description, and MIME type.
401 Raises:
402 RuntimeError: If an error occurs during fetching or processing.
403 """
404 try:
405 ids = ["0"] if SERVER_CATALOG_URLS[0] == BASE_URL else await get_resources_from_mcp_server(SERVER_CATALOG_URLS)
406 meta = await resources_metadata(ids)
407 resources = []
408 for r in meta:
409 uri = r.get("uri")
410 if not uri: 410 ↛ 411line 410 didn't jump to line 411 because the condition on line 410 was never true
411 logger.warning(f"Resource missing URI, skipping: {r}")
412 continue
413 try:
414 resources.append(
415 types.Resource(
416 uri=AnyUrl(uri),
417 name=r.get("name", ""),
418 description=r.get("description", ""),
419 mimeType=r.get("mimeType", "text/plain"),
420 )
421 )
422 except Exception as e:
423 logger.warning(f"Invalid resource URI {uri}: {e}")
424 continue
425 return resources
426 except Exception as exc:
427 logger.exception("Error listing resources")
428 raise RuntimeError(f"Error listing resources: {exc}")
431@server.read_resource()
432async def handle_read_resource(uri: AnyUrl) -> str:
433 """
434 Read and return the content of a resource by its URI.
436 Args:
437 uri (AnyUrl): The URI of the resource to read.
439 Returns:
440 str: The body text of the fetched resource.
442 Raises:
443 ValueError: If the resource cannot be fetched.
444 """
445 try:
446 response = await fetch_url(str(uri))
447 return response.text
448 except Exception as exc:
449 logger.exception(f"Error reading resource {uri}")
450 raise ValueError(f"Failed to read resource at {uri}: {exc}")
453@server.list_prompts()
454async def handle_list_prompts() -> List[types.Prompt]:
455 """
456 List all available MCP prompts exposed by the gateway.
458 Retrieves prompt IDs from the catalogs and fetches metadata
459 to create Prompt instances.
461 Returns:
462 List[types.Prompt]: A list of Prompt objects including name, description, and arguments.
464 Raises:
465 RuntimeError: If an error occurs during fetching or processing.
466 """
467 try:
468 ids = ["0"] if SERVER_CATALOG_URLS[0] == BASE_URL else await get_prompts_from_mcp_server(SERVER_CATALOG_URLS)
469 meta = await prompts_metadata(ids)
470 prompts = []
471 for p in meta:
472 prompt_name = p.get("name")
473 if prompt_name: # Only include prompts with valid names 473 ↛ 471line 473 didn't jump to line 471 because the condition on line 473 was always true
474 prompts.append(
475 types.Prompt(
476 name=str(prompt_name),
477 description=p.get("description", ""),
478 arguments=p.get("arguments", []),
479 )
480 )
481 return prompts
482 except Exception as exc:
483 logger.exception("Error listing prompts")
484 raise RuntimeError(f"Error listing prompts: {exc}")
487@server.get_prompt()
488async def handle_get_prompt(name: str, arguments: Optional[Dict[str, str]] = None) -> types.GetPromptResult:
489 """
490 Retrieve and format a single prompt template with provided arguments.
492 Args:
493 name (str): The unique name of the prompt to fetch.
494 arguments (Optional[Dict[str, str]]): A mapping of placeholder names to replacement values.
496 Returns:
497 types.GetPromptResult: Contains description and list of formatted PromptMessage instances.
499 Raises:
500 ValueError: If fetching or formatting fails.
502 Example:
503 >>> await handle_get_prompt("greet", {"username": "Alice"})
504 """
505 try:
506 url = f"{BASE_URL}/prompts/{name}"
507 response = await fetch_url(url)
508 prompt_data = response.json()
510 template = prompt_data.get("template", "")
511 try:
512 formatted = template.format(**(arguments or {}))
513 except KeyError as exc:
514 raise ValueError(f"Missing placeholder in arguments: {exc}")
515 except Exception as exc:
516 raise ValueError(f"Error formatting prompt: {exc}")
518 return types.GetPromptResult(
519 description=prompt_data.get("description", ""),
520 messages=[
521 types.PromptMessage(
522 role="user",
523 content=types.TextContent(type="text", text=formatted),
524 )
525 ],
526 )
527 except ValueError:
528 raise
529 except Exception as exc:
530 logger.exception(f"Error getting prompt {name}")
531 raise ValueError(f"Failed to fetch prompt '{name}': {exc}")
534async def main() -> None:
535 """
536 Main entry point to start the MCP stdio server.
538 Initializes the server over standard IO, registers capabilities,
539 and begins listening for JSON-RPC messages.
541 This function should only be called in a script context.
543 Raises:
544 RuntimeError: If the server fails to start.
546 Example:
547 if __name__ == "__main__":
548 asyncio.run(main())
549 """
550 try:
551 async with mcp.server.stdio.stdio_server() as (reader, writer):
552 await server.run(
553 reader,
554 writer,
555 InitializationOptions(
556 server_name="mcpgateway-wrapper",
557 server_version=__version__,
558 capabilities=server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}),
559 ),
560 )
561 except Exception as exc:
562 logger.exception("Server failed to start")
563 raise RuntimeError(f"Server startup failed: {exc}")
566if __name__ == "__main__":
567 try:
568 asyncio.run(main())
569 except KeyboardInterrupt:
570 logger.info("Server interrupted by user")
571 except Exception:
572 logger.exception("Server failed")
573 sys.exit(1)
574 finally:
575 logger.info("Wrapper shutdown complete")