Coverage for mcpgateway/translate.py: 83%

198 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-09 11:03 +0100

1# -*- coding: utf-8 -*- 

2""" mcpgateway.translate - bridges local JSON-RPC/stdio servers to HTTP/SSE 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti, Manav Gupta 

7 

8You can now run the bridge in either direction: 

9 

10- stdio to SSE (expose local stdio MCP server over SSE) 

11- SSE to stdio (bridge remote SSE endpoint to local stdio) 

12 

13 

14Usage 

15----- 

16# 1. expose an MCP server that talks JSON-RPC on stdio at :9000/sse 

17python -m mcpgateway.translate --stdio "uvx mcp-server-git" --port 9000 

18 

19# 2. from another shell / browser subscribe to the SSE stream 

20curl -N http://localhost:9000/sse # receive the stream 

21 

22# 3. send a test echo request 

23curl -X POST http://localhost:9000/message \\ 

24 -H 'Content-Type: application/json' \\ 

25 -d '{"jsonrpc":"2.0","id":1,"method":"echo","params":{"value":"hi"}}' 

26 

27# 4. proper MCP handshake and tool listing 

28curl -X POST http://localhost:9000/message \\ 

29 -H 'Content-Type: application/json' \\ 

30 -d '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"demo","version":"0.0.1"}}}' 

31 

32curl -X POST http://localhost:9000/message \\ 

33 -H 'Content-Type: application/json' \\ 

34 -d '{"jsonrpc":"2.0","id":2,"method":"tools/list"}' 

35 

36The SSE stream now emits JSON-RPC responses as `event: message` frames and sends 

37regular `event: keepalive` frames (default every 30s) so that proxies and 

38clients never time out. Each client receives a unique *session-id* that is 

39appended as a query parameter to the back-channel `/message` URL. 

40""" 

41 

42# Future 

43from __future__ import annotations 

44 

45# Standard 

46import argparse 

47import asyncio 

48from contextlib import suppress 

49import json 

50import logging 

51import shlex 

52import signal 

53import sys 

54from typing import Any, AsyncIterator, Dict, List, Optional, Sequence 

55import uuid 

56 

57# Third-Party 

58from fastapi import FastAPI, Request, Response, status 

59from fastapi.middleware.cors import CORSMiddleware 

60from fastapi.responses import PlainTextResponse 

61from sse_starlette.sse import EventSourceResponse 

62import uvicorn 

63 

64try: 

65 # Third-Party 

66 import httpx 

67except ImportError: 

68 httpx = None # type: ignore[assignment] 

69 

70LOGGER = logging.getLogger("mcpgateway.translate") 

71KEEP_ALIVE_INTERVAL = 30 # seconds - matches the reference implementation 

72__all__ = ["main"] # for console-script entry-point 

73 

74 

75# ---------------------------------------------------------------------------# 

76# Helpers - trivial in-process Pub/Sub # 

77# ---------------------------------------------------------------------------# 

78class _PubSub: 

79 """Very small fan-out helper - one async Queue per subscriber.""" 

80 

81 def __init__(self) -> None: 

82 self._subscribers: List[asyncio.Queue[str]] = [] 

83 

84 async def publish(self, data: str) -> None: 

85 """Publish data to all subscribers. 

86 

87 Args: 

88 data: The data string to publish to all subscribers. 

89 """ 

90 dead: List[asyncio.Queue[str]] = [] 

91 for q in self._subscribers: 

92 try: 

93 q.put_nowait(data) 

94 except asyncio.QueueFull: 

95 dead.append(q) 

96 for q in dead: 

97 with suppress(ValueError): 

98 self._subscribers.remove(q) 

99 

100 def subscribe(self) -> "asyncio.Queue[str]": 

101 """Subscribe to published data. 

102 

103 Returns: 

104 asyncio.Queue[str]: A queue that will receive published data. 

105 """ 

106 q: asyncio.Queue[str] = asyncio.Queue(maxsize=1024) 

107 self._subscribers.append(q) 

108 return q 

109 

110 def unsubscribe(self, q: "asyncio.Queue[str]") -> None: 

