Coverage for mcpgateway/transports/sse_transport.py: 95%

77 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-09 11:03 +0100

1# -*- coding: utf-8 -*- 

2"""SSE Transport Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8This module implements Server-Sent Events (SSE) transport for MCP, 

9providing server-to-client streaming with proper session management. 

10""" 

11 

12# Standard 

13import asyncio 

14from datetime import datetime 

15import json 

16import logging 

17from typing import Any, AsyncGenerator, Dict 

18import uuid 

19 

20# Third-Party 

21from fastapi import Request 

22from sse_starlette.sse import EventSourceResponse 

23 

24# First-Party 

25from mcpgateway.config import settings 

26from mcpgateway.transports.base import Transport 

27 

28logger = logging.getLogger(__name__) 

29 

30 

31class SSETransport(Transport): 

32 """Transport implementation using Server-Sent Events with proper session management.""" 

33 

34 def __init__(self, base_url: str = None): 

35 """Initialize SSE transport. 

36 

37 Args: 

38 base_url: Base URL for client message endpoints 

39 """ 

40 self._base_url = base_url or f"http://{settings.host}:{settings.port}" 

41 self._connected = False 

42 self._message_queue = asyncio.Queue() 

43 self._client_gone = asyncio.Event() 

44 self._session_id = str(uuid.uuid4()) 

45 

46 logger.info(f"Creating SSE transport with base_url={self._base_url}, session_id={self._session_id}") 

47 

48 async def connect(self) -> None: 

49 """Set up SSE connection.""" 

50 self._connected = True 

51 logger.info(f"SSE transport connected: {self._session_id}") 

52 

53 async def disconnect(self) -> None: 

54 """Clean up SSE connection.""" 

55 if self._connected: 55 ↛ exitline 55 didn't return from function 'disconnect' because the condition on line 55 was always true

56 self._connected = False 

57 self._client_gone.set() 

58 logger.info(f"SSE transport disconnected: {self._session_id}") 

59 

60 async def send_message(self, message: Dict[str, Any]) -> None: 

61 """Send a message over SSE. 

62 

63 Args: 

64 message: Message to send 

65 

66 Raises: 

67 RuntimeError: If transport is not connected 

68 Exception: If unable to put message to queue 

69 """ 

70 if not self._connected: 

71 raise RuntimeError("Transport not connected") 

72 

73 try: 

74 await self._message_queue.put(message) 

75 logger.debug(f"Message queued for SSE: {self._session_id}, method={message.get('method', '(response)')}") 

76 except Exception as e: 

77 logger.error(f"Failed to queue message: {e}") 

78 raise 

79 

80 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]: 

81 """Receive messages from the client over SSE transport. 

82 

83 This method implements a continuous message-receiving pattern for SSE transport. 

84 Since SSE is primarily a server-to-client communication channel, this method 

85 yields an initial initialize placeholder message and then enters a waiting loop. 

86 The actual client messages are received via a separate HTTP POST endpoint 

87 (not handled in this method). 

88 

89 The method will continue running until either: 

90 1. The connection is explicitly disconnected (client_gone event is set) 

91 2. The receive loop is cancelled from outside 

92 

93 Yields: 

94 Dict[str, Any]: JSON-RPC formatted messages. The first yielded message is always 

95 an initialize placeholder with the format: 

96 {"jsonrpc": "2.0", "method": "initialize", "id": 1} 

97 

98 Raises: 

99 RuntimeError: If the transport is not connected when this method is called 

100 asyncio.CancelledError: When the SSE receive loop is cancelled externally 

101 """ 

102 if not self._connected: 

103 raise RuntimeError("Transport not connected") 

104 

105 # For SSE, we set up a loop to wait for messages which are delivered via POST 

106 # Most messages come via the POST endpoint, but we yield an initial initialize placeholder 

107 # to keep the receive loop running 

108 yield {"jsonrpc": "2.0", "method": "initialize", "id": 1} 

109 

110 # Continue waiting for cancellation 

111 try: 

112 while not self._client_gone.is_set(): 

