Coverage for session_buddy / shutdown_manager.py: 87.65%

144 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1"""Graceful shutdown manager for session-mgmt-mcp. 

2 

3Provides signal handling, cleanup task registration, and resource cleanup 

4for clean server shutdown. 

5 

6Phase 10.2: Production Hardening - Graceful Shutdown 

7""" 

8 

9from __future__ import annotations 

10 

11import asyncio 

12import atexit 

13import signal 

14import typing as t 

15from contextlib import suppress 

16from dataclasses import dataclass 

17 

18from session_buddy.utils.logging import get_session_logger 

19 

20 

21def _get_logger() -> t.Any: 

22 """Get logger with lazy initialization to avoid DI issues during import.""" 

23 try: 

24 return get_session_logger() 

25 except Exception: 

26 import logging 

27 

28 return logging.getLogger(__name__) 

29 

30 

31@dataclass 

32class CleanupTask: 

33 """Represents a cleanup task to be executed during shutdown.""" 

34 

35 name: str 

36 """Human-readable name for the cleanup task.""" 

37 

38 callback: t.Callable[[], t.Awaitable[None] | None] 

39 """Cleanup function (sync or async).""" 

40 

41 priority: int = 0 

42 """Priority for execution order (higher = earlier).""" 

43 

44 timeout_seconds: float = 30.0 

45 """Maximum time allowed for cleanup task.""" 

46 

47 critical: bool = False 

48 """Whether failure of this task should stop other cleanups.""" 

49 

50 

51@dataclass 

52class ShutdownStats: 

53 """Statistics about shutdown execution.""" 

54 

55 tasks_registered: int = 0 

56 """Total cleanup tasks registered.""" 

57 

58 tasks_executed: int = 0 

59 """Tasks successfully executed.""" 

60 

61 tasks_failed: int = 0 

62 """Tasks that failed during execution.""" 

63 

64 tasks_timeout: int = 0 

65 """Tasks that exceeded timeout.""" 

66 

67 total_duration_ms: float = 0.0 

68 """Total shutdown duration in milliseconds.""" 

69 

70 

71class ShutdownManager: 

72 """Manages graceful shutdown with cleanup task coordination. 

73 

74 Features: 

75 - Signal handler registration (SIGTERM, SIGINT, SIGQUIT) 

76 - Cleanup task registration with priorities 

77 - Async/sync cleanup task support 

78 - Timeout enforcement per task 

79 - Comprehensive error handling 

80 - Shutdown statistics tracking 

81 

82 Example: 

83 >>> shutdown_mgr = ShutdownManager() 

84 >>> 

85 >>> # Register cleanup tasks 

86 >>> async def cleanup_database(): 

87 ... await db.close() 

88 >>> 

89 >>> shutdown_mgr.register_cleanup( 

90 ... "database_cleanup", cleanup_database, priority=100 

91 ... ) 

92 >>> 

93 >>> # Setup signal handlers 

94 >>> shutdown_mgr.setup_signal_handlers() 

95 >>> 

96 >>> # Cleanup happens automatically on shutdown 

97 

98 """ 

99 

100 def __init__(self) -> None: 

101 """Initialize shutdown manager.""" 

102 self._cleanup_tasks: list[CleanupTask] = [] 

103 self._shutdown_initiated = False 

104 self._shutdown_lock = asyncio.Lock() 

105 self._original_handlers: dict[int, t.Any] = {} 

106 self._stats = ShutdownStats() 

107 

108 def register_cleanup( 

109 self, 

110 name: str, 

111 callback: t.Callable[[], t.Awaitable[None] | None], 

112 priority: int = 0, 

113 timeout_seconds: float = 30.0, 

114 critical: bool = False, 

115 ) -> None: 

116 """Register a cleanup task to be executed during shutdown. 

117 

118 Args: 

119 name: Human-readable name for logging 

120 callback: Cleanup function (async or sync) 

121 priority: Execution priority (higher = earlier), default 0 

122 timeout_seconds: Maximum execution time, default 30s 

123 critical: If True, failure stops other cleanups, default False 

124 

125 Example: 

126 >>> async def close_database(): 

127 ... await db.close() 

128 >>> 

129 >>> shutdown_mgr.register_cleanup( 

130 ... "database", close_database, priority=100, critical=True 

131 ... ) 

132 

133 """ 

134 task = CleanupTask( 

135 name=name, 

136 callback=callback, 

137 priority=priority, 

138 timeout_seconds=timeout_seconds, 

139 critical=critical, 

140 ) 

141 self._cleanup_tasks.append(task) 

142 self._stats.tasks_registered += 1 

143 _get_logger().debug(f"Registered cleanup task: {name} (priority: {priority})") 

144 

145 def setup_signal_handlers(self) -> None: 