111 """Unsubscribe from published data. 

112 

113 Args: 

114 q: The queue to unsubscribe from published data. 

115 """ 

116 with suppress(ValueError): 

117 self._subscribers.remove(q) 

118 

119 

120# ---------------------------------------------------------------------------# 

121# StdIO endpoint (child process ↔ async queues) # 

122# ---------------------------------------------------------------------------# 

123class StdIOEndpoint: 

124 """Wrap a child process whose stdin/stdout speak line-delimited JSON-RPC.""" 

125 

126 def __init__(self, cmd: str, pubsub: _PubSub) -> None: 

127 self._cmd = cmd 

128 self._pubsub = pubsub 

129 self._proc: Optional[asyncio.subprocess.Process] = None 

130 self._stdin: Optional[asyncio.StreamWriter] = None 

131 self._pump_task: Optional[asyncio.Task[None]] = None 

132 

133 async def start(self) -> None: 

134 """Start the stdio subprocess. 

135 

136 Creates the subprocess and starts the stdout pump task. 

137 """ 

138 LOGGER.info("Starting stdio subprocess: %s", self._cmd) 

139 self._proc = await asyncio.create_subprocess_exec( 

140 *shlex.split(self._cmd), 

141 stdin=asyncio.subprocess.PIPE, 

142 stdout=asyncio.subprocess.PIPE, 

143 stderr=sys.stderr, # passthrough for visibility 

144 ) 

145 assert self._proc.stdin and self._proc.stdout 

146 self._stdin = self._proc.stdin 

147 self._pump_task = asyncio.create_task(self._pump_stdout()) 

148 

149 async def stop(self) -> None: 

150 """Stop the stdio subprocess. 

151 

152 Terminates the subprocess and cancels the pump task. 

153 """ 

154 if self._proc is None: 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true

155 return 

156 LOGGER.info("Stopping subprocess (pid=%s)", self._proc.pid) 

157 self._proc.terminate() 

158 with suppress(asyncio.TimeoutError): 

159 await asyncio.wait_for(self._proc.wait(), timeout=5) 

160 if self._pump_task: 160 ↛ exitline 160 didn't return from function 'stop' because the condition on line 160 was always true

161 self._pump_task.cancel() 

162 

163 async def send(self, raw: str) -> None: 

164 """Send data to the subprocess stdin. 

165 

166 Args: 

167 raw: The raw data string to send to the subprocess. 

168 

169 Raises: 

170 RuntimeError: If the stdio endpoint is not started. 

171 """ 

172 if not self._stdin: 

173 raise RuntimeError("stdio endpoint not started") 

174 LOGGER.debug("→ stdio: %s", raw.strip()) 

175 self._stdin.write(raw.encode()) 

176 await self._stdin.drain() 

177 

178 async def _pump_stdout(self) -> None: 

179 """Pump stdout from subprocess to pubsub. 

180 

181 Continuously reads lines from the subprocess stdout and publishes them 

182 to the pubsub system. 

183 

184 Raises: 

185 Exception: For any other error encountered while pumping stdout. 

186 """ 

187 assert self._proc and self._proc.stdout 

188 reader = self._proc.stdout 

189 try: 

190 while True: 

191 line = await reader.readline() 

192 if not line: # EOF 

193 break 

194 text = line.decode(errors="replace") 

195 LOGGER.debug("← stdio: %s", text.strip()) 

196 await self._pubsub.publish(text) 

197 except Exception: # pragma: no cover --best-effort logging 

198 LOGGER.exception("stdout pump crashed - terminating bridge") 

199 raise 

200 

201 

202# ---------------------------------------------------------------------------# 

203# FastAPI app exposing /sse & /message # 

204# ---------------------------------------------------------------------------# 

205 

206 

207def _build_fastapi( 

208 pubsub: _PubSub, 

209 stdio: StdIOEndpoint, 

210 keep_alive: int = KEEP_ALIVE_INTERVAL, 

211 sse_path: str = "/sse", 

212 message_path: str = "/message", 

213 cors_origins: Optional[List[str]] = None, 

214) -> FastAPI: 

