Coverage for agentos/swarm/execution_trace.py: 37%

185 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2v1.9.6: Execution Trace — full observability into Agent task execution. 

3 

4Captures every sub-task, gate evaluation, retry, and timing detail. 

5Supports timeline visualization, bottleneck detection, and debugging. 

6""" 

7 

8from __future__ import annotations 

9 

10import json 

11import time 

12import uuid 

13from collections import defaultdict 

14from dataclasses import dataclass, field 

15from enum import Enum 

16from typing import Any, Optional 

17 

18 

19class TraceEvent(str, Enum): 

20 """Event types in an execution trace.""" 

21 

22 TASK_START = "task_start" 

23 TASK_END = "task_end" 

24 SUBTASK_START = "subtask_start" 

25 SUBTASK_END = "subtask_end" 

26 DECOMPOSE = "decompose" 

27 FUSE = "fuse" 

28 RETRY = "retry" 

29 FALLBACK = "fallback" 

30 GATE_CHECK = "gate_check" 

31 HITL_BREAK = "hitl_break" 

32 HITL_RESUME = "hitl_resume" 

33 SANDBOX_RUN = "sandbox_run" 

34 ERROR = "error" 

35 ABORT = "abort" 

36 

37 

38@dataclass 

39class TraceSpan: 

40 """A single span in an execution trace.""" 

41 

42 id: str = field(default_factory=lambda: uuid.uuid4().hex[:8]) 

43 parent_id: str = "" 

44 event: TraceEvent = TraceEvent.TASK_START 

45 name: str = "" 

46 status: str = "started" # started, done, failed, aborted 

47 start_ms: float = field(default_factory=lambda: time.time() * 1000) 

48 end_ms: float = 0.0 

49 duration_ms: float = 0.0 

50 data: dict[str, Any] = field(default_factory=dict) 

51 tags: list[str] = field(default_factory=list) 

52 children: list[TraceSpan] = field(default_factory=list) 

53 

54 @property 

55 def is_leaf(self) -> bool: 

56 return len(self.children) == 0 

57 

58 def to_dict(self) -> dict: 

59 return { 

60 "id": self.id, 

61 "parent_id": self.parent_id, 

62 "event": self.event.value, 

63 "name": self.name, 

64 "status": self.status, 

65 "start_ms": f"{self.start_ms:.1f}", 

66 "end_ms": f"{self.end_ms:.1f}" if self.end_ms else "-", 

67 "duration_ms": f"{self.duration_ms:.1f}", 

68 "data": {k: str(v)[:100] for k, v in self.data.items()}, 

69 "tags": self.tags, 

70 "children": [c.to_dict() for c in self.children] if self.children else [], 

71 } 

72 

73 def to_flat_list(self) -> list[dict]: 

74 """Flatten tree to list for tabular display.""" 

75 rows = [self.to_dict()] 

76 for child in self.children: 

77 rows.extend(child.to_flat_list()) 

78 return rows 

79 

80 

81@dataclass 

82class ExecutionTrace: 

83 """ 

84 Full execution trace for a single task execution. 

85 

86 Captures a tree of spans representing every step: decomposition, 

87 sub-task execution, fusion, retries, gate checks, etc. 

88 

89 Usage: 

90 trace = ExecutionTrace(task_name="research_query") 

91 

92 span = trace.start_span(TraceEvent.TASK_START, name="main") 

93 # ... do work ... 

94 trace.end_span(span.id, status="done") 

95 

96 print(trace.summary()) 

97 print(trace.to_json()) 

