Coverage for agentos/background/supervisor.py: 35%
283 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Agent Supervision Tree — v1.11.0
4Resource-bounded agent hierarchy with monitoring, quotas, and auto-recovery.
5Inspired by OS process supervision (systemd, supervisord).
7Features:
8- Hierarchical supervision tree: parent monitors children
9- Resource quotas per agent (time, cost, tokens, concurrency)
10- Heartbeat-based health monitoring
11- Auto-kill for runaway agents exceeding quotas
12- Graceful degradation: kill child, preserve parent
13- Aggregate progress across the tree
14- Supervision events for external monitoring
16Usage:
17 sup = AgentSupervisor()
18 child = await sup.spawn(
19 name="data_analyzer",
20 loop_factory=lambda: AgentLoop(...),
21 quotas=AgentQuota(max_duration=600, max_cost_usd=2.0),
22 )
23 result = await sup.await_child(child.id, timeout=600)
24"""
26from __future__ import annotations
28import asyncio
29import time
30import uuid
31from dataclasses import dataclass, field
32from enum import Enum
33from typing import Any, Callable, Coroutine, Optional
36# ── Enums ────────────────────────────────────────────────────────
38class SupervisionEventType(str, Enum):
39 """Types of supervision events."""
40 SPAWNED = "spawned"
41 STARTED = "started"
42 HEARTBEAT = "heartbeat"
43 PROGRESS = "progress"
44 QUOTA_WARNING = "quota_warning" # Nearing quota limit
45 QUOTA_EXCEEDED = "quota_exceeded" # Quota hit, killed
46 HEARTBEAT_LOST = "heartbeat_lost" # Child unresponsive
47 COMPLETED = "completed"
48 FAILED = "failed"
49 CANCELLED = "cancelled"
50 KILLED = "killed" # Killed by supervisor
53@dataclass
54class SupervisionEvent:
55 """Event emitted by the supervision tree."""
56 type: SupervisionEventType
57 child_id: str
58 child_name: str
59 timestamp: float = field(default_factory=time.time)
60 data: dict[str, Any] = field(default_factory=dict)
61 message: str = ""
63 def to_dict(self) -> dict:
64 return {
65 "type": self.type.value, "child_id": self.child_id,
66 "child_name": self.child_name, "timestamp": self.timestamp,
67 "data": self.data, "message": self.message,
68 }
71# ── Data Models ──────────────────────────────────────────────────
73@dataclass
74class AgentQuota:
75 """Resource limits for a supervised agent."""
76 max_duration_seconds: float = 3600.0 # Wall-clock time budget
77 max_cost_usd: float = 10.0 # Cost budget
78 max_tokens: int = 1_000_000 # Token budget
79 max_iterations: int = 500 # Max loop iterations
80 heartbeat_interval: float = 10.0 # Seconds between heartbeats
81 heartbeat_timeout: float = 30.0 # Seconds before considered dead
82 max_retries: int = 0 # Auto-restart on failure (0=no restart)
83 retry_delay: float = 5.0 # Delay before restart
84 cooldown_period: float = 60.0 # Rate limit on restarts
87@dataclass
88class AgentQuotaUsage:
89 """Current resource consumption of a supervised agent."""
90 elapsed_seconds: float = 0.0
91 cost_usd: float = 0.0
92 tokens_used: int = 0
93 iterations: int = 0
94 heartbeats_received: int = 0
95 last_heartbeat: float = 0.0
96 restarts: int = 0
97 last_restart: float = 0.0
99 @property
100 def duration_percent(self) -> float:
101 return 0.0 # Set externally with quota context
103 @property
104 def cost_percent(self) -> float:
105 return 0.0
107 def to_dict(self) -> dict:
108 return {
109 "elapsed_seconds": self.elapsed_seconds,
110 "cost_usd": self.cost_usd, "tokens_used": self.tokens_used,
111 "iterations": self.iterations, "heartbeats_received": self.heartbeats_received,
112 "last_heartbeat": self.last_heartbeat, "restarts": self.restarts,
113 }
116@dataclass
117class SupervisedAgent:
118 """An agent running under supervision."""
119 id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
120 name: str = ""
121 quotas: AgentQuota = field(default_factory=AgentQuota)
122 usage: AgentQuotaUsage = field(default_factory=AgentQuotaUsage)
123 status: str = "pending" # pending/running/paused/completed/failed/killed
124 started_at: float = 0.0
125 finished_at: float = 0.0
126 result: Any = None
127 error: str = ""
128 metadata: dict[str, Any] = field(default_factory=dict)
130 # Internal
131 _task: asyncio.Task | None = field(default=None, repr=False)
132 _heartbeat_task: asyncio.Task | None = field(default=None, repr=False)
133 _pause_event: asyncio.Event | None = field(default=None, repr=False)
134 _kill_event: asyncio.Event | None = field(default=None, repr=False)
136 @property
137 def is_alive(self) -> bool:
138 return self.status in ("running", "paused")
140 @property
141 def duration_seconds(self) -> float:
142 end = self.finished_at or time.time()
143 return end - self.started_at if self.started_at else 0.0
145 def to_dict(self) -> dict:
146 return {
147 "id": self.id, "name": self.name,
148 "quotas": {
149 "max_duration_seconds": self.quotas.max_duration_seconds,
150 "max_cost_usd": self.quotas.max_cost_usd,
151 "max_tokens": self.quotas.max_tokens,
152 "max_iterations": self.quotas.max_iterations,
153 },
154 "usage": self.usage.to_dict(),
155 "status": self.status, "started_at": self.started_at,
156 "finished_at": self.finished_at, "error": self.error,
157 "metadata": self.metadata,
158 }
161@dataclass
162class SupervisorConfig:
163 """Global supervisor configuration."""
164 max_children: int = 20
165 monitor_interval: float = 1.0 # Seconds between health checks
166 event_history_size: int = 500 # Max events to retain
167 auto_kill_on_quota: bool = True
168 log_events: bool = True
171# ── Callback types ───────────────────────────────────────────────
173EventCallback = Callable[[SupervisionEvent], None]
176# ── Agent Supervisor ─────────────────────────────────────────────
178class AgentSupervisor:
179 """
180 Hierarchical supervision tree for long-running multi-agent tasks.
182 Monitors children for:
183 - Resource quota violations (time, cost, tokens)
184 - Heartbeat loss (crash/hang detection)
185 - Progress stalls
187 Actions:
188 - Auto-kill runaway agents
189 - Graceful restart (optional)
190 - Event emission for external monitoring
191 """
193 def __init__(
194 self,
195 config: SupervisorConfig | None = None,
196 on_event: EventCallback | None = None,
197 ):
198 self.config = config or SupervisorConfig()
199 self._on_event = on_event
200 self._children: dict[str, SupervisedAgent] = {}
201 self._events: list[SupervisionEvent] = []
202 self._monitor_task: asyncio.Task | None = None
203 self._lock = asyncio.Lock()
205 # ── Public API ───────────────────────────────────────────────
207 async def spawn(
208 self,
209 name: str,
210 task: str = "",
211 loop_factory: Callable[[], Any] | None = None,
212 agent_loop: Any = None,
213 quotas: AgentQuota | None = None,
214 on_heartbeat: Callable[[], Coroutine] | None = None,
215 metadata: dict[str, Any] | None = None,
216 ) -> str:
217 """
218 Spawn a new child agent under supervision.
220 Returns child_id for monitoring/control.
221 """
222 async with self._lock:
223 if len(self._children) >= self.config.max_children:
224 raise RuntimeError(f"Max children ({self.config.max_children}) reached")
226 child = SupervisedAgent(
227 name=name, quotas=quotas or AgentQuota(),
228 metadata=metadata or {},
229 )
230 child._pause_event = asyncio.Event()
231 child._pause_event.set() # Not paused
232 child._kill_event = asyncio.Event()
234 self._children[child.id] = child
235 self._emit(SupervisionEvent(SupervisionEventType.SPAWNED, child.id, name))
237 # Start monitoring in background
238 if not self._monitor_task or self._monitor_task.done():
239 self._monitor_task = asyncio.create_task(self._monitor_loop())
241 # Start heartbeat task
242 if on_heartbeat:
243 child._heartbeat_task = asyncio.create_task(
244 self._heartbeat_loop(child, on_heartbeat)
245 )
247 # Start execution
248 child._task = asyncio.create_task(
249 self._run_child(child, task, loop_factory, agent_loop)
250 )
252 return child.id
254 async def get_child(self, child_id: str) -> SupervisedAgent | None:
255 """Get child agent by ID."""
256 return self._children.get(child_id)
258 async def list_children(self) -> list[SupervisedAgent]:
259 """List all children with their status."""
260 return list(self._children.values())
262 async def await_child(self, child_id: str, timeout: float | None = None) -> Any:
263 """Wait for a child to complete, return its result."""
264 child = self._children.get(child_id)
265 if not child:
266 raise KeyError(f"Child {child_id} not found")
267 if not child._task:
268 raise RuntimeError(f"Child {child_id} has no running task")
270 try:
271 return await asyncio.wait_for(child._task, timeout=timeout)
272 except asyncio.TimeoutError:
273 # Kill the child on timeout
274 await self.kill_child(child_id, reason="await timeout")
275 raise
277 async def pause_child(self, child_id: str) -> bool:
278 """Pause a running child."""
279 child = self._children.get(child_id)
280 if not child or not child.is_alive or not child._pause_event:
281 return False
282 child._pause_event.clear()
283 child.status = "paused"
284 return True
286 async def resume_child(self, child_id: str) -> bool:
287 """Resume a paused child."""
288 child = self._children.get(child_id)
289 if not child or child.status != "paused" or not child._pause_event:
290 return False
291 child._pause_event.set()
292 child.status = "running"
293 self._emit(SupervisionEvent(SupervisionEventType.PROGRESS, child.id, child.name,
294 message="Resumed"))
295 return True
297 async def kill_child(self, child_id: str, reason: str = "") -> bool:
298 """Force-kill a child agent."""
299 child = self._children.get(child_id)
300 if not child or not child.is_alive:
301 return False
303 if child._kill_event:
304 child._kill_event.set()
305 if child._task and not child._task.done():
306 child._task.cancel()
308 child.status = "killed"
309 child.finished_at = time.time()
310 child.error = reason
312 self._emit(SupervisionEvent(
313 SupervisionEventType.KILLED, child.id, child.name,
314 message=reason,
315 ))
316 return True
318 async def aggregate_progress(self) -> dict[str, Any]:
319 """Aggregate progress across all children."""
320 children = list(self._children.values())
321 total = len(children)
322 completed = sum(1 for c in children if c.status == "completed")
323 failed = sum(1 for c in children if c.status in ("failed", "killed"))
324 running = sum(1 for c in children if c.status == "running")
326 # Aggregate costs
327 total_cost = sum(c.usage.cost_usd for c in children)
328 total_tokens = sum(c.usage.tokens_used for c in children)
330 return {
331 "total_children": total,
332 "completed": completed,
333 "failed": failed,
334 "running": running,
335 "total_cost_usd": total_cost,
336 "total_tokens": total_tokens,
337 "percent_complete": (completed / total * 100) if total > 0 else 0,
338 "children": [c.to_dict() for c in children],
339 }
341 async def shutdown(self, timeout: float = 10.0):
342 """Graceful shutdown: pause new spawns, wait for children, kill stragglers."""
343 # Kill all running children
344 for child_id in list(self._children.keys()):
345 await self.kill_child(child_id, reason="supervisor shutdown")
347 if self._monitor_task and not self._monitor_task.done():
348 self._monitor_task.cancel()
350 # Wait for children to die
351 deadline = time.time() + timeout
352 for child in self._children.values():
353 if child._task and not child._task.done():
354 remaining = max(0, deadline - time.time())
355 try:
356 await asyncio.wait_for(child._task, timeout=remaining)
357 except (asyncio.TimeoutError, asyncio.CancelledError):
358 pass
360 # ── Internal ─────────────────────────────────────────────────
362 async def _run_child(
363 self, child: SupervisedAgent, task: str,
364 loop_factory: Callable[[], Any] | None,
365 agent_loop: Any,
366 ):
367 """Execute a child agent with full supervision."""
368 child.status = "running"
369 child.started_at = time.time()
370 self._emit(SupervisionEvent(SupervisionEventType.STARTED, child.id, child.name))
372 try:
373 if loop_factory:
374 loop = loop_factory()
375 elif agent_loop:
376 loop = agent_loop
377 else:
378 raise ValueError("Must provide loop_factory or agent_loop")
380 # Wrap loop to check for pause/kill signals
381 original_on_iteration = getattr(loop, 'on_iteration', None)
383 async def supervised_on_iteration(iteration: int, tool_results: list):
384 # Check kill signal
385 if child._kill_event and child._kill_event.is_set():
386 raise asyncio.CancelledError("Killed by supervisor")
388 # Check pause signal
389 if child._pause_event:
390 await child._pause_event.wait()
392 # Update usage
393 child.usage.iterations = iteration
394 child.usage.elapsed_seconds = time.time() - child.started_at
396 # Quota checks
397 if child.usage.elapsed_seconds > child.quotas.max_duration_seconds:
398 if self.config.auto_kill_on_quota:
399 child._kill_event.set()
400 raise asyncio.TimeoutError("Duration quota exceeded")
401 else:
402 self._emit(SupervisionEvent(
403 SupervisionEventType.QUOTA_WARNING, child.id, child.name,
404 message=f"Duration at {child.usage.elapsed_seconds:.0f}s / {child.quotas.max_duration_seconds}s",
405 ))
407 if original_on_iteration:
408 original_on_iteration(iteration, tool_results)
410 loop.on_iteration = supervised_on_iteration
412 # Run with retry logic
413 for attempt in range(child.quotas.max_retries + 1):
414 try:
415 result = await loop.run(task, session_id=child.id)
416 child.result = result.output if hasattr(result, 'output') else result
417 child.usage.cost_usd = getattr(result, 'cost_usd', 0.0)
418 child.usage.tokens_used = sum(getattr(result, 'tokens_used', {}).values())
419 child.status = "completed"
420 child.finished_at = time.time()
421 self._emit(SupervisionEvent(SupervisionEventType.COMPLETED, child.id, child.name))
422 return child.result
423 except (asyncio.CancelledError, asyncio.TimeoutError):
424 raise
425 except Exception as e:
426 if attempt < child.quotas.max_retries:
427 # Check cooldown
428 since_last = time.time() - child.usage.last_restart
429 if child.usage.restarts > 0 and since_last < child.quotas.cooldown_period:
430 await asyncio.sleep(child.quotas.cooldown_period - since_last)
431 child.usage.restarts += 1
432 child.usage.last_restart = time.time()
433 await asyncio.sleep(child.quotas.retry_delay)
434 continue
435 child.status = "failed"
436 child.finished_at = time.time()
437 child.error = str(e)
438 self._emit(SupervisionEvent(
439 SupervisionEventType.FAILED, child.id, child.name,
440 message=str(e),
441 ))
442 raise
444 except asyncio.CancelledError:
445 child.status = "killed"
446 child.finished_at = time.time()
447 except Exception as e:
448 child.status = "failed"
449 child.finished_at = time.time()
450 child.error = str(e)
451 self._emit(SupervisionEvent(
452 SupervisionEventType.FAILED, child.id, child.name,
453 message=str(e),
454 ))
456 async def _heartbeat_loop(
457 self, child: SupervisedAgent,
458 on_heartbeat: Callable[[], Coroutine],
459 ):
460 """Send periodic heartbeats and update usage."""
461 interval = child.quotas.heartbeat_interval
462 while child.is_alive:
463 try:
464 await asyncio.sleep(interval)
465 if not child.is_alive:
466 break
467 await on_heartbeat()
468 child.usage.heartbeats_received += 1
469 child.usage.last_heartbeat = time.time()
470 self._emit(SupervisionEvent(
471 SupervisionEventType.HEARTBEAT, child.id, child.name,
472 data={"heartbeats": child.usage.heartbeats_received},
473 ))
474 except asyncio.CancelledError:
475 break
476 except Exception:
477 pass
479 async def _monitor_loop(self):
480 """Monitor all children for health and quota violations."""
481 while True:
482 try:
483 await asyncio.sleep(self.config.monitor_interval)
484 now = time.time()
486 for child in list(self._children.values()):
487 if not child.is_alive:
488 continue
490 # Heartbeat timeout check
491 if (child.quotas.heartbeat_timeout > 0 and
492 child.usage.last_heartbeat > 0 and
493 now - child.usage.last_heartbeat > child.quotas.heartbeat_timeout):
494 self._emit(SupervisionEvent(
495 SupervisionEventType.HEARTBEAT_LOST, child.id, child.name,
496 message=f"No heartbeat for {now - child.usage.last_heartbeat:.0f}s",
497 ))
498 await self.kill_child(child.id, reason="heartbeat lost")
500 # Duration check
501 elapsed = now - child.started_at if child.started_at else 0
502 if elapsed > child.quotas.max_duration_seconds * 0.9:
503 self._emit(SupervisionEvent(
504 SupervisionEventType.QUOTA_WARNING, child.id, child.name,
505 message=f"90% duration used: {elapsed:.0f}s / {child.quotas.max_duration_seconds}s",
506 ))
508 except asyncio.CancelledError:
509 break
510 except Exception:
511 pass
513 def _emit(self, event: SupervisionEvent):
514 """Emit a supervision event."""
515 if self.config.log_events:
516 self._events.append(event)
517 if len(self._events) > self.config.event_history_size:
518 self._events = self._events[-self.config.event_history_size:]
520 if self._on_event:
521 try:
522 self._on_event(event)
523 except Exception:
524 pass