215 """Build FastAPI application with SSE and message endpoints. 

216 

217 Args: 

218 pubsub: The publish/subscribe system for message routing. 

219 stdio: The stdio endpoint for subprocess communication. 

220 keep_alive: Interval in seconds for keepalive messages. Defaults to KEEP_ALIVE_INTERVAL. 

221 sse_path: Path for the SSE endpoint. Defaults to "/sse". 

222 message_path: Path for the message endpoint. Defaults to "/message". 

223 cors_origins: Optional list of CORS allowed origins. 

224 

225 Returns: 

226 FastAPI: The configured FastAPI application. 

227 """ 

228 app = FastAPI() 

229 

230 # Add CORS middleware if origins specified 

231 if cors_origins: 

232 app.add_middleware( 

233 CORSMiddleware, 

234 allow_origins=cors_origins, 

235 allow_credentials=True, 

236 allow_methods=["*"], 

237 allow_headers=["*"], 

238 ) 

239 

240 # ----- GET /sse ---------------------------------------------------------# 

241 @app.get(sse_path) 

242 async def get_sse(request: Request) -> EventSourceResponse: # noqa: D401 

243 """Stream subprocess stdout to any number of SSE clients. 

244 

245 Args: 

246 request (Request): The incoming ``GET`` request that will be 

247 upgraded to a Server-Sent Events (SSE) stream. 

248 

249 Returns: 

250 EventSourceResponse: A streaming response that forwards JSON-RPC 

251 messages from the child process and emits periodic ``keepalive`` 

252 frames so that clients and proxies do not time out. 

253 """ 

254 queue = pubsub.subscribe() 

255 session_id = uuid.uuid4().hex 

256 

257 async def event_gen() -> AsyncIterator[Dict[str, Any]]: 

258 # 1️⃣ Mandatory "endpoint" bootstrap required by the MCP spec 

259 endpoint_url = f"{str(request.base_url).rstrip('/')}{message_path}?session_id={session_id}" 

260 yield { 

261 "event": "endpoint", 

262 "data": endpoint_url, 

263 "retry": int(keep_alive * 1000), 

264 } 

265 

266 # 2️⃣ Immediate keepalive so clients know the stream is alive 

267 yield {"event": "keepalive", "data": "{}", "retry": keep_alive * 1000} 

268 

269 try: 

270 while True: 

271 if await request.is_disconnected(): 

272 break 

273 

274 try: 

275 msg = await asyncio.wait_for(queue.get(), keep_alive) 

276 yield {"event": "message", "data": msg.rstrip()} 

277 except asyncio.TimeoutError: 

278 yield { 

279 "event": "keepalive", 

280 "data": "{}", 

281 "retry": keep_alive * 1000, 

282 } 

283 finally: 

284 pubsub.unsubscribe(queue) 

285 

286 return EventSourceResponse( 

287 event_gen(), 

288 headers={ 

289 "Cache-Control": "no-cache", 

290 "Connection": "keep-alive", 

291 "X-Accel-Buffering": "no", # disable proxy buffering 

292 }, 

293 ) 

294 

295 # ----- POST /message ----------------------------------------------------# 

296 @app.post(message_path, status_code=status.HTTP_202_ACCEPTED) 

297 async def post_message(raw: Request, session_id: str | None = None) -> Response: # noqa: D401 

298 """Forward a raw JSON-RPC request to the stdio subprocess. 

299 

300 Args: 

301 raw (Request): The incoming ``POST`` request whose body contains 

302 a single JSON-RPC message. 

303 session_id (str | None): The SSE session identifier that originated 

304 this back-channel call (present when the client obtained the 

305 endpoint URL from an ``endpoint`` bootstrap frame). 

306 

307 Returns: 

308 Response: ``202 Accepted`` if the payload is forwarded successfully, 

309 or ``400 Bad Request`` when the body is not valid JSON. 

310 """ 

311 _ = session_id # Unused but required for API compatibility 

312 payload = await raw.body() 

313 try: 

314 json.loads(payload) # validate 

315 except Exception as exc: # noqa: BLE001 

316 return PlainTextResponse( 

317 f"Invalid JSON payload: {exc}", 

318 status_code=status.HTTP_400_BAD_REQUEST, 

319 ) 

