Coverage for agentos/checkpoint/sqlite.py: 39%

59 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 16:36 +0800

1""" 

2SQLite Checkpointer — 零依赖本地持久化。 

3 

4适用场景: 单机部署、开发调试、POC。 

5生产多机部署请使用 PostgresCheckpointer。 

6""" 

7 

8from __future__ import annotations 

9 

10import asyncio 

11import json 

12import os 

13import sqlite3 

14from datetime import datetime, timezone 

15from typing import Any 

16 

17from agentos.checkpoint.base import ( 

18 Checkpoint, 

19 CheckpointBackend, 

20 CheckpointMetadata, 

21) 

22 

23__all__ = ["SQLiteCheckpointer"] 

24 

25_SCHEMA = """ 

26CREATE TABLE IF NOT EXISTS checkpoints ( 

27 id INTEGER PRIMARY KEY AUTOINCREMENT, 

28 thread_id TEXT NOT NULL, 

29 checkpoint_id TEXT NOT NULL UNIQUE, 

30 parent_id TEXT, 

31 step INTEGER NOT NULL, 

32 created_at TEXT NOT NULL, 

33 tags TEXT NOT NULL DEFAULT '[]', 

34 summary TEXT NOT NULL DEFAULT '', 

35 messages_blob TEXT NOT NULL DEFAULT '[]', 

36 state_blob TEXT NOT NULL DEFAULT '{}', 

37 tools_blob TEXT NOT NULL DEFAULT '{}', 

38 next_node TEXT NOT NULL DEFAULT '' 

39); 

40 

41CREATE INDEX IF NOT EXISTS idx_thread_step ON checkpoints(thread_id, step DESC); 

42CREATE INDEX IF NOT EXISTS idx_checkpoint_id ON checkpoints(checkpoint_id); 

43CREATE INDEX IF NOT EXISTS idx_parent ON checkpoints(parent_id); 

44""" 

45 

46 

47class SQLiteCheckpointer(CheckpointBackend): 

48 """SQLite 后端 Checkpointer。 

49 

50 用法: 

51 cp = SQLiteCheckpointer(db_path="data/checkpoints.db") 

52 await cp.put(checkpoint) 

53 latest = await cp.get_latest("thread_abc") 

54 """ 

55 

56 def __init__(self, db_path: str = "checkpoints.db"): 

57 self._db_path = db_path 

58 os.makedirs(os.path.dirname(db_path) or ".", exist_ok=True) 

59 self._init_db() 

60 

61 def _init_db(self) -> None: 

62 with sqlite3.connect(self._db_path) as conn: 

63 conn.executescript(_SCHEMA) 

64 conn.commit() 

65 

66 def _get_conn(self) -> sqlite3.Connection: 

67 conn = sqlite3.connect(self._db_path) 

68 conn.row_factory = sqlite3.Row 

69 return conn 

70 

71 def _row_to_metadata(self, row: sqlite3.Row) -> CheckpointMetadata: 

72 return CheckpointMetadata( 

73 thread_id=row["thread_id"], 

74 checkpoint_id=row["checkpoint_id"], 

75 parent_checkpoint_id=row["parent_id"], 

76 step=row["step"], 

77 created_at=row["created_at"], 

78 tags=json.loads(row["tags"]), 

79 summary=row["summary"], 

80 ) 

81 

82 def _row_to_checkpoint(self, row: sqlite3.Row) -> Checkpoint: 

83 return Checkpoint( 

84 metadata=self._row_to_metadata(row), 

85 messages=json.loads(row["messages_blob"]), 

86 state=json.loads(row["state_blob"]), 

87 tools_result=json.loads(row["tools_blob"]), 

88 next_node=row["next_node"], 

89 ) 

90 

91 async def put(self, checkpoint: Checkpoint) -> str: 

92 meta = checkpoint.metadata 

93 with self._get_conn() as conn: 

94 conn.execute( 

95 """INSERT OR REPLACE INTO checkpoints 

96 (thread_id, checkpoint_id, parent_id, step, created_at, tags, summary, 

97 messages_blob, state_blob, tools_blob, next_node) 

98 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", 

99 ( 

100 meta.thread_id, 

101 meta.checkpoint_id, 

102 meta.parent_checkpoint_id, 

103 meta.step, 

104 meta.created_at, 

105 json.dumps(meta.tags), 

106 meta.summary, 

107 json.dumps(checkpoint.messages, ensure_ascii=False), 

108 json.dumps(checkpoint.state, ensure_ascii=False), 

109 json.dumps(checkpoint.tools_result, ensure_ascii=False), 

110 checkpoint.next_node, 

111 ), 

112 ) 

113 conn.commit() 

114 return meta.checkpoint_id 

115 

116 async def get(self, checkpoint_id: str) -> Checkpoint | None: 

117 with self._get_conn() as conn: 

118 row = conn.execute( 

119 "SELECT * FROM checkpoints WHERE checkpoint_id = ?", (checkpoint_id,) 

120 ).fetchone() 

121 return self._row_to_checkpoint(row) if row else None 

122 

123 async def get_latest(self, thread_id: str) -> Checkpoint | None: 

124 with self._get_conn() as conn: 

125 row = conn.execute( 

126 "SELECT * FROM checkpoints WHERE thread_id = ? ORDER BY step DESC LIMIT 1", 

127 (thread_id,), 

128 ).fetchone() 

129 return self._row_to_checkpoint(row) if row else None 

130 

131 async def list_threads( 

132 self, limit: int = 50, offset: int = 0 

133 ) -> list[CheckpointMetadata]: 

134 with self._get_conn() as conn: 

135 rows = conn.execute( 

136 """SELECT * FROM checkpoints 

137 WHERE checkpoint_id IN ( 

138 SELECT checkpoint_id FROM checkpoints 

139 GROUP BY thread_id HAVING step = MAX(step) 

140 ) 

141 ORDER BY created_at DESC LIMIT ? OFFSET ?""", 

142 (limit, offset), 

143 ).fetchall() 

144 return [self._row_to_metadata(r) for r in rows] 

145 

146 async def list_checkpoints( 

147 self, thread_id: str, limit: int = 100, offset: int = 0 

148 ) -> list[CheckpointMetadata]: 

149 with self._get_conn() as conn: 

150 rows = conn.execute( 

151 "SELECT * FROM checkpoints WHERE thread_id = ? ORDER BY step DESC LIMIT ? OFFSET ?", 

152 (thread_id, limit, offset), 

153 ).fetchall() 

154 return [self._row_to_metadata(r) for r in rows] 

155 

156 async def delete_thread(self, thread_id: str) -> int: 

157 with self._get_conn() as conn: 

158 cur = conn.execute("DELETE FROM checkpoints WHERE thread_id = ?", (thread_id,)) 

159 conn.commit() 

160 return cur.rowcount 

161 

162 async def delete_before(self, thread_id: str, before_step: int) -> int: 

163 with self._get_conn() as conn: 

164 cur = conn.execute( 

165 "DELETE FROM checkpoints WHERE thread_id = ? AND step < ?", 

166 (thread_id, before_step), 

167 ) 

168 conn.commit() 

169 return cur.rowcount