Coverage for session_buddy / backends / redis_backend.py: 13.99%

123 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1"""Redis-based session storage backend. 

2 

3**DEPRECATED**: This module is deprecated and will be removed in v1.0. 

4Use ServerlessStorageAdapter instead, which provides Oneiric storage backends 

5and standardized lifecycle handling. 

6 

7Migration: 

8 Old: RedisStorage(config) 

9 New: ServerlessStorageAdapter(config, backend="memory") for caching 

10 

11This module provides a Redis implementation of the SessionStorage interface 

12for storing and retrieving session state in Redis with TTL support. 

13""" 

14 

15from __future__ import annotations 

16 

17import gzip 

18import json 

19import warnings 

20from typing import Any 

21 

22from session_buddy.backends.base import SessionState, SessionStorage 

23 

24 

25class RedisStorage(SessionStorage): 

26 """Redis-based session storage. 

27 

28 .. deprecated:: 0.9.3 

29 RedisStorage is deprecated. Use ``ServerlessStorageAdapter`` for Oneiric storage. 

30 

31 """ 

32 

33 def __init__(self, config: dict[str, Any]) -> None: 

34 warnings.warn( 

35 "RedisStorage is deprecated and will be removed in v1.0. " 

36 "Use ServerlessStorageAdapter for Oneiric storage instead.", 

37 DeprecationWarning, 

38 stacklevel=2, 

39 ) 

40 super().__init__(config) 

41 self.host = config.get("host", "localhost") 

42 self.port = config.get("port", 6379) 

43 self.db = config.get("db", 0) 

44 self.password = config.get("password") 

45 self.key_prefix = config.get("key_prefix", "session_mgmt:") 

46 self._redis = None 

47 

48 async def _get_redis(self) -> Any: 

49 """Get or create Redis connection.""" 

50 if self._redis is None: 

51 try: 

52 import redis.asyncio as redis 

53 

54 self._redis = redis.Redis( # type: ignore[assignment] 

55 host=self.host, 

56 port=self.port, 

57 db=self.db, 

58 password=self.password, 

59 decode_responses=False, # We handle encoding ourselves 

60 ) 

61 except ImportError: 

62 msg = "Redis package not installed. Install with: pip install redis" 

63 raise ImportError( 

64 msg, 

65 ) 

66 return self._redis 

67 

68 def _get_key(self, session_id: str) -> str: 

69 """Get Redis key for session.""" 

70 return f"{self.key_prefix}session:{session_id}" 

71 

72 def _get_index_key(self, index_type: str) -> str: 

73 """Get Redis key for index.""" 

74 return f"{self.key_prefix}index:{index_type}" 

75 

76 async def store_session( 

77 self, 

78 session_state: SessionState, 

79 ttl_seconds: int | None = None, 

80 ) -> bool: 

81 """Store session in Redis with optional TTL.""" 

82 try: 

83 redis_client = await self._get_redis() 

84 

85 # Serialize and compress session state 

86 serialized = json.dumps(session_state.to_dict()) 

87 compressed = gzip.compress(serialized.encode("utf-8")) 

88 

89 # Store session data 

90 key = self._get_key(session_state.session_id) 

91 await redis_client.set(key, compressed, ex=ttl_seconds) 

92 

93 # Update indexes 

94 user_index_key = self._get_index_key(f"user:{session_state.user_id}") 

95 project_index_key = self._get_index_key( 

96 f"project:{session_state.project_id}", 

97 ) 

98 

99 await redis_client.sadd(user_index_key, session_state.session_id) 

100 await redis_client.sadd(project_index_key, session_state.session_id) 

101 

102 # Set TTL on indexes if specified 

103 if ttl_seconds: 

104 await redis_client.expire(user_index_key, ttl_seconds) 

105 await redis_client.expire(project_index_key, ttl_seconds) 

106 

107 return True 

108 

109 except Exception as e: 

110 self.logger.exception( 

111 f"Failed to store session {session_state.session_id}: {e}", 

112 ) 

113 return False 

114 

115 async def retrieve_session(self, session_id: str) -> SessionState | None: 

116 """Retrieve session from Redis.""" 

117 try: 

118 redis_client = await self._get_redis() 

119 key = self._get_key(session_id) 

120 

121 compressed_data = await redis_client.get(key) 

122 if not compressed_data: 

123 return None 

124 

125 # Decompress and deserialize 

126 serialized = gzip.decompress(compressed_data).decode("utf-8") 

127 session_data = json.loads(serialized) 

128 

