Coverage for mcpgateway/transports/streamablehttp_transport.py: 95%

142 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-09 11:03 +0100

1# -*- coding: utf-8 -*- 

2"""Streamable HTTP Transport Implementation. 

3 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Keval Mahajan 

7 

8This module implements Streamable Http transport for MCP 

9 

10Key components include: 

11- SessionManagerWrapper: Manages the lifecycle of streamable HTTP sessions 

12- Configuration options for: 

13 1. stateful/stateless operation 

14 2. JSON response mode or SSE streams 

15- InMemoryEventStore: A simple in-memory event storage system for maintaining session state 

16 

17""" 

18 

19# Standard 

20from collections import deque 

21from contextlib import asynccontextmanager, AsyncExitStack 

22import contextvars 

23from dataclasses import dataclass 

24import logging 

25import re 

26from typing import List, Union 

27from uuid import uuid4 

28 

29# Third-Party 

30from fastapi.security.utils import get_authorization_scheme_param 

31from mcp import types 

32from mcp.server.lowlevel import Server 

33from mcp.server.streamable_http import ( 

34 EventCallback, 

35 EventId, 

36 EventMessage, 

37 EventStore, 

38 StreamId, 

39) 

40from mcp.server.streamable_http_manager import StreamableHTTPSessionManager 

41from mcp.types import JSONRPCMessage 

42from starlette.datastructures import Headers 

43from starlette.responses import JSONResponse 

44from starlette.status import HTTP_401_UNAUTHORIZED 

45from starlette.types import Receive, Scope, Send 

46 

47# First-Party 

48from mcpgateway.config import settings 

49from mcpgateway.db import SessionLocal 

50from mcpgateway.services.tool_service import ToolService 

51from mcpgateway.utils.verify_credentials import verify_credentials 

52 

53logger = logging.getLogger(__name__) 

54logging.basicConfig(level=logging.INFO) 

55 

56# Initialize ToolService and MCP Server 

57tool_service = ToolService() 

58mcp_app = Server("mcp-streamable-http-stateless") 

59 

60server_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_id", default=None) 

61 

62# ------------------------------ Event store ------------------------------ 

63 

64 

65@dataclass 

66class EventEntry: 

67 """ 

68 Represents an event entry in the event store. 

69 """ 

70 

71 event_id: EventId 

72 stream_id: StreamId 

73 message: JSONRPCMessage 

74 

75 

76class InMemoryEventStore(EventStore): 

77 """ 

78 Simple in-memory implementation of the EventStore interface for resumability. 

79 This is primarily intended for examples and testing, not for production use 

80 where a persistent storage solution would be more appropriate. 

81 

82 This implementation keeps only the last N events per stream for memory efficiency. 

83 """ 

84 

85 def __init__(self, max_events_per_stream: int = 100): 

86 """Initialize the event store. 

87 

88 Args: 

89 max_events_per_stream: Maximum number of events to keep per stream 

90 """ 

91 self.max_events_per_stream = max_events_per_stream 

92 # for maintaining last N events per stream 

93 self.streams: dict[StreamId, deque[EventEntry]] = {} 

94 # event_id -> EventEntry for quick lookup 

95 self.event_index: dict[EventId, EventEntry] = {} 

96 

97 async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: 

98 """ 

99 Stores an event with a generated event ID. 

100 

101 Args: 

102 stream_id (StreamId): The ID of the stream. 

103 message (JSONRPCMessage): The message to store. 

104 

105 Returns: 

106 EventId: The ID of the stored event. 

107 """ 

108 event_id = str(uuid4()) 

109 event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message) 

110 

111 # Get or create deque for this stream 

112 if stream_id not in self.streams: 

113 self.streams[stream_id] = deque(maxlen=self.max_events_per_stream) 

114 

115 # If deque is full, the oldest event will be automatically removed 

116 # We need to remove it from the event_index as well 

117 if len(self.streams[stream_id]) == self.max_events_per_stream: 

118 oldest_event = self.streams[stream_id][0] 

