Coverage for agentos/checkpoint/base.py: 93%

42 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Checkpointer 抽象基类与数据结构。 

3""" 

4 

5from __future__ import annotations 

6 

7from abc import ABC, abstractmethod 

8from dataclasses import dataclass, field 

9from datetime import datetime, timezone 

10from typing import Any 

11 

12 

13@dataclass 

14class CheckpointMetadata: 

15 """Checkpoint 元信息。""" 

16 thread_id: str # 对话线程 ID 

17 checkpoint_id: str # 唯一 ID 

18 step: int # 步骤序号 

19 parent_checkpoint_id: str | None = None # 父 checkpoint(用于分支/回溯) 

20 created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) 

21 tags: list[str] = field(default_factory=list) # 标签 

22 summary: str = "" # 可选摘要 

23 

24 

25@dataclass 

26class Checkpoint: 

27 """单个 Checkpoint — 完整的运行时状态快照。""" 

28 metadata: CheckpointMetadata # 元信息 

29 messages: list[dict[str, Any]] # 对话消息(序列化后) 

30 state: dict[str, Any] # Agent 运行时状态 

31 tools_result: dict[str, Any] # 工具调用结果 

32 next_node: str = "" # 下一个执行节点 

33 

34 def to_dict(self) -> dict[str, Any]: 

35 return { 

36 "metadata": { 

37 "thread_id": self.metadata.thread_id, 

38 "checkpoint_id": self.metadata.checkpoint_id, 

39 "parent_checkpoint_id": self.metadata.parent_checkpoint_id, 

40 "step": self.metadata.step, 

41 "created_at": self.metadata.created_at, 

42 "tags": self.metadata.tags, 

43 "summary": self.metadata.summary, 

44 }, 

45 "messages": self.messages, 

46 "state": self.state, 

47 "tools_result": self.tools_result, 

48 "next_node": self.next_node, 

49 } 

50 

51 @classmethod 

52 def from_dict(cls, d: dict[str, Any]) -> "Checkpoint": 

53 meta = d["metadata"] 

54 return cls( 

55 metadata=CheckpointMetadata( 

56 thread_id=meta["thread_id"], 

57 checkpoint_id=meta["checkpoint_id"], 

58 parent_checkpoint_id=meta.get("parent_checkpoint_id"), 

59 step=meta["step"], 

60 created_at=meta["created_at"], 

61 tags=meta.get("tags", []), 

62 summary=meta.get("summary", ""), 

63 ), 

64 messages=d.get("messages", []), 

65 state=d.get("state", {}), 

66 tools_result=d.get("tools_result", {}), 

67 next_node=d.get("next_node", ""), 

68 ) 

69 

70 

71class CheckpointBackend(ABC): 

72 """Checkpoint 存储后端抽象基类。""" 

73 

74 @abstractmethod 

75 async def put(self, checkpoint: Checkpoint) -> str: 

76 """保存 checkpoint,返回 checkpoint_id。""" 

77 ... 

78 

79 @abstractmethod 

80 async def get(self, checkpoint_id: str) -> Checkpoint | None: 

81 """按 ID 获取 checkpoint。""" 

82 ... 

83 

84 @abstractmethod 

85 async def get_latest(self, thread_id: str) -> Checkpoint | None: 

86 """获取某线程的最新 checkpoint。""" 

87 ... 

88 

89 @abstractmethod 

90 async def list_threads( 

91 self, limit: int = 50, offset: int = 0 

92 ) -> list[CheckpointMetadata]: 

93 """列出所有线程的最新 checkpoint 元信息。""" 

94 ... 

95 

96 @abstractmethod 

97 async def list_checkpoints( 

98 self, thread_id: str, limit: int = 100, offset: int = 0 

99 ) -> list[CheckpointMetadata]: 

100 """列出某线程的所有 checkpoint(支持回溯/时间旅行)。""" 

101 ... 

102 

103 @abstractmethod 

104 async def delete_thread(self, thread_id: str) -> int: 

105 """删除某线程的所有 checkpoint,返回删除数。""" 

106 ... 

107 

108 @abstractmethod 

109 async def delete_before(self, thread_id: str, before_step: int) -> int: 

110 """删除某线程 before_step 之前的所有 checkpoint。""" 

111 ...