Coverage for mcpgateway/transports/streamablehttp_transport.py: 95%
142 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
1# -*- coding: utf-8 -*-
2"""Streamable HTTP Transport Implementation.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Keval Mahajan
8This module implements Streamable Http transport for MCP
10Key components include:
11- SessionManagerWrapper: Manages the lifecycle of streamable HTTP sessions
12- Configuration options for:
13 1. stateful/stateless operation
14 2. JSON response mode or SSE streams
15- InMemoryEventStore: A simple in-memory event storage system for maintaining session state
17"""
19# Standard
20from collections import deque
21from contextlib import asynccontextmanager, AsyncExitStack
22import contextvars
23from dataclasses import dataclass
24import logging
25import re
26from typing import List, Union
27from uuid import uuid4
29# Third-Party
30from fastapi.security.utils import get_authorization_scheme_param
31from mcp import types
32from mcp.server.lowlevel import Server
33from mcp.server.streamable_http import (
34 EventCallback,
35 EventId,
36 EventMessage,
37 EventStore,
38 StreamId,
39)
40from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
41from mcp.types import JSONRPCMessage
42from starlette.datastructures import Headers
43from starlette.responses import JSONResponse
44from starlette.status import HTTP_401_UNAUTHORIZED
45from starlette.types import Receive, Scope, Send
47# First-Party
48from mcpgateway.config import settings
49from mcpgateway.db import SessionLocal
50from mcpgateway.services.tool_service import ToolService
51from mcpgateway.utils.verify_credentials import verify_credentials
53logger = logging.getLogger(__name__)
54logging.basicConfig(level=logging.INFO)
56# Initialize ToolService and MCP Server
57tool_service = ToolService()
58mcp_app = Server("mcp-streamable-http-stateless")
60server_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_id", default=None)
62# ------------------------------ Event store ------------------------------
65@dataclass
66class EventEntry:
67 """
68 Represents an event entry in the event store.
69 """
71 event_id: EventId
72 stream_id: StreamId
73 message: JSONRPCMessage
76class InMemoryEventStore(EventStore):
77 """
78 Simple in-memory implementation of the EventStore interface for resumability.
79 This is primarily intended for examples and testing, not for production use
80 where a persistent storage solution would be more appropriate.
82 This implementation keeps only the last N events per stream for memory efficiency.
83 """
85 def __init__(self, max_events_per_stream: int = 100):
86 """Initialize the event store.
88 Args:
89 max_events_per_stream: Maximum number of events to keep per stream
90 """
91 self.max_events_per_stream = max_events_per_stream
92 # for maintaining last N events per stream
93 self.streams: dict[StreamId, deque[EventEntry]] = {}
94 # event_id -> EventEntry for quick lookup
95 self.event_index: dict[EventId, EventEntry] = {}
97 async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId:
98 """
99 Stores an event with a generated event ID.
101 Args:
102 stream_id (StreamId): The ID of the stream.
103 message (JSONRPCMessage): The message to store.
105 Returns:
106 EventId: The ID of the stored event.
107 """
108 event_id = str(uuid4())
109 event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message)
111 # Get or create deque for this stream
112 if stream_id not in self.streams:
113 self.streams[stream_id] = deque(maxlen=self.max_events_per_stream)
115 # If deque is full, the oldest event will be automatically removed
116 # We need to remove it from the event_index as well
117 if len(self.streams[stream_id]) == self.max_events_per_stream:
118 oldest_event = self.streams[stream_id][0]
119 self.event_index.pop(oldest_event.event_id, None)
121 # Add new event
122 self.streams[stream_id].append(event_entry)
123 self.event_index[event_id] = event_entry
125 return event_id
127 async def replay_events_after(
128 self,
129 last_event_id: EventId,
130 send_callback: EventCallback,
131 ) -> Union[StreamId, None]:
132 """
133 Replays events that occurred after the specified event ID.
135 Args:
136 last_event_id (EventId): The ID of the last received event. Replay starts after this event.
137 send_callback (EventCallback): Async callback to send each replayed event.
139 Returns:
140 StreamId | None: The stream ID if the event is found and replayed, otherwise None.
141 """
142 if last_event_id not in self.event_index:
143 logger.warning(f"Event ID {last_event_id} not found in store")
144 return None
146 # Get the stream and find events after the last one
147 last_event = self.event_index[last_event_id]
148 stream_id = last_event.stream_id
149 stream_events = self.streams.get(last_event.stream_id, deque())
151 # Events in deque are already in chronological order
152 found_last = False
153 for event in stream_events:
154 if found_last:
155 await send_callback(EventMessage(event.message, event.event_id))
156 elif event.event_id == last_event_id: 156 ↛ 153line 156 didn't jump to line 153 because the condition on line 156 was always true
157 found_last = True
159 return stream_id
162# ------------------------------ Streamable HTTP Transport ------------------------------
165@asynccontextmanager
166async def get_db():
167 """
168 Asynchronous context manager for database sessions.
170 Yields:
171 A database session instance from SessionLocal.
172 Ensures the session is closed after use.
173 """
174 db = SessionLocal()
175 try:
176 yield db
177 finally:
178 db.close()
181@mcp_app.call_tool()
182async def call_tool(name: str, arguments: dict) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
183 """
184 Handles tool invocation via the MCP Server.
186 Args:
187 name (str): The name of the tool to invoke.
188 arguments (dict): A dictionary of arguments to pass to the tool.
190 Returns:
191 List of content (TextContent, ImageContent, or EmbeddedResource) from the tool response.
192 Logs and returns an empty list on failure.
193 """
194 try:
195 async with get_db() as db:
196 result = await tool_service.invoke_tool(db=db, name=name, arguments=arguments)
197 if not result or not result.content:
198 logger.warning(f"No content returned by tool: {name}")
199 return []
201 return [types.TextContent(type=result.content[0].type, text=result.content[0].text)]
202 except Exception as e:
203 logger.exception(f"Error calling tool '{name}': {e}")
204 return []
207@mcp_app.list_tools()
208async def list_tools() -> List[types.Tool]:
209 """
210 Lists all tools available to the MCP Server.
212 Returns:
213 A list of Tool objects containing metadata such as name, description, and input schema.
214 Logs and returns an empty list on failure.
215 """
216 server_id = server_id_var.get()
218 if server_id:
219 try:
220 async with get_db() as db:
221 tools = await tool_service.list_server_tools(db, server_id)
222 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema, annotations=tool.annotations) for tool in tools]
223 except Exception as e:
224 logger.exception(f"Error listing tools:{e}")
225 return []
226 else:
227 try:
228 async with get_db() as db:
229 tools = await tool_service.list_tools(db)
230 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema, annotations=tool.annotations) for tool in tools]
231 except Exception as e:
232 logger.exception(f"Error listing tools:{e}")
233 return []
236class SessionManagerWrapper:
237 """
238 Wrapper class for managing the lifecycle of a StreamableHTTPSessionManager instance.
239 Provides start, stop, and request handling methods.
240 """
242 def __init__(self) -> None:
243 """
244 Initializes the session manager and the exit stack used for managing its lifecycle.
245 """
247 if settings.use_stateful_sessions: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true
248 event_store = InMemoryEventStore()
249 stateless = False
250 else:
251 event_store = None
252 stateless = True
254 self.session_manager = StreamableHTTPSessionManager(
255 app=mcp_app,
256 event_store=event_store,
257 json_response=settings.json_response_enabled,
258 stateless=stateless,
259 )
260 self.stack = AsyncExitStack()
262 async def initialize(self) -> None:
263 """
264 Starts the Streamable HTTP session manager context.
265 """
266 logger.info("Initializing Streamable HTTP service")
267 await self.stack.enter_async_context(self.session_manager.run())
269 async def shutdown(self) -> None:
270 """
271 Gracefully shuts down the Streamable HTTP session manager.
272 """
273 logger.info("Stopping Streamable HTTP Session Manager...")
274 await self.stack.aclose()
276 async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Send) -> None:
277 """
278 Forwards an incoming ASGI request to the streamable HTTP session manager.
280 Args:
281 scope (Scope): ASGI scope object containing connection information.
282 receive (Receive): ASGI receive callable.
283 send (Send): ASGI send callable.
285 Raises:
286 Exception: Any exception raised during request handling is logged.
288 Logs any exceptions that occur during request handling.
289 """
291 path = scope["modified_path"]
292 match = re.search(r"/servers/(?P<server_id>\d+)/mcp", path)
294 if match: 294 ↛ 298line 294 didn't jump to line 298 because the condition on line 294 was always true
295 server_id = match.group("server_id")
296 server_id_var.set(server_id)
298 try:
299 await self.session_manager.handle_request(scope, receive, send)
300 except Exception as e:
301 logger.exception(f"Error handling streamable HTTP request: {e}")
302 raise
305# ------------------------- Authentication for /mcp routes ------------------------------
308async def streamable_http_auth(scope, receive, send):
309 """
310 Perform authentication check in middleware context (ASGI scope).
312 This function is intended to be used in middleware wrapping ASGI apps.
313 It authenticates only requests targeting paths ending in "/mcp" or "/mcp/".
315 Behavior:
316 - If the path does not end with "/mcp", authentication is skipped.
317 - If there is no Authorization header, the request is allowed.
318 - If a Bearer token is present, it is verified using `verify_credentials`.
319 - If verification fails, a 401 Unauthorized JSON response is sent.
321 Args:
322 scope: The ASGI scope dictionary, which includes request metadata.
323 receive: ASGI receive callable used to receive events.
324 send: ASGI send callable used to send events (e.g. a 401 response).
326 Returns:
327 bool: True if authentication passes or is skipped.
328 False if authentication fails and a 401 response is sent.
329 """
331 path = scope.get("path", "")
332 if not path.endswith("/mcp") and not path.endswith("/mcp/"):
333 # No auth needed for other paths in this middleware usage
334 return True
336 headers = Headers(scope=scope)
337 authorization = headers.get("authorization")
339 token = None
340 if authorization:
341 scheme, credentials = get_authorization_scheme_param(authorization)
342 if scheme.lower() == "bearer" and credentials:
343 token = credentials
344 try:
345 await verify_credentials(token)
346 except Exception:
347 response = JSONResponse(
348 {"detail": "Authentication failed"},
349 status_code=HTTP_401_UNAUTHORIZED,
350 headers={"WWW-Authenticate": "Bearer"},
351 )
352 await response(scope, receive, send)
353 return False
355 return True