119 self.event_index.pop(oldest_event.event_id, None) 

120 

121 # Add new event 

122 self.streams[stream_id].append(event_entry) 

123 self.event_index[event_id] = event_entry 

124 

125 return event_id 

126 

127 async def replay_events_after( 

128 self, 

129 last_event_id: EventId, 

130 send_callback: EventCallback, 

131 ) -> Union[StreamId, None]: 

132 """ 

133 Replays events that occurred after the specified event ID. 

134 

135 Args: 

136 last_event_id (EventId): The ID of the last received event. Replay starts after this event. 

137 send_callback (EventCallback): Async callback to send each replayed event. 

138 

139 Returns: 

140 StreamId | None: The stream ID if the event is found and replayed, otherwise None. 

141 """ 

142 if last_event_id not in self.event_index: 

143 logger.warning(f"Event ID {last_event_id} not found in store") 

144 return None 

145 

146 # Get the stream and find events after the last one 

147 last_event = self.event_index[last_event_id] 

148 stream_id = last_event.stream_id 

149 stream_events = self.streams.get(last_event.stream_id, deque()) 

150 

151 # Events in deque are already in chronological order 

152 found_last = False 

153 for event in stream_events: 

154 if found_last: 

155 await send_callback(EventMessage(event.message, event.event_id)) 

156 elif event.event_id == last_event_id: 156 ↛ 153line 156 didn't jump to line 153 because the condition on line 156 was always true

157 found_last = True 

158 

159 return stream_id 

160 

161 

162# ------------------------------ Streamable HTTP Transport ------------------------------ 

163 

164 

165@asynccontextmanager 

166async def get_db(): 

167 """ 

168 Asynchronous context manager for database sessions. 

169 

170 Yields: 

171 A database session instance from SessionLocal. 

172 Ensures the session is closed after use. 

173 """ 

174 db = SessionLocal() 

175 try: 

176 yield db 

177 finally: 

178 db.close() 

179 

180 

181@mcp_app.call_tool() 

182async def call_tool(name: str, arguments: dict) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: 

183 """ 

184 Handles tool invocation via the MCP Server. 

185 

186 Args: 

187 name (str): The name of the tool to invoke. 

188 arguments (dict): A dictionary of arguments to pass to the tool. 

189 

190 Returns: 

191 List of content (TextContent, ImageContent, or EmbeddedResource) from the tool response. 

192 Logs and returns an empty list on failure. 

193 """ 

194 try: 

195 async with get_db() as db: 

196 result = await tool_service.invoke_tool(db=db, name=name, arguments=arguments) 

197 if not result or not result.content: 

198 logger.warning(f"No content returned by tool: {name}") 

199 return [] 

200 

201 return [types.TextContent(type=result.content[0].type, text=result.content[0].text)] 

202 except Exception as e: 

203 logger.exception(f"Error calling tool '{name}': {e}") 

204 return [] 

205 

206 

207@mcp_app.list_tools() 

208async def list_tools() -> List[types.Tool]: 

209 """ 

210 Lists all tools available to the MCP Server. 

211 

212 Returns: 

213 A list of Tool objects containing metadata such as name, description, and input schema. 

214 Logs and returns an empty list on failure. 

215 """ 

216 server_id = server_id_var.get() 

217 

218 if server_id: 

219 try: 

220 async with get_db() as db: 

221 tools = await tool_service.list_server_tools(db, server_id) 

222 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema, annotations=tool.annotations) for tool in tools] 

223 except Exception as e: 

224 logger.exception(f"Error listing tools:{e}") 

225 return [] 

226 else: 

227 try: 

228 async with get_db() as db: 

229 tools = await tool_service.list_tools(db) 

230 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema, annotations=tool.annotations) for tool in tools] 

231 except Exception as e: 

232 logger.exception(f"Error listing tools:{e}") 

233 return [] 

234 

235 

236class SessionManagerWrapper: 

237 """ 

238 Wrapper class for managing the lifecycle of a StreamableHTTPSessionManager instance. 

239 Provides start, stop, and request handling methods. 

240 """ 