113 await asyncio.sleep(1.0) 

114 except asyncio.CancelledError: 

115 logger.info(f"SSE receive loop cancelled for session {self._session_id}") 

116 raise 

117 finally: 

118 logger.info(f"SSE receive loop ended for session {self._session_id}") 

119 

120 async def is_connected(self) -> bool: 

121 """Check if transport is connected. 

122 

123 Returns: 

124 True if connected 

125 """ 

126 return self._connected 

127 

128 async def create_sse_response(self, _request: Request) -> EventSourceResponse: 

129 """Create SSE response for streaming. 

130 

131 Args: 

132 _request: FastAPI request 

133 

134 Returns: 

135 SSE response object 

136 """ 

137 endpoint_url = f"{self._base_url}/message?session_id={self._session_id}" 

138 

139 async def event_generator(): 

140 """Generate SSE events. 

141 

142 Yields: 

143 SSE event 

144 """ 

145 # Send the endpoint event first 

146 yield { 

147 "event": "endpoint", 

148 "data": endpoint_url, 

149 "retry": settings.sse_retry_timeout, 

150 } 

151 

152 # Send keepalive immediately to help establish connection 

153 yield { 

154 "event": "keepalive", 

155 "data": "{}", 

156 "retry": settings.sse_retry_timeout, 

157 } 

158 

159 try: 

160 while not self._client_gone.is_set(): 160 ↛ 197line 160 didn't jump to line 197 because the condition on line 160 was always true

161 try: 

162 # Wait for messages with a timeout for keepalives 

163 message = await asyncio.wait_for( 

164 self._message_queue.get(), 

165 timeout=30.0, # 30 second timeout for keepalives (some tools require more timeout for execution) 

166 ) 

167 

168 data = json.dumps(message, default=lambda obj: (obj.strftime("%Y-%m-%d %H:%M:%S") if isinstance(obj, datetime) else TypeError("Type not serializable"))) 

169 

170 # logger.info(f"Sending SSE message: {data[:100]}...") 

171 logger.debug(f"Sending SSE message: {data}") 

172 

173 yield { 

174 "event": "message", 

175 "data": data, 

176 "retry": settings.sse_retry_timeout, 

177 } 

178 except asyncio.TimeoutError: 

179 # Send keepalive on timeout 

180 yield { 

181 "event": "keepalive", 

182 "data": "{}", 

183 "retry": settings.sse_retry_timeout, 

184 } 

185 except Exception as e: 

186 logger.error(f"Error processing SSE message: {e}") 

187 yield { 

188 "event": "error", 

189 "data": json.dumps({"error": str(e)}), 

190 "retry": settings.sse_retry_timeout, 

191 } 

192 except asyncio.CancelledError: 

193 logger.info(f"SSE event generator cancelled: {self._session_id}") 

194 except Exception as e: 

195 logger.error(f"SSE event generator error: {e}") 

196 finally: 

197 logger.info(f"SSE event generator completed: {self._session_id}") 

198 # We intentionally don't set client_gone here to allow queued messages to be processed 

199 

200 return EventSourceResponse( 

201 event_generator(), 

202 status_code=200, 

203 headers={ 

204 "Cache-Control": "no-cache", 

205 "Connection": "keep-alive", 

206 "Content-Type": "text/event-stream", 

207 "X-MCP-SSE": "true", 

208 }, 

209 ) 

210 

211 async def _client_disconnected(self, _request: Request) -> bool: 

212 """Check if client has disconnected. 

213 

214 Args: 

215 _request: FastAPI Request object 

216 

217 Returns: 

218 bool: True if client disconnected 

219 """ 

220 # We only check our internal client_gone flag 

221 # We intentionally don't check connection_lost on the request 

222 # as it can be unreliable and cause premature closures 

223 return self._client_gone.is_set() 

224 

225 @property 

226 def session_id(self) -> str: 

227 """ 

228 Get the session ID for this transport. 

229 

230 Returns: 

231 str: session_id 

232 """ 

233 return self._session_id