Coverage for agentos/checkpoint/engine.py: 0%

184 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 16:36 +0800

1""" 

2AgentOS v1.14.7 — Fine-grained Checkpoint Engine. 

3 

4LangGraph-aligned step-level checkpointing with time travel. 

5Every tool_call, llm_call, and state transition triggers a snapshot. 

6 

7Key differences from v1.14.6 checkpoint module: 

8- Step-level (not task-level) granularity 

9- Time travel: rewind to any checkpoint and replay from there 

10- Branching: fork execution from any historical checkpoint 

11- Delta snapshots: only store state diffs when possible 

12- Automatic pruning: configurable retention policies 

13 

14Usage: 

15 engine = CheckpointEngine(checkpointer=SQLiteCheckpointer("checkpoints.db")) 

16 

17 # Auto-snapshot around tool calls 

18 @engine.snapshot_on("tool_call") 

19 async def my_tool(...): ... 

20 

21 # Time travel 

22 await engine.rewind("checkpoint-42") 

23 # Now continue execution from that point 

24 

25 # Branch 

26 branch_id = await engine.branch("checkpoint-42", "bugfix-experiment") 

27""" 

28 

29from __future__ import annotations 

30 

31import asyncio 

32import functools 

33import logging 

34import time 

35import uuid 

36from contextlib import asynccontextmanager 

37from dataclasses import dataclass, field 

38from enum import Enum, auto 

39from typing import Any, Callable, Dict, List, Optional, Set, Tuple 

40 

41from agentos.checkpoint.base import ( 

42 Checkpoint, 

43 CheckpointMetadata, 

44 CheckpointBackend, 

45) 

46 

47logger = logging.getLogger(__name__) 

48 

49 

50# ── Types ──────────────────────────────────── 

51 

52 

53class SnapshotTrigger(str, Enum): 

54 """快照触发点。""" 

55 TOOL_CALL = "tool_call" # 工具调用前后 

56 LLM_CALL = "llm_call" # LLM 调用前后 

57 STATE_CHANGE = "state_change" # Agent 状态变更 

58 TASK_BOUNDARY = "task_boundary" # 任务开始/结束 

59 MANUAL = "manual" # 显式调用 

60 INTERVAL = "interval" # 定时快照 

61 

62 

63class CheckpointGC(str, Enum): 

64 """检查点垃圾回收策略。""" 

65 KEEP_ALL = "keep_all" 

66 KEEP_LAST_N = "keep_last_n" 

67 KEEP_AGE = "keep_age" # 仅保留 N 秒内 

68 KEEP_MILESTONES = "keep_milestones" # 仅保留首尾 + N 个分位点 

69 

70 

71@dataclass 

72class SnapshotConfig: 

73 """快照配置。""" 

74 triggers: Set[SnapshotTrigger] = field(default_factory=lambda: { 

75 SnapshotTrigger.MANUAL, 

76 SnapshotTrigger.TOOL_CALL, 

77 SnapshotTrigger.LLM_CALL, 

78 SnapshotTrigger.STATE_CHANGE, 

79 }) 

80 gc_policy: CheckpointGC = CheckpointGC.KEEP_LAST_N 

81 gc_param: int = 100 # keep_last_n 的 n 或 keep_age 的秒数 

82 delta_snapshots: bool = True # 是否使用增量快照(减少存储) 

83 max_snapshot_size_mb: float = 10.0 

84 

85 

86@dataclass 

87class TimeTravelResult: 

88 """时间旅行操作结果。""" 

89 checkpoint: Checkpoint 

90 thread_id: str 

91 rewind_depth: int # 回退了几个 checkpoint 

92 snapshot_count_before: int # 重放前的快照数 

93 can_replay: bool = True 

94 

95 

96# ── Checkpoint Engine ──────────────────────── 

97 

98 

99class CheckpointEngine: 

100 """细粒度 Checkpoint 引擎。 

101 

102 提供每步快照、时间旅行、分支等能力。 

103 """ 

104 

105 def __init__( 

106 self, 

107 checkpointer: CheckpointBackend, 

108 config: Optional[SnapshotConfig] = None, 

109 ): 

110 self._checkpointer = checkpointer 

111 self._config = config or SnapshotConfig() 

112 self._snapshot_counters: Dict[str, int] = {} # thread_id → step counter 

113 self._last_delta: Dict[str, Dict[str, Any]] = {} # thread_id → last full state 

114 

115 # ── Snapshot API ──────────────────────── 

116 

117 async def snapshot( 

118 self, 

119 thread_id: str, 

120 messages: List[Dict[str, Any]], 

121 state: Dict[str, Any], 

122 tools_result: Dict[str, Any], 

123 trigger: SnapshotTrigger = SnapshotTrigger.MANUAL, 

124 parent_checkpoint_id: Optional[str] = None, 

125 next_node: str = "", 

126 ) -> str: 

