Coverage for session_mgmt_mcp/serverless_mode.py: 15.28%
416 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 05:22 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 05:22 -0700
1#!/usr/bin/env python3
2"""Stateless/Serverless Mode for Session Management MCP Server.
4Enables request-scoped sessions with external storage backends (Redis, S3, DynamoDB).
5Allows the session management server to operate in cloud/serverless environments.
6"""
8import asyncio
9import gzip
10import hashlib
11import json
12import logging
13from abc import ABC, abstractmethod
14from dataclasses import asdict, dataclass
15from datetime import datetime, timedelta
16from pathlib import Path
17from typing import Any
20@dataclass
21class SessionState:
22 """Represents complete session state for serialization."""
24 session_id: str
25 user_id: str
26 project_id: str
27 created_at: str
28 last_activity: str
29 permissions: list[str]
30 conversation_history: list[dict[str, Any]]
31 reflection_data: dict[str, Any]
32 app_monitoring_state: dict[str, Any]
33 llm_provider_configs: dict[str, Any]
34 metadata: dict[str, Any]
36 def to_dict(self) -> dict[str, Any]:
37 """Convert to dictionary for serialization."""
38 return asdict(self)
40 @classmethod
41 def from_dict(cls, data: dict[str, Any]) -> "SessionState":
42 """Create from dictionary."""
43 return cls(**data)
45 def get_compressed_size(self) -> int:
46 """Get compressed size of session state."""
47 serialized = json.dumps(self.to_dict())
48 compressed = gzip.compress(serialized.encode("utf-8"))
49 return len(compressed)
52class SessionStorage(ABC):
53 """Abstract base class for session storage backends."""
55 def __init__(self, config: dict[str, Any]) -> None:
56 self.config = config
57 self.logger = logging.getLogger(f"serverless.{self.__class__.__name__.lower()}")
59 @abstractmethod
60 async def store_session(
61 self,
62 session_state: SessionState,
63 ttl_seconds: int | None = None,
64 ) -> bool:
65 """Store session state with optional TTL."""
67 @abstractmethod
68 async def retrieve_session(self, session_id: str) -> SessionState | None:
69 """Retrieve session state by ID."""
71 @abstractmethod
72 async def delete_session(self, session_id: str) -> bool:
73 """Delete session state."""
75 @abstractmethod
76 async def list_sessions(
77 self,
78 user_id: str | None = None,
79 project_id: str | None = None,
80 ) -> list[str]:
81 """List session IDs matching criteria."""
83 @abstractmethod
84 async def cleanup_expired_sessions(self) -> int:
85 """Clean up expired sessions, return count removed."""
87 @abstractmethod
88 async def is_available(self) -> bool:
89 """Check if storage backend is available."""
92class RedisStorage(SessionStorage):
93 """Redis-based session storage."""
95 def __init__(self, config: dict[str, Any]) -> None:
96 super().__init__(config)
97 self.host = config.get("host", "localhost")
98 self.port = config.get("port", 6379)
99 self.db = config.get("db", 0)
100 self.password = config.get("password")
101 self.key_prefix = config.get("key_prefix", "session_mgmt:")
102 self._redis = None
104 async def _get_redis(self):
105 """Get or create Redis connection."""
106 if self._redis is None:
107 try:
108 import redis.asyncio as redis
110 self._redis = redis.Redis(
111 host=self.host,
112 port=self.port,
113 db=self.db,
114 password=self.password,
115 decode_responses=False, # We handle encoding ourselves
116 )
117 except ImportError:
118 msg = "Redis package not installed. Install with: pip install redis"
119 raise ImportError(
120 msg,
121 )
122 return self._redis
124 def _get_key(self, session_id: str) -> str:
125 """Get Redis key for session."""
126 return f"{self.key_prefix}session:{session_id}"
128 def _get_index_key(self, index_type: str) -> str:
129 """Get Redis key for index."""
130 return f"{self.key_prefix}index:{index_type}"
132 async def store_session(
133 self,
134 session_state: SessionState,
135 ttl_seconds: int | None = None,
136 ) -> bool:
137 """Store session in Redis with optional TTL."""
138 try:
139 redis_client = await self._get_redis()
141 # Serialize and compress session state
142 serialized = json.dumps(session_state.to_dict())
143 compressed = gzip.compress(serialized.encode("utf-8"))
145 # Store session data
146 key = self._get_key(session_state.session_id)
147 await redis_client.set(key, compressed, ex=ttl_seconds)
149 # Update indexes
150 user_index_key = self._get_index_key(f"user:{session_state.user_id}")
151 project_index_key = self._get_index_key(
152 f"project:{session_state.project_id}",
153 )
155 await redis_client.sadd(user_index_key, session_state.session_id)
156 await redis_client.sadd(project_index_key, session_state.session_id)
158 # Set TTL on indexes if specified
159 if ttl_seconds:
160 await redis_client.expire(user_index_key, ttl_seconds)
161 await redis_client.expire(project_index_key, ttl_seconds)
163 return True
165 except Exception as e:
166 self.logger.exception(
167 f"Failed to store session {session_state.session_id}: {e}",
168 )
169 return False
171 async def retrieve_session(self, session_id: str) -> SessionState | None:
172 """Retrieve session from Redis."""
173 try:
174 redis_client = await self._get_redis()
175 key = self._get_key(session_id)
177 compressed_data = await redis_client.get(key)
178 if not compressed_data:
179 return None
181 # Decompress and deserialize
182 serialized = gzip.decompress(compressed_data).decode("utf-8")
183 session_data = json.loads(serialized)
185 return SessionState.from_dict(session_data)
187 except Exception as e:
188 self.logger.exception(f"Failed to retrieve session {session_id}: {e}")
189 return None
191 async def delete_session(self, session_id: str) -> bool:
192 """Delete session from Redis."""
193 try:
194 redis_client = await self._get_redis()
196 # Get session to find user/project for index cleanup
197 session_state = await self.retrieve_session(session_id)
199 # Delete session data
200 key = self._get_key(session_id)
201 deleted = await redis_client.delete(key)
203 # Clean up indexes
204 if session_state:
205 user_index_key = self._get_index_key(f"user:{session_state.user_id}")
206 project_index_key = self._get_index_key(
207 f"project:{session_state.project_id}",
208 )
210 await redis_client.srem(user_index_key, session_id)
211 await redis_client.srem(project_index_key, session_id)
213 return deleted > 0
215 except Exception as e:
216 self.logger.exception(f"Failed to delete session {session_id}: {e}")
217 return False
219 async def list_sessions(
220 self,
221 user_id: str | None = None,
222 project_id: str | None = None,
223 ) -> list[str]:
224 """List sessions by user or project."""
225 try:
226 redis_client = await self._get_redis()
228 if user_id:
229 index_key = self._get_index_key(f"user:{user_id}")
230 session_ids = await redis_client.smembers(index_key)
231 return [
232 sid.decode("utf-8") if isinstance(sid, bytes) else sid
233 for sid in session_ids
234 ]
236 if project_id:
237 index_key = self._get_index_key(f"project:{project_id}")
238 session_ids = await redis_client.smembers(index_key)
239 return [
240 sid.decode("utf-8") if isinstance(sid, bytes) else sid
241 for sid in session_ids
242 ]
244 # List all sessions (expensive operation)
245 pattern = self._get_key("*")
246 keys = await redis_client.keys(pattern)
247 return [
248 key.decode("utf-8").split(":")[-1]
249 if isinstance(key, bytes)
250 else key.split(":")[-1]
251 for key in keys
252 ]
254 except Exception as e:
255 self.logger.exception(f"Failed to list sessions: {e}")
256 return []
258 async def cleanup_expired_sessions(self) -> int:
259 """Clean up expired sessions."""
260 # Redis automatically handles TTL expiration
261 # This method could scan for orphaned index entries
262 try:
263 redis_client = await self._get_redis()
265 # Scan for index entries that point to non-existent sessions
266 cleaned = 0
267 index_pattern = self._get_index_key("*")
268 index_keys = await redis_client.keys(index_pattern)
270 for index_key in index_keys:
271 if isinstance(index_key, bytes):
272 index_key = index_key.decode("utf-8")
274 session_ids = await redis_client.smembers(index_key)
275 for session_id in session_ids:
276 if isinstance(session_id, bytes):
277 session_id = session_id.decode("utf-8")
279 session_key = self._get_key(session_id)
280 exists = await redis_client.exists(session_key)
282 if not exists:
283 await redis_client.srem(index_key, session_id)
284 cleaned += 1
286 return cleaned
288 except Exception as e:
289 self.logger.exception(f"Failed to cleanup expired sessions: {e}")
290 return 0
292 async def is_available(self) -> bool:
293 """Check if Redis is available."""
294 try:
295 redis_client = await self._get_redis()
296 await redis_client.ping()
297 return True
298 except Exception:
299 return False
302class S3Storage(SessionStorage):
303 """S3-based session storage."""
305 def __init__(self, config: dict[str, Any]) -> None:
306 super().__init__(config)
307 self.bucket_name = config.get("bucket_name", "session-mgmt-mcp")
308 self.region = config.get("region", "us-east-1")
309 self.key_prefix = config.get("key_prefix", "sessions/")
310 self.access_key_id = config.get("access_key_id")
311 self.secret_access_key = config.get("secret_access_key")
312 self._s3_client = None
314 async def _get_s3_client(self):
315 """Get or create S3 client."""
316 if self._s3_client is None:
317 try:
318 import boto3
319 from botocore.client import Config
321 session = boto3.Session(
322 aws_access_key_id=self.access_key_id,
323 aws_secret_access_key=self.secret_access_key,
324 region_name=self.region,
325 )
327 self._s3_client = session.client(
328 "s3",
329 config=Config(retries={"max_attempts": 3}, max_pool_connections=50),
330 )
331 except ImportError:
332 msg = "Boto3 package not installed. Install with: pip install boto3"
333 raise ImportError(
334 msg,
335 )
337 return self._s3_client
339 def _get_key(self, session_id: str) -> str:
340 """Get S3 key for session."""
341 return f"{self.key_prefix}{session_id}.json.gz"
343 async def store_session(
344 self,
345 session_state: SessionState,
346 ttl_seconds: int | None = None,
347 ) -> bool:
348 """Store session in S3."""
349 try:
350 s3_client = await self._get_s3_client()
352 # Serialize and compress session state
353 serialized = json.dumps(session_state.to_dict())
354 compressed = gzip.compress(serialized.encode("utf-8"))
356 # Prepare S3 object metadata
357 metadata = {
358 "user_id": session_state.user_id,
359 "project_id": session_state.project_id,
360 "created_at": session_state.created_at,
361 "last_activity": session_state.last_activity,
362 }
364 # Set expiration if TTL specified
365 expires = None
366 if ttl_seconds:
367 expires = datetime.utcnow() + timedelta(seconds=ttl_seconds)
369 # Upload to S3
370 key = self._get_key(session_state.session_id)
372 put_args = {
373 "Bucket": self.bucket_name,
374 "Key": key,
375 "Body": compressed,
376 "ContentType": "application/json",
377 "ContentEncoding": "gzip",
378 "Metadata": metadata,
379 }
381 if expires:
382 put_args["Expires"] = expires
384 # Execute in thread pool since boto3 is synchronous
385 loop = asyncio.get_event_loop()
386 await loop.run_in_executor(None, lambda: s3_client.put_object(**put_args))
388 return True
390 except Exception as e:
391 self.logger.exception(
392 f"Failed to store session {session_state.session_id}: {e}",
393 )
394 return False
396 async def retrieve_session(self, session_id: str) -> SessionState | None:
397 """Retrieve session from S3."""
398 try:
399 s3_client = await self._get_s3_client()
400 key = self._get_key(session_id)
402 # Download from S3
403 loop = asyncio.get_event_loop()
404 response = await loop.run_in_executor(
405 None,
406 lambda: s3_client.get_object(Bucket=self.bucket_name, Key=key),
407 )
409 # Decompress and deserialize
410 compressed_data = response["Body"].read()
411 serialized = gzip.decompress(compressed_data).decode("utf-8")
412 session_data = json.loads(serialized)
414 return SessionState.from_dict(session_data)
416 except Exception as e:
417 self.logger.exception(f"Failed to retrieve session {session_id}: {e}")
418 return None
420 async def delete_session(self, session_id: str) -> bool:
421 """Delete session from S3."""
422 try:
423 s3_client = await self._get_s3_client()
424 key = self._get_key(session_id)
426 loop = asyncio.get_event_loop()
427 await loop.run_in_executor(
428 None,
429 lambda: s3_client.delete_object(Bucket=self.bucket_name, Key=key),
430 )
432 return True
434 except Exception as e:
435 self.logger.exception(f"Failed to delete session {session_id}: {e}")
436 return False
438 async def list_sessions(
439 self,
440 user_id: str | None = None,
441 project_id: str | None = None,
442 ) -> list[str]:
443 """List sessions in S3."""
444 try:
445 s3_client = await self._get_s3_client()
447 # List objects with prefix
448 loop = asyncio.get_event_loop()
449 response = await loop.run_in_executor(
450 None,
451 lambda: s3_client.list_objects_v2(
452 Bucket=self.bucket_name,
453 Prefix=self.key_prefix,
454 ),
455 )
457 session_ids = []
458 for obj in response.get("Contents", []):
459 key = obj["Key"]
460 session_id = key.replace(self.key_prefix, "").replace(".json.gz", "")
462 # Filter by user_id or project_id if specified
463 if user_id or project_id:
464 # Get object metadata to filter
465 head_response = await loop.run_in_executor(
466 None,
467 lambda: s3_client.head_object(Bucket=self.bucket_name, Key=key),
468 )
470 metadata = head_response.get("Metadata", {})
472 if user_id and metadata.get("user_id") != user_id:
473 continue
474 if project_id and metadata.get("project_id") != project_id:
475 continue
477 session_ids.append(session_id)
479 return session_ids
481 except Exception as e:
482 self.logger.exception(f"Failed to list sessions: {e}")
483 return []
485 async def cleanup_expired_sessions(self) -> int:
486 """Clean up expired sessions from S3."""
487 try:
488 s3_client = await self._get_s3_client()
490 # S3 lifecycle policies handle expiration automatically
491 # This could implement custom logic for old sessions
493 now = datetime.utcnow()
494 cleaned = 0
496 loop = asyncio.get_event_loop()
497 response = await loop.run_in_executor(
498 None,
499 lambda: s3_client.list_objects_v2(
500 Bucket=self.bucket_name,
501 Prefix=self.key_prefix,
502 ),
503 )
505 for obj in response.get("Contents", []):
506 # Check if object is expired (custom logic)
507 last_modified = obj["LastModified"].replace(tzinfo=None)
508 age_days = (now - last_modified).days
510 if age_days > 30: # Cleanup sessions older than 30 days
511 await loop.run_in_executor(
512 None,
513 lambda: s3_client.delete_object(
514 Bucket=self.bucket_name,
515 Key=obj["Key"],
516 ),
517 )
518 cleaned += 1
520 return cleaned
522 except Exception as e:
523 self.logger.exception(f"Failed to cleanup expired sessions: {e}")
524 return 0
526 async def is_available(self) -> bool:
527 """Check if S3 is available."""
528 try:
529 s3_client = await self._get_s3_client()
531 # Test bucket access
532 loop = asyncio.get_event_loop()
533 await loop.run_in_executor(
534 None,
535 lambda: s3_client.head_bucket(Bucket=self.bucket_name),
536 )
538 return True
539 except Exception:
540 return False
543class LocalFileStorage(SessionStorage):
544 """Local file-based session storage (for development/testing)."""
546 def __init__(self, config: dict[str, Any]) -> None:
547 super().__init__(config)
548 self.storage_dir = Path(
549 config.get("storage_dir", Path.home() / ".claude" / "data" / "sessions"),
550 )
551 self.storage_dir.mkdir(parents=True, exist_ok=True)
553 def _get_session_file(self, session_id: str) -> Path:
554 """Get file path for session."""
555 return self.storage_dir / f"{session_id}.json.gz"
557 async def store_session(
558 self,
559 session_state: SessionState,
560 ttl_seconds: int | None = None,
561 ) -> bool:
562 """Store session in local file."""
563 try:
564 # Serialize and compress session state
565 serialized = json.dumps(session_state.to_dict())
566 compressed = gzip.compress(serialized.encode("utf-8"))
568 # Write to file
569 session_file = self._get_session_file(session_state.session_id)
570 with open(session_file, "wb") as f:
571 f.write(compressed)
573 return True
575 except Exception as e:
576 self.logger.exception(
577 f"Failed to store session {session_state.session_id}: {e}",
578 )
579 return False
581 async def retrieve_session(self, session_id: str) -> SessionState | None:
582 """Retrieve session from local file."""
583 try:
584 session_file = self._get_session_file(session_id)
586 if not session_file.exists():
587 return None
589 # Read and decompress
590 with open(session_file, "rb") as f:
591 compressed_data = f.read()
593 serialized = gzip.decompress(compressed_data).decode("utf-8")
594 session_data = json.loads(serialized)
596 return SessionState.from_dict(session_data)
598 except Exception as e:
599 self.logger.exception(f"Failed to retrieve session {session_id}: {e}")
600 return None
602 async def delete_session(self, session_id: str) -> bool:
603 """Delete session file."""
604 try:
605 session_file = self._get_session_file(session_id)
607 if session_file.exists():
608 session_file.unlink()
609 return True
611 return False
613 except Exception as e:
614 self.logger.exception(f"Failed to delete session {session_id}: {e}")
615 return False
617 async def list_sessions(
618 self,
619 user_id: str | None = None,
620 project_id: str | None = None,
621 ) -> list[str]:
622 """List session files."""
623 try:
624 session_ids = []
626 for session_file in self.storage_dir.glob("*.json.gz"):
627 session_id = session_file.stem.replace(".json", "")
629 # Filter by user_id or project_id if specified
630 if user_id or project_id:
631 session_state = await self.retrieve_session(session_id)
632 if not session_state:
633 continue
635 if user_id and session_state.user_id != user_id:
636 continue
637 if project_id and session_state.project_id != project_id:
638 continue
640 session_ids.append(session_id)
642 return session_ids
644 except Exception as e:
645 self.logger.exception(f"Failed to list sessions: {e}")
646 return []
648 async def cleanup_expired_sessions(self) -> int:
649 """Clean up old session files."""
650 try:
651 now = datetime.now()
652 cleaned = 0
654 for session_file in self.storage_dir.glob("*.json.gz"):
655 # Check file age
656 file_age = now - datetime.fromtimestamp(session_file.stat().st_mtime)
658 if file_age.days > 7: # Cleanup sessions older than 7 days
659 session_file.unlink()
660 cleaned += 1
662 return cleaned
664 except Exception as e:
665 self.logger.exception(f"Failed to cleanup expired sessions: {e}")
666 return 0
668 async def is_available(self) -> bool:
669 """Check if local storage is available."""
670 return self.storage_dir.exists() and self.storage_dir.is_dir()
673class ServerlessSessionManager:
674 """Main session manager for serverless/stateless operation."""
676 def __init__(self, storage_backend: SessionStorage) -> None:
677 self.storage = storage_backend
678 self.logger = logging.getLogger("serverless.session_manager")
679 self.session_cache = {} # In-memory cache for current request
681 async def create_session(
682 self,
683 user_id: str,
684 project_id: str,
685 session_data: dict[str, Any] | None = None,
686 ttl_hours: int = 24,
687 ) -> str:
688 """Create new session."""
689 session_id = self._generate_session_id(user_id, project_id)
691 session_state = SessionState(
692 session_id=session_id,
693 user_id=user_id,
694 project_id=project_id,
695 created_at=datetime.now().isoformat(),
696 last_activity=datetime.now().isoformat(),
697 permissions=[],
698 conversation_history=[],
699 reflection_data={},
700 app_monitoring_state={},
701 llm_provider_configs={},
702 metadata=session_data or {},
703 )
705 # Store with TTL
706 ttl_seconds = ttl_hours * 3600
707 success = await self.storage.store_session(session_state, ttl_seconds)
709 if success:
710 self.session_cache[session_id] = session_state
711 return session_id
712 msg = "Failed to create session"
713 raise RuntimeError(msg)
715 async def get_session(self, session_id: str) -> SessionState | None:
716 """Get session state."""
717 # Check cache first
718 if session_id in self.session_cache:
719 return self.session_cache[session_id]
721 # Load from storage
722 session_state = await self.storage.retrieve_session(session_id)
723 if session_state:
724 self.session_cache[session_id] = session_state
726 return session_state
728 async def update_session(
729 self,
730 session_id: str,
731 updates: dict[str, Any],
732 ttl_hours: int | None = None,
733 ) -> bool:
734 """Update session state."""
735 session_state = await self.get_session(session_id)
736 if not session_state:
737 return False
739 # Apply updates
740 for key, value in updates.items():
741 if hasattr(session_state, key):
742 setattr(session_state, key, value)
744 # Update last activity
745 session_state.last_activity = datetime.now().isoformat()
747 # Store updated state
748 ttl_seconds = ttl_hours * 3600 if ttl_hours else None
749 success = await self.storage.store_session(session_state, ttl_seconds)
751 if success:
752 self.session_cache[session_id] = session_state
754 return success
756 async def delete_session(self, session_id: str) -> bool:
757 """Delete session."""
758 # Remove from cache
759 self.session_cache.pop(session_id, None)
761 # Delete from storage
762 return await self.storage.delete_session(session_id)
764 async def list_user_sessions(self, user_id: str) -> list[str]:
765 """List sessions for user."""
766 return await self.storage.list_sessions(user_id=user_id)
768 async def list_project_sessions(self, project_id: str) -> list[str]:
769 """List sessions for project."""
770 return await self.storage.list_sessions(project_id=project_id)
772 async def cleanup_sessions(self) -> int:
773 """Clean up expired sessions."""
774 return await self.storage.cleanup_expired_sessions()
776 def _generate_session_id(self, user_id: str, project_id: str) -> str:
777 """Generate unique session ID."""
778 timestamp = datetime.now().isoformat()
779 data = f"{user_id}:{project_id}:{timestamp}"
780 return hashlib.sha256(data.encode()).hexdigest()[:16]
782 def get_session_stats(self) -> dict[str, Any]:
783 """Get session statistics."""
784 return {
785 "cached_sessions": len(self.session_cache),
786 "storage_backend": self.storage.__class__.__name__,
787 "storage_config": {
788 k: v for k, v in self.storage.config.items() if "key" not in k.lower()
789 },
790 }
793class ServerlessConfigManager:
794 """Manages configuration for serverless mode."""
796 @staticmethod
797 def load_config(config_path: str | None = None) -> dict[str, Any]:
798 """Load serverless configuration."""
799 default_config = {
800 "storage_backend": "local",
801 "session_ttl_hours": 24,
802 "cleanup_interval_hours": 6,
803 "backends": {
804 "redis": {
805 "host": "localhost",
806 "port": 6379,
807 "db": 0,
808 "key_prefix": "session_mgmt:",
809 },
810 "s3": {
811 "bucket_name": "session-mgmt-mcp",
812 "region": "us-east-1",
813 "key_prefix": "sessions/",
814 },
815 "local": {
816 "storage_dir": str(Path.home() / ".claude" / "data" / "sessions"),
817 },
818 },
819 }
821 if config_path and Path(config_path).exists():
822 try:
823 with open(config_path) as f:
824 file_config = json.load(f)
825 default_config.update(file_config)
826 except (OSError, json.JSONDecodeError):
827 pass
829 return default_config
831 @staticmethod
832 def create_storage_backend(config: dict[str, Any]) -> SessionStorage:
833 """Create storage backend from config."""
834 backend_type = config.get("storage_backend", "local")
835 backend_config = config.get("backends", {}).get(backend_type, {})
837 if backend_type == "redis":
838 return RedisStorage(backend_config)
839 if backend_type == "s3":
840 return S3Storage(backend_config)
841 if backend_type == "local":
842 return LocalFileStorage(backend_config)
843 msg = f"Unsupported storage backend: {backend_type}"
844 raise ValueError(msg)
846 @staticmethod
847 async def test_storage_backends(config: dict[str, Any]) -> dict[str, bool]:
848 """Test all configured storage backends."""
849 results = {}
851 for backend_name, backend_config in config.get("backends", {}).items():
852 try:
853 if backend_name == "redis":
854 storage = RedisStorage(backend_config)
855 elif backend_name == "s3":
856 storage = S3Storage(backend_config)
857 elif backend_name == "local":
858 storage = LocalFileStorage(backend_config)
859 else:
860 results[backend_name] = False
861 continue
863 results[backend_name] = await storage.is_available()
865 except Exception:
866 results[backend_name] = False
868 return results