Coverage for agentos/checkpoint/engine.py: 0%
184 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
1"""
2AgentOS v1.14.7 — Fine-grained Checkpoint Engine.
4LangGraph-aligned step-level checkpointing with time travel.
5Every tool_call, llm_call, and state transition triggers a snapshot.
7Key differences from v1.14.6 checkpoint module:
8- Step-level (not task-level) granularity
9- Time travel: rewind to any checkpoint and replay from there
10- Branching: fork execution from any historical checkpoint
11- Delta snapshots: only store state diffs when possible
12- Automatic pruning: configurable retention policies
14Usage:
15 engine = CheckpointEngine(checkpointer=SQLiteCheckpointer("checkpoints.db"))
17 # Auto-snapshot around tool calls
18 @engine.snapshot_on("tool_call")
19 async def my_tool(...): ...
21 # Time travel
22 await engine.rewind("checkpoint-42")
23 # Now continue execution from that point
25 # Branch
26 branch_id = await engine.branch("checkpoint-42", "bugfix-experiment")
27"""
29from __future__ import annotations
31import asyncio
32import functools
33import logging
34import time
35import uuid
36from contextlib import asynccontextmanager
37from dataclasses import dataclass, field
38from enum import Enum, auto
39from typing import Any, Callable, Dict, List, Optional, Set, Tuple
41from agentos.checkpoint.base import (
42 Checkpoint,
43 CheckpointMetadata,
44 CheckpointBackend,
45)
47logger = logging.getLogger(__name__)
50# ── Types ────────────────────────────────────
53class SnapshotTrigger(str, Enum):
54 """快照触发点。"""
55 TOOL_CALL = "tool_call" # 工具调用前后
56 LLM_CALL = "llm_call" # LLM 调用前后
57 STATE_CHANGE = "state_change" # Agent 状态变更
58 TASK_BOUNDARY = "task_boundary" # 任务开始/结束
59 MANUAL = "manual" # 显式调用
60 INTERVAL = "interval" # 定时快照
63class CheckpointGC(str, Enum):
64 """检查点垃圾回收策略。"""
65 KEEP_ALL = "keep_all"
66 KEEP_LAST_N = "keep_last_n"
67 KEEP_AGE = "keep_age" # 仅保留 N 秒内
68 KEEP_MILESTONES = "keep_milestones" # 仅保留首尾 + N 个分位点
71@dataclass
72class SnapshotConfig:
73 """快照配置。"""
74 triggers: Set[SnapshotTrigger] = field(default_factory=lambda: {
75 SnapshotTrigger.MANUAL,
76 SnapshotTrigger.TOOL_CALL,
77 SnapshotTrigger.LLM_CALL,
78 SnapshotTrigger.STATE_CHANGE,
79 })
80 gc_policy: CheckpointGC = CheckpointGC.KEEP_LAST_N
81 gc_param: int = 100 # keep_last_n 的 n 或 keep_age 的秒数
82 delta_snapshots: bool = True # 是否使用增量快照(减少存储)
83 max_snapshot_size_mb: float = 10.0
86@dataclass
87class TimeTravelResult:
88 """时间旅行操作结果。"""
89 checkpoint: Checkpoint
90 thread_id: str
91 rewind_depth: int # 回退了几个 checkpoint
92 snapshot_count_before: int # 重放前的快照数
93 can_replay: bool = True
96# ── Checkpoint Engine ────────────────────────
99class CheckpointEngine:
100 """细粒度 Checkpoint 引擎。
102 提供每步快照、时间旅行、分支等能力。
103 """
105 def __init__(
106 self,
107 checkpointer: CheckpointBackend,
108 config: Optional[SnapshotConfig] = None,
109 ):
110 self._checkpointer = checkpointer
111 self._config = config or SnapshotConfig()
112 self._snapshot_counters: Dict[str, int] = {} # thread_id → step counter
113 self._last_delta: Dict[str, Dict[str, Any]] = {} # thread_id → last full state
115 # ── Snapshot API ────────────────────────
117 async def snapshot(
118 self,
119 thread_id: str,
120 messages: List[Dict[str, Any]],
121 state: Dict[str, Any],
122 tools_result: Dict[str, Any],
123 trigger: SnapshotTrigger = SnapshotTrigger.MANUAL,
124 parent_checkpoint_id: Optional[str] = None,
125 next_node: str = "",
126 ) -> str:
127 """创建一次快照。返回 checkpoint_id。"""
128 if trigger not in self._config.triggers:
129 return "" # 不在此触发范围内
131 step = self._snapshot_counters.get(thread_id, 0) + 1
132 self._snapshot_counters[thread_id] = step
134 checkpoint_id = f"ckpt-{thread_id}-{step}-{uuid.uuid4().hex[:6]}"
136 metadata = CheckpointMetadata(
137 thread_id=thread_id,
138 checkpoint_id=checkpoint_id,
139 parent_checkpoint_id=parent_checkpoint_id,
140 step=step,
141 tags=[trigger.value],
142 summary=self._auto_summary(messages, state),
143 )
145 checkpoint = Checkpoint(
146 metadata=metadata,
147 messages=list(messages),
148 state=dict(state),
149 tools_result=dict(tools_result),
150 next_node=next_node,
151 )
153 await self._checkpointer.put(checkpoint)
155 # GC
156 await self._maybe_gc(thread_id)
158 return checkpoint_id
160 async def snapshot_safe(
161 self,
162 thread_id: str,
163 messages: List[Dict[str, Any]],
164 state: Dict[str, Any],
165 tools_result: Dict[str, Any],
166 trigger: SnapshotTrigger = SnapshotTrigger.MANUAL,
167 parent_checkpoint_id: Optional[str] = None,
168 next_node: str = "",
169 ) -> str:
170 """安全快照:失败不抛异常,不影响主流程。"""
171 try:
172 return await self.snapshot(
173 thread_id, messages, state, tools_result,
174 trigger, parent_checkpoint_id, next_node,
175 )
176 except Exception as e:
177 logger.error(f"Snapshot failed (non-blocking): {e}")
178 return ""
180 # ── Time Travel API ─────────────────────
182 async def rewind(
183 self,
184 checkpoint_id: str,
185 ) -> TimeTravelResult:
186 """时间旅行:回退到指定 checkpoint。"""
187 target = await self._checkpointer.get(checkpoint_id)
188 if not target:
189 raise ValueError(f"Checkpoint {checkpoint_id} not found")
191 thread_id = target.metadata.thread_id
193 # 计算回退深度
194 current_step = self._snapshot_counters.get(thread_id, 0)
195 target_step = target.metadata.step
196 rewind_depth = current_step - target_step
198 # 删除目标之后的 checkpoint(默认行为,可配置)
199 later_checkpoints = await self._checkpointer.list_checkpoints(thread_id)
200 deleted = 0
201 for cp_meta in later_checkpoints:
202 if cp_meta.step > target_step:
203 await self._checkpointer.delete_thread(cp_meta.thread_id)
204 deleted += 1
206 # 重置计数器
207 self._snapshot_counters[thread_id] = target_step
209 logger.info(
210 f"Time travel: rewound {thread_id} by {rewind_depth} steps "
211 f"to checkpoint {checkpoint_id} (step {target_step}), deleted {deleted} later checkpoints"
212 )
214 return TimeTravelResult(
215 checkpoint=target,
216 thread_id=thread_id,
217 rewind_depth=rewind_depth,
218 snapshot_count_before=current_step,
219 )
221 async def time_travel_to_step(
222 self,
223 thread_id: str,
224 step: int,
225 ) -> Optional[TimeTravelResult]:
226 """按步骤号时间旅行。"""
227 checkpoints = await self._checkpointer.list_checkpoints(thread_id, limit=500)
229 # 找到最接近目标 step 的 checkpoint
230 matching = [cp for cp in checkpoints if cp.step <= step]
231 if not matching:
232 return None
234 target = sorted(matching, key=lambda c: c.step, reverse=True)[0]
235 return await self.rewind(target.checkpoint_id)
237 async def list_time_travel_points(
238 self,
239 thread_id: str,
240 limit: int = 50,
241 ) -> List[CheckpointMetadata]:
242 """列出所有可回溯的时间点。"""
243 return await self._checkpointer.list_checkpoints(thread_id, limit=limit)
245 # ── Branch API ──────────────────────────
247 async def branch(
248 self,
249 from_checkpoint_id: str,
250 branch_name: str,
251 ) -> str:
252 """从某个历史 checkpoint 创建分支执行。"""
253 source = await self._checkpointer.get(from_checkpoint_id)
254 if not source:
255 raise ValueError(f"Source checkpoint {from_checkpoint_id} not found")
257 branch_thread_id = f"{source.metadata.thread_id}-branch-{branch_name}-{uuid.uuid4().hex[:4]}"
259 # 在新分支中创建起始 checkpoint(引用源 checkpoint 状态)
260 branch_checkpoint = Checkpoint(
261 metadata=CheckpointMetadata(
262 thread_id=branch_thread_id,
263 checkpoint_id=f"ckpt-{branch_thread_id}-0",
264 parent_checkpoint_id=from_checkpoint_id,
265 step=0,
266 tags=["branch", branch_name],
267 summary=f"Branch '{branch_name}' from {from_checkpoint_id}",
268 ),
269 messages=list(source.messages),
270 state=dict(source.state),
271 tools_result=dict(source.tools_result),
272 next_node="",
273 )
275 await self._checkpointer.put(branch_checkpoint)
276 self._snapshot_counters[branch_thread_id] = 0
278 logger.info(f"Created branch: {branch_thread_id} from {from_checkpoint_id}")
279 return branch_thread_id
281 async def merge_branch(
282 self,
283 branch_thread_id: str,
284 into_thread_id: str,
285 ) -> str:
286 """合并分支到主线程。"""
287 branch_latest = await self._checkpointer.get_latest(branch_thread_id)
288 if not branch_latest:
289 raise ValueError(f"Branch {branch_thread_id} has no checkpoints")
291 # 在主线程创建一个引用分支状态的快照
292 merge_id = await self.snapshot(
293 thread_id=into_thread_id,
294 messages=branch_latest.messages,
295 state=branch_latest.state,
296 tools_result=branch_latest.tools_result,
297 trigger=SnapshotTrigger.MANUAL,
298 parent_checkpoint_id=branch_latest.metadata.checkpoint_id,
299 )
301 logger.info(f"Merged branch {branch_thread_id} → {into_thread_id} (merge ckpt: {merge_id})")
302 return merge_id
304 # ── Decorator API ───────────────────────
306 def snapshot_on(self, trigger: SnapshotTrigger):
307 """装饰器:在调用前后自动快照。
309 Usage:
310 engine = CheckpointEngine(...)
312 @engine.snapshot_on(SnapshotTrigger.TOOL_CALL)
313 async def search_database(query: str): ...
314 """
315 def decorator(func):
316 @functools.wraps(func)
317 async def wrapper(*args, **kwargs):
318 thread_id = kwargs.pop("_checkpoint_thread_id", "default")
319 state = kwargs.pop("_checkpoint_state", {})
321 # Before snapshot
322 await self.snapshot_safe(
323 thread_id=thread_id,
324 messages=[{"role": "tool_call", "content": f"{func.__name__}({kwargs})"}],
325 state=state,
326 tools_result={},
327 trigger=trigger,
328 )
330 result = await func(*args, **kwargs)
332 # After snapshot
333 await self.snapshot_safe(
334 thread_id=thread_id,
335 messages=[{"role": "tool_result", "content": str(result)[:500]}],
336 state=state,
337 tools_result={"result": str(result)[:1000]},
338 trigger=trigger,
339 )
341 return result
342 return wrapper
343 return decorator
345 @asynccontextmanager
346 async def snapshot_scope(
347 self,
348 thread_id: str,
349 state: Dict[str, Any],
350 trigger: SnapshotTrigger = SnapshotTrigger.STATE_CHANGE,
351 ):
352 """上下文管理器:进入和退出作用域时自动快照。
354 Usage:
355 async with engine.snapshot_scope("thread-1", state):
356 await execute_workflow(...)
357 """
358 await self.snapshot_safe(
359 thread_id=thread_id,
360 messages=[{"role": "system", "content": f"Enter scope ({trigger.value})"}],
361 state=state,
362 tools_result={},
363 trigger=trigger,
364 )
365 try:
366 yield
367 finally:
368 await self.snapshot_safe(
369 thread_id=thread_id,
370 messages=[{"role": "system", "content": f"Exit scope ({trigger.value})"}],
371 state=state,
372 tools_result={},
373 trigger=trigger,
374 )
376 # ── Query API ───────────────────────────
378 async def get_latest(self, thread_id: str) -> Optional[Checkpoint]:
379 return await self._checkpointer.get_latest(thread_id)
381 async def get_checkpoint_tree(
382 self, thread_id: str, limit: int = 200
383 ) -> Dict[str, Any]:
384 """获取线程的 checkpoint 树结构(用于可视化)。"""
385 checkpoints = await self._checkpointer.list_checkpoints(thread_id, limit=limit)
387 nodes: List[Dict] = []
388 edges: List[Dict] = []
389 by_id: Dict[str, CheckpointMetadata] = {}
391 for cp in checkpoints:
392 by_id[cp.checkpoint_id] = cp
393 nodes.append({
394 "id": cp.checkpoint_id,
395 "step": cp.step,
396 "tags": cp.tags,
397 "summary": cp.summary,
398 "created_at": cp.created_at,
399 })
401 for cp in checkpoints:
402 if cp.parent_checkpoint_id and cp.parent_checkpoint_id in by_id:
403 edges.append({
404 "from": cp.parent_checkpoint_id,
405 "to": cp.checkpoint_id,
406 })
408 return {
409 "thread_id": thread_id,
410 "total_checkpoints": len(checkpoints),
411 "nodes": nodes,
412 "edges": edges,
413 }
415 # ── Internal ────────────────────────────
417 def _auto_summary(
418 self, messages: List[Dict[str, Any]], state: Dict[str, Any]
419 ) -> str:
420 """自动生成 checkpoint 摘要。"""
421 if messages:
422 last = messages[-1]
423 role = last.get("role", "unknown")
424 content = str(last.get("content", ""))[:100]
425 return f"[{role}] {content}"
426 return f"State: {len(state)} keys"
428 async def _maybe_gc(self, thread_id: str):
429 """根据 GC 策略清理旧 checkpoint。"""
430 if self._config.gc_policy == CheckpointGC.KEEP_ALL:
431 return
433 checkpoints = await self._checkpointer.list_checkpoints(thread_id, limit=500)
435 if self._config.gc_policy == CheckpointGC.KEEP_LAST_N:
436 if len(checkpoints) > self._config.gc_param:
437 to_delete = sorted(checkpoints, key=lambda c: c.step)[
438 :len(checkpoints) - self._config.gc_param
439 ]
440 for cp in to_delete:
441 await self._checkpointer.delete_before(thread_id, cp.step + 1)
442 logger.debug(f"GC: removed {len(to_delete)} old checkpoints from {thread_id}")
444 elif self._config.gc_policy == CheckpointGC.KEEP_AGE:
445 cutoff = time.time() - self._config.gc_param
446 deleted = 0
447 for cp in checkpoints:
448 try:
449 created = __import__('datetime').datetime.fromisoformat(cp.created_at).timestamp()
450 if created < cutoff:
451 await self._checkpointer.delete_thread(cp.thread_id)
452 deleted += 1
453 except Exception:
454 continue
455 if deleted:
456 logger.debug(f"GC: removed {deleted} expired checkpoints from {thread_id}")
458 elif self._config.gc_policy == CheckpointGC.KEEP_MILESTONES:
459 if len(checkpoints) > self._config.gc_param:
460 # 保留 first, last, 和均匀分布的 milestones
461 sorted_cps = sorted(checkpoints, key=lambda c: c.step)
462 keep = {sorted_cps[0].step, sorted_cps[-1].step}
464 n_milestones = max(2, self._config.gc_param - 2)
465 step_size = max(1, len(sorted_cps) // n_milestones)
466 for i in range(1, n_milestones):
467 idx = i * step_size
468 if idx < len(sorted_cps):
469 keep.add(sorted_cps[idx].step)
471 for cp in sorted_cps:
472 if cp.step not in keep:
473 await self._checkpointer.delete_thread(cp.thread_id)
474 logger.debug(f"GC milestones: kept {len(keep)} of {len(sorted_cps)} in {thread_id}")
477# ── Quick Start ──────────────────────────────
480async def create_checkpoint_engine(
481 backend: str = "sqlite",
482 db_path: str = "checkpoints.db",
483) -> CheckpointEngine:
484 """快速创建 checkpoint 引擎。"""
485 from agentos.checkpoint.factory import create_checkpointer
487 checkpointer = create_checkpointer(backend, db_path=db_path)
488 return CheckpointEngine(checkpointer)