98 """ 

99 

100 id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) 

101 task_name: str = "" 

102 root_span: TraceSpan | None = None 

103 _span_map: dict[str, TraceSpan] = field(default_factory=dict) 

104 _current_spans: list[str] = field(default_factory=list) # stack 

105 total_spans: int = 0 

106 total_retries: int = 0 

107 total_errors: int = 0 

108 total_fallbacks: int = 0 

109 created_at: float = field(default_factory=time.time) 

110 

111 def start_span( 

112 self, 

113 event: TraceEvent, 

114 name: str = "", 

115 data: dict | None = None, 

116 tags: list[str] | None = None, 

117 ) -> TraceSpan: 

118 """Start a new span and add it to the trace tree.""" 

119 span = TraceSpan( 

120 event=event, 

121 name=name, 

122 data=data or {}, 

123 tags=tags or [], 

124 ) 

125 

126 # Determine parent 

127 if self._current_spans: 

128 span.parent_id = self._current_spans[-1] 

129 parent = self._span_map.get(span.parent_id) 

130 if parent: 

131 parent.children.append(span) 

132 elif self.root_span is None: 

133 self.root_span = span 

134 else: 

135 # Attach to root as sibling 

136 span.parent_id = self.root_span.id 

137 self.root_span.children.append(span) 

138 

139 self._span_map[span.id] = span 

140 self._current_spans.append(span.id) 

141 self.total_spans += 1 

142 

143 if event == TraceEvent.RETRY: 

144 self.total_retries += 1 

145 elif event == TraceEvent.ERROR: 

146 self.total_errors += 1 

147 elif event == TraceEvent.FALLBACK: 

148 self.total_fallbacks += 1 

149 

150 return span 

151 

152 def end_span(self, span_id: str, status: str = "done", data: dict | None = None) -> TraceSpan | None: 

153 """End a span and record its duration.""" 

154 span = self._span_map.get(span_id) 

155 if not span: 

156 return None 

157 

158 span.end_ms = time.time() * 1000 

159 span.duration_ms = span.end_ms - span.start_ms 

160 span.status = status 

161 if data: 

162 span.data.update(data) 

163 

164 # Pop from current stack 

165 if self._current_spans and self._current_spans[-1] == span_id: 

166 self._current_spans.pop() 

167 

168 return span 

169 

170 def add_event( 

171 self, 

172 event: TraceEvent, 

173 name: str = "", 

174 data: dict | None = None, 

175 tags: list[str] | None = None, 

176 duration_ms: float = 0.0, 

177 ) -> TraceSpan: 

178 """Quick-add a leaf event (start+end in one call).""" 

179 span = self.start_span(event, name, data, tags) 

180 span.end_ms = span.start_ms + duration_ms 

181 span.duration_ms = duration_ms 

182 span.status = "done" 

183 

184 if self._current_spans and self._current_spans[-1] == span.id: 

185 self._current_spans.pop() 

186 

187 return span 

188 

189 def to_dict(self) -> dict: 

190 return { 

191 "trace_id": self.id, 

192 "task_name": self.task_name, 

193 "total_spans": self.total_spans, 

194 "total_retries": self.total_retries, 

195 "total_errors": self.total_errors, 

196 "total_fallbacks": self.total_fallbacks, 

197 "created_at": self.created_at, 

198 "root": self.root_span.to_dict() if self.root_span else {}, 

199 } 

200 

201 def to_json(self, indent: int = 2) -> str: 

202 return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False) 

203 

204 def to_tree_string(self, span: TraceSpan | None = None, indent: int = 0) -> str: 

205 """Render trace as indented ASCII tree.""" 

206 if span is None: 

207 span = self.root_span 

208 if span is None: 

209 return "(empty trace)" 

210 

211 lines = [] 

212 prefix = " " * indent 

213 status_icon = {"done": "O", "failed": "X", "started": ">", "aborted": "!"}.get(span.status, "?") 

214 lines.append( 

215 f"{prefix}{status_icon} [{span.event.value}] {span.name} " 

216 f"({span.duration_ms:.0f}ms) [{span.status}]" 

217 ) 

218 for child in span.children: 

219 lines.append(self.to_tree_string(child, indent + 1)) 

220 return "\n".join(lines) 

221 

222 def summary(self) -> str: 

223 """One-line summary of trace.""" 

224 total_ms = 0.0 

225 if self.root_span and self.root_span.duration_ms: 

226 total_ms = self.root_span.duration_ms 

227 return ( 

228 f"Trace[{self.id}] '{self.task_name}' " 

229 f"{self.total_spans} spans, " 

230 f"{self.total_retries} retries, " 

231 f"{self.total_errors} errors, " 

232 f"{total_ms:.0f}ms total" 

233 ) 

234 

235 def bottlenecks(self, top_n: int = 5) -> list[dict]: 

236 """Find slowest spans — helps identify bottlenecks.""" 

237 all_spans: list[TraceSpan] = [] 

238 

239 def collect(s: TraceSpan): 

240 all_spans.append(s) 

241 for c in s.children: 

242 collect(c) 

243 

244 if self.root_span: 

245 collect(self.root_span) 

246 

247 sorted_spans = sorted(all_spans, key=lambda s: s.duration_ms, reverse=True) 

248 result = [] 

249 for s in sorted_spans[:top_n]: 

250 result.append({ 

251 "name": s.name, 

252 "event": s.event.value, 

253 "duration_ms": round(s.duration_ms, 1), 

254 "status": s.status, 

255 "tags": s.tags, 

256 }) 

257 return result 

258 

259 def errors_list(self) -> list[dict]: 

260 """List all error spans.""" 

261 errors: list[dict] = [] 

262 

263 def collect(s: TraceSpan): 

264 if s.event == TraceEvent.ERROR or s.status == "failed": 

265 errors.append({ 

266 "name": s.name, 

267 "data": {k: str(v)[:100] for k, v in s.data.items()}, 

268 "tags": s.tags, 

269 }) 

270 for c in s.children: 

271 collect(c) 

272 

273 if self.root_span: 

274 collect(self.root_span) 

275 

276 return errors 

277 

278 def timeline(self) -> list[dict]: 

279 """Generate a flat timeline of all spans sorted by start time.""" 

280 all_spans: list[TraceSpan] = [] 

281 

282 def collect(s: TraceSpan): 

283 all_spans.append(s) 

284 for c in s.children: 

285 collect(c) 

286 

287 if self.root_span: 

288 collect(self.root_span) 

289 

290 all_spans.sort(key=lambda s: s.start_ms) 

291 

292 timeline = [] 

293 for s in all_spans: 

294 timeline.append({ 

295 "time_ms": f"{s.start_ms:.1f}", 

296 "event": s.event.value, 

297 "name": s.name, 

298 "duration_ms": f"{s.duration_ms:.1f}", 

299 "status": s.status, 

300 }) 

301 return timeline 

302 

303 

304class TraceCollector: 

305 """ 