320 await stdio.send(payload.decode().rstrip() + "\n") 

321 return PlainTextResponse("forwarded", status_code=status.HTTP_202_ACCEPTED) 

322 

323 # ----- Liveness ---------------------------------------------------------# 

324 @app.get("/healthz") 

325 async def health() -> Response: # noqa: D401 

326 """Health check endpoint. 

327 

328 Returns: 

329 Response: A plain text response with "ok" status. 

330 """ 

331 return PlainTextResponse("ok") 

332 

333 return app 

334 

335 

336# ---------------------------------------------------------------------------# 

337# CLI & orchestration # 

338# ---------------------------------------------------------------------------# 

339 

340 

341def _parse_args(argv: Sequence[str]) -> argparse.Namespace: 

342 """Parse command line arguments. 

343 

344 Args: 

345 argv: Sequence of command line arguments. 

346 

347 Returns: 

348 argparse.Namespace: Parsed command line arguments. 

349 

350 Raises: 

351 NotImplementedError: If streamableHttp option is specified. 

352 """ 

353 p = argparse.ArgumentParser( 

354 prog="mcpgateway.translate", 

355 description="Bridges stdio JSON-RPC to SSE or SSE to stdio.", 

356 ) 

357 src = p.add_mutually_exclusive_group(required=True) 

358 src.add_argument("--stdio", help='Command to run, e.g. "uv run mcp-server-git"') 

359 src.add_argument("--sse", help="Remote SSE endpoint URL") 

360 src.add_argument("--streamableHttp", help="[NOT IMPLEMENTED]") 

361 

362 p.add_argument("--port", type=int, default=8000, help="HTTP port to bind") 

363 p.add_argument( 

364 "--logLevel", 

365 default="info", 

366 choices=["debug", "info", "warning", "error", "critical"], 

367 help="Log level", 

368 ) 

369 p.add_argument( 

370 "--cors", 

371 nargs="*", 

372 help="CORS allowed origins (e.g., --cors https://app.example.com)", 

373 ) 

374 p.add_argument( 

375 "--oauth2Bearer", 

376 help="OAuth2 Bearer token for authentication", 

377 ) 

378 

379 args = p.parse_args(argv) 

380 if args.streamableHttp: 

381 raise NotImplementedError("Only --stdio → SSE and --sse → stdio are available in this build.") 

382 return args 

383 

384 

385async def _run_stdio_to_sse(cmd: str, port: int, log_level: str = "info", cors: Optional[List[str]] = None) -> None: 

386 """Run stdio to SSE bridge. 

387 

388 Args: 

389 cmd: The command to run as a stdio subprocess. 

390 port: The port to bind the HTTP server to. 

391 log_level: The logging level to use. Defaults to "info". 

392 cors: Optional list of CORS allowed origins. 

393 """ 

394 pubsub = _PubSub() 

395 stdio = StdIOEndpoint(cmd, pubsub) 

396 await stdio.start() 

397 

398 app = _build_fastapi(pubsub, stdio, cors_origins=cors) 

399 config = uvicorn.Config( 

400 app, 

401 host="0.0.0.0", 

402 port=port, 

403 log_level=log_level, 

404 lifespan="off", 

405 ) 

406 server = uvicorn.Server(config) 

407 

408 shutting_down = asyncio.Event() # 🔄 make shutdown idempotent 

409 

410 async def _shutdown() -> None: 

411 if shutting_down.is_set(): 411 ↛ 412line 411 didn't jump to line 412 because the condition on line 411 was never true

412 return 

413 shutting_down.set() 

414 LOGGER.info("Shutting down ...") 

415 await stdio.stop() 

416 await server.shutdown() 

417 

418 loop = asyncio.get_running_loop() 

419 for sig in (signal.SIGINT, signal.SIGTERM): 

420 with suppress(NotImplementedError): # Windows lacks add_signal_handler 

421 loop.add_signal_handler(sig, lambda: asyncio.create_task(_shutdown())) 

422 

423 LOGGER.info("Bridge ready → http://127.0.0.1:%s/sse", port) 

