Coverage for mcpgateway/cache/session_registry.py: 58%
417 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-09 11:03 +0100
1# -*- coding: utf-8 -*-
2"""Session Registry with optional distributed state.
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8This module provides a registry for SSE sessions with support for distributed deployment
9using Redis or SQLAlchemy as optional backends for shared state between workers.
10"""
12# Standard
13import asyncio
14import json
15import logging
16import time
17from typing import Any, Dict, Optional
19# Third-Party
20from fastapi import HTTPException, status
21import httpx
23# First-Party
24from mcpgateway.config import settings
25from mcpgateway.db import get_db, SessionMessageRecord, SessionRecord
26from mcpgateway.models import Implementation, InitializeResult, ServerCapabilities
27from mcpgateway.services import PromptService, ResourceService, ToolService
28from mcpgateway.transports import SSETransport
30logger = logging.getLogger(__name__)
32tool_service = ToolService()
33resource_service = ResourceService()
34prompt_service = PromptService()
36try:
37 # Third-Party
38 from redis.asyncio import Redis
40 REDIS_AVAILABLE = True
41except ImportError:
42 REDIS_AVAILABLE = False
44try:
45 # Third-Party
46 from sqlalchemy import func
48 SQLALCHEMY_AVAILABLE = True
49except ImportError:
50 SQLALCHEMY_AVAILABLE = False
53class SessionBackend:
54 """Session backend related fields"""
56 def __init__(
57 self,
58 backend: str = "memory",
59 redis_url: Optional[str] = None,
60 database_url: Optional[str] = None,
61 session_ttl: int = 3600, # 1 hour
62 message_ttl: int = 600, # 10 min
63 ):
64 """Initialize session registry.
66 Args:
67 backend: "memory", "redis", "database", or "none"
68 redis_url: Redis connection URL (required for redis backend)
69 database_url: Database connection URL (required for database backend)
70 session_ttl: Session time-to-live in seconds
71 message_ttl: Message time-to-live in seconds
73 Raises:
74 ValueError: If backend is invalid or required URL is missing
75 """
77 self._backend = backend.lower()
78 self._session_ttl = session_ttl
79 self._message_ttl = message_ttl
81 # Set up backend-specific components
82 if self._backend == "memory":
83 # Nothing special needed for memory backend
84 self._session_message = None
86 elif self._backend == "none":
87 # No session tracking - this is just a dummy registry
88 logger.info("Session registry initialized with 'none' backend - session tracking disabled")
90 elif self._backend == "redis":
91 if not REDIS_AVAILABLE:
92 raise ValueError("Redis backend requested but redis package not installed")
93 if not redis_url:
94 raise ValueError("Redis backend requires redis_url")
96 self._redis = Redis.from_url(redis_url)
97 self._pubsub = self._redis.pubsub()
99 elif self._backend == "database":
100 if not SQLALCHEMY_AVAILABLE:
101 raise ValueError("Database backend requested but SQLAlchemy not installed")
102 if not database_url:
103 raise ValueError("Database backend requires database_url")
104 else:
105 raise ValueError(f"Invalid backend: {backend}")
108class SessionRegistry(SessionBackend):
109 """Registry for SSE sessions with optional distributed state.
111 Supports three backend modes:
112 - memory: In-memory storage (default, no dependencies)
113 - redis: Redis-backed shared storage
114 - database: SQLAlchemy-backed shared storage
116 In distributed mode (redis/database), session existence is tracked in the shared
117 backend while transports themselves remain local to each worker process.
118 """
120 def __init__(
121 self,
122 backend: str = "memory",
123 redis_url: Optional[str] = None,
124 database_url: Optional[str] = None,
125 session_ttl: int = 3600, # 1 hour
126 message_ttl: int = 600, # 10 min
127 ):
128 """Initialize session registry.
130 Args:
131 backend: "memory", "redis", "database", or "none"
132 redis_url: Redis connection URL (required for redis backend)
133 database_url: Database connection URL (required for database backend)
134 session_ttl: Session time-to-live in seconds
135 message_ttl: Message time-to-live in seconds
136 """
137 super().__init__(backend=backend, redis_url=redis_url, database_url=database_url, session_ttl=session_ttl, message_ttl=message_ttl)
138 self._sessions: Dict[str, Any] = {} # Local transport cache
139 self._lock = asyncio.Lock()
140 self._cleanup_task = None
142 async def initialize(self) -> None:
143 """Initialize the registry with async setup.
145 Call this during application startup.
146 """
147 logger.info(f"Initializing session registry with backend: {self._backend}")
149 if self._backend == "database":
150 # Start database cleanup task
151 self._cleanup_task = asyncio.create_task(self._db_cleanup_task())
152 logger.info("Database cleanup task started")
154 elif self._backend == "redis":
155 await self._pubsub.subscribe("mcp_session_events")
157 elif self._backend == "none":
158 # Nothing to initialize for none backend
159 pass
161 # Memory backend needs session cleanup
162 elif self._backend == "memory": 162 ↛ exitline 162 didn't return from function 'initialize' because the condition on line 162 was always true
163 self._cleanup_task = asyncio.create_task(self._memory_cleanup_task())
164 logger.info("Memory cleanup task started")
166 async def shutdown(self) -> None:
167 """Shutdown the registry.
169 Call this during application shutdown.
170 """
171 logger.info("Shutting down session registry")
173 # Cancel cleanup task
174 if self._cleanup_task:
175 self._cleanup_task.cancel()
176 try:
177 await self._cleanup_task
178 except asyncio.CancelledError:
179 pass
181 # Close Redis connections
182 if self._backend == "redis":
183 try:
184 await self._pubsub.aclose()
185 await self._redis.aclose()
186 except Exception as e:
187 logger.error(f"Error closing Redis connection: {e}")
189 async def add_session(self, session_id: str, transport: SSETransport) -> None:
190 """Add a session to the registry.
192 Args:
193 session_id: Unique session identifier
194 transport: Transport session
195 """
196 # Skip for none backend
197 if self._backend == "none":
198 return
200 async with self._lock:
201 self._sessions[session_id] = transport
203 if self._backend == "redis":
204 # Store session marker in Redis
205 try:
206 await self._redis.setex(f"mcp:session:{session_id}", self._session_ttl, "1")
207 # Publish event to notify other workers
208 await self._redis.publish("mcp_session_events", json.dumps({"type": "add", "session_id": session_id, "timestamp": time.time()}))
209 except Exception as e:
210 logger.error(f"Redis error adding session {session_id}: {e}")
212 elif self._backend == "database":
213 # Store session in database
214 try:
216 def _db_add():
217 db_session = next(get_db())
218 try:
219 session_record = SessionRecord(session_id=session_id)
220 db_session.add(session_record)
221 db_session.commit()
222 except Exception as ex:
223 db_session.rollback()
224 raise ex
225 finally:
226 db_session.close()
228 await asyncio.to_thread(_db_add)
229 except Exception as e:
230 logger.error(f"Database error adding session {session_id}: {e}")
232 logger.info(f"Added session: {session_id}")
234 async def get_session(self, session_id: str) -> Any:
235 """Get session by ID.
237 Args:
238 session_id: Session identifier
240 Returns:
241 Transport object or None if not found
242 """
243 # Skip for none backend
244 if self._backend == "none":
245 return None
247 # First check local cache
248 async with self._lock:
249 transport = self._sessions.get(session_id)
250 if transport:
251 logger.info(f"Session {session_id} exists in local cache")
252 return transport
254 # If not in local cache, check if it exists in shared backend
255 if self._backend == "redis": 255 ↛ 256line 255 didn't jump to line 256 because the condition on line 255 was never true
256 try:
257 exists = await self._redis.exists(f"mcp:session:{session_id}")
258 session_exists = bool(exists)
259 if session_exists:
260 logger.info(f"Session {session_id} exists in Redis but not in local cache")
261 return None # We don't have the transport locally
262 except Exception as e:
263 logger.error(f"Redis error checking session {session_id}: {e}")
264 return None
266 elif self._backend == "database": 266 ↛ 267line 266 didn't jump to line 267 because the condition on line 266 was never true
267 try:
269 def _db_check():
270 db_session = next(get_db())
271 try:
272 record = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
273 return record is not None
274 finally:
275 db_session.close()
277 exists = await asyncio.to_thread(_db_check)
278 if exists:
279 logger.info(f"Session {session_id} exists in database but not in local cache")
280 return None
281 except Exception as e:
282 logger.error(f"Database error checking session {session_id}: {e}")
283 return None
285 return None
287 async def remove_session(self, session_id: str) -> None:
288 """Remove a session from the registry.
290 Args:
291 session_id: Session identifier
292 """
293 # Skip for none backend
294 if self._backend == "none":
295 return
297 # Clean up local transport
298 transport = None
299 async with self._lock:
300 if session_id in self._sessions:
301 transport = self._sessions.pop(session_id)
303 # Disconnect transport if found
304 if transport:
305 try:
306 await transport.disconnect()
307 except Exception as e:
308 logger.error(f"Error disconnecting transport for session {session_id}: {e}")
310 # Remove from shared backend
311 if self._backend == "redis":
312 try:
313 await self._redis.delete(f"mcp:session:{session_id}")
314 # Notify other workers
315 await self._redis.publish("mcp_session_events", json.dumps({"type": "remove", "session_id": session_id, "timestamp": time.time()}))
316 except Exception as e:
317 logger.error(f"Redis error removing session {session_id}: {e}")
319 elif self._backend == "database":
320 try:
322 def _db_remove():
323 db_session = next(get_db())
324 try:
325 db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).delete()
326 db_session.commit()
327 except Exception as ex:
328 db_session.rollback()
329 raise ex
330 finally:
331 db_session.close()
333 await asyncio.to_thread(_db_remove)
334 except Exception as e:
335 logger.error(f"Database error removing session {session_id}: {e}")
337 logger.info(f"Removed session: {session_id}")
339 async def broadcast(self, session_id: str, message: dict) -> None:
340 """Broadcast a session_id and message to a channel.
342 Args:
343 session_id: Session ID
344 message: Message to broadcast
345 """
346 # Skip for none and memory backend
347 if self._backend == "none":
348 return
350 if self._backend == "memory":
351 if isinstance(message, (dict, list)):
352 msg_json = json.dumps(message)
353 else:
354 msg_json = json.dumps(str(message))
356 self._session_message = {"session_id": session_id, "message": msg_json}
358 elif self._backend == "redis":
359 try:
360 if isinstance(message, (dict, list)): 360 ↛ 363line 360 didn't jump to line 363 because the condition on line 360 was always true
361 msg_json = json.dumps(message)
362 else:
363 msg_json = json.dumps(str(message))
365 await self._redis.publish(session_id, json.dumps({"type": "message", "message": msg_json, "timestamp": time.time()}))
366 except Exception as e:
367 logger.error(f"Redis error during broadcast: {e}")
368 elif self._backend == "database": 368 ↛ exitline 368 didn't return from function 'broadcast' because the condition on line 368 was always true
369 try:
370 if isinstance(message, (dict, list)): 370 ↛ 373line 370 didn't jump to line 373 because the condition on line 370 was always true
371 msg_json = json.dumps(message)
372 else:
373 msg_json = json.dumps(str(message))
375 def _db_add():
376 db_session = next(get_db())
377 try:
378 message_record = SessionMessageRecord(session_id=session_id, message=msg_json)
379 db_session.add(message_record)
380 db_session.commit()
381 except Exception as ex:
382 db_session.rollback()
383 raise ex
384 finally:
385 db_session.close()
387 await asyncio.to_thread(_db_add)
388 except Exception as e:
389 logger.error(f"Database error during broadcast: {e}")
391 def get_session_sync(self, session_id: str) -> Any:
392 """Get session synchronously (not checking shared backend).
394 This is a non-blocking method for handlers that need quick access.
395 It only checks the local cache, not the shared backend.
397 Args:
398 session_id: Session identifier
400 Returns:
401 Transport object or None if not found
402 """
403 # Skip for none backend
404 if self._backend == "none":
405 return None
407 return self._sessions.get(session_id)
409 async def respond(
410 self,
411 server_id: Optional[str],
412 user: json,
413 session_id: str,
414 base_url: str,
415 ) -> None:
416 """Respond to broadcast message is transport relevant to session_id is found locally
418 Args:
419 server_id: Server ID
420 session_id: Session ID
421 user: User information
422 base_url: Base URL for the FastAPI request
424 """
426 if self._backend == "none":
427 pass
429 elif self._backend == "memory": 429 ↛ 436line 429 didn't jump to line 436 because the condition on line 429 was always true
430 # if self._session_message:
431 transport = self.get_session_sync(session_id)
432 if transport: 432 ↛ exitline 432 didn't return from function 'respond' because the condition on line 432 was always true
433 message = json.loads(self._session_message.get("message"))
434 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url)
436 elif self._backend == "redis":
437 await self._pubsub.subscribe(session_id)
439 try:
440 async for msg in self._pubsub.listen():
441 if msg["type"] != "message":
442 continue
443 data = json.loads(msg["data"])
444 message = data.get("message", {})
445 if isinstance(message, str):
446 message = json.loads(message)
447 transport = self.get_session_sync(session_id)
448 if transport:
449 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url)
450 except asyncio.CancelledError:
451 logger.info(f"PubSub listener for session {session_id} cancelled")
452 finally:
453 await self._pubsub.unsubscribe(session_id)
454 logger.info(f"Cleaned up pubsub for session {session_id}")
456 elif self._backend == "database":
458 def _db_read_session(session_id):
459 db_session = next(get_db())
460 try:
461 # Delete sessions that haven't been accessed for TTL seconds
462 result = db_session.query(SessionRecord).filter_by(session_id=session_id).first()
463 return result
464 except Exception as ex:
465 db_session.rollback()
466 raise ex
467 finally:
468 db_session.close()
470 def _db_read(session_id):
471 db_session = next(get_db())
472 try:
473 # Delete sessions that haven't been accessed for TTL seconds
474 result = db_session.query(SessionMessageRecord).filter_by(session_id=session_id).first()
475 return result
476 except Exception as ex:
477 db_session.rollback()
478 raise ex
479 finally:
480 db_session.close()
482 def _db_remove(session_id, message):
483 db_session = next(get_db())
484 try:
485 db_session.query(SessionMessageRecord).filter(SessionMessageRecord.session_id == session_id).filter(SessionMessageRecord.message == message).delete()
486 db_session.commit()
487 logger.info("Removed message from mcp_messages table")
488 except Exception as ex:
489 db_session.rollback()
490 raise ex
491 finally:
492 db_session.close()
494 async def message_check_loop(session_id):
495 while True:
496 record = await asyncio.to_thread(_db_read, session_id)
498 if record:
499 message = json.loads(record.message)
500 transport = self.get_session_sync(session_id)
501 if transport:
502 logger.info("Ready to respond")
503 await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url)
505 await asyncio.to_thread(_db_remove, session_id, record.message)
507 session_exists = await asyncio.to_thread(_db_read_session, session_id)
508 if not session_exists:
509 break
511 await asyncio.sleep(0.1)
513 asyncio.create_task(message_check_loop(session_id))
515 async def _refresh_redis_sessions(self) -> None:
516 """Refresh TTLs for Redis sessions and clean up disconnected sessions."""
517 try:
518 # Check all local sessions
519 local_transports = {}
520 async with self._lock:
521 local_transports = self._sessions.copy()
523 for session_id, transport in local_transports.items():
524 try:
525 if await transport.is_connected():
526 # Refresh TTL in Redis
527 await self._redis.expire(f"mcp:session:{session_id}", self._session_ttl)
528 else:
529 # Remove disconnected session
530 await self.remove_session(session_id)
531 except Exception as e:
532 logger.error(f"Error refreshing session {session_id}: {e}")
534 except Exception as e:
535 logger.error(f"Error in Redis session refresh: {e}")
537 async def _db_cleanup_task(self) -> None:
538 """Periodically clean up expired database sessions."""
539 logger.info("Starting database cleanup task")
540 while True:
541 try:
542 # Clean up expired sessions every 5 minutes
543 def _db_cleanup():
544 db_session = next(get_db())
545 try:
546 # Delete sessions that haven't been accessed for TTL seconds
547 expiry_time = func.now() - func.make_interval(seconds=self._session_ttl) # pylint: disable=not-callable
548 result = db_session.query(SessionRecord).filter(SessionRecord.last_accessed < expiry_time).delete()
549 db_session.commit()
550 return result
551 except Exception as ex:
552 db_session.rollback()
553 raise ex
554 finally:
555 db_session.close()
557 deleted = await asyncio.to_thread(_db_cleanup)
558 if deleted > 0:
559 logger.info(f"Cleaned up {deleted} expired database sessions")
561 # Check local sessions against database
562 local_transports = {}
563 async with self._lock:
564 local_transports = self._sessions.copy()
566 for session_id, transport in local_transports.items():
567 try:
568 if not await transport.is_connected():
569 await self.remove_session(session_id)
570 continue
572 # Refresh session in database
573 def _refresh_session():
574 db_session = next(get_db())
575 try:
576 session = db_session.query(SessionRecord).filter(SessionRecord.session_id == session_id).first()
578 if session:
579 # Update last_accessed
580 session.last_accessed = func.now() # pylint: disable=not-callable
581 db_session.commit()
582 return True
583 return False
584 except Exception as ex:
585 db_session.rollback()
586 raise ex
587 finally:
588 db_session.close()
590 session_exists = await asyncio.to_thread(_refresh_session)
591 if not session_exists:
592 # Session no longer in database, remove locally
593 await self.remove_session(session_id)
595 except Exception as e:
596 logger.error(f"Error checking session {session_id}: {e}")
598 await asyncio.sleep(300) # Run every 5 minutes
600 except asyncio.CancelledError:
601 logger.info("Database cleanup task cancelled")
602 break
603 except Exception as e:
604 logger.error(f"Error in database cleanup task: {e}")
605 await asyncio.sleep(600) # Sleep longer on error
607 async def _memory_cleanup_task(self) -> None:
608 """Periodically clean up disconnected sessions."""
609 logger.info("Starting memory cleanup task")
610 while True:
611 try:
612 # Check all local sessions
613 local_transports = {}
614 async with self._lock:
615 local_transports = self._sessions.copy()
617 for session_id, transport in local_transports.items(): 617 ↛ 618line 617 didn't jump to line 618 because the loop on line 617 never started
618 try:
619 if not await transport.is_connected():
620 await self.remove_session(session_id)
621 except Exception as e:
622 logger.error(f"Error checking session {session_id}: {e}")
623 await self.remove_session(session_id)
625 await asyncio.sleep(60) # Run every minute
627 except asyncio.CancelledError:
628 logger.info("Memory cleanup task cancelled")
629 break
630 except Exception as e:
631 logger.error(f"Error in memory cleanup task: {e}")
632 await asyncio.sleep(300) # Sleep longer on error
634 # Handle initialize logic
635 async def handle_initialize_logic(self, body: dict) -> InitializeResult:
636 """
637 Validates the protocol version from the request body and returns an InitializeResult with server capabilities and info.
639 Args:
640 body (dict): The incoming request body.
642 Raises:
643 HTTPException: If the protocol version is missing or unsupported.
645 Returns:
646 InitializeResult: Initialization result with protocol version, capabilities, and server info.
647 """
648 protocol_version = body.get("protocol_version") or body.get("protocolVersion")
649 # body.get("capabilities", {})
650 # body.get("client_info") or body.get("clientInfo", {})
652 if not protocol_version:
653 raise HTTPException(
654 status_code=status.HTTP_400_BAD_REQUEST,
655 detail="Missing protocol version",
656 headers={"MCP-Error-Code": "-32002"},
657 )
659 if protocol_version != settings.protocol_version:
660 raise HTTPException(
661 status_code=status.HTTP_400_BAD_REQUEST,
662 detail=f"Unsupported protocol version: {protocol_version}",
663 headers={"MCP-Error-Code": "-32003"},
664 )
666 return InitializeResult(
667 protocolVersion=settings.protocol_version,
668 capabilities=ServerCapabilities(
669 prompts={"listChanged": True},
670 resources={"subscribe": True, "listChanged": True},
671 tools={"listChanged": True},
672 logging={},
673 roots={"listChanged": True},
674 sampling={},
675 ),
676 serverInfo=Implementation(name=settings.app_name, version="1.0.0"),
677 instructions=("MCP Gateway providing federated tools, resources and prompts. Use /admin interface for configuration."),
678 )
680 async def generate_response(self, message: json, transport: SSETransport, server_id: Optional[str], user: dict, base_url: str):
681 """
682 Generates response according to SSE specifications
684 Args:
685 message: Message JSON
686 transport: Transport where message should be responded in
687 server_id: Server ID
688 user: User information
689 base_url: Base URL for the FastAPI request
691 """
692 result = {}
694 if "method" in message and "id" in message:
695 method = message["method"]
696 params = message.get("params", {})
697 req_id = message["id"]
698 db = next(get_db())
699 if method == "initialize":
700 init_result = await self.handle_initialize_logic(params)
701 response = {
702 "jsonrpc": "2.0",
703 "result": init_result.model_dump(by_alias=True, exclude_none=True),
704 "id": req_id,
705 }
706 await transport.send_message(response)
707 await transport.send_message(
708 {
709 "jsonrpc": "2.0",
710 "method": "notifications/initialized",
711 "params": {},
712 }
713 )
714 notifications = [
715 "tools/list_changed",
716 "resources/list_changed",
717 "prompts/list_changed",
718 ]
719 for notification in notifications:
720 await transport.send_message(
721 {
722 "jsonrpc": "2.0",
723 "method": f"notifications/{notification}",
724 "params": {},
725 }
726 )
727 elif method == "tools/list":
728 if server_id:
729 tools = await tool_service.list_server_tools(db, server_id=server_id)
730 else:
731 tools = await tool_service.list_tools(db)
732 result = {"tools": [t.model_dump(by_alias=True, exclude_none=True) for t in tools]}
733 elif method == "resources/list":
734 if server_id: 734 ↛ 735line 734 didn't jump to line 735 because the condition on line 734 was never true
735 resources = await resource_service.list_server_resources(db, server_id=server_id)
736 else:
737 resources = await resource_service.list_resources(db)
738 result = {"resources": [r.model_dump(by_alias=True, exclude_none=True) for r in resources]}
739 elif method == "prompts/list":
740 if server_id: 740 ↛ 741line 740 didn't jump to line 741 because the condition on line 740 was never true
741 prompts = await prompt_service.list_server_prompts(db, server_id=server_id)
742 else:
743 prompts = await prompt_service.list_prompts(db)
744 result = {"prompts": [p.model_dump(by_alias=True, exclude_none=True) for p in prompts]}
745 elif method == "ping":
746 result = {}
747 elif method == "tools/call":
748 rpc_input = {
749 "jsonrpc": "2.0",
750 "method": message["params"]["name"],
751 "params": message["params"]["arguments"],
752 "id": 1,
753 }
754 headers = {"Authorization": f"Bearer {user['token']}", "Content-Type": "application/json"}
755 rpc_url = base_url + "/rpc"
756 async with httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) as client:
757 rpc_response = await client.post(
758 url=rpc_url,
759 json=rpc_input,
760 headers=headers,
761 )
762 result = rpc_response.json()
763 else:
764 result = {}
766 response = {"jsonrpc": "2.0", "result": result, "id": req_id}
767 logging.info(f"Sending sse message:{response}")
768 await transport.send_message(response)