Coverage for mcpgateway/translate.py: 83%
198 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
1# -*- coding: utf-8 -*-
2""" mcpgateway.translate - bridges local JSON-RPC/stdio servers to HTTP/SSE
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti, Manav Gupta
8You can now run the bridge in either direction:
10- stdio to SSE (expose local stdio MCP server over SSE)
11- SSE to stdio (bridge remote SSE endpoint to local stdio)
14Usage
15-----
16# 1. expose an MCP server that talks JSON-RPC on stdio at :9000/sse
17python -m mcpgateway.translate --stdio "uvx mcp-server-git" --port 9000
19# 2. from another shell / browser subscribe to the SSE stream
20curl -N http://localhost:9000/sse # receive the stream
22# 3. send a test echo request
23curl -X POST http://localhost:9000/message \\
24 -H 'Content-Type: application/json' \\
25 -d '{"jsonrpc":"2.0","id":1,"method":"echo","params":{"value":"hi"}}'
27# 4. proper MCP handshake and tool listing
28curl -X POST http://localhost:9000/message \\
29 -H 'Content-Type: application/json' \\
30 -d '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"demo","version":"0.0.1"}}}'
32curl -X POST http://localhost:9000/message \\
33 -H 'Content-Type: application/json' \\
34 -d '{"jsonrpc":"2.0","id":2,"method":"tools/list"}'
36The SSE stream now emits JSON-RPC responses as `event: message` frames and sends
37regular `event: keepalive` frames (default every 30s) so that proxies and
38clients never time out. Each client receives a unique *session-id* that is
39appended as a query parameter to the back-channel `/message` URL.
40"""
42# Future
43from __future__ import annotations
45# Standard
46import argparse
47import asyncio
48from contextlib import suppress
49import json
50import logging
51import shlex
52import signal
53import sys
54from typing import Any, AsyncIterator, Dict, List, Optional, Sequence
55import uuid
57# Third-Party
58from fastapi import FastAPI, Request, Response, status
59from fastapi.middleware.cors import CORSMiddleware
60from fastapi.responses import PlainTextResponse
61from sse_starlette.sse import EventSourceResponse
62import uvicorn
64try:
65 # Third-Party
66 import httpx
67except ImportError:
68 httpx = None # type: ignore[assignment]
70LOGGER = logging.getLogger("mcpgateway.translate")
71KEEP_ALIVE_INTERVAL = 30 # seconds - matches the reference implementation
72__all__ = ["main"] # for console-script entry-point
75# ---------------------------------------------------------------------------#
76# Helpers - trivial in-process Pub/Sub #
77# ---------------------------------------------------------------------------#
78class _PubSub:
79 """Very small fan-out helper - one async Queue per subscriber."""
81 def __init__(self) -> None:
82 self._subscribers: List[asyncio.Queue[str]] = []
84 async def publish(self, data: str) -> None:
85 """Publish data to all subscribers.
87 Args:
88 data: The data string to publish to all subscribers.
89 """
90 dead: List[asyncio.Queue[str]] = []
91 for q in self._subscribers:
92 try:
93 q.put_nowait(data)
94 except asyncio.QueueFull:
95 dead.append(q)
96 for q in dead:
97 with suppress(ValueError):
98 self._subscribers.remove(q)
100 def subscribe(self) -> "asyncio.Queue[str]":
101 """Subscribe to published data.
103 Returns:
104 asyncio.Queue[str]: A queue that will receive published data.
105 """
106 q: asyncio.Queue[str] = asyncio.Queue(maxsize=1024)
107 self._subscribers.append(q)
108 return q
110 def unsubscribe(self, q: "asyncio.Queue[str]") -> None:
111 """Unsubscribe from published data.
113 Args:
114 q: The queue to unsubscribe from published data.
115 """
116 with suppress(ValueError):
117 self._subscribers.remove(q)
120# ---------------------------------------------------------------------------#
121# StdIO endpoint (child process ↔ async queues) #
122# ---------------------------------------------------------------------------#
123class StdIOEndpoint:
124 """Wrap a child process whose stdin/stdout speak line-delimited JSON-RPC."""
126 def __init__(self, cmd: str, pubsub: _PubSub) -> None:
127 self._cmd = cmd
128 self._pubsub = pubsub
129 self._proc: Optional[asyncio.subprocess.Process] = None
130 self._stdin: Optional[asyncio.StreamWriter] = None
131 self._pump_task: Optional[asyncio.Task[None]] = None
133 async def start(self) -> None:
134 """Start the stdio subprocess.
136 Creates the subprocess and starts the stdout pump task.
137 """
138 LOGGER.info("Starting stdio subprocess: %s", self._cmd)
139 self._proc = await asyncio.create_subprocess_exec(
140 *shlex.split(self._cmd),
141 stdin=asyncio.subprocess.PIPE,
142 stdout=asyncio.subprocess.PIPE,
143 stderr=sys.stderr, # passthrough for visibility
144 )
145 assert self._proc.stdin and self._proc.stdout
146 self._stdin = self._proc.stdin
147 self._pump_task = asyncio.create_task(self._pump_stdout())
149 async def stop(self) -> None:
150 """Stop the stdio subprocess.
152 Terminates the subprocess and cancels the pump task.
153 """
154 if self._proc is None: 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true
155 return
156 LOGGER.info("Stopping subprocess (pid=%s)", self._proc.pid)
157 self._proc.terminate()
158 with suppress(asyncio.TimeoutError):
159 await asyncio.wait_for(self._proc.wait(), timeout=5)
160 if self._pump_task: 160 ↛ exitline 160 didn't return from function 'stop' because the condition on line 160 was always true
161 self._pump_task.cancel()
163 async def send(self, raw: str) -> None:
164 """Send data to the subprocess stdin.
166 Args:
167 raw: The raw data string to send to the subprocess.
169 Raises:
170 RuntimeError: If the stdio endpoint is not started.
171 """
172 if not self._stdin:
173 raise RuntimeError("stdio endpoint not started")
174 LOGGER.debug("→ stdio: %s", raw.strip())
175 self._stdin.write(raw.encode())
176 await self._stdin.drain()
178 async def _pump_stdout(self) -> None:
179 """Pump stdout from subprocess to pubsub.
181 Continuously reads lines from the subprocess stdout and publishes them
182 to the pubsub system.
184 Raises:
185 Exception: For any other error encountered while pumping stdout.
186 """
187 assert self._proc and self._proc.stdout
188 reader = self._proc.stdout
189 try:
190 while True:
191 line = await reader.readline()
192 if not line: # EOF
193 break
194 text = line.decode(errors="replace")
195 LOGGER.debug("← stdio: %s", text.strip())
196 await self._pubsub.publish(text)
197 except Exception: # pragma: no cover --best-effort logging
198 LOGGER.exception("stdout pump crashed - terminating bridge")
199 raise
202# ---------------------------------------------------------------------------#
203# FastAPI app exposing /sse & /message #
204# ---------------------------------------------------------------------------#
207def _build_fastapi(
208 pubsub: _PubSub,
209 stdio: StdIOEndpoint,
210 keep_alive: int = KEEP_ALIVE_INTERVAL,
211 sse_path: str = "/sse",
212 message_path: str = "/message",
213 cors_origins: Optional[List[str]] = None,
214) -> FastAPI:
215 """Build FastAPI application with SSE and message endpoints.
217 Args:
218 pubsub: The publish/subscribe system for message routing.
219 stdio: The stdio endpoint for subprocess communication.
220 keep_alive: Interval in seconds for keepalive messages. Defaults to KEEP_ALIVE_INTERVAL.
221 sse_path: Path for the SSE endpoint. Defaults to "/sse".
222 message_path: Path for the message endpoint. Defaults to "/message".
223 cors_origins: Optional list of CORS allowed origins.
225 Returns:
226 FastAPI: The configured FastAPI application.
227 """
228 app = FastAPI()
230 # Add CORS middleware if origins specified
231 if cors_origins:
232 app.add_middleware(
233 CORSMiddleware,
234 allow_origins=cors_origins,
235 allow_credentials=True,
236 allow_methods=["*"],
237 allow_headers=["*"],
238 )
240 # ----- GET /sse ---------------------------------------------------------#
241 @app.get(sse_path)
242 async def get_sse(request: Request) -> EventSourceResponse: # noqa: D401
243 """Stream subprocess stdout to any number of SSE clients.
245 Args:
246 request (Request): The incoming ``GET`` request that will be
247 upgraded to a Server-Sent Events (SSE) stream.
249 Returns:
250 EventSourceResponse: A streaming response that forwards JSON-RPC
251 messages from the child process and emits periodic ``keepalive``
252 frames so that clients and proxies do not time out.
253 """
254 queue = pubsub.subscribe()
255 session_id = uuid.uuid4().hex
257 async def event_gen() -> AsyncIterator[Dict[str, Any]]:
258 # 1️⃣ Mandatory "endpoint" bootstrap required by the MCP spec
259 endpoint_url = f"{str(request.base_url).rstrip('/')}{message_path}?session_id={session_id}"
260 yield {
261 "event": "endpoint",
262 "data": endpoint_url,
263 "retry": int(keep_alive * 1000),
264 }
266 # 2️⃣ Immediate keepalive so clients know the stream is alive
267 yield {"event": "keepalive", "data": "{}", "retry": keep_alive * 1000}
269 try:
270 while True:
271 if await request.is_disconnected():
272 break
274 try:
275 msg = await asyncio.wait_for(queue.get(), keep_alive)
276 yield {"event": "message", "data": msg.rstrip()}
277 except asyncio.TimeoutError:
278 yield {
279 "event": "keepalive",
280 "data": "{}",
281 "retry": keep_alive * 1000,
282 }
283 finally:
284 pubsub.unsubscribe(queue)
286 return EventSourceResponse(
287 event_gen(),
288 headers={
289 "Cache-Control": "no-cache",
290 "Connection": "keep-alive",
291 "X-Accel-Buffering": "no", # disable proxy buffering
292 },
293 )
295 # ----- POST /message ----------------------------------------------------#
296 @app.post(message_path, status_code=status.HTTP_202_ACCEPTED)
297 async def post_message(raw: Request, session_id: str | None = None) -> Response: # noqa: D401
298 """Forward a raw JSON-RPC request to the stdio subprocess.
300 Args:
301 raw (Request): The incoming ``POST`` request whose body contains
302 a single JSON-RPC message.
303 session_id (str | None): The SSE session identifier that originated
304 this back-channel call (present when the client obtained the
305 endpoint URL from an ``endpoint`` bootstrap frame).
307 Returns:
308 Response: ``202 Accepted`` if the payload is forwarded successfully,
309 or ``400 Bad Request`` when the body is not valid JSON.
310 """
311 _ = session_id # Unused but required for API compatibility
312 payload = await raw.body()
313 try:
314 json.loads(payload) # validate
315 except Exception as exc: # noqa: BLE001
316 return PlainTextResponse(
317 f"Invalid JSON payload: {exc}",
318 status_code=status.HTTP_400_BAD_REQUEST,
319 )
320 await stdio.send(payload.decode().rstrip() + "\n")
321 return PlainTextResponse("forwarded", status_code=status.HTTP_202_ACCEPTED)
323 # ----- Liveness ---------------------------------------------------------#
324 @app.get("/healthz")
325 async def health() -> Response: # noqa: D401
326 """Health check endpoint.
328 Returns:
329 Response: A plain text response with "ok" status.
330 """
331 return PlainTextResponse("ok")
333 return app
336# ---------------------------------------------------------------------------#
337# CLI & orchestration #
338# ---------------------------------------------------------------------------#
341def _parse_args(argv: Sequence[str]) -> argparse.Namespace:
342 """Parse command line arguments.
344 Args:
345 argv: Sequence of command line arguments.
347 Returns:
348 argparse.Namespace: Parsed command line arguments.
350 Raises:
351 NotImplementedError: If streamableHttp option is specified.
352 """
353 p = argparse.ArgumentParser(
354 prog="mcpgateway.translate",
355 description="Bridges stdio JSON-RPC to SSE or SSE to stdio.",
356 )
357 src = p.add_mutually_exclusive_group(required=True)
358 src.add_argument("--stdio", help='Command to run, e.g. "uv run mcp-server-git"')
359 src.add_argument("--sse", help="Remote SSE endpoint URL")
360 src.add_argument("--streamableHttp", help="[NOT IMPLEMENTED]")
362 p.add_argument("--port", type=int, default=8000, help="HTTP port to bind")
363 p.add_argument(
364 "--logLevel",
365 default="info",
366 choices=["debug", "info", "warning", "error", "critical"],
367 help="Log level",
368 )
369 p.add_argument(
370 "--cors",
371 nargs="*",
372 help="CORS allowed origins (e.g., --cors https://app.example.com)",
373 )
374 p.add_argument(
375 "--oauth2Bearer",
376 help="OAuth2 Bearer token for authentication",
377 )
379 args = p.parse_args(argv)
380 if args.streamableHttp:
381 raise NotImplementedError("Only --stdio → SSE and --sse → stdio are available in this build.")
382 return args
385async def _run_stdio_to_sse(cmd: str, port: int, log_level: str = "info", cors: Optional[List[str]] = None) -> None:
386 """Run stdio to SSE bridge.
388 Args:
389 cmd: The command to run as a stdio subprocess.
390 port: The port to bind the HTTP server to.
391 log_level: The logging level to use. Defaults to "info".
392 cors: Optional list of CORS allowed origins.
393 """
394 pubsub = _PubSub()
395 stdio = StdIOEndpoint(cmd, pubsub)
396 await stdio.start()
398 app = _build_fastapi(pubsub, stdio, cors_origins=cors)
399 config = uvicorn.Config(
400 app,
401 host="0.0.0.0",
402 port=port,
403 log_level=log_level,
404 lifespan="off",
405 )
406 server = uvicorn.Server(config)
408 shutting_down = asyncio.Event() # 🔄 make shutdown idempotent
410 async def _shutdown() -> None:
411 if shutting_down.is_set(): 411 ↛ 412line 411 didn't jump to line 412 because the condition on line 411 was never true
412 return
413 shutting_down.set()
414 LOGGER.info("Shutting down ...")
415 await stdio.stop()
416 await server.shutdown()
418 loop = asyncio.get_running_loop()
419 for sig in (signal.SIGINT, signal.SIGTERM):
420 with suppress(NotImplementedError): # Windows lacks add_signal_handler
421 loop.add_signal_handler(sig, lambda: asyncio.create_task(_shutdown()))
423 LOGGER.info("Bridge ready → http://127.0.0.1:%s/sse", port)
424 await server.serve()
425 await _shutdown() # final cleanup
428async def _run_sse_to_stdio(url: str, oauth2_bearer: Optional[str]) -> None:
429 """Run SSE to stdio bridge.
431 Args:
432 url: The SSE endpoint URL to connect to.
433 oauth2_bearer: Optional OAuth2 bearer token for authentication.
435 Raises:
436 ImportError: If httpx package is not available.
437 """
438 if not httpx: 438 ↛ 439line 438 didn't jump to line 439 because the condition on line 438 was never true
439 raise ImportError("httpx package is required for SSE to stdio bridging")
441 headers = {}
442 if oauth2_bearer:
443 headers["Authorization"] = f"Bearer {oauth2_bearer}"
445 async with httpx.AsyncClient(headers=headers, timeout=None) as client:
446 process = await asyncio.create_subprocess_shell(
447 "cat", # Placeholder command; replace with actual stdio server command if needed
448 stdin=asyncio.subprocess.PIPE,
449 stdout=asyncio.subprocess.PIPE,
450 stderr=sys.stderr,
451 )
453 async def read_stdout() -> None:
454 assert process.stdout
455 while True:
456 line = await process.stdout.readline()
457 if not line: 457 ↛ 459line 457 didn't jump to line 459 because the condition on line 457 was always true
458 break
459 print(line.decode().rstrip())
461 async def pump_sse_to_stdio() -> None:
462 async with client.stream("GET", url) as response:
463 async for line in response.aiter_lines():
464 if line.startswith("data: "): 464 ↛ 465line 464 didn't jump to line 465 because the condition on line 464 was never true
465 data = line[6:]
466 if data and data != "{}":
467 if process.stdin:
468 process.stdin.write((data + "\n").encode())
469 await process.stdin.drain()
471 await asyncio.gather(read_stdout(), pump_sse_to_stdio())
474def start_stdio(cmd: str, port: int, log_level: str, cors: Optional[List[str]]) -> None:
475 """Start stdio bridge.
477 Args:
478 cmd: The command to run as a stdio subprocess.
479 port: The port to bind the HTTP server to.
480 log_level: The logging level to use.
481 cors: Optional list of CORS allowed origins.
483 Returns:
484 None: This function does not return a value.
485 """
486 return asyncio.run(_run_stdio_to_sse(cmd, port, log_level, cors))
489def start_sse(url: str, bearer: Optional[str]) -> None:
490 """Start SSE bridge.
492 Args:
493 url: The SSE endpoint URL to connect to.
494 bearer: Optional OAuth2 bearer token for authentication.
496 Returns:
497 None: This function does not return a value.
498 """
499 return asyncio.run(_run_sse_to_stdio(url, bearer))
502def main(argv: Optional[Sequence[str]] | None = None) -> None:
503 """Entry point for the translate module.
505 Args:
506 argv: Optional sequence of command line arguments. If None, uses sys.argv[1:].
507 """
508 args = _parse_args(argv or sys.argv[1:])
509 logging.basicConfig(
510 level=getattr(logging, args.logLevel.upper(), logging.INFO),
511 format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
512 )
513 try:
514 if args.stdio:
515 start_stdio(args.stdio, args.port, args.logLevel, args.cors)
516 elif args.sse: 516 ↛ exitline 516 didn't return from function 'main' because the condition on line 516 was always true
517 start_sse(args.sse, args.oauth2Bearer)
518 except KeyboardInterrupt:
519 print("") # restore shell prompt
520 sys.exit(0)
521 except NotImplementedError as exc:
522 print(exc, file=sys.stderr)
523 sys.exit(1)
526if __name__ == "__main__": # python -m mcpgateway.translate ...
527 main()