Coverage for mcpgateway/transports/sse_transport.py: 95%
77 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
1# -*- coding: utf-8 -*-
2"""SSE Transport Implementation.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module implements Server-Sent Events (SSE) transport for MCP,
9providing server-to-client streaming with proper session management.
10"""
12# Standard
13import asyncio
14from datetime import datetime
15import json
16import logging
17from typing import Any, AsyncGenerator, Dict
18import uuid
20# Third-Party
21from fastapi import Request
22from sse_starlette.sse import EventSourceResponse
24# First-Party
25from mcpgateway.config import settings
26from mcpgateway.transports.base import Transport
28logger = logging.getLogger(__name__)
31class SSETransport(Transport):
32 """Transport implementation using Server-Sent Events with proper session management."""
34 def __init__(self, base_url: str = None):
35 """Initialize SSE transport.
37 Args:
38 base_url: Base URL for client message endpoints
39 """
40 self._base_url = base_url or f"http://{settings.host}:{settings.port}"
41 self._connected = False
42 self._message_queue = asyncio.Queue()
43 self._client_gone = asyncio.Event()
44 self._session_id = str(uuid.uuid4())
46 logger.info(f"Creating SSE transport with base_url={self._base_url}, session_id={self._session_id}")
48 async def connect(self) -> None:
49 """Set up SSE connection."""
50 self._connected = True
51 logger.info(f"SSE transport connected: {self._session_id}")
53 async def disconnect(self) -> None:
54 """Clean up SSE connection."""
55 if self._connected: 55 ↛ exitline 55 didn't return from function 'disconnect' because the condition on line 55 was always true
56 self._connected = False
57 self._client_gone.set()
58 logger.info(f"SSE transport disconnected: {self._session_id}")
60 async def send_message(self, message: Dict[str, Any]) -> None:
61 """Send a message over SSE.
63 Args:
64 message: Message to send
66 Raises:
67 RuntimeError: If transport is not connected
68 Exception: If unable to put message to queue
69 """
70 if not self._connected:
71 raise RuntimeError("Transport not connected")
73 try:
74 await self._message_queue.put(message)
75 logger.debug(f"Message queued for SSE: {self._session_id}, method={message.get('method', '(response)')}")
76 except Exception as e:
77 logger.error(f"Failed to queue message: {e}")
78 raise
80 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]:
81 """Receive messages from the client over SSE transport.
83 This method implements a continuous message-receiving pattern for SSE transport.
84 Since SSE is primarily a server-to-client communication channel, this method
85 yields an initial initialize placeholder message and then enters a waiting loop.
86 The actual client messages are received via a separate HTTP POST endpoint
87 (not handled in this method).
89 The method will continue running until either:
90 1. The connection is explicitly disconnected (client_gone event is set)
91 2. The receive loop is cancelled from outside
93 Yields:
94 Dict[str, Any]: JSON-RPC formatted messages. The first yielded message is always
95 an initialize placeholder with the format:
96 {"jsonrpc": "2.0", "method": "initialize", "id": 1}
98 Raises:
99 RuntimeError: If the transport is not connected when this method is called
100 asyncio.CancelledError: When the SSE receive loop is cancelled externally
101 """
102 if not self._connected:
103 raise RuntimeError("Transport not connected")
105 # For SSE, we set up a loop to wait for messages which are delivered via POST
106 # Most messages come via the POST endpoint, but we yield an initial initialize placeholder
107 # to keep the receive loop running
108 yield {"jsonrpc": "2.0", "method": "initialize", "id": 1}
110 # Continue waiting for cancellation
111 try:
112 while not self._client_gone.is_set():
113 await asyncio.sleep(1.0)
114 except asyncio.CancelledError:
115 logger.info(f"SSE receive loop cancelled for session {self._session_id}")
116 raise
117 finally:
118 logger.info(f"SSE receive loop ended for session {self._session_id}")
120 async def is_connected(self) -> bool:
121 """Check if transport is connected.
123 Returns:
124 True if connected
125 """
126 return self._connected
128 async def create_sse_response(self, _request: Request) -> EventSourceResponse:
129 """Create SSE response for streaming.
131 Args:
132 _request: FastAPI request
134 Returns:
135 SSE response object
136 """
137 endpoint_url = f"{self._base_url}/message?session_id={self._session_id}"
139 async def event_generator():
140 """Generate SSE events.
142 Yields:
143 SSE event
144 """
145 # Send the endpoint event first
146 yield {
147 "event": "endpoint",
148 "data": endpoint_url,
149 "retry": settings.sse_retry_timeout,
150 }
152 # Send keepalive immediately to help establish connection
153 yield {
154 "event": "keepalive",
155 "data": "{}",
156 "retry": settings.sse_retry_timeout,
157 }
159 try:
160 while not self._client_gone.is_set(): 160 ↛ 197line 160 didn't jump to line 197 because the condition on line 160 was always true
161 try:
162 # Wait for messages with a timeout for keepalives
163 message = await asyncio.wait_for(
164 self._message_queue.get(),
165 timeout=30.0, # 30 second timeout for keepalives (some tools require more timeout for execution)
166 )
168 data = json.dumps(message, default=lambda obj: (obj.strftime("%Y-%m-%d %H:%M:%S") if isinstance(obj, datetime) else TypeError("Type not serializable")))
170 # logger.info(f"Sending SSE message: {data[:100]}...")
171 logger.debug(f"Sending SSE message: {data}")
173 yield {
174 "event": "message",
175 "data": data,
176 "retry": settings.sse_retry_timeout,
177 }
178 except asyncio.TimeoutError:
179 # Send keepalive on timeout
180 yield {
181 "event": "keepalive",
182 "data": "{}",
183 "retry": settings.sse_retry_timeout,
184 }
185 except Exception as e:
186 logger.error(f"Error processing SSE message: {e}")
187 yield {
188 "event": "error",
189 "data": json.dumps({"error": str(e)}),
190 "retry": settings.sse_retry_timeout,
191 }
192 except asyncio.CancelledError:
193 logger.info(f"SSE event generator cancelled: {self._session_id}")
194 except Exception as e:
195 logger.error(f"SSE event generator error: {e}")
196 finally:
197 logger.info(f"SSE event generator completed: {self._session_id}")
198 # We intentionally don't set client_gone here to allow queued messages to be processed
200 return EventSourceResponse(
201 event_generator(),
202 status_code=200,
203 headers={
204 "Cache-Control": "no-cache",
205 "Connection": "keep-alive",
206 "Content-Type": "text/event-stream",
207 "X-MCP-SSE": "true",
208 },
209 )
211 async def _client_disconnected(self, _request: Request) -> bool:
212 """Check if client has disconnected.
214 Args:
215 _request: FastAPI Request object
217 Returns:
218 bool: True if client disconnected
219 """
220 # We only check our internal client_gone flag
221 # We intentionally don't check connection_lost on the request
222 # as it can be unreliable and cause premature closures
223 return self._client_gone.is_set()
225 @property
226 def session_id(self) -> str:
227 """
228 Get the session ID for this transport.
230 Returns:
231 str: session_id
232 """
233 return self._session_id