Coverage for agentos/background/task_manager.py: 32%
328 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Background Task Manager — v1.11.0
4Production-grade long-running task execution with:
5- Submit task → get task_id → poll progress → retrieve result
6- Persistent task state (SQLite/Postgres)
7- Progress milestones with phase tracking
8- Graceful pause/resume/cancel
9- Timeout and resource budget enforcement
10- Crash recovery via full checkpoint integration
12Usage:
13 mgr = BackgroundTaskManager(loop_factory=my_loop, store=SqliteStore("tasks.db"))
14 task_id = await mgr.submit("Analyze 10GB dataset", task="...", config=...)
15 while True:
16 progress = await mgr.get_progress(task_id)
17 print(f"{progress.current_phase}: {progress.percent:.0f}%")
18 if progress.status.is_terminal:
19 break
20 result = await mgr.get_result(task_id)
21"""
23from __future__ import annotations
25import asyncio
26import json
27import time
28import uuid
29from dataclasses import dataclass, field
30from enum import Enum
31from pathlib import Path
32from typing import Any, Callable, Coroutine, Optional
35# ── Enums ────────────────────────────────────────────────────────
37class BackgroundTaskStatus(str, Enum):
38 """Background task lifecycle states."""
39 QUEUED = "queued" # Accepted, waiting to start
40 RUNNING = "running" # Actively executing
41 PAUSED = "paused" # Paused by user or system
42 COMPLETED = "completed" # Finished successfully
43 FAILED = "failed" # Finished with error
44 CANCELLED = "cancelled" # Cancelled by user
45 TIMED_OUT = "timed_out" # Exceeded time budget
47 @property
48 def is_terminal(self) -> bool:
49 return self in (BackgroundTaskStatus.COMPLETED, BackgroundTaskStatus.FAILED,
50 BackgroundTaskStatus.CANCELLED, BackgroundTaskStatus.TIMED_OUT)
52 @property
53 def is_active(self) -> bool:
54 return self in (BackgroundTaskStatus.QUEUED, BackgroundTaskStatus.RUNNING)
57# ── Data Models ──────────────────────────────────────────────────
59@dataclass
60class ProgressPhase:
61 """A named phase within task execution."""
62 name: str # e.g. "data_loading", "analysis", "reporting"
63 label: str = "" # Human-readable label
64 weight: float = 1.0 # Relative weight for percent calculation
65 started_at: float = 0.0
66 finished_at: float = 0.0
67 completed: bool = False
68 metadata: dict[str, Any] = field(default_factory=dict)
70 def to_dict(self) -> dict:
71 return {
72 "name": self.name, "label": self.label or self.name,
73 "weight": self.weight, "started_at": self.started_at,
74 "finished_at": self.finished_at, "completed": self.completed,
75 "metadata": self.metadata,
76 }
78 @classmethod
79 def from_dict(cls, d: dict) -> "ProgressPhase":
80 return cls(name=d["name"], label=d.get("label", ""), weight=d.get("weight", 1.0),
81 started_at=d.get("started_at", 0.0), finished_at=d.get("finished_at", 0.0),
82 completed=d.get("completed", False), metadata=d.get("metadata", {}))
85@dataclass
86class TaskProgress:
87 """Structured progress report for a background task."""
88 task_id: str
89 status: BackgroundTaskStatus = BackgroundTaskStatus.QUEUED
90 phases: list[ProgressPhase] = field(default_factory=list)
91 current_phase: str = ""
92 current_step: int = 0
93 total_steps: int = 0
94 percent: float = 0.0
95 elapsed_seconds: float = 0.0
96 estimated_remaining_seconds: float = 0.0
97 last_update: float = 0.0
98 message: str = ""
100 def to_dict(self) -> dict:
101 return {
102 "task_id": self.task_id, "status": self.status.value,
103 "phases": [p.to_dict() for p in self.phases],
104 "current_phase": self.current_phase,
105 "current_step": self.current_step, "total_steps": self.total_steps,
106 "percent": self.percent, "elapsed_seconds": self.elapsed_seconds,
107 "estimated_remaining_seconds": self.estimated_remaining_seconds,
108 "last_update": self.last_update, "message": self.message,
109 }
111 @classmethod
112 def from_dict(cls, d: dict) -> "TaskProgress":
113 return cls(
114 task_id=d["task_id"], status=BackgroundTaskStatus(d.get("status", "queued")),
115 phases=[ProgressPhase.from_dict(p) for p in d.get("phases", [])],
116 current_phase=d.get("current_phase", ""), current_step=d.get("current_step", 0),
117 total_steps=d.get("total_steps", 0), percent=d.get("percent", 0.0),
118 elapsed_seconds=d.get("elapsed_seconds", 0.0),
119 estimated_remaining_seconds=d.get("estimated_remaining_seconds", 0.0),
120 last_update=d.get("last_update", 0.0), message=d.get("message", ""),
121 )
124@dataclass
125class BackgroundTaskConfig:
126 """Configuration for a background task."""
127 max_duration_seconds: float = 3600.0 # 1 hour default
128 max_cost_usd: float = 10.0
129 enable_checkpoints: bool = True
130 checkpoint_interval: int = 20 # iterations between checkpoints
131 enable_progress: bool = True
132 progress_report_interval: float = 5.0 # seconds between progress updates
133 auto_resume: bool = True # auto-resume from checkpoint on restart
134 max_retries: int = 2 # retries on transient failure
135 pause_on_cost_warning: bool = True
136 notify_on_completion: bool = False
137 metadata: dict[str, Any] = field(default_factory=dict)
140@dataclass
141class BackgroundTask:
142 """Complete background task record."""
143 id: str = field(default_factory=lambda: uuid.uuid4().hex[:16])
144 name: str = ""
145 task_description: str = ""
146 status: BackgroundTaskStatus = BackgroundTaskStatus.QUEUED
147 config: BackgroundTaskConfig = field(default_factory=BackgroundTaskConfig)
148 progress: TaskProgress = field(default_factory=lambda: TaskProgress(task_id=""))
149 result: Any = None
150 error: str = ""
151 created_at: float = field(default_factory=time.time)
152 started_at: float = 0.0
153 finished_at: float = 0.0
154 cost_usd: float = 0.0
155 tokens_used: int = 0
156 checkpoint_id: str = ""
157 metadata: dict[str, Any] = field(default_factory=dict)
159 def __post_init__(self):
160 if not self.progress.task_id:
161 self.progress.task_id = self.id
163 @property
164 def duration_seconds(self) -> float:
165 end = self.finished_at or time.time()
166 start = self.started_at or self.created_at
167 return end - start
169 def to_dict(self) -> dict:
170 return {
171 "id": self.id, "name": self.name, "task_description": self.task_description,
172 "status": self.status.value, "config": {
173 "max_duration_seconds": self.config.max_duration_seconds,
174 "max_cost_usd": self.config.max_cost_usd,
175 "enable_checkpoints": self.config.enable_checkpoints,
176 "checkpoint_interval": self.config.checkpoint_interval,
177 "auto_resume": self.config.auto_resume,
178 "max_retries": self.config.max_retries,
179 },
180 "progress": self.progress.to_dict(),
181 "result": self.result, "error": self.error,
182 "created_at": self.created_at, "started_at": self.started_at,
183 "finished_at": self.finished_at, "cost_usd": self.cost_usd,
184 "tokens_used": self.tokens_used, "checkpoint_id": self.checkpoint_id,
185 "metadata": self.metadata,
186 }
188 @classmethod
189 def from_dict(cls, d: dict) -> "BackgroundTask":
190 cfg_d = d.get("config", {})
191 return cls(
192 id=d["id"], name=d.get("name", ""),
193 task_description=d.get("task_description", ""),
194 status=BackgroundTaskStatus(d.get("status", "queued")),
195 config=BackgroundTaskConfig(
196 max_duration_seconds=cfg_d.get("max_duration_seconds", 3600.0),
197 max_cost_usd=cfg_d.get("max_cost_usd", 10.0),
198 enable_checkpoints=cfg_d.get("enable_checkpoints", True),
199 checkpoint_interval=cfg_d.get("checkpoint_interval", 20),
200 auto_resume=cfg_d.get("auto_resume", True),
201 max_retries=cfg_d.get("max_retries", 2),
202 ),
203 progress=TaskProgress.from_dict(d.get("progress", {"task_id": d["id"]})),
204 result=d.get("result"), error=d.get("error", ""),
205 created_at=d.get("created_at", 0.0), started_at=d.get("started_at", 0.0),
206 finished_at=d.get("finished_at", 0.0), cost_usd=d.get("cost_usd", 0.0),
207 tokens_used=d.get("tokens_used", 0), checkpoint_id=d.get("checkpoint_id", ""),
208 metadata=d.get("metadata", {}),
209 )
212# ── Callback types ───────────────────────────────────────────────
214ProgressCallback = Callable[[TaskProgress], None]
215CompletionCallback = Callable[[BackgroundTask], None]
218# ── Background Task Manager ──────────────────────────────────────
220class BackgroundTaskManager:
221 """
222 Manages long-running background agent tasks.
224 Features:
225 - Async task submission with configurable budgets
226 - Persistent task state (in-memory + optional DB store)
227 - Progress tracking with named phases
228 - Pause/resume/cancel by task ID
229 - Crash recovery with checkpoint replay
230 - Concurrent task execution with configurable max workers
231 """
233 def __init__(
234 self,
235 max_concurrent: int = 5,
236 store: Any = None, # Optional CheckpointStore-like persistence
237 ):
238 self.max_concurrent = max_concurrent
239 self._store = store
240 self._semaphore = asyncio.Semaphore(max_concurrent)
241 self._tasks: dict[str, BackgroundTask] = {}
242 self._running: dict[str, asyncio.Task] = {}
243 self._progress_callbacks: dict[str, list[ProgressCallback]] = {}
244 self._completion_callbacks: dict[str, list[CompletionCallback]] = {}
246 # ── Public API ───────────────────────────────────────────────
248 async def submit(
249 self,
250 name: str,
251 task: str,
252 loop_factory: Callable[[], Any] | None = None,
253 agent_loop: Any = None,
254 config: BackgroundTaskConfig | None = None,
255 phases: list[ProgressPhase] | None = None,
256 ) -> str:
257 """Submit a task for background execution. Returns task_id."""
258 bt = BackgroundTask(
259 name=name,
260 task_description=task,
261 config=config or BackgroundTaskConfig(),
262 )
263 if phases:
264 bt.progress.phases = phases
265 bt.progress.total_steps = len(phases)
267 bt.progress.last_update = time.time()
268 self._tasks[bt.id] = bt
270 if self._store:
271 await self._persist(bt)
273 # Start in background
274 coro = self._run_task(bt, loop_factory, agent_loop)
275 self._running[bt.id] = asyncio.create_task(coro)
277 return bt.id
279 async def get_task(self, task_id: str) -> BackgroundTask | None:
280 """Get full task record."""
281 if task_id in self._tasks:
282 return self._tasks[task_id]
283 if self._store:
284 return await self._load(task_id)
285 return None
287 async def get_progress(self, task_id: str) -> TaskProgress | None:
288 """Get current progress for a task."""
289 t = await self.get_task(task_id)
290 return t.progress if t else None
292 async def get_result(self, task_id: str) -> Any:
293 """Get task result (blocks if still running)."""
294 t = await self.get_task(task_id)
295 if not t:
296 raise KeyError(f"Task {task_id} not found")
297 if t.status.is_active:
298 # Wait for completion
299 running_task = self._running.get(task_id)
300 if running_task and not running_task.done():
301 await running_task
302 t = self._tasks.get(task_id)
303 if not t:
304 raise KeyError(f"Task {task_id} vanished")
305 if t.status == BackgroundTaskStatus.FAILED:
306 raise RuntimeError(f"Task {task_id} failed: {t.error}")
307 return t.result
309 async def pause(self, task_id: str) -> bool:
310 """Pause a running task."""
311 t = self._tasks.get(task_id)
312 if not t or not t.status.is_active:
313 return False
314 t.status = BackgroundTaskStatus.PAUSED
315 t.progress.status = BackgroundTaskStatus.PAUSED
316 await self._update_progress(task_id)
317 return True
319 async def resume(self, task_id: str) -> bool:
320 """Resume a paused task."""
321 t = self._tasks.get(task_id)
322 if not t or t.status != BackgroundTaskStatus.PAUSED:
323 return False
324 t.status = BackgroundTaskStatus.RUNNING
325 t.progress.status = BackgroundTaskStatus.RUNNING
326 await self._update_progress(task_id)
327 return True
329 async def cancel(self, task_id: str) -> bool:
330 """Cancel a task."""
331 t = self._tasks.get(task_id)
332 if not t:
333 return False
334 t.status = BackgroundTaskStatus.CANCELLED
335 t.progress.status = BackgroundTaskStatus.CANCELLED
336 t.finished_at = time.time()
337 running = self._running.pop(task_id, None)
338 if running and not running.done():
339 running.cancel()
340 await self._update_progress(task_id)
341 await self._notify_completion(task_id)
342 if self._store:
343 await self._persist(t)
344 return True
346 async def list_tasks(
347 self,
348 status: BackgroundTaskStatus | None = None,
349 limit: int = 50,
350 ) -> list[BackgroundTask]:
351 """List tasks, optionally filtered by status."""
352 tasks = list(self._tasks.values())
353 if status:
354 tasks = [t for t in tasks if t.status == status]
355 return sorted(tasks, key=lambda t: t.created_at, reverse=True)[:limit]
357 def on_progress(self, task_id: str, callback: ProgressCallback):
358 """Register a progress callback for a task."""
359 if task_id not in self._progress_callbacks:
360 self._progress_callbacks[task_id] = []
361 self._progress_callbacks[task_id].append(callback)
363 def on_completion(self, task_id: str, callback: CompletionCallback):
364 """Register a completion callback for a task."""
365 if task_id not in self._completion_callbacks:
366 self._completion_callbacks[task_id] = []
367 self._completion_callbacks[task_id].append(callback)
369 # ── Progress Reporting ───────────────────────────────────────
371 async def update_phase(
372 self, task_id: str, phase_name: str,
373 completed: bool = False, step: int = 0, message: str = "",
374 ):
375 """Update a named phase in the task progress."""
376 t = self._tasks.get(task_id)
377 if not t or not t.config.enable_progress:
378 return
380 prog = t.progress
381 # Find or create phase
382 phase = None
383 for p in prog.phases:
384 if p.name == phase_name:
385 phase = p
386 break
387 if not phase:
388 phase = ProgressPhase(name=phase_name, label=phase_name)
389 prog.phases.append(phase)
390 prog.total_steps = len(prog.phases)
392 if completed:
393 phase.completed = True
394 phase.finished_at = time.time()
395 elif not phase.started_at:
396 phase.started_at = time.time()
398 prog.current_phase = phase_name
399 if step:
400 prog.current_step = step
401 if message:
402 prog.message = message
404 # Calculate percent from phase weights
405 total_weight = sum(p.weight for p in prog.phases)
406 completed_weight = sum(p.weight for p in prog.phases if p.completed)
407 if prog.current_phase and total_weight > 0:
408 current_phase_obj = phase
409 if current_phase_obj and not current_phase_obj.completed and current_phase_obj.weight > 0:
410 # Partial credit for current phase
411 partial = current_phase_obj.weight * min(step / max(t.config.checkpoint_interval, 1), 1.0)
412 completed_weight += partial
413 prog.percent = min(completed_weight / total_weight * 100, 99.9)
414 elif completed_weight >= total_weight:
415 prog.percent = 100.0
417 prog.last_update = time.time()
418 elapsed = prog.last_update - (t.started_at or t.created_at)
419 prog.elapsed_seconds = elapsed
420 if prog.percent > 0:
421 prog.estimated_remaining_seconds = elapsed / (prog.percent / 100) - elapsed
423 await self._update_progress(task_id)
425 # ── Internal ─────────────────────────────────────────────────
427 async def _run_task(
428 self,
429 bt: BackgroundTask,
430 loop_factory: Callable[[], Any] | None,
431 agent_loop: Any,
432 ):
433 """Execute a background task with full lifecycle management."""
434 async with self._semaphore:
435 bt.status = BackgroundTaskStatus.RUNNING
436 bt.progress.status = BackgroundTaskStatus.RUNNING
437 bt.started_at = time.time()
438 await self._update_progress(bt.id)
439 if self._store:
440 await self._persist(bt)
442 try:
443 # Timeout enforcement
444 timeout = bt.config.max_duration_seconds
445 start = time.time()
447 if loop_factory:
448 loop = loop_factory()
449 elif agent_loop:
450 loop = agent_loop
451 else:
452 raise ValueError("Must provide loop_factory or agent_loop")
454 # Inject progress callback into the loop
455 original_on_iteration = getattr(loop, 'on_iteration', None)
457 async def progress_on_iteration(iteration: int, tool_results: list):
458 elapsed = time.time() - start
459 if elapsed > timeout:
460 raise asyncio.TimeoutError("Task exceeded max duration")
461 if bt.status == BackgroundTaskStatus.PAUSED:
462 # Spin-wait for resume (or timeout)
463 while bt.status == BackgroundTaskStatus.PAUSED:
464 await asyncio.sleep(0.5)
465 if time.time() - start > timeout:
466 raise asyncio.TimeoutError("Task timed out while paused")
467 if bt.config.enable_checkpoints and iteration % bt.config.checkpoint_interval == 0:
468 bt.progress.current_step = iteration
469 await self.update_phase(bt.id, "execution", step=iteration,
470 message=f"Step {iteration}")
471 if original_on_iteration:
472 original_on_iteration(iteration, tool_results)
474 loop.on_iteration = progress_on_iteration
476 # Run
477 result = await loop.run(bt.task_description, session_id=bt.id)
478 bt.result = result.output if hasattr(result, 'output') else result
479 bt.cost_usd = getattr(result, 'cost_usd', 0.0)
480 bt.tokens_used = sum(getattr(result, 'tokens_used', {}).values())
481 bt.status = BackgroundTaskStatus.COMPLETED
482 bt.progress.status = BackgroundTaskStatus.COMPLETED
483 bt.progress.percent = 100.0
485 except asyncio.TimeoutError:
486 bt.status = BackgroundTaskStatus.TIMED_OUT
487 bt.progress.status = BackgroundTaskStatus.TIMED_OUT
488 bt.error = f"Exceeded max duration of {bt.config.max_duration_seconds}s"
489 except asyncio.CancelledError:
490 bt.status = BackgroundTaskStatus.CANCELLED
491 bt.progress.status = BackgroundTaskStatus.CANCELLED
492 except Exception as e:
493 bt.status = BackgroundTaskStatus.FAILED
494 bt.progress.status = BackgroundTaskStatus.FAILED
495 bt.error = str(e)
496 finally:
497 bt.finished_at = time.time()
498 bt.progress.last_update = time.time()
499 bt.progress.elapsed_seconds = bt.finished_at - bt.started_at
500 await self._update_progress(bt.id)
501 await self._notify_completion(bt.id)
502 if self._store:
503 await self._persist(bt)
504 self._running.pop(bt.id, None)
506 async def _update_progress(self, task_id: str):
507 """Fire progress callbacks."""
508 callbacks = self._progress_callbacks.get(task_id, [])
509 if not callbacks:
510 return
511 t = self._tasks.get(task_id)
512 if t:
513 for cb in callbacks:
514 try:
515 cb(t.progress)
516 except Exception:
517 pass
519 async def _notify_completion(self, task_id: str):
520 """Fire completion callbacks."""
521 callbacks = self._completion_callbacks.get(task_id, [])
522 if not callbacks:
523 return
524 t = self._tasks.get(task_id)
525 if t:
526 for cb in callbacks:
527 try:
528 cb(t)
529 except Exception:
530 pass
532 async def _persist(self, bt: BackgroundTask):
533 """Persist task to store."""
534 if not self._store:
535 return
536 try:
537 data = json.dumps(bt.to_dict())
538 if hasattr(self._store, 'save'):
539 await self._store.save(f"bg_task:{bt.id}", {"data": data})
540 except Exception:
541 pass
543 async def _load(self, task_id: str) -> BackgroundTask | None:
544 """Load task from store."""
545 if not self._store:
546 return None
547 try:
548 snap = await self._store.load(f"bg_task:{task_id}")
549 if snap and "data" in snap:
550 return BackgroundTask.from_dict(json.loads(snap["data"]))
551 except Exception:
552 pass
553 return None