306 Collects multiple execution traces and generates aggregate reports. 

307 

308 Usage: 

309 collector = TraceCollector() 

310 trace1 = await run_task("query_a") 

311 collector.add(trace1) 

312 trace2 = await run_task("query_b") 

313 collector.add(trace2) 

314 

315 print(collector.stats()) 

316 """ 

317 

318 def __init__(self, max_traces: int = 100): 

319 self._traces: dict[str, ExecutionTrace] = {} 

320 self.max_traces = max_traces 

321 

322 def add(self, trace: ExecutionTrace) -> None: 

323 self._traces[trace.id] = trace 

324 if len(self._traces) > self.max_traces: 

325 oldest = next(iter(self._traces)) 

326 del self._traces[oldest] 

327 

328 def get(self, trace_id: str) -> ExecutionTrace | None: 

329 return self._traces.get(trace_id) 

330 

331 def stats(self) -> dict: 

332 """Aggregate statistics across all traces.""" 

333 if not self._traces: 

334 return {"count": 0} 

335 

336 total_spans = sum(t.total_spans for t in self._traces.values()) 

337 total_retries = sum(t.total_retries for t in self._traces.values()) 

338 total_errors = sum(t.total_errors for t in self._traces.values()) 

339 total_fallbacks = sum(t.total_fallbacks for t in self._traces.values()) 

340 

341 durations = [] 

342 for t in self._traces.values(): 

343 if t.root_span and t.root_span.duration_ms: 

344 durations.append(t.root_span.duration_ms) 

345 

346 return { 

347 "count": len(self._traces), 

348 "total_spans": total_spans, 

349 "total_retries": total_retries, 

350 "total_errors": total_errors, 

351 "total_fallbacks": total_fallbacks, 

352 "retry_rate": round(total_retries / max(total_spans, 1), 3), 

353 "error_rate": round(total_errors / max(total_spans, 1), 3), 

354 "avg_duration_ms": round(sum(durations) / len(durations), 1) if durations else 0, 

355 "max_duration_ms": round(max(durations), 1) if durations else 0, 

356 "min_duration_ms": round(min(durations), 1) if durations else 0, 

357 } 

358 

359 def failed_tasks(self) -> list[dict]: 

360 """List tasks that ended with errors.""" 

361 failed = [] 

362 for t in self._traces.values(): 

363 if t.root_span and t.root_span.status in ("failed", "aborted"): 

364 failed.append({ 

365 "trace_id": t.id, 

366 "task_name": t.task_name, 

367 "status": t.root_span.status, 

368 "errors": t.total_errors, 

369 }) 

370 return failed