Coverage for agentos/background/supervisor.py: 35%

283 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Agent Supervision Tree — v1.11.0 

3 

4Resource-bounded agent hierarchy with monitoring, quotas, and auto-recovery. 

5Inspired by OS process supervision (systemd, supervisord). 

6 

7Features: 

8- Hierarchical supervision tree: parent monitors children 

9- Resource quotas per agent (time, cost, tokens, concurrency) 

10- Heartbeat-based health monitoring 

11- Auto-kill for runaway agents exceeding quotas 

12- Graceful degradation: kill child, preserve parent 

13- Aggregate progress across the tree 

14- Supervision events for external monitoring 

15 

16Usage: 

17 sup = AgentSupervisor() 

18 child = await sup.spawn( 

19 name="data_analyzer", 

20 loop_factory=lambda: AgentLoop(...), 

21 quotas=AgentQuota(max_duration=600, max_cost_usd=2.0), 

22 ) 

23 result = await sup.await_child(child.id, timeout=600) 

24""" 

25 

26from __future__ import annotations 

27 

28import asyncio 

29import time 

30import uuid 

31from dataclasses import dataclass, field 

32from enum import Enum 

33from typing import Any, Callable, Coroutine, Optional 

34 

35 

36# ── Enums ──────────────────────────────────────────────────────── 

37 

38class SupervisionEventType(str, Enum): 

39 """Types of supervision events.""" 

40 SPAWNED = "spawned" 

41 STARTED = "started" 

42 HEARTBEAT = "heartbeat" 

43 PROGRESS = "progress" 

44 QUOTA_WARNING = "quota_warning" # Nearing quota limit 

45 QUOTA_EXCEEDED = "quota_exceeded" # Quota hit, killed 

46 HEARTBEAT_LOST = "heartbeat_lost" # Child unresponsive 

47 COMPLETED = "completed" 

48 FAILED = "failed" 

49 CANCELLED = "cancelled" 

50 KILLED = "killed" # Killed by supervisor 

51 

52 

53@dataclass 

54class SupervisionEvent: 

55 """Event emitted by the supervision tree.""" 

56 type: SupervisionEventType 

57 child_id: str 

58 child_name: str 

59 timestamp: float = field(default_factory=time.time) 

60 data: dict[str, Any] = field(default_factory=dict) 

61 message: str = "" 

62 

63 def to_dict(self) -> dict: 

64 return { 

65 "type": self.type.value, "child_id": self.child_id, 

66 "child_name": self.child_name, "timestamp": self.timestamp, 

67 "data": self.data, "message": self.message, 

68 } 

69 

70 

71# ── Data Models ────────────────────────────────────────────────── 

72 

73@dataclass 

74class AgentQuota: 

75 """Resource limits for a supervised agent.""" 

76 max_duration_seconds: float = 3600.0 # Wall-clock time budget 

77 max_cost_usd: float = 10.0 # Cost budget 

78 max_tokens: int = 1_000_000 # Token budget 

79 max_iterations: int = 500 # Max loop iterations 

80 heartbeat_interval: float = 10.0 # Seconds between heartbeats 

81 heartbeat_timeout: float = 30.0 # Seconds before considered dead 

82 max_retries: int = 0 # Auto-restart on failure (0=no restart) 

83 retry_delay: float = 5.0 # Delay before restart 

84 cooldown_period: float = 60.0 # Rate limit on restarts 

85 

86 

87@dataclass 

88class AgentQuotaUsage: 

89 """Current resource consumption of a supervised agent.""" 

90 elapsed_seconds: float = 0.0 

91 cost_usd: float = 0.0 

92 tokens_used: int = 0 

93 iterations: int = 0 

94 heartbeats_received: int = 0 

95 last_heartbeat: float = 0.0 

96 restarts: int = 0 

97 last_restart: float = 0.0 

98 

99 @property 

100 def duration_percent(self) -> float: 

101 return 0.0 # Set externally with quota context 

102 

103 @property 

104 def cost_percent(self) -> float: 

105 return 0.0 

106 

107 def to_dict(self) -> dict: 

108 return { 

109 "elapsed_seconds": self.elapsed_seconds, 

110 "cost_usd": self.cost_usd, "tokens_used": self.tokens_used, 

111 "iterations": self.iterations, "heartbeats_received": self.heartbeats_received, 

112 "last_heartbeat": self.last_heartbeat, "restarts": self.restarts, 

113 } 

114 

115 

116@dataclass 

117class SupervisedAgent: 

118 """An agent running under supervision.""" 

119 id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) 

120 name: str = "" 

121 quotas: AgentQuota = field(default_factory=AgentQuota) 

122 usage: AgentQuotaUsage = field(default_factory=AgentQuotaUsage) 

123 status: str = "pending" # pending/running/paused/completed/failed/killed 

124 started_at: float = 0.0 

125 finished_at: float = 0.0 

126 result: Any = None 

127 error: str = "" 

128 metadata: dict[str, Any] = field(default_factory=dict) 

129 

130 # Internal 

131 _task: asyncio.Task | None = field(default=None, repr=False) 

132 _heartbeat_task: asyncio.Task | None = field(default=None, repr=False) 

133 _pause_event: asyncio.Event | None = field(default=None, repr=False) 

134 _kill_event: asyncio.Event | None = field(default=None, repr=False) 

135 

136 @property 

137 def is_alive(self) -> bool: 

138 return self.status in ("running", "paused") 

139 

140 @property 

141 def duration_seconds(self) -> float: 

142 end = self.finished_at or time.time() 

143 return end - self.started_at if self.started_at else 0.0 

144 

145 def to_dict(self) -> dict: 

146 return { 

147 "id": self.id, "name": self.name, 

148 "quotas": { 

149 "max_duration_seconds": self.quotas.max_duration_seconds, 

150 "max_cost_usd": self.quotas.max_cost_usd, 

151 "max_tokens": self.quotas.max_tokens, 

152 "max_iterations": self.quotas.max_iterations, 

153 }, 

154 "usage": self.usage.to_dict(), 

155 "status": self.status, "started_at": self.started_at, 

156 "finished_at": self.finished_at, "error": self.error, 

157 "metadata": self.metadata, 

158 } 

159 

160 

161@dataclass 

162class SupervisorConfig: 

163 """Global supervisor configuration.""" 

164 max_children: int = 20 

165 monitor_interval: float = 1.0 # Seconds between health checks 

166 event_history_size: int = 500 # Max events to retain 

167 auto_kill_on_quota: bool = True 

168 log_events: bool = True 

169 

170 

171# ── Callback types ─────────────────────────────────────────────── 

172 

173EventCallback = Callable[[SupervisionEvent], None] 

174 

175 

176# ── Agent Supervisor ───────────────────────────────────────────── 

177 

178class AgentSupervisor: 

179 """ 

