Coverage for agentos/swarm/execution_trace.py: 37%
185 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2v1.9.6: Execution Trace — full observability into Agent task execution.
4Captures every sub-task, gate evaluation, retry, and timing detail.
5Supports timeline visualization, bottleneck detection, and debugging.
6"""
8from __future__ import annotations
10import json
11import time
12import uuid
13from collections import defaultdict
14from dataclasses import dataclass, field
15from enum import Enum
16from typing import Any, Optional
19class TraceEvent(str, Enum):
20 """Event types in an execution trace."""
22 TASK_START = "task_start"
23 TASK_END = "task_end"
24 SUBTASK_START = "subtask_start"
25 SUBTASK_END = "subtask_end"
26 DECOMPOSE = "decompose"
27 FUSE = "fuse"
28 RETRY = "retry"
29 FALLBACK = "fallback"
30 GATE_CHECK = "gate_check"
31 HITL_BREAK = "hitl_break"
32 HITL_RESUME = "hitl_resume"
33 SANDBOX_RUN = "sandbox_run"
34 ERROR = "error"
35 ABORT = "abort"
38@dataclass
39class TraceSpan:
40 """A single span in an execution trace."""
42 id: str = field(default_factory=lambda: uuid.uuid4().hex[:8])
43 parent_id: str = ""
44 event: TraceEvent = TraceEvent.TASK_START
45 name: str = ""
46 status: str = "started" # started, done, failed, aborted
47 start_ms: float = field(default_factory=lambda: time.time() * 1000)
48 end_ms: float = 0.0
49 duration_ms: float = 0.0
50 data: dict[str, Any] = field(default_factory=dict)
51 tags: list[str] = field(default_factory=list)
52 children: list[TraceSpan] = field(default_factory=list)
54 @property
55 def is_leaf(self) -> bool:
56 return len(self.children) == 0
58 def to_dict(self) -> dict:
59 return {
60 "id": self.id,
61 "parent_id": self.parent_id,
62 "event": self.event.value,
63 "name": self.name,
64 "status": self.status,
65 "start_ms": f"{self.start_ms:.1f}",
66 "end_ms": f"{self.end_ms:.1f}" if self.end_ms else "-",
67 "duration_ms": f"{self.duration_ms:.1f}",
68 "data": {k: str(v)[:100] for k, v in self.data.items()},
69 "tags": self.tags,
70 "children": [c.to_dict() for c in self.children] if self.children else [],
71 }
73 def to_flat_list(self) -> list[dict]:
74 """Flatten tree to list for tabular display."""
75 rows = [self.to_dict()]
76 for child in self.children:
77 rows.extend(child.to_flat_list())
78 return rows
81@dataclass
82class ExecutionTrace:
83 """
84 Full execution trace for a single task execution.
86 Captures a tree of spans representing every step: decomposition,
87 sub-task execution, fusion, retries, gate checks, etc.
89 Usage:
90 trace = ExecutionTrace(task_name="research_query")
92 span = trace.start_span(TraceEvent.TASK_START, name="main")
93 # ... do work ...
94 trace.end_span(span.id, status="done")
96 print(trace.summary())
97 print(trace.to_json())
98 """
100 id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
101 task_name: str = ""
102 root_span: TraceSpan | None = None
103 _span_map: dict[str, TraceSpan] = field(default_factory=dict)
104 _current_spans: list[str] = field(default_factory=list) # stack
105 total_spans: int = 0
106 total_retries: int = 0
107 total_errors: int = 0
108 total_fallbacks: int = 0
109 created_at: float = field(default_factory=time.time)
111 def start_span(
112 self,
113 event: TraceEvent,
114 name: str = "",
115 data: dict | None = None,
116 tags: list[str] | None = None,
117 ) -> TraceSpan:
118 """Start a new span and add it to the trace tree."""
119 span = TraceSpan(
120 event=event,
121 name=name,
122 data=data or {},
123 tags=tags or [],
124 )
126 # Determine parent
127 if self._current_spans:
128 span.parent_id = self._current_spans[-1]
129 parent = self._span_map.get(span.parent_id)
130 if parent:
131 parent.children.append(span)
132 elif self.root_span is None:
133 self.root_span = span
134 else:
135 # Attach to root as sibling
136 span.parent_id = self.root_span.id
137 self.root_span.children.append(span)
139 self._span_map[span.id] = span
140 self._current_spans.append(span.id)
141 self.total_spans += 1
143 if event == TraceEvent.RETRY:
144 self.total_retries += 1
145 elif event == TraceEvent.ERROR:
146 self.total_errors += 1
147 elif event == TraceEvent.FALLBACK:
148 self.total_fallbacks += 1
150 return span
152 def end_span(self, span_id: str, status: str = "done", data: dict | None = None) -> TraceSpan | None:
153 """End a span and record its duration."""
154 span = self._span_map.get(span_id)
155 if not span:
156 return None
158 span.end_ms = time.time() * 1000
159 span.duration_ms = span.end_ms - span.start_ms
160 span.status = status
161 if data:
162 span.data.update(data)
164 # Pop from current stack
165 if self._current_spans and self._current_spans[-1] == span_id:
166 self._current_spans.pop()
168 return span
170 def add_event(
171 self,
172 event: TraceEvent,
173 name: str = "",
174 data: dict | None = None,
175 tags: list[str] | None = None,
176 duration_ms: float = 0.0,
177 ) -> TraceSpan:
178 """Quick-add a leaf event (start+end in one call)."""
179 span = self.start_span(event, name, data, tags)
180 span.end_ms = span.start_ms + duration_ms
181 span.duration_ms = duration_ms
182 span.status = "done"
184 if self._current_spans and self._current_spans[-1] == span.id:
185 self._current_spans.pop()
187 return span
189 def to_dict(self) -> dict:
190 return {
191 "trace_id": self.id,
192 "task_name": self.task_name,
193 "total_spans": self.total_spans,
194 "total_retries": self.total_retries,
195 "total_errors": self.total_errors,
196 "total_fallbacks": self.total_fallbacks,
197 "created_at": self.created_at,
198 "root": self.root_span.to_dict() if self.root_span else {},
199 }
201 def to_json(self, indent: int = 2) -> str:
202 return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
204 def to_tree_string(self, span: TraceSpan | None = None, indent: int = 0) -> str:
205 """Render trace as indented ASCII tree."""
206 if span is None:
207 span = self.root_span
208 if span is None:
209 return "(empty trace)"
211 lines = []
212 prefix = " " * indent
213 status_icon = {"done": "O", "failed": "X", "started": ">", "aborted": "!"}.get(span.status, "?")
214 lines.append(
215 f"{prefix}{status_icon} [{span.event.value}] {span.name} "
216 f"({span.duration_ms:.0f}ms) [{span.status}]"
217 )
218 for child in span.children:
219 lines.append(self.to_tree_string(child, indent + 1))
220 return "\n".join(lines)
222 def summary(self) -> str:
223 """One-line summary of trace."""
224 total_ms = 0.0
225 if self.root_span and self.root_span.duration_ms:
226 total_ms = self.root_span.duration_ms
227 return (
228 f"Trace[{self.id}] '{self.task_name}' "
229 f"{self.total_spans} spans, "
230 f"{self.total_retries} retries, "
231 f"{self.total_errors} errors, "
232 f"{total_ms:.0f}ms total"
233 )
235 def bottlenecks(self, top_n: int = 5) -> list[dict]:
236 """Find slowest spans — helps identify bottlenecks."""
237 all_spans: list[TraceSpan] = []
239 def collect(s: TraceSpan):
240 all_spans.append(s)
241 for c in s.children:
242 collect(c)
244 if self.root_span:
245 collect(self.root_span)
247 sorted_spans = sorted(all_spans, key=lambda s: s.duration_ms, reverse=True)
248 result = []
249 for s in sorted_spans[:top_n]:
250 result.append({
251 "name": s.name,
252 "event": s.event.value,
253 "duration_ms": round(s.duration_ms, 1),
254 "status": s.status,
255 "tags": s.tags,
256 })
257 return result
259 def errors_list(self) -> list[dict]:
260 """List all error spans."""
261 errors: list[dict] = []
263 def collect(s: TraceSpan):
264 if s.event == TraceEvent.ERROR or s.status == "failed":
265 errors.append({
266 "name": s.name,
267 "data": {k: str(v)[:100] for k, v in s.data.items()},
268 "tags": s.tags,
269 })
270 for c in s.children:
271 collect(c)
273 if self.root_span:
274 collect(self.root_span)
276 return errors
278 def timeline(self) -> list[dict]:
279 """Generate a flat timeline of all spans sorted by start time."""
280 all_spans: list[TraceSpan] = []
282 def collect(s: TraceSpan):
283 all_spans.append(s)
284 for c in s.children:
285 collect(c)
287 if self.root_span:
288 collect(self.root_span)
290 all_spans.sort(key=lambda s: s.start_ms)
292 timeline = []
293 for s in all_spans:
294 timeline.append({
295 "time_ms": f"{s.start_ms:.1f}",
296 "event": s.event.value,
297 "name": s.name,
298 "duration_ms": f"{s.duration_ms:.1f}",
299 "status": s.status,
300 })
301 return timeline
304class TraceCollector:
305 """
306 Collects multiple execution traces and generates aggregate reports.
308 Usage:
309 collector = TraceCollector()
310 trace1 = await run_task("query_a")
311 collector.add(trace1)
312 trace2 = await run_task("query_b")
313 collector.add(trace2)
315 print(collector.stats())
316 """
318 def __init__(self, max_traces: int = 100):
319 self._traces: dict[str, ExecutionTrace] = {}
320 self.max_traces = max_traces
322 def add(self, trace: ExecutionTrace) -> None:
323 self._traces[trace.id] = trace
324 if len(self._traces) > self.max_traces:
325 oldest = next(iter(self._traces))
326 del self._traces[oldest]
328 def get(self, trace_id: str) -> ExecutionTrace | None:
329 return self._traces.get(trace_id)
331 def stats(self) -> dict:
332 """Aggregate statistics across all traces."""
333 if not self._traces:
334 return {"count": 0}
336 total_spans = sum(t.total_spans for t in self._traces.values())
337 total_retries = sum(t.total_retries for t in self._traces.values())
338 total_errors = sum(t.total_errors for t in self._traces.values())
339 total_fallbacks = sum(t.total_fallbacks for t in self._traces.values())
341 durations = []
342 for t in self._traces.values():
343 if t.root_span and t.root_span.duration_ms:
344 durations.append(t.root_span.duration_ms)
346 return {
347 "count": len(self._traces),
348 "total_spans": total_spans,
349 "total_retries": total_retries,
350 "total_errors": total_errors,
351 "total_fallbacks": total_fallbacks,
352 "retry_rate": round(total_retries / max(total_spans, 1), 3),
353 "error_rate": round(total_errors / max(total_spans, 1), 3),
354 "avg_duration_ms": round(sum(durations) / len(durations), 1) if durations else 0,
355 "max_duration_ms": round(max(durations), 1) if durations else 0,
356 "min_duration_ms": round(min(durations), 1) if durations else 0,
357 }
359 def failed_tasks(self) -> list[dict]:
360 """List tasks that ended with errors."""
361 failed = []
362 for t in self._traces.values():
363 if t.root_span and t.root_span.status in ("failed", "aborted"):
364 failed.append({
365 "trace_id": t.id,
366 "task_name": t.task_name,
367 "status": t.root_span.status,
368 "errors": t.total_errors,
369 })
370 return failed