Coverage for src / documint_mcp / utils / cache.py: 0%

280 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 22:30 -0400

1"""High-performance caching module with memory management optimisations.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import json 

7import logging 

8import time 

9from collections import OrderedDict 

10from dataclasses import dataclass 

11from sys import getsizeof 

12from types import ModuleType 

13from typing import Any 

14from urllib.parse import urlparse 

15 

16from ..config import CACHE_CONFIG, settings 

17 

18redis: ModuleType | None 

19try: 

20 import redis.asyncio as redis 

21except ImportError: 

22 redis = None 

23 

24logger = logging.getLogger(__name__) 

25 

26 

27@dataclass 

28class CacheEntry: 

29 """Cache entry with metadata for efficient management.""" 

30 

31 value: Any 

32 timestamp: float 

33 ttl: int 

34 access_count: int = 0 

35 last_access: float = 0.0 

36 

37 def is_expired(self) -> bool: 

38 """Check if cache entry is expired.""" 

39 return time.time() - self.timestamp > self.ttl 

40 

41 def update_access(self) -> None: 

42 """Update access statistics.""" 

43 self.access_count += 1 

44 self.last_access = time.time() 

45 

46 

47class MemoryCache: 

48 """ 

49 In-memory cache with LRU eviction and memory management. 

50 

51 Performance Optimization #2: Efficient memory usage 

52 - LRU eviction policy 

53 - Memory usage monitoring 

54 - Automatic cleanup of expired entries 

55 """ 

56 

57 def __init__(self, max_size: int = 1000, default_ttl: int = 3600): 

58 self.max_size = max_size 

59 self.default_ttl = default_ttl 

60 self._cache: dict[str, CacheEntry] = {} 

61 self._access_order: OrderedDict[str, None] = OrderedDict() 

62 self._lock = asyncio.Lock() 

63 self._cleanup_task: asyncio.Task[None] | None = None 

64 self._memory_threshold = 100 * 1024 * 1024 # 100MB 

65 self._cleanup_started = False 

66 self._hits = 0 

67 self._misses = 0 

68 

69 def _start_cleanup_task(self) -> None: 

70 """Start background cleanup task if not already started.""" 

71 if self._cleanup_started: 

72 return 

73 try: 

74 # Only start cleanup if there's a running event loop 

75 loop = asyncio.get_running_loop() 

76 except RuntimeError: 

77 # No event loop running yet, will start on first use 

78 return 

79 

80 if self._cleanup_task is None or self._cleanup_task.done(): 

81 self._cleanup_task = loop.create_task(self._cleanup_loop()) 

82 self._cleanup_started = True 

83 

84 async def _cleanup_loop(self) -> None: 

85 """Background cleanup loop to prevent memory leaks.""" 

86 while True: 

87 try: 

88 await asyncio.sleep(60) # Run every minute 

89 await self._cleanup_expired() 

90 await self._check_memory_usage() 

91 except asyncio.CancelledError: 

92 break 

93 except Exception as e: 

94 logger.error(f"Cleanup loop error: {e}") 

95 self._cleanup_started = False 

96 

97 async def _cleanup_expired(self) -> None: 

98 """Remove expired entries from cache.""" 

99 async with self._lock: 

100 expired_keys = [] 

101 for key, entry in self._cache.items(): 

102 if entry.is_expired(): 

103 expired_keys.append(key) 

104 

105 for key in expired_keys: 

106 del self._cache[key] 

107 self._access_order.pop(key, None) 

108 

109 if expired_keys: 

110 logger.debug(f"Cleaned up {len(expired_keys)} expired cache entries") 

111 

112 async def _check_memory_usage(self) -> None: 

113 """Monitor memory usage and evict entries if needed.""" 

114 async with self._lock: 

115 estimated_size = self._estimate_memory_usage() 

116 

117 if estimated_size > self._memory_threshold and self._cache: 

118 await self._evict_lru_entries(max(1, int(len(self._cache) * 0.2))) 

119 

120 async def _evict_lru_entries(self, count: int) -> None: 

121 """Evict least recently used entries.""" 

122 if not self._cache: 

123 return 

124 

125 evicted_count = 0 

126 while evicted_count < count and self._access_order: 

127 key, _ = self._access_order.popitem(last=False) 

128 if key in self._cache: 

129 del self._cache[key] 

130 evicted_count += 1 

131 

132 if evicted_count: 

133 logger.debug(f"Evicted {evicted_count} LRU cache entries") 

134 

135 def _estimate_memory_usage(self) -> int: 

136 """Estimate memory usage of the in-memory cache.""" 

137 size = getsizeof(self._cache) + getsizeof(self._access_order) 

138 for key, entry in self._cache.items(): 

139 size += getsizeof(key) 

140 size += getsizeof(entry) 

141 size += getsizeof(entry.value) 

142 for key in self._access_order.keys(): 

143 size += getsizeof(key) 

144 return size 

145 

146 async def get(self, key: str) -> Any | None: 

147 """Get value from cache.""" 

148 # Start cleanup task on first use 

149 self._start_cleanup_task() 

150 

151 async with self._lock: 

152 entry = self._cache.get(key) 

153 if entry is None: 

154 self._misses += 1 

155 return None 

156 

157 if entry.is_expired(): 

158 del self._cache[key] 

159 self._access_order.pop(key, None) 

160 self._misses += 1 

161 return None 

162 

163 entry.update_access() 

164 self._access_order.pop(key, None) 

165 self._access_order[key] = None 

166 self._hits += 1 

167 return entry.value 

168 

169 async def set(self, key: str, value: Any, ttl: int | None = None) -> None: 

170 """Set value in cache.""" 

171 # Start cleanup task on first use 

172 self._start_cleanup_task() 

173 

174 async with self._lock: 

175 if len(self._cache) >= self.max_size: 

176 await self._evict_lru_entries(1) 

177 

178 entry = CacheEntry( 

179 value=value, timestamp=time.time(), ttl=ttl or self.default_ttl 

180 ) 

181 

182 self._cache[key] = entry 

183 self._access_order.pop(key, None) 

184 self._access_order[key] = None 

185 

186 async def delete(self, key: str) -> bool: 

187 """Delete key from cache.""" 

188 async with self._lock: 

189 if key in self._cache: 

190 del self._cache[key] 

191 self._access_order.pop(key, None) 

192 return True 

193 return False 

194 

195 async def clear(self) -> None: 

196 """Clear all cache entries.""" 

197 async with self._lock: 

198 self._cache.clear() 

199 self._access_order.clear() 

200 self._hits = 0 

201 self._misses = 0 

202 

203 def get_stats(self) -> dict[str, Any]: 

204 """Get cache statistics.""" 

205 return { 

206 "size": len(self._cache), 

207 "max_size": self.max_size, 

208 "memory_usage": self._estimate_memory_usage(), 

209 "hits": self._hits, 

210 "misses": self._misses, 

211 "hit_rate": self._calculate_hit_rate(), 

212 "cleanup_active": bool( 

213 self._cleanup_task and not self._cleanup_task.done() 

214 ), 

215 } 

216 

217 def _calculate_hit_rate(self) -> float: 

218 """Calculate cache hit rate.""" 

219 total_requests = self._hits + self._misses 

220 if total_requests == 0: 

221 return 0.0 

222 return self._hits / total_requests 

223 

224 async def close(self) -> None: 

225 """Close cache and cleanup resources.""" 

226 if self._cleanup_task: 

227 self._cleanup_task.cancel() 

228 try: 

229 await self._cleanup_task 

230 except asyncio.CancelledError: 

231 pass 

232 

233 await self.clear() 

234 self._cleanup_started = False 

235 

236 async def force_cleanup(self) -> None: 

237 """Run cleanup routines immediately. 