180 Hierarchical supervision tree for long-running multi-agent tasks. 

181 

182 Monitors children for: 

183 - Resource quota violations (time, cost, tokens) 

184 - Heartbeat loss (crash/hang detection) 

185 - Progress stalls 

186 

187 Actions: 

188 - Auto-kill runaway agents 

189 - Graceful restart (optional) 

190 - Event emission for external monitoring 

191 """ 

192 

193 def __init__( 

194 self, 

195 config: SupervisorConfig | None = None, 

196 on_event: EventCallback | None = None, 

197 ): 

198 self.config = config or SupervisorConfig() 

199 self._on_event = on_event 

200 self._children: dict[str, SupervisedAgent] = {} 

201 self._events: list[SupervisionEvent] = [] 

202 self._monitor_task: asyncio.Task | None = None 

203 self._lock = asyncio.Lock() 

204 

205 # ── Public API ─────────────────────────────────────────────── 

206 

207 async def spawn( 

208 self, 

209 name: str, 

210 task: str = "", 

211 loop_factory: Callable[[], Any] | None = None, 

212 agent_loop: Any = None, 

213 quotas: AgentQuota | None = None, 

214 on_heartbeat: Callable[[], Coroutine] | None = None, 

215 metadata: dict[str, Any] | None = None, 

216 ) -> str: 

217 """ 