146 """Setup signal handlers for graceful shutdown. 

147 

148 Handles: 

149 - SIGTERM: Graceful termination (e.g., systemd stop) 

150 - SIGINT: Keyboard interrupt (Ctrl+C) 

151 - SIGQUIT: Quit signal with core dump 

152 

153 Note: 

154 Previous handlers are saved and can be restored. 

155 

156 """ 

157 signals_to_handle = [ 

158 (signal.SIGTERM, "SIGTERM"), 

159 (signal.SIGINT, "SIGINT"), 

160 ] 

161 

162 # Add SIGQUIT on Unix systems 

163 if hasattr(signal, "SIGQUIT"): 163 ↛ 166line 163 didn't jump to line 166 because the condition on line 163 was always true

164 signals_to_handle.append((signal.SIGQUIT, "SIGQUIT")) 

165 

166 for sig, name in signals_to_handle: 

167 try: 

168 # Save original handler 

169 original = signal.getsignal(sig) 

170 self._original_handlers[sig] = original 

171 

172 # Set new handler 

173 signal.signal(sig, self._signal_handler) 

174 _get_logger().debug(f"Registered signal handler for {name}") 

175 except (OSError, ValueError) as e: 

176 _get_logger().warning(f"Could not register handler for {name}: {e}") 

177 

178 # Register atexit handler as final fallback 

179 atexit.register(self._atexit_handler) 

180 _get_logger().debug("Registered atexit handler") 

181 

182 def restore_signal_handlers(self) -> None: 

183 """Restore original signal handlers. 

184 

185 Useful for cleanup or testing. 

186 """ 

187 for sig, original in self._original_handlers.items(): 

188 try: 

189 signal.signal(sig, original) 

190 except (OSError, ValueError) as e: 

191 _get_logger().warning(f"Could not restore signal {sig}: {e}") 

192 

193 self._original_handlers.clear() 

194 _get_logger().debug("Restored original signal handlers") 

195 

196 def _signal_handler(self, signum: int, frame: t.Any) -> None: 

197 """Internal signal handler that triggers shutdown. 

198 

199 Args: 

200 signum: Signal number 

201 frame: Current stack frame (unused) 

202 

203 """ 

204 sig_name = signal.Signals(signum).name 

205 _get_logger().info(f"Received signal {sig_name}, initiating graceful shutdown") 

206 

207 # Run shutdown in the event loop 

208 try: 

209 loop = asyncio.get_running_loop() 

210 # Schedule shutdown as a task 

211 loop.create_task(self.shutdown()) 

212 except RuntimeError: 

213 # No running loop, run in new loop 

214 asyncio.run(self.shutdown()) 

215 

216 def _atexit_handler(self) -> None: 

217 """Final cleanup handler registered with atexit. 

218 

219 Ensures cleanup runs even if signals aren't caught. 

220 """ 

221 if not self._shutdown_initiated: 

222 _get_logger().info("atexit handler triggered, running final cleanup") 

223 with suppress(RuntimeError): 

224 asyncio.run(self.shutdown()) 

225 

226 async def _execute_cleanup_task(self, task: CleanupTask) -> None: 

227 """Execute a single cleanup task with timeout enforcement. 

228 

229 Args: 

230 task: Cleanup task to execute 

231 

232 Raises: 

233 TimeoutError: If task exceeds timeout 

234 Exception: If task execution fails 

235 

236 """ 

237 _get_logger().debug( 

238 f"Executing cleanup task: {task.name} " 

239 f"(priority: {task.priority}, timeout: {task.timeout_seconds}s)", 

240 ) 

241 

242 # Execute with timeout 

243 if asyncio.iscoroutinefunction(task.callback): 

244 await asyncio.wait_for(task.callback(), timeout=task.timeout_seconds) 

245 else: 

246 # Sync function - run in executor 

247 loop = asyncio.get_running_loop() 

248 await asyncio.wait_for( 

249 loop.run_in_executor(None, task.callback), 

250 timeout=task.timeout_seconds, 

251 ) 

252 

253 def _handle_task_timeout(self, task: CleanupTask) -> bool: 

254 """Handle cleanup task timeout. 

255 

256 Args: 

257 task: Task that timed out 

258 

259 Returns: 

260 True if should stop cleanup (critical task), False otherwise 

261 

262 """ 

263 self._stats.tasks_timeout += 1 

264 _get_logger().error( 

265 f"Cleanup task timed out after {task.timeout_seconds}s: {task.name}", 

266 ) 

267 if task.critical: 267 ↛ 268line 267 didn't jump to line 268 because the condition on line 267 was never true

268 _get_logger().critical( 

269 f"Critical task failed: {task.name}, stopping cleanup", 

270 ) 

271 return True 

272 return False 

273 

274 def _handle_task_failure(self, task: CleanupTask, error: Exception) -> bool: 