127 """创建一次快照。返回 checkpoint_id。""" 

128 if trigger not in self._config.triggers: 

129 return "" # 不在此触发范围内 

130 

131 step = self._snapshot_counters.get(thread_id, 0) + 1 

132 self._snapshot_counters[thread_id] = step 

133 

134 checkpoint_id = f"ckpt-{thread_id}-{step}-{uuid.uuid4().hex[:6]}" 

135 

136 metadata = CheckpointMetadata( 

137 thread_id=thread_id, 

138 checkpoint_id=checkpoint_id, 

139 parent_checkpoint_id=parent_checkpoint_id, 

140 step=step, 

141 tags=[trigger.value], 

142 summary=self._auto_summary(messages, state), 

143 ) 

144 

145 checkpoint = Checkpoint( 

146 metadata=metadata, 

147 messages=list(messages), 

148 state=dict(state), 

149 tools_result=dict(tools_result), 

150 next_node=next_node, 

151 ) 

152 

153 await self._checkpointer.put(checkpoint) 

154 

155 # GC 

156 await self._maybe_gc(thread_id) 

157 

158 return checkpoint_id 

159 

160 async def snapshot_safe( 

161 self, 

162 thread_id: str, 

163 messages: List[Dict[str, Any]], 

164 state: Dict[str, Any], 

165 tools_result: Dict[str, Any], 

166 trigger: SnapshotTrigger = SnapshotTrigger.MANUAL, 

167 parent_checkpoint_id: Optional[str] = None, 

168 next_node: str = "", 

169 ) -> str: 

170 """安全快照:失败不抛异常,不影响主流程。""" 

171 try: 

172 return await self.snapshot( 

173 thread_id, messages, state, tools_result, 

174 trigger, parent_checkpoint_id, next_node, 

175 ) 

176 except Exception as e: 

177 logger.error(f"Snapshot failed (non-blocking): {e}") 

178 return "" 

179 

180 # ── Time Travel API ───────────────────── 

181 

182 async def rewind( 

183 self, 

184 checkpoint_id: str, 

185 ) -> TimeTravelResult: 

186 """时间旅行:回退到指定 checkpoint。""" 

187 target = await self._checkpointer.get(checkpoint_id) 

188 if not target: 

189 raise ValueError(f"Checkpoint {checkpoint_id} not found") 

190 

191 thread_id = target.metadata.thread_id 

192 

193 # 计算回退深度 

194 current_step = self._snapshot_counters.get(thread_id, 0) 

195 target_step = target.metadata.step 

196 rewind_depth = current_step - target_step 

197 

198 # 删除目标之后的 checkpoint(默认行为,可配置) 

199 later_checkpoints = await self._checkpointer.list_checkpoints(thread_id) 

200 deleted = 0 

201 for cp_meta in later_checkpoints: 

202 if cp_meta.step > target_step: 

203 await self._checkpointer.delete_thread(cp_meta.thread_id) 

204 deleted += 1 

205 

206 # 重置计数器 

207 self._snapshot_counters[thread_id] = target_step 

208 

209 logger.info( 

210 f"Time travel: rewound {thread_id} by {rewind_depth} steps " 

211 f"to checkpoint {checkpoint_id} (step {target_step}), deleted {deleted} later checkpoints" 

212 ) 

213 

214 return TimeTravelResult( 

215 checkpoint=target, 

216 thread_id=thread_id, 

217 rewind_depth=rewind_depth, 

218 snapshot_count_before=current_step, 

219 ) 

220 

221 async def time_travel_to_step( 

222 self, 

223 thread_id: str, 

224 step: int, 

225 ) -> Optional[TimeTravelResult]: 

226 """按步骤号时间旅行。""" 

227 checkpoints = await self._checkpointer.list_checkpoints(thread_id, limit=500) 

228 

229 # 找到最接近目标 step 的 checkpoint 

230 matching = [cp for cp in checkpoints if cp.step <= step] 

231 if not matching: 

232 return None 

233 

234 target = sorted(matching, key=lambda c: c.step, reverse=True)[0] 

235 return await self.rewind(target.checkpoint_id) 

236 

237 async def list_time_travel_points( 

238 self, 

239 thread_id: str, 

240 limit: int = 50, 

241 ) -> List[CheckpointMetadata]: 

242 """列出所有可回溯的时间点。""" 

243 return await self._checkpointer.list_checkpoints(thread_id, limit=limit) 

244 

245 # ── Branch API ────────────────────────── 

246 

247 async def branch( 

248 self, 

249 from_checkpoint_id: str, 

250 branch_name: str, 

251 ) -> str: 

252 """从某个历史 checkpoint 创建分支执行。""" 