218 Spawn a new child agent under supervision. 

219 

220 Returns child_id for monitoring/control. 

221 """ 

222 async with self._lock: 

223 if len(self._children) >= self.config.max_children: 

224 raise RuntimeError(f"Max children ({self.config.max_children}) reached") 

225 

226 child = SupervisedAgent( 

227 name=name, quotas=quotas or AgentQuota(), 

228 metadata=metadata or {}, 

229 ) 

230 child._pause_event = asyncio.Event() 

231 child._pause_event.set() # Not paused 

232 child._kill_event = asyncio.Event() 

233 

234 self._children[child.id] = child 

235 self._emit(SupervisionEvent(SupervisionEventType.SPAWNED, child.id, name)) 

236 

237 # Start monitoring in background 

238 if not self._monitor_task or self._monitor_task.done(): 

239 self._monitor_task = asyncio.create_task(self._monitor_loop()) 

240 

241 # Start heartbeat task 

242 if on_heartbeat: 

243 child._heartbeat_task = asyncio.create_task( 

244 self._heartbeat_loop(child, on_heartbeat) 

245 ) 

246 

247 # Start execution 

248 child._task = asyncio.create_task( 

249 self._run_child(child, task, loop_factory, agent_loop) 

250 ) 

251 

252 return child.id 

253 

254 async def get_child(self, child_id: str) -> SupervisedAgent | None: 

255 """Get child agent by ID.""" 

256 return self._children.get(child_id) 

257 

258 async def list_children(self) -> list[SupervisedAgent]: 

259 """List all children with their status.""" 

260 return list(self._children.values()) 

261 

262 async def await_child(self, child_id: str, timeout: float | None = None) -> Any: 

263 """Wait for a child to complete, return its result.""" 

264 child = self._children.get(child_id) 

265 if not child: 

266 raise KeyError(f"Child {child_id} not found") 

267 if not child._task: 

268 raise RuntimeError(f"Child {child_id} has no running task") 

269 

270 try: 

271 return await asyncio.wait_for(child._task, timeout=timeout) 

272 except asyncio.TimeoutError: 

273 # Kill the child on timeout 

274 await self.kill_child(child_id, reason="await timeout") 

275 raise 

276 

277 async def pause_child(self, child_id: str) -> bool: 

278 """Pause a running child.""" 

279 child = self._children.get(child_id) 

280 if not child or not child.is_alive or not child._pause_event: 

281 return False 

282 child._pause_event.clear() 

283 child.status = "paused" 

284 return True 

285 

286 async def resume_child(self, child_id: str) -> bool: 

287 """Resume a paused child.""" 

288 child = self._children.get(child_id) 

289 if not child or child.status != "paused" or not child._pause_event: 

290 return False 

291 child._pause_event.set() 

292 child.status = "running" 

293 self._emit(SupervisionEvent(SupervisionEventType.PROGRESS, child.id, child.name, 

294 message="Resumed")) 

295 return True 

296 

297 async def kill_child(self, child_id: str, reason: str = "") -> bool: 

298 """Force-kill a child agent.""" 

299 child = self._children.get(child_id) 

300 if not child or not child.is_alive: 

301 return False 

302 

303 if child._kill_event: 

304 child._kill_event.set() 

305 if child._task and not child._task.done(): 

306 child._task.cancel() 

307 

308 child.status = "killed" 

309 child.finished_at = time.time() 

310 child.error = reason 

311 

312 self._emit(SupervisionEvent( 

313 SupervisionEventType.KILLED, child.id, child.name, 

314 message=reason, 

315 )) 

316 return True 

317 

318 async def aggregate_progress(self) -> dict[str, Any]: 

319 """Aggregate progress across all children.""" 

320 children = list(self._children.values()) 

321 total = len(children) 

322 completed = sum(1 for c in children if c.status == "completed") 

323 failed = sum(1 for c in children if c.status in ("failed", "killed")) 

324 running = sum(1 for c in children if c.status == "running") 

325 

326 # Aggregate costs 

327 total_cost = sum(c.usage.cost_usd for c in children) 

328 total_tokens = sum(c.usage.tokens_used for c in children) 

329 

330 return { 

331 "total_children": total, 

332 "completed": completed, 

333 "failed": failed, 

334 "running": running, 

335 "total_cost_usd": total_cost, 

336 "total_tokens": total_tokens, 

337 "percent_complete": (completed / total * 100) if total > 0 else 0, 

338 "children": [c.to_dict() for c in children], 

339 } 

340 

341 async def shutdown(self, timeout: float = 10.0): 

342 """Graceful shutdown: pause new spawns, wait for children, kill stragglers.""" 

343 # Kill all running children 

344 for child_id in list(self._children.keys()): 

345 await self.kill_child(child_id, reason="supervisor shutdown") 

346 

347 if self._monitor_task and not self._monitor_task.done(): 

348 self._monitor_task.cancel() 

349 

350 # Wait for children to die 

351 deadline = time.time() + timeout 

352 for child in self._children.values(): 

353 if child._task and not child._task.done(): 

354 remaining = max(0, deadline - time.time()) 

355 try: 

356 await asyncio.wait_for(child._task, timeout=remaining) 

357 except (asyncio.TimeoutError, asyncio.CancelledError): 

358 pass 

359 

360 # ── Internal ───────────────────────────────────────────────── 

361 

362 async def _run_child( 

363 self, child: SupervisedAgent, task: str, 

364 loop_factory: Callable[[], Any] | None, 

365 agent_loop: Any, 

366 ): 

367 """Execute a child agent with full supervision.""" 

368 child.status = "running" 

369 child.started_at = time.time() 

370 self._emit(SupervisionEvent(SupervisionEventType.STARTED, child.id, child.name)) 

371 

372 try: 

373 if loop_factory: 

374 loop = loop_factory() 

375 elif agent_loop: 

376 loop = agent_loop 

377 else: 

378 raise ValueError("Must provide loop_factory or agent_loop") 

379 

380 # Wrap loop to check for pause/kill signals 

381 original_on_iteration = getattr(loop, 'on_iteration', None) 

382 

383 async def supervised_on_iteration(iteration: int, tool_results: list): 

384 # Check kill signal 

385 if child._kill_event and child._kill_event.is_set(): 

386 raise asyncio.CancelledError("Killed by supervisor") 

387 

388 # Check pause signal 

389 if child._pause_event: 

390 await child._pause_event.wait() 

391 

392 # Update usage 

393 child.usage.iterations = iteration 

394 child.usage.elapsed_seconds = time.time() - child.started_at 

395 

396 # Quota checks 

397 if child.usage.elapsed_seconds > child.quotas.max_duration_seconds: 

398 if self.config.auto_kill_on_quota: 

399 child._kill_event.set() 

400 raise asyncio.TimeoutError("Duration quota exceeded") 

401 else: 

402 self._emit(SupervisionEvent( 

403 SupervisionEventType.QUOTA_WARNING, child.id, child.name, 

404 message=f"Duration at {child.usage.elapsed_seconds:.0f}s / {child.quotas.max_duration_seconds}s", 

405 )) 

406 

407 if original_on_iteration: 

408 original_on_iteration(iteration, tool_results) 

409 

410 loop.on_iteration = supervised_on_iteration 

411 

412 # Run with retry logic 

413 for attempt in range(child.quotas.max_retries + 1): 

414 try: 

415 result = await loop.run(task, session_id=child.id) 

416 child.result = result.output if hasattr(result, 'output') else result 

417 child.usage.cost_usd = getattr(result, 'cost_usd', 0.0) 

418 child.usage.tokens_used = sum(getattr(result, 'tokens_used', {}).values()) 

419 child.status = "completed" 

420 child.finished_at = time.time() 

421 self._emit(SupervisionEvent(SupervisionEventType.COMPLETED, child.id, child.name)) 

422 return child.result 

423 except (asyncio.CancelledError, asyncio.TimeoutError): 

424 raise 

425 except Exception as e: 

426 if attempt < child.quotas.max_retries: 

427 # Check cooldown 

428 since_last = time.time() - child.usage.last_restart 

429 if child.usage.restarts > 0 and since_last < child.quotas.cooldown_period: 

430 await asyncio.sleep(child.quotas.cooldown_period - since_last) 

431 child.usage.restarts += 1 

432 child.usage.last_restart = time.time() 

433 await asyncio.sleep(child.quotas.retry_delay) 

434 continue 

435 child.status = "failed" 

436 child.finished_at = time.time() 

437 child.error = str(e) 

438 self._emit(SupervisionEvent( 

439 SupervisionEventType.FAILED, child.id, child.name, 

440 message=str(e), 

441 )) 

442 raise 

443 

444 except asyncio.CancelledError: 

445 child.status = "killed" 

446 child.finished_at = time.time() 

447 except Exception as e: 

448 child.status = "failed" 

449 child.finished_at = time.time() 

450 child.error = str(e) 

451 self._emit(SupervisionEvent( 

452 SupervisionEventType.FAILED, child.id, child.name, 

453 message=str(e), 

454 )) 

455 

456 async def _heartbeat_loop( 

457 self, child: SupervisedAgent, 

458 on_heartbeat: Callable[[], Coroutine], 

459 ): 

460 """Send periodic heartbeats and update usage.""" 

461 interval = child.quotas.heartbeat_interval 

462 while child.is_alive: 

463 try: 

464 await asyncio.sleep(interval) 

465 if not child.is_alive: 

466 break 

467 await on_heartbeat() 

468 child.usage.heartbeats_received += 1 

469 child.usage.last_heartbeat = time.time() 

470 self._emit(SupervisionEvent( 

471 SupervisionEventType.HEARTBEAT, child.id, child.name, 

472 data={"heartbeats": child.usage.heartbeats_received}, 

473 )) 

474 except asyncio.CancelledError: 

475 break 

476 except Exception: 

477 pass 

478 

479 async def _monitor_loop(self): 

480 """Monitor all children for health and quota violations.""" 

481 while True: 

482 try: 

483 await asyncio.sleep(self.config.monitor_interval) 

484 now = time.time() 

485 

486 for child in list(self._children.values()): 

487 if not child.is_alive: 

488 continue 

489 

490 # Heartbeat timeout check 

491 if (child.quotas.heartbeat_timeout > 0 and 

492 child.usage.last_heartbeat > 0 and 

493 now - child.usage.last_heartbeat > child.quotas.heartbeat_timeout): 

494 self._emit(SupervisionEvent( 

495 SupervisionEventType.HEARTBEAT_LOST, child.id, child.name, 

496 message=f"No heartbeat for {now - child.usage.last_heartbeat:.0f}s", 

497 )) 

498 await self.kill_child(child.id, reason="heartbeat lost") 

499 

500 # Duration check 

501 elapsed = now - child.started_at if child.started_at else 0 

502 if elapsed > child.quotas.max_duration_seconds * 0.9: 

503 self._emit(SupervisionEvent( 

504 SupervisionEventType.QUOTA_WARNING, child.id, child.name, 

505 message=f"90% duration used: {elapsed:.0f}s / {child.quotas.max_duration_seconds}s", 

506 )) 

507 

508 except asyncio.CancelledError: 

509 break 

510 except Exception: 

511 pass 

512 

513 def _emit(self, event: SupervisionEvent): 

514 """Emit a supervision event.""" 

515 if self.config.log_events: 

516 self._events.append(event) 

517 if len(self._events) > self.config.event_history_size: 

518 self._events = self._events[-self.config.event_history_size:] 

519 

520 if self._on_event: 

521 try: 

522 self._on_event(event) 

523 except Exception: 

524 pass