424 await server.serve() 

425 await _shutdown() # final cleanup 

426 

427 

428async def _run_sse_to_stdio(url: str, oauth2_bearer: Optional[str]) -> None: 

429 """Run SSE to stdio bridge. 

430 

431 Args: 

432 url: The SSE endpoint URL to connect to. 

433 oauth2_bearer: Optional OAuth2 bearer token for authentication. 

434 

435 Raises: 

436 ImportError: If httpx package is not available. 

437 """ 

438 if not httpx: 438 ↛ 439line 438 didn't jump to line 439 because the condition on line 438 was never true

439 raise ImportError("httpx package is required for SSE to stdio bridging") 

440 

441 headers = {} 

442 if oauth2_bearer: 

443 headers["Authorization"] = f"Bearer {oauth2_bearer}" 

444 

445 async with httpx.AsyncClient(headers=headers, timeout=None) as client: 

446 process = await asyncio.create_subprocess_shell( 

447 "cat", # Placeholder command; replace with actual stdio server command if needed 

448 stdin=asyncio.subprocess.PIPE, 

449 stdout=asyncio.subprocess.PIPE, 

450 stderr=sys.stderr, 

451 ) 

452 

453 async def read_stdout() -> None: 

454 assert process.stdout 

455 while True: 

456 line = await process.stdout.readline() 

457 if not line: 457 ↛ 459line 457 didn't jump to line 459 because the condition on line 457 was always true

458 break 

459 print(line.decode().rstrip()) 

460 

461 async def pump_sse_to_stdio() -> None: 

462 async with client.stream("GET", url) as response: 

463 async for line in response.aiter_lines(): 

464 if line.startswith("data: "): 464 ↛ 465line 464 didn't jump to line 465 because the condition on line 464 was never true

465 data = line[6:] 

466 if data and data != "{}": 

467 if process.stdin: 

468 process.stdin.write((data + "\n").encode()) 

469 await process.stdin.drain() 

470 

471 await asyncio.gather(read_stdout(), pump_sse_to_stdio()) 

472 

473 

474def start_stdio(cmd: str, port: int, log_level: str, cors: Optional[List[str]]) -> None: 

475 """Start stdio bridge. 

476 

477 Args: 

478 cmd: The command to run as a stdio subprocess. 

479 port: The port to bind the HTTP server to. 

480 log_level: The logging level to use. 

481 cors: Optional list of CORS allowed origins. 

482 

483 Returns: 

484 None: This function does not return a value. 

485 """ 

486 return asyncio.run(_run_stdio_to_sse(cmd, port, log_level, cors)) 

487 

488 

489def start_sse(url: str, bearer: Optional[str]) -> None: 

490 """Start SSE bridge. 

491 

492 Args: 

493 url: The SSE endpoint URL to connect to. 

494 bearer: Optional OAuth2 bearer token for authentication. 

495 

496 Returns: 

497 None: This function does not return a value. 

498 """ 

499 return asyncio.run(_run_sse_to_stdio(url, bearer)) 

500 

501 

502def main(argv: Optional[Sequence[str]] | None = None) -> None: 

503 """Entry point for the translate module. 

504 

505 Args: 

506 argv: Optional sequence of command line arguments. If None, uses sys.argv[1:]. 

507 """ 

508 args = _parse_args(argv or sys.argv[1:]) 

509 logging.basicConfig( 

510 level=getattr(logging, args.logLevel.upper(), logging.INFO), 

511 format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", 

512 ) 

513 try: 

514 if args.stdio: 

515 start_stdio(args.stdio, args.port, args.logLevel, args.cors) 

516 elif args.sse: 516 ↛ exitline 516 didn't return from function 'main' because the condition on line 516 was always true

517 start_sse(args.sse, args.oauth2Bearer) 

518 except KeyboardInterrupt: 

519 print("") # restore shell prompt 

520 sys.exit(0) 

521 except NotImplementedError as exc: 

522 print(exc, file=sys.stderr) 

523 sys.exit(1) 

524 

525 

526if __name__ == "__main__": # python -m mcpgateway.translate ... 

527 main()