Coverage for agentos/background/task_manager.py: 32%

328 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Background Task Manager — v1.11.0 

3 

4Production-grade long-running task execution with: 

5- Submit task → get task_id → poll progress → retrieve result 

6- Persistent task state (SQLite/Postgres) 

7- Progress milestones with phase tracking 

8- Graceful pause/resume/cancel 

9- Timeout and resource budget enforcement 

10- Crash recovery via full checkpoint integration 

11 

12Usage: 

13 mgr = BackgroundTaskManager(loop_factory=my_loop, store=SqliteStore("tasks.db")) 

14 task_id = await mgr.submit("Analyze 10GB dataset", task="...", config=...) 

15 while True: 

16 progress = await mgr.get_progress(task_id) 

17 print(f"{progress.current_phase}: {progress.percent:.0f}%") 

18 if progress.status.is_terminal: 

19 break 

20 result = await mgr.get_result(task_id) 

21""" 

22 

23from __future__ import annotations 

24 

25import asyncio 

26import json 

27import time 

28import uuid 

29from dataclasses import dataclass, field 

30from enum import Enum 

31from pathlib import Path 

32from typing import Any, Callable, Coroutine, Optional 

33 

34 

35# ── Enums ──────────────────────────────────────────────────────── 

36 

37class BackgroundTaskStatus(str, Enum): 

38 """Background task lifecycle states.""" 

39 QUEUED = "queued" # Accepted, waiting to start 

40 RUNNING = "running" # Actively executing 

41 PAUSED = "paused" # Paused by user or system 

42 COMPLETED = "completed" # Finished successfully 

43 FAILED = "failed" # Finished with error 

44 CANCELLED = "cancelled" # Cancelled by user 

45 TIMED_OUT = "timed_out" # Exceeded time budget 

46 

47 @property 

48 def is_terminal(self) -> bool: 

49 return self in (BackgroundTaskStatus.COMPLETED, BackgroundTaskStatus.FAILED, 

50 BackgroundTaskStatus.CANCELLED, BackgroundTaskStatus.TIMED_OUT) 

51 

52 @property 

53 def is_active(self) -> bool: 

54 return self in (BackgroundTaskStatus.QUEUED, BackgroundTaskStatus.RUNNING) 

55 

56 

57# ── Data Models ────────────────────────────────────────────────── 

58 

59@dataclass 

60class ProgressPhase: 

61 """A named phase within task execution.""" 

62 name: str # e.g. "data_loading", "analysis", "reporting" 

63 label: str = "" # Human-readable label 

64 weight: float = 1.0 # Relative weight for percent calculation 

65 started_at: float = 0.0 

66 finished_at: float = 0.0 

67 completed: bool = False 

68 metadata: dict[str, Any] = field(default_factory=dict) 

69 

70 def to_dict(self) -> dict: 

71 return { 

72 "name": self.name, "label": self.label or self.name, 

73 "weight": self.weight, "started_at": self.started_at, 

74 "finished_at": self.finished_at, "completed": self.completed, 

75 "metadata": self.metadata, 

76 } 

77 

78 @classmethod 

79 def from_dict(cls, d: dict) -> "ProgressPhase": 

80 return cls(name=d["name"], label=d.get("label", ""), weight=d.get("weight", 1.0), 

81 started_at=d.get("started_at", 0.0), finished_at=d.get("finished_at", 0.0), 

82 completed=d.get("completed", False), metadata=d.get("metadata", {})) 

83 

84 

85@dataclass 

86class TaskProgress: 

87 """Structured progress report for a background task.""" 

88 task_id: str 

89 status: BackgroundTaskStatus = BackgroundTaskStatus.QUEUED 

90 phases: list[ProgressPhase] = field(default_factory=list) 

91 current_phase: str = "" 

92 current_step: int = 0 

93 total_steps: int = 0 

94 percent: float = 0.0 

95 elapsed_seconds: float = 0.0 

96 estimated_remaining_seconds: float = 0.0 

97 last_update: float = 0.0 

98 message: str = "" 

99 

100 def to_dict(self) -> dict: 

101 return { 

102 "task_id": self.task_id, "status": self.status.value, 

103 "phases": [p.to_dict() for p in self.phases], 

104 "current_phase": self.current_phase, 

105 "current_step": self.current_step, "total_steps": self.total_steps, 

106 "percent": self.percent, "elapsed_seconds": self.elapsed_seconds, 

107 "estimated_remaining_seconds": self.estimated_remaining_seconds, 

108 "last_update": self.last_update, "message": self.message, 

109 } 

110 

111 @classmethod 

112 def from_dict(cls, d: dict) -> "TaskProgress": 

113 return cls( 

114 task_id=d["task_id"], status=BackgroundTaskStatus(d.get("status", "queued")), 

115 phases=[ProgressPhase.from_dict(p) for p in d.get("phases", [])], 

116 current_phase=d.get("current_phase", ""), current_step=d.get("current_step", 0), 

117 total_steps=d.get("total_steps", 0), percent=d.get("percent", 0.0), 

118 elapsed_seconds=d.get("elapsed_seconds", 0.0), 

119 estimated_remaining_seconds=d.get("estimated_remaining_seconds", 0.0), 

120 last_update=d.get("last_update", 0.0), message=d.get("message", ""), 

121 ) 

122 

123 

124@dataclass 

125class BackgroundTaskConfig: 

126 """Configuration for a background task.""" 

127 max_duration_seconds: float = 3600.0 # 1 hour default 

128 max_cost_usd: float = 10.0 

129 enable_checkpoints: bool = True 

130 checkpoint_interval: int = 20 # iterations between checkpoints 

131 enable_progress: bool = True 

132 progress_report_interval: float = 5.0 # seconds between progress updates 

133 auto_resume: bool = True # auto-resume from checkpoint on restart 

134 max_retries: int = 2 # retries on transient failure 

135 pause_on_cost_warning: bool = True 

136 notify_on_completion: bool = False 

137 metadata: dict[str, Any] = field(default_factory=dict) 

138 

139 

140@dataclass 

141class BackgroundTask: 

142 """Complete background task record.""" 

143 id: str = field(default_factory=lambda: uuid.uuid4().hex[:16]) 

144 name: str = "" 

145 task_description: str = "" 

146 status: BackgroundTaskStatus = BackgroundTaskStatus.QUEUED 

147 config: BackgroundTaskConfig = field(default_factory=BackgroundTaskConfig) 

148 progress: TaskProgress = field(default_factory=lambda: TaskProgress(task_id="")) 

149 result: Any = None 

150 error: str = "" 

151 created_at: float = field(default_factory=time.time) 

152 started_at: float = 0.0 

153 finished_at: float = 0.0 

154 cost_usd: float = 0.0 

155 tokens_used: int = 0 

156 checkpoint_id: str = "" 

157 metadata: dict[str, Any] = field(default_factory=dict) 

158 

159 def __post_init__(self): 

160 if not self.progress.task_id: 

161 self.progress.task_id = self.id 

162 

163 @property 

164 def duration_seconds(self) -> float: 

165 end = self.finished_at or time.time() 

166 start = self.started_at or self.created_at 

167 return end - start 

168 

169 def to_dict(self) -> dict: 

170 return { 

171 "id": self.id, "name": self.name, "task_description": self.task_description, 

172 "status": self.status.value, "config": { 

173 "max_duration_seconds": self.config.max_duration_seconds, 

174 "max_cost_usd": self.config.max_cost_usd, 

175 "enable_checkpoints": self.config.enable_checkpoints, 

176 "checkpoint_interval": self.config.checkpoint_interval, 

177 "auto_resume": self.config.auto_resume, 

178 "max_retries": self.config.max_retries, 

179 }, 

180 "progress": self.progress.to_dict(), 

181 "result": self.result, "error": self.error, 

182 "created_at": self.created_at, "started_at": self.started_at, 

183 "finished_at": self.finished_at, "cost_usd": self.cost_usd, 

184 "tokens_used": self.tokens_used, "checkpoint_id": self.checkpoint_id, 

185 "metadata": self.metadata, 

186 } 

187 

188 @classmethod 

189 def from_dict(cls, d: dict) -> "BackgroundTask": 

190 cfg_d = d.get("config", {}) 

191 return cls( 

192 id=d["id"], name=d.get("name", ""), 

193 task_description=d.get("task_description", ""), 

194 status=BackgroundTaskStatus(d.get("status", "queued")), 

195 config=BackgroundTaskConfig( 

196 max_duration_seconds=cfg_d.get("max_duration_seconds", 3600.0), 

197 max_cost_usd=cfg_d.get("max_cost_usd", 10.0), 

198 enable_checkpoints=cfg_d.get("enable_checkpoints", True), 

199 checkpoint_interval=cfg_d.get("checkpoint_interval", 20), 

200 auto_resume=cfg_d.get("auto_resume", True), 

201 max_retries=cfg_d.get("max_retries", 2), 

202 ), 

203 progress=TaskProgress.from_dict(d.get("progress", {"task_id": d["id"]})), 

204 result=d.get("result"), error=d.get("error", ""), 

205 created_at=d.get("created_at", 0.0), started_at=d.get("started_at", 0.0), 

206 finished_at=d.get("finished_at", 0.0), cost_usd=d.get("cost_usd", 0.0), 

207 tokens_used=d.get("tokens_used", 0), checkpoint_id=d.get("checkpoint_id", ""), 

208 metadata=d.get("metadata", {}), 

209 ) 

210 

211 

212# ── Callback types ─────────────────────────────────────────────── 

213 

214ProgressCallback = Callable[[TaskProgress], None] 

215CompletionCallback = Callable[[BackgroundTask], None] 

216 

217 

218# ── Background Task Manager ────────────────────────────────────── 

219 

220class BackgroundTaskManager: 

221 """ 