238 

239 Exposed primarily for unit testing and emergency maintenance tasks where we 

240 want deterministic cleanup behaviour without waiting for the background 

241 task interval. 

242 """ 

243 

244 await self._cleanup_expired() 

245 await self._check_memory_usage() 

246 

247 

248class RedisCache: 

249 """ 

250 Redis-based distributed cache with connection pooling. 

251 

252 Performance Optimization #3: Connection pooling and async operations 

253 - Connection pooling for better performance 

254 - Async operations to prevent blocking 

255 - Automatic reconnection handling 

256 """ 

257 

258 def __init__(self, redis_url: str | None = None) -> None: 

259 self.redis_url = redis_url or settings.redis_url 

260 self.pool: Any | None = None 

261 self.client: Any | None = None 

262 self._connected = False 

263 self._last_error: str | None = None 

264 

265 async def connect(self) -> None: 

266 """Connect to Redis with connection pooling.""" 

267 if redis is None: 

268 logger.warning("Redis not available, falling back to memory cache") 

269 return 

270 

271 try: 

272 self.pool = redis.ConnectionPool.from_url( 

273 self.redis_url, 

274 max_connections=CACHE_CONFIG["max_connections"], 

275 retry_on_timeout=CACHE_CONFIG["retry_on_timeout"], 

276 socket_connect_timeout=CACHE_CONFIG["socket_connect_timeout"], 

277 socket_timeout=CACHE_CONFIG["socket_timeout"], 

278 ) 

279 

280 self.client = redis.Redis(connection_pool=self.pool) 

281 

282 # Test connection 

283 await self.client.ping() 

284 self._connected = True 

285 self._last_error = None 

286 logger.info("Connected to Redis cache") 

287 

288 except Exception as e: 

289 logger.error(f"Failed to connect to Redis: {e}") 

290 self._connected = False 

291 self._last_error = str(e) 

292 

293 async def get(self, key: str) -> Any | None: 

294 """Get value from Redis cache.""" 

295 if not self._connected or not self.client: 

296 return None 

297 

298 try: 

299 value = await self.client.get(key) 

300 if value: 

301 return json.loads(value) 

302 return None 

303 except Exception as e: 

304 logger.error(f"Redis get error: {e}") 

305 self._last_error = str(e) 

306 return None 

307 

308 async def set(self, key: str, value: Any, ttl: int | None = None) -> bool: 

309 """Set value in Redis cache.""" 

310 if not self._connected or not self.client: 

311 return False 

312 

313 try: 

314 serialized = json.dumps(value) 

315 if ttl: 

316 await self.client.setex(key, ttl, serialized) 

317 else: 

318 await self.client.set(key, serialized) 

319 return True 

320 except Exception as e: 

321 logger.error(f"Redis set error: {e}") 

322 self._last_error = str(e) 

323 return False 

324 

325 async def delete(self, key: str) -> bool: 

326 """Delete key from Redis cache.""" 

327 if not self._connected or not self.client: 

328 return False 

329 

330 try: 

331 result = await self.client.delete(key) 

332 return result > 0 

333 except Exception as e: 

334 logger.error(f"Redis delete error: {e}") 

335 self._last_error = str(e) 

336 return False 

337 

338 async def clear(self) -> None: 

339 """Clear all cache entries.""" 

340 if not self._connected or not self.client: 

341 return 

342 

343 try: 

344 await self.client.flushdb() 

345 except Exception as e: 

346 logger.error(f"Redis clear error: {e}") 

347 self._last_error = str(e) 

348 

349 async def close(self) -> None: 

350 """Close Redis connection.""" 

351 if self.client: 

352 await self.client.close() 

353 if self.pool: 

354 await self.pool.disconnect() 

355 self._connected = False 

356 self._last_error = None 

357 

358 def get_stats(self) -> dict[str, Any]: 

359 """Expose Redis cache health information without leaking secrets.""" 

360 

361 redacted_url = self._redact_url(self.redis_url) 

362 stats: dict[str, Any] = {"connected": self._connected, "url": redacted_url} 

363 if self._last_error: 

364 stats["last_error"] = self._last_error 

365 return stats 

366 

367 def _redact_url(self, raw_url: str) -> str: 

368 parsed = urlparse(raw_url) 

369 hostname = parsed.hostname or "localhost" 

370 port = f":{parsed.port}" if parsed.port else "" 

371 path = parsed.path or "" 

372 return f"{parsed.scheme}://{hostname}{port}{path}" 

373 

374 

375class CacheManager: 

376 """ 