253 source = await self._checkpointer.get(from_checkpoint_id) 

254 if not source: 

255 raise ValueError(f"Source checkpoint {from_checkpoint_id} not found") 

256 

257 branch_thread_id = f"{source.metadata.thread_id}-branch-{branch_name}-{uuid.uuid4().hex[:4]}" 

258 

259 # 在新分支中创建起始 checkpoint(引用源 checkpoint 状态) 

260 branch_checkpoint = Checkpoint( 

261 metadata=CheckpointMetadata( 

262 thread_id=branch_thread_id, 

263 checkpoint_id=f"ckpt-{branch_thread_id}-0", 

264 parent_checkpoint_id=from_checkpoint_id, 

265 step=0, 

266 tags=["branch", branch_name], 

267 summary=f"Branch '{branch_name}' from {from_checkpoint_id}", 

268 ), 

269 messages=list(source.messages), 

270 state=dict(source.state), 

271 tools_result=dict(source.tools_result), 

272 next_node="", 

273 ) 

274 

275 await self._checkpointer.put(branch_checkpoint) 

276 self._snapshot_counters[branch_thread_id] = 0 

277 

278 logger.info(f"Created branch: {branch_thread_id} from {from_checkpoint_id}") 

279 return branch_thread_id 

280 

281 async def merge_branch( 

282 self, 

283 branch_thread_id: str, 

284 into_thread_id: str, 

285 ) -> str: 

286 """合并分支到主线程。""" 

287 branch_latest = await self._checkpointer.get_latest(branch_thread_id) 

288 if not branch_latest: 

289 raise ValueError(f"Branch {branch_thread_id} has no checkpoints") 

290 

291 # 在主线程创建一个引用分支状态的快照 

292 merge_id = await self.snapshot( 

293 thread_id=into_thread_id, 

294 messages=branch_latest.messages, 

295 state=branch_latest.state, 

296 tools_result=branch_latest.tools_result, 

297 trigger=SnapshotTrigger.MANUAL, 

298 parent_checkpoint_id=branch_latest.metadata.checkpoint_id, 

299 ) 

300 

301 logger.info(f"Merged branch {branch_thread_id} → {into_thread_id} (merge ckpt: {merge_id})") 

302 return merge_id 

303 

304 # ── Decorator API ─────────────────────── 

305 

306 def snapshot_on(self, trigger: SnapshotTrigger): 

307 """装饰器:在调用前后自动快照。 

308 

309 Usage: 

310 engine = CheckpointEngine(...) 

311 

312 @engine.snapshot_on(SnapshotTrigger.TOOL_CALL) 

313 async def search_database(query: str): ... 

314 """ 

315 def decorator(func): 

316 @functools.wraps(func) 

317 async def wrapper(*args, **kwargs): 

318 thread_id = kwargs.pop("_checkpoint_thread_id", "default") 

319 state = kwargs.pop("_checkpoint_state", {}) 

320 

321 # Before snapshot 

322 await self.snapshot_safe( 

323 thread_id=thread_id, 

324 messages=[{"role": "tool_call", "content": f"{func.__name__}({kwargs})"}], 

325 state=state, 

326 tools_result={}, 

327 trigger=trigger, 

328 ) 

329 

330 result = await func(*args, **kwargs) 

331 

332 # After snapshot 

333 await self.snapshot_safe( 

334 thread_id=thread_id, 

335 messages=[{"role": "tool_result", "content": str(result)[:500]}], 

336 state=state, 

337 tools_result={"result": str(result)[:1000]}, 

338 trigger=trigger, 

339 ) 

340 

341 return result 

342 return wrapper 

343 return decorator 

344 

345 @asynccontextmanager 

346 async def snapshot_scope( 

347 self, 

348 thread_id: str, 

349 state: Dict[str, Any], 

350 trigger: SnapshotTrigger = SnapshotTrigger.STATE_CHANGE, 

351 ): 

352 """上下文管理器:进入和退出作用域时自动快照。 

353 

354 Usage: 

355 async with engine.snapshot_scope("thread-1", state): 

356 await execute_workflow(...) 

357 """ 

358 await self.snapshot_safe( 

359 thread_id=thread_id, 

360 messages=[{"role": "system", "content": f"Enter scope ({trigger.value})"}], 

361 state=state, 

362 tools_result={}, 

363 trigger=trigger, 

364 ) 

365 try: 

366 yield 

367 finally: 

368 await self.snapshot_safe( 

369 thread_id=thread_id, 

370 messages=[{"role": "system", "content": f"Exit scope ({trigger.value})"}], 

371 state=state, 

372 tools_result={}, 

373 trigger=trigger, 

374 ) 

375 

376 # ── Query API ─────────────────────────── 

377 