275 """Handle cleanup task failure. 

276 

277 Args: 

278 task: Task that failed 

279 error: Exception that occurred 

280 

281 Returns: 

282 True if should stop cleanup (critical task), False otherwise 

283 

284 """ 

285 self._stats.tasks_failed += 1 

286 _get_logger().error( 

287 f"Cleanup task failed: {task.name} - {error}", 

288 exc_info=True, 

289 ) 

290 if task.critical: 

291 _get_logger().critical( 

292 f"Critical task failed: {task.name}, stopping cleanup", 

293 ) 

294 return True 

295 return False 

296 

297 def _finalize_shutdown( 

298 self, 

299 sorted_tasks: list[CleanupTask], 

300 start_time: float, 

301 ) -> None: 

302 """Finalize shutdown and log results. 

303 

304 Args: 

305 sorted_tasks: List of tasks that were executed 

306 start_time: When shutdown started (from time.perf_counter()) 

307 

308 """ 

309 import time 

310 

311 # Calculate total duration 

312 self._stats.total_duration_ms = (time.perf_counter() - start_time) * 1000 

313 

314 _get_logger().info( 

315 f"Shutdown complete: {self._stats.tasks_executed}/{len(sorted_tasks)} tasks succeeded " 

316 f"in {self._stats.total_duration_ms:.2f}ms", 

317 ) 

318 

319 if self._stats.tasks_failed > 0 or self._stats.tasks_timeout > 0: 

320 _get_logger().warning( 

321 f"Shutdown had issues: {self._stats.tasks_failed} failed, " 

322 f"{self._stats.tasks_timeout} timed out", 

323 ) 

324 

325 async def shutdown(self) -> ShutdownStats: 

326 """Execute all cleanup tasks in priority order. 

327 

328 Returns: 

329 ShutdownStats with execution details 

330 

331 Features: 

332 - Executes tasks by priority (highest first) 

333 - Enforces per-task timeouts 

334 - Handles both async and sync cleanup functions 

335 - Continues on non-critical failures 

336 - Tracks comprehensive statistics 

337 

338 """ 

339 import time 

340 

341 start_time = time.perf_counter() 

342 

343 # Prevent multiple simultaneous shutdowns 

344 async with self._shutdown_lock: 

345 if self._shutdown_initiated: 

346 _get_logger().debug("Shutdown already initiated, skipping") 

347 return self._stats 

348 

349 self._shutdown_initiated = True 

350 _get_logger().info( 

351 f"Starting graceful shutdown with {len(self._cleanup_tasks)} tasks", 

352 ) 

353 

354 # Sort by priority (highest first) 

355 sorted_tasks = sorted( 

356 self._cleanup_tasks, 

357 key=lambda t: t.priority, 

358 reverse=True, 

359 ) 

360 

361 for task in sorted_tasks: 

362 try: 

363 await self._execute_cleanup_task(task) 

364 self._stats.tasks_executed += 1 

365 _get_logger().debug(f"Cleanup task completed: {task.name}") 

366 

367 except TimeoutError: 

368 if self._handle_task_timeout(task): 368 ↛ 369line 368 didn't jump to line 369 because the condition on line 368 was never true

369 break 

370 

371 except Exception as e: 

372 if self._handle_task_failure(task, e): 

373 break 

374 

375 self._finalize_shutdown(sorted_tasks, start_time) 

376 return self._stats 

377 

378 def get_stats(self) -> ShutdownStats: 

379 """Get current shutdown statistics. 

380 

381 Returns: 

382 ShutdownStats with current state 

383 

384 """ 

385 return self._stats 

386 

387 def is_shutdown_initiated(self) -> bool: 

388 """Check if shutdown has been initiated. 

389 

390 Returns: 

391 True if shutdown is in progress or complete 

392 

393 """ 

394 return self._shutdown_initiated 

395 

396 

397# Global shutdown manager instance 

398_global_shutdown_manager: ShutdownManager | None = None 

399 

400 

401def get_shutdown_manager() -> ShutdownManager: 

402 """Get the global shutdown manager instance. 

403 

404 Returns: 

405 Global ShutdownManager singleton 

406 

407 Example: 

408 >>> from session_buddy.shutdown_manager import get_shutdown_manager 

409 >>> 

410 >>> shutdown_mgr = get_shutdown_manager() 

411 >>> shutdown_mgr.register_cleanup("my_cleanup", cleanup_func) 

412 

413 """ 

414 global _global_shutdown_manager 

415 

416 if _global_shutdown_manager is None: 

417 _global_shutdown_manager = ShutdownManager() 

418 

419 return _global_shutdown_manager 

420 

421 

422__all__ = [ 

423 "CleanupTask", 

424 "ShutdownManager", 

425 "ShutdownStats", 

426 "get_shutdown_manager", 

427]