377 Multi-level cache manager with automatic fallback. 

378 

379 Performance Optimization #4: Multi-level caching strategy 

380 - Memory cache for hot data 

381 - Redis cache for shared data 

382 - Automatic fallback between cache levels 

383 """ 

384 

385 def __init__(self) -> None: 

386 self.memory_cache = MemoryCache() 

387 self.redis_cache = RedisCache() 

388 self._initialized = False 

389 

390 async def initialize(self) -> None: 

391 """Initialize cache manager.""" 

392 if self._initialized: 

393 return 

394 

395 await self.redis_cache.connect() 

396 self._initialized = True 

397 

398 async def get(self, key: str) -> Any | None: 

399 """Get value from cache with fallback strategy.""" 

400 # Try memory cache first 

401 value = await self.memory_cache.get(key) 

402 if value is not None: 

403 return value 

404 

405 # Try Redis cache 

406 value = await self.redis_cache.get(key) 

407 if value is not None: 

408 # Store in memory cache for faster access 

409 await self.memory_cache.set(key, value) 

410 return value 

411 

412 return None 

413 

414 async def set(self, key: str, value: Any, ttl: int | None = None) -> None: 

415 """Set value in both cache levels.""" 

416 # Set in memory cache 

417 await self.memory_cache.set(key, value, ttl) 

418 

419 # Set in Redis cache 

420 await self.redis_cache.set(key, value, ttl) 

421 

422 async def delete(self, key: str) -> None: 

423 """Delete key from both cache levels.""" 

424 await self.memory_cache.delete(key) 

425 await self.redis_cache.delete(key) 

426 

427 async def clear(self) -> None: 

428 """Clear both cache levels.""" 

429 await self.memory_cache.clear() 

430 await self.redis_cache.clear() 

431 

432 def get_stats(self) -> dict[str, Any]: 

433 """Get cache statistics.""" 

434 return { 

435 "memory_cache": self.memory_cache.get_stats(), 

436 "redis_cache": self.redis_cache.get_stats(), 

437 } 

438 

439 async def close(self) -> None: 

440 """Close cache manager.""" 

441 await self.memory_cache.close() 

442 await self.redis_cache.close() 

443 self._initialized = False 

444 

445 

446# Global cache manager instance 

447cache_manager = CacheManager() 

448 

449 

450async def get_cache() -> CacheManager: 

451 """Get initialized cache manager.""" 

452 if not cache_manager._initialized: 

453 await cache_manager.initialize() 

454 return cache_manager