241 

242 def __init__(self) -> None: 

243 """ 

244 Initializes the session manager and the exit stack used for managing its lifecycle. 

245 """ 

246 

247 if settings.use_stateful_sessions: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true

248 event_store = InMemoryEventStore() 

249 stateless = False 

250 else: 

251 event_store = None 

252 stateless = True 

253 

254 self.session_manager = StreamableHTTPSessionManager( 

255 app=mcp_app, 

256 event_store=event_store, 

257 json_response=settings.json_response_enabled, 

258 stateless=stateless, 

259 ) 

260 self.stack = AsyncExitStack() 

261 

262 async def initialize(self) -> None: 

263 """ 

264 Starts the Streamable HTTP session manager context. 

265 """ 

266 logger.info("Initializing Streamable HTTP service") 

267 await self.stack.enter_async_context(self.session_manager.run()) 

268 

269 async def shutdown(self) -> None: 

270 """ 

271 Gracefully shuts down the Streamable HTTP session manager. 

272 """ 

273 logger.info("Stopping Streamable HTTP Session Manager...") 

274 await self.stack.aclose() 

275 

276 async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Send) -> None: 

277 """ 

278 Forwards an incoming ASGI request to the streamable HTTP session manager. 

279 

280 Args: 

281 scope (Scope): ASGI scope object containing connection information. 

282 receive (Receive): ASGI receive callable. 

283 send (Send): ASGI send callable. 

284 

285 Raises: 

286 Exception: Any exception raised during request handling is logged. 

287 

288 Logs any exceptions that occur during request handling. 

289 """ 

290 

291 path = scope["modified_path"] 

292 match = re.search(r"/servers/(?P<server_id>\d+)/mcp", path) 

293 

294 if match: 294 ↛ 298line 294 didn't jump to line 298 because the condition on line 294 was always true

295 server_id = match.group("server_id") 

296 server_id_var.set(server_id) 

297 

298 try: 

299 await self.session_manager.handle_request(scope, receive, send) 

300 except Exception as e: 

301 logger.exception(f"Error handling streamable HTTP request: {e}") 

302 raise 

303 

304 

305# ------------------------- Authentication for /mcp routes ------------------------------ 

306 

307 

308async def streamable_http_auth(scope, receive, send): 

309 """ 

310 Perform authentication check in middleware context (ASGI scope). 

311 

312 This function is intended to be used in middleware wrapping ASGI apps. 

313 It authenticates only requests targeting paths ending in "/mcp" or "/mcp/". 

314 

315 Behavior: 

316 - If the path does not end with "/mcp", authentication is skipped. 

317 - If there is no Authorization header, the request is allowed. 

318 - If a Bearer token is present, it is verified using `verify_credentials`. 

319 - If verification fails, a 401 Unauthorized JSON response is sent. 

320 

321 Args: 

322 scope: The ASGI scope dictionary, which includes request metadata. 

323 receive: ASGI receive callable used to receive events. 

324 send: ASGI send callable used to send events (e.g. a 401 response). 

325 

326 Returns: 

327 bool: True if authentication passes or is skipped. 

328 False if authentication fails and a 401 response is sent. 

329 """ 

330 

331 path = scope.get("path", "") 

332 if not path.endswith("/mcp") and not path.endswith("/mcp/"): 

333 # No auth needed for other paths in this middleware usage 

334 return True 

335 

336 headers = Headers(scope=scope) 

337 authorization = headers.get("authorization") 

338 

339 token = None 

340 if authorization: 

341 scheme, credentials = get_authorization_scheme_param(authorization) 

342 if scheme.lower() == "bearer" and credentials: 

343 token = credentials 

344 try: 

345 await verify_credentials(token) 

346 except Exception: 

347 response = JSONResponse( 

348 {"detail": "Authentication failed"}, 

349 status_code=HTTP_401_UNAUTHORIZED, 

350 headers={"WWW-Authenticate": "Bearer"}, 

351 ) 

352 await response(scope, receive, send) 

353 return False 

354 

355 return True