222 Manages long-running background agent tasks. 

223 

224 Features: 

225 - Async task submission with configurable budgets 

226 - Persistent task state (in-memory + optional DB store) 

227 - Progress tracking with named phases 

228 - Pause/resume/cancel by task ID 

229 - Crash recovery with checkpoint replay 

230 - Concurrent task execution with configurable max workers 

231 """ 

232 

233 def __init__( 

234 self, 

235 max_concurrent: int = 5, 

236 store: Any = None, # Optional CheckpointStore-like persistence 

237 ): 

238 self.max_concurrent = max_concurrent 

239 self._store = store 

240 self._semaphore = asyncio.Semaphore(max_concurrent) 

241 self._tasks: dict[str, BackgroundTask] = {} 

242 self._running: dict[str, asyncio.Task] = {} 

243 self._progress_callbacks: dict[str, list[ProgressCallback]] = {} 

244 self._completion_callbacks: dict[str, list[CompletionCallback]] = {} 

245 

246 # ── Public API ─────────────────────────────────────────────── 

247 

248 async def submit( 

249 self, 

250 name: str, 

251 task: str, 

252 loop_factory: Callable[[], Any] | None = None, 

253 agent_loop: Any = None, 

254 config: BackgroundTaskConfig | None = None, 

255 phases: list[ProgressPhase] | None = None, 

256 ) -> str: 

257 """Submit a task for background execution. Returns task_id.""" 

258 bt = BackgroundTask( 

259 name=name, 

260 task_description=task, 

261 config=config or BackgroundTaskConfig(), 

262 ) 

263 if phases: 

264 bt.progress.phases = phases 

265 bt.progress.total_steps = len(phases) 

266 

267 bt.progress.last_update = time.time() 

268 self._tasks[bt.id] = bt 

269 

270 if self._store: 

271 await self._persist(bt) 

272 

273 # Start in background 

274 coro = self._run_task(bt, loop_factory, agent_loop) 

275 self._running[bt.id] = asyncio.create_task(coro) 

276 

277 return bt.id 

278 

279 async def get_task(self, task_id: str) -> BackgroundTask | None: 

280 """Get full task record.""" 

281 if task_id in self._tasks: 

282 return self._tasks[task_id] 

283 if self._store: 

284 return await self._load(task_id) 

285 return None 

286 

287 async def get_progress(self, task_id: str) -> TaskProgress | None: 

288 """Get current progress for a task.""" 

289 t = await self.get_task(task_id) 

290 return t.progress if t else None 

291 

292 async def get_result(self, task_id: str) -> Any: 

293 """Get task result (blocks if still running).""" 

294 t = await self.get_task(task_id) 

295 if not t: 

296 raise KeyError(f"Task {task_id} not found") 

297 if t.status.is_active: 

298 # Wait for completion 

299 running_task = self._running.get(task_id) 

300 if running_task and not running_task.done(): 

301 await running_task 

302 t = self._tasks.get(task_id) 

303 if not t: 

304 raise KeyError(f"Task {task_id} vanished") 

305 if t.status == BackgroundTaskStatus.FAILED: 

306 raise RuntimeError(f"Task {task_id} failed: {t.error}") 

307 return t.result 

308 

309 async def pause(self, task_id: str) -> bool: 

310 """Pause a running task.""" 

311 t = self._tasks.get(task_id) 

312 if not t or not t.status.is_active: 

313 return False 

314 t.status = BackgroundTaskStatus.PAUSED 

315 t.progress.status = BackgroundTaskStatus.PAUSED 

316 await self._update_progress(task_id) 

317 return True 

318 

319 async def resume(self, task_id: str) -> bool: 

320 """Resume a paused task.""" 

321 t = self._tasks.get(task_id) 

322 if not t or t.status != BackgroundTaskStatus.PAUSED: 

323 return False 

324 t.status = BackgroundTaskStatus.RUNNING 

325 t.progress.status = BackgroundTaskStatus.RUNNING 

326 await self._update_progress(task_id) 

327 return True 

328 

329 async def cancel(self, task_id: str) -> bool: 

330 """Cancel a task.""" 

331 t = self._tasks.get(task_id) 

332 if not t: 

333 return False 

334 t.status = BackgroundTaskStatus.CANCELLED 

335 t.progress.status = BackgroundTaskStatus.CANCELLED 

336 t.finished_at = time.time() 

337 running = self._running.pop(task_id, None) 

338 if running and not running.done(): 

339 running.cancel() 

340 await self._update_progress(task_id) 

341 await self._notify_completion(task_id) 

342 if self._store: 

343 await self._persist(t) 

344 return True 

345 

346 async def list_tasks( 

347 self, 

348 status: BackgroundTaskStatus | None = None, 

349 limit: int = 50, 

350 ) -> list[BackgroundTask]: 

351 """List tasks, optionally filtered by status.""" 

352 tasks = list(self._tasks.values()) 

353 if status: 

354 tasks = [t for t in tasks if t.status == status] 

355 return sorted(tasks, key=lambda t: t.created_at, reverse=True)[:limit] 

356 

357 def on_progress(self, task_id: str, callback: ProgressCallback): 

358 """Register a progress callback for a task.""" 

359 if task_id not in self._progress_callbacks: 

360 self._progress_callbacks[task_id] = [] 

361 self._progress_callbacks[task_id].append(callback) 

362 

363 def on_completion(self, task_id: str, callback: CompletionCallback): 

364 """Register a completion callback for a task.""" 

365 if task_id not in self._completion_callbacks: 

366 self._completion_callbacks[task_id] = [] 

367 self._completion_callbacks[task_id].append(callback) 

368 

369 # ── Progress Reporting ─────────────────────────────────────── 

370 

371 async def update_phase( 

372 self, task_id: str, phase_name: str, 

373 completed: bool = False, step: int = 0, message: str = "", 

374 ): 

375 """Update a named phase in the task progress.""" 

376 t = self._tasks.get(task_id) 

377 if not t or not t.config.enable_progress: 

378 return 

379 

380 prog = t.progress 

381 # Find or create phase 

382 phase = None 

383 for p in prog.phases: 

384 if p.name == phase_name: 

385 phase = p 

386 break 

387 if not phase: 

388 phase = ProgressPhase(name=phase_name, label=phase_name) 

389 prog.phases.append(phase) 

390 prog.total_steps = len(prog.phases) 

391 

392 if completed: 

393 phase.completed = True 

394 phase.finished_at = time.time() 

395 elif not phase.started_at: 

396 phase.started_at = time.time() 

397 

398 prog.current_phase = phase_name 

399 if step: 

400 prog.current_step = step 

401 if message: 

402 prog.message = message 

403 

404 # Calculate percent from phase weights 

405 total_weight = sum(p.weight for p in prog.phases) 

406 completed_weight = sum(p.weight for p in prog.phases if p.completed) 

407 if prog.current_phase and total_weight > 0: 

408 current_phase_obj = phase 

409 if current_phase_obj and not current_phase_obj.completed and current_phase_obj.weight > 0: 

410 # Partial credit for current phase 

411 partial = current_phase_obj.weight * min(step / max(t.config.checkpoint_interval, 1), 1.0) 

412 completed_weight += partial 

413 prog.percent = min(completed_weight / total_weight * 100, 99.9) 

414 elif completed_weight >= total_weight: 

415 prog.percent = 100.0 

416 

417 prog.last_update = time.time() 

418 elapsed = prog.last_update - (t.started_at or t.created_at) 

419 prog.elapsed_seconds = elapsed 

420 if prog.percent > 0: 

421 prog.estimated_remaining_seconds = elapsed / (prog.percent / 100) - elapsed 

422 

423 await self._update_progress(task_id) 

424 

425 # ── Internal ───────────────────────────────────────────────── 

426 

427 async def _run_task( 

428 self, 

429 bt: BackgroundTask, 

430 loop_factory: Callable[[], Any] | None, 

431 agent_loop: Any, 

432 ): 

433 """Execute a background task with full lifecycle management.""" 

434 async with self._semaphore: 

435 bt.status = BackgroundTaskStatus.RUNNING 

436 bt.progress.status = BackgroundTaskStatus.RUNNING 

437 bt.started_at = time.time() 

438 await self._update_progress(bt.id) 

439 if self._store: 

440 await self._persist(bt) 

441 

442 try: 

443 # Timeout enforcement 

444 timeout = bt.config.max_duration_seconds 

445 start = time.time() 

446 

447 if loop_factory: 

448 loop = loop_factory() 

449 elif agent_loop: 

450 loop = agent_loop 

451 else: 

452 raise ValueError("Must provide loop_factory or agent_loop") 

453 

454 # Inject progress callback into the loop 

455 original_on_iteration = getattr(loop, 'on_iteration', None) 

456 

457 async def progress_on_iteration(iteration: int, tool_results: list): 

458 elapsed = time.time() - start 

459 if elapsed > timeout: 

460 raise asyncio.TimeoutError("Task exceeded max duration") 

461 if bt.status == BackgroundTaskStatus.PAUSED: 

462 # Spin-wait for resume (or timeout) 

463 while bt.status == BackgroundTaskStatus.PAUSED: 

464 await asyncio.sleep(0.5) 

465 if time.time() - start > timeout: 

466 raise asyncio.TimeoutError("Task timed out while paused") 

467 if bt.config.enable_checkpoints and iteration % bt.config.checkpoint_interval == 0: 

468 bt.progress.current_step = iteration 

469 await self.update_phase(bt.id, "execution", step=iteration, 

470 message=f"Step {iteration}") 

471 if original_on_iteration: 

472 original_on_iteration(iteration, tool_results) 

473 

474 loop.on_iteration = progress_on_iteration 

475 

476 # Run 

477 result = await loop.run(bt.task_description, session_id=bt.id) 

478 bt.result = result.output if hasattr(result, 'output') else result 

479 bt.cost_usd = getattr(result, 'cost_usd', 0.0) 

480 bt.tokens_used = sum(getattr(result, 'tokens_used', {}).values()) 

481 bt.status = BackgroundTaskStatus.COMPLETED 

482 bt.progress.status = BackgroundTaskStatus.COMPLETED 

483 bt.progress.percent = 100.0 

484 

485 except asyncio.TimeoutError: 

486 bt.status = BackgroundTaskStatus.TIMED_OUT 

487 bt.progress.status = BackgroundTaskStatus.TIMED_OUT 

488 bt.error = f"Exceeded max duration of {bt.config.max_duration_seconds}s" 

489 except asyncio.CancelledError: 

490 bt.status = BackgroundTaskStatus.CANCELLED 

491 bt.progress.status = BackgroundTaskStatus.CANCELLED 

492 except Exception as e: 

493 bt.status = BackgroundTaskStatus.FAILED 

494 bt.progress.status = BackgroundTaskStatus.FAILED 

495 bt.error = str(e) 

496 finally: 

497 bt.finished_at = time.time() 

498 bt.progress.last_update = time.time() 

499 bt.progress.elapsed_seconds = bt.finished_at - bt.started_at 

500 await self._update_progress(bt.id) 

501 await self._notify_completion(bt.id) 

502 if self._store: 

503 await self._persist(bt) 

504 self._running.pop(bt.id, None) 

505 

506 async def _update_progress(self, task_id: str): 

507 """Fire progress callbacks.""" 

508 callbacks = self._progress_callbacks.get(task_id, []) 

509 if not callbacks: 

510 return 

511 t = self._tasks.get(task_id) 

512 if t: 

513 for cb in callbacks: 

514 try: 

515 cb(t.progress) 

516 except Exception: 

517 pass 

518 

519 async def _notify_completion(self, task_id: str): 

520 """Fire completion callbacks.""" 

521 callbacks = self._completion_callbacks.get(task_id, []) 

522 if not callbacks: 

523 return 

524 t = self._tasks.get(task_id) 

525 if t: 

526 for cb in callbacks: 

527 try: 

528 cb(t) 

529 except Exception: 

530 pass 

531 

532 async def _persist(self, bt: BackgroundTask): 

533 """Persist task to store.""" 

534 if not self._store: 

535 return 

536 try: 

537 data = json.dumps(bt.to_dict()) 

538 if hasattr(self._store, 'save'): 

539 await self._store.save(f"bg_task:{bt.id}", {"data": data}) 

540 except Exception: 

541 pass 

542 

543 async def _load(self, task_id: str) -> BackgroundTask | None: 

544 """Load task from store.""" 

545 if not self._store: 

546 return None 

547 try: 

548 snap = await self._store.load(f"bg_task:{task_id}") 

549 if snap and "data" in snap: 

550 return BackgroundTask.from_dict(json.loads(snap["data"])) 

551 except Exception: 

552 pass 

553 return None