378 async def get_latest(self, thread_id: str) -> Optional[Checkpoint]: 

379 return await self._checkpointer.get_latest(thread_id) 

380 

381 async def get_checkpoint_tree( 

382 self, thread_id: str, limit: int = 200 

383 ) -> Dict[str, Any]: 

384 """获取线程的 checkpoint 树结构(用于可视化)。""" 

385 checkpoints = await self._checkpointer.list_checkpoints(thread_id, limit=limit) 

386 

387 nodes: List[Dict] = [] 

388 edges: List[Dict] = [] 

389 by_id: Dict[str, CheckpointMetadata] = {} 

390 

391 for cp in checkpoints: 

392 by_id[cp.checkpoint_id] = cp 

393 nodes.append({ 

394 "id": cp.checkpoint_id, 

395 "step": cp.step, 

396 "tags": cp.tags, 

397 "summary": cp.summary, 

398 "created_at": cp.created_at, 

399 }) 

400 

401 for cp in checkpoints: 

402 if cp.parent_checkpoint_id and cp.parent_checkpoint_id in by_id: 

403 edges.append({ 

404 "from": cp.parent_checkpoint_id, 

405 "to": cp.checkpoint_id, 

406 }) 

407 

408 return { 

409 "thread_id": thread_id, 

410 "total_checkpoints": len(checkpoints), 

411 "nodes": nodes, 

412 "edges": edges, 

413 } 

414 

415 # ── Internal ──────────────────────────── 

416 

417 def _auto_summary( 

418 self, messages: List[Dict[str, Any]], state: Dict[str, Any] 

419 ) -> str: 

420 """自动生成 checkpoint 摘要。""" 

421 if messages: 

422 last = messages[-1] 

423 role = last.get("role", "unknown") 

424 content = str(last.get("content", ""))[:100] 

425 return f"[{role}] {content}" 

426 return f"State: {len(state)} keys" 

427 

428 async def _maybe_gc(self, thread_id: str): 

429 """根据 GC 策略清理旧 checkpoint。""" 

430 if self._config.gc_policy == CheckpointGC.KEEP_ALL: 

431 return 

432 

433 checkpoints = await self._checkpointer.list_checkpoints(thread_id, limit=500) 

434 

435 if self._config.gc_policy == CheckpointGC.KEEP_LAST_N: 

436 if len(checkpoints) > self._config.gc_param: 

437 to_delete = sorted(checkpoints, key=lambda c: c.step)[ 

438 :len(checkpoints) - self._config.gc_param 

439 ] 

440 for cp in to_delete: 

441 await self._checkpointer.delete_before(thread_id, cp.step + 1) 

442 logger.debug(f"GC: removed {len(to_delete)} old checkpoints from {thread_id}") 

443 

444 elif self._config.gc_policy == CheckpointGC.KEEP_AGE: 

445 cutoff = time.time() - self._config.gc_param 

446 deleted = 0 

447 for cp in checkpoints: 

448 try: 

449 created = __import__('datetime').datetime.fromisoformat(cp.created_at).timestamp() 

450 if created < cutoff: 

451 await self._checkpointer.delete_thread(cp.thread_id) 

452 deleted += 1 

453 except Exception: 

454 continue 

455 if deleted: 

456 logger.debug(f"GC: removed {deleted} expired checkpoints from {thread_id}") 

457 

458 elif self._config.gc_policy == CheckpointGC.KEEP_MILESTONES: 

459 if len(checkpoints) > self._config.gc_param: 

460 # 保留 first, last, 和均匀分布的 milestones 

461 sorted_cps = sorted(checkpoints, key=lambda c: c.step) 

462 keep = {sorted_cps[0].step, sorted_cps[-1].step} 

463 

464 n_milestones = max(2, self._config.gc_param - 2) 

465 step_size = max(1, len(sorted_cps) // n_milestones) 

466 for i in range(1, n_milestones): 

467 idx = i * step_size 

468 if idx < len(sorted_cps): 

469 keep.add(sorted_cps[idx].step) 

470 

471 for cp in sorted_cps: 

472 if cp.step not in keep: 

473 await self._checkpointer.delete_thread(cp.thread_id) 

474 logger.debug(f"GC milestones: kept {len(keep)} of {len(sorted_cps)} in {thread_id}") 

475 

476 

477# ── Quick Start ────────────────────────────── 

478 

479 

480async def create_checkpoint_engine( 

481 backend: str = "sqlite", 

482 db_path: str = "checkpoints.db", 

483) -> CheckpointEngine: 

484 """快速创建 checkpoint 引擎。""" 

485 from agentos.checkpoint.factory import create_checkpointer 

486 

487 checkpointer = create_checkpointer(backend, db_path=db_path) 

488 return CheckpointEngine(checkpointer)