129 return SessionState.from_dict(session_data) 

130 

131 except Exception as e: 

132 self.logger.exception(f"Failed to retrieve session {session_id}: {e}") 

133 return None 

134 

135 async def delete_session(self, session_id: str) -> bool: 

136 """Delete session from Redis.""" 

137 try: 

138 redis_client = await self._get_redis() 

139 

140 # Get session to find user/project for index cleanup 

141 session_state = await self.retrieve_session(session_id) 

142 

143 # Delete session data 

144 key = self._get_key(session_id) 

145 deleted_result = await redis_client.delete(key) 

146 deleted = int(deleted_result) if deleted_result is not None else 0 

147 

148 # Clean up indexes 

149 if session_state: 

150 user_index_key = self._get_index_key(f"user:{session_state.user_id}") 

151 project_index_key = self._get_index_key( 

152 f"project:{session_state.project_id}", 

153 ) 

154 

155 await redis_client.srem(user_index_key, session_id) 

156 await redis_client.srem(project_index_key, session_id) 

157 

158 return deleted > 0 

159 

160 except Exception as e: 

161 self.logger.exception(f"Failed to delete session {session_id}: {e}") 

162 return False 

163 

164 async def list_sessions( 

165 self, 

166 user_id: str | None = None, 

167 project_id: str | None = None, 

168 ) -> list[str]: 

169 """List sessions by user or project.""" 

170 try: 

171 redis_client = await self._get_redis() 

172 

173 if user_id: 

174 index_key = self._get_index_key(f"user:{user_id}") 

175 session_ids = await redis_client.smembers(index_key) 

176 return [ 

177 sid.decode("utf-8") if isinstance(sid, bytes) else sid 

178 for sid in session_ids 

179 ] 

180 

181 if project_id: 

182 index_key = self._get_index_key(f"project:{project_id}") 

183 session_ids = await redis_client.smembers(index_key) 

184 return [ 

185 sid.decode("utf-8") if isinstance(sid, bytes) else sid 

186 for sid in session_ids 

187 ] 

188 

189 # List all sessions (expensive operation) 

190 pattern = self._get_key("*") 

191 keys = await redis_client.keys(pattern) 

192 return [ 

193 key.decode("utf-8").split(":")[-1] 

194 if isinstance(key, bytes) 

195 else key.split(":")[-1] 

196 for key in keys 

197 ] 

198 

199 except Exception as e: 

200 self.logger.exception(f"Failed to list sessions: {e}") 

201 return [] 

202 

203 async def cleanup_expired_sessions(self) -> int: 

204 """Clean up expired sessions.""" 

205 # Redis automatically handles TTL expiration 

206 # This method could scan for orphaned index entries 

207 try: 

208 redis_client = await self._get_redis() 

209 index_keys = await self._get_index_keys(redis_client) 

210 

211 cleaned = 0 

212 for index_key in index_keys: 

213 cleaned += await self._cleanup_index_key(redis_client, index_key) 

214 

215 return cleaned 

216 

217 except Exception as e: 

218 self.logger.exception(f"Failed to cleanup expired sessions: {e}") 

219 return 0 

220 

221 async def _get_index_keys(self, redis_client: Any) -> list[str]: 

222 """Get all index keys for cleanup.""" 

223 index_pattern = self._get_index_key("*") 

224 raw_keys = await redis_client.keys(index_pattern) 

225 

226 return [ 

227 key.decode("utf-8") if isinstance(key, bytes) else key for key in raw_keys 

228 ] 

229 

230 async def _cleanup_index_key(self, redis_client: Any, index_key: str) -> int: 

231 """Clean up orphaned sessions from a single index key.""" 

232 session_ids = await redis_client.smembers(index_key) 

233 cleaned = 0 

234 

235 for session_id in session_ids: 

236 if await self._is_orphaned_session(redis_client, session_id): 

237 await redis_client.srem(index_key, session_id) 

238 cleaned += 1 

239 

240 return cleaned 

241 

242 async def _is_orphaned_session(self, redis_client: Any, session_id: Any) -> bool: 

243 """Check if a session ID refers to an orphaned session.""" 

244 if isinstance(session_id, bytes): 

245 session_id = session_id.decode("utf-8") 

246 

247 session_key = self._get_key(session_id) 

248 return not await redis_client.exists(session_key) 

249 

250 async def is_available(self) -> bool: 

251 """Check if Redis is available.""" 

252 try: 

253 redis_client = await self._get_redis() 

254 await redis_client.ping() 

255 return True 

256 except Exception: 

257 return False