Coverage for agentos/checkpoint/postgres.py: 31%

68 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Postgres Checkpointer — 生产级持久化后端。 

3 

4需安装: pip install asyncpg 

5 

6参考 LangGraph PostgresSaver 的 schema 设计。 

7""" 

8 

9from __future__ import annotations 

10 

11import json 

12from typing import Any 

13 

14from agentos.checkpoint.base import ( 

15 Checkpoint, 

16 CheckpointBackend, 

17 CheckpointMetadata, 

18) 

19 

20__all__ = ["PostgresCheckpointer"] 

21 

22_SCHEMA = """ 

23CREATE TABLE IF NOT EXISTS checkpoints ( 

24 id BIGSERIAL PRIMARY KEY, 

25 thread_id TEXT NOT NULL, 

26 checkpoint_id TEXT NOT NULL UNIQUE, 

27 parent_id TEXT, 

28 step INTEGER NOT NULL, 

29 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), 

30 tags JSONB NOT NULL DEFAULT '[]', 

31 summary TEXT NOT NULL DEFAULT '', 

32 messages_blob JSONB NOT NULL DEFAULT '[]', 

33 state_blob JSONB NOT NULL DEFAULT '{}', 

34 tools_blob JSONB NOT NULL DEFAULT '{}', 

35 next_node TEXT NOT NULL DEFAULT '' 

36); 

37 

38CREATE INDEX IF NOT EXISTS idx_thread_step ON checkpoints(thread_id, step DESC); 

39CREATE INDEX IF NOT EXISTS idx_checkpoint_id ON checkpoints(checkpoint_id); 

40CREATE INDEX IF NOT EXISTS idx_parent ON checkpoints(parent_id); 

41CREATE INDEX IF NOT EXISTS idx_created_at ON checkpoints(created_at DESC); 

42""" 

43 

44 

45class PostgresCheckpointer(CheckpointBackend): 

46 """Postgres 后端 Checkpointer — 生产环境推荐。 

47 

48 用法: 

49 cp = PostgresCheckpointer(dsn="postgresql://user:pass@localhost:5432/agentos") 

50 await cp.put(checkpoint) 

51 latest = await cp.get_latest("thread_abc") 

52 """ 

53 

54 def __init__(self, dsn: str = "", **kwargs: Any): 

55 self._dsn = dsn or "postgresql://localhost:5432/agentos" 

56 self._kwargs = kwargs 

57 self._pool: Any = None 

58 self._initialized = False 

59 

60 async def _ensure_pool(self): 

61 if self._pool is not None: 

62 return 

63 import asyncpg 

64 self._pool = await asyncpg.create_pool(dsn=self._dsn, **self._kwargs) 

65 async with self._pool.acquire() as conn: 

66 await conn.execute(_SCHEMA) 

67 

68 async def put(self, checkpoint: Checkpoint) -> str: 

69 await self._ensure_pool() 

70 meta = checkpoint.metadata 

71 async with self._pool.acquire() as conn: 

72 await conn.execute( 

73 """INSERT INTO checkpoints 

74 (thread_id, checkpoint_id, parent_id, step, created_at, tags, summary, 

75 messages_blob, state_blob, tools_blob, next_node) 

76 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) 

77 ON CONFLICT (checkpoint_id) DO UPDATE SET 

78 step=$4, messages_blob=$8, state_blob=$9, tools_blob=$10, next_node=$11""", 

79 meta.thread_id, 

80 meta.checkpoint_id, 

81 meta.parent_checkpoint_id, 

82 meta.step, 

83 meta.created_at, 

84 json.dumps(meta.tags), 

85 meta.summary, 

86 json.dumps(checkpoint.messages, ensure_ascii=False), 

87 json.dumps(checkpoint.state, ensure_ascii=False), 

88 json.dumps(checkpoint.tools_result, ensure_ascii=False), 

89 checkpoint.next_node, 

90 ) 

91 return meta.checkpoint_id 

92 

93 async def get(self, checkpoint_id: str) -> Checkpoint | None: 

94 await self._ensure_pool() 

95 async with self._pool.acquire() as conn: 

96 row = await conn.fetchrow( 

97 "SELECT * FROM checkpoints WHERE checkpoint_id = $1", checkpoint_id 

98 ) 

99 return self._row_to_checkpoint(row) if row else None 

100 

101 async def get_latest(self, thread_id: str) -> Checkpoint | None: 

102 await self._ensure_pool() 

103 async with self._pool.acquire() as conn: 

104 row = await conn.fetchrow( 

105 "SELECT * FROM checkpoints WHERE thread_id = $1 ORDER BY step DESC LIMIT 1", 

106 thread_id, 

107 ) 

108 return self._row_to_checkpoint(row) if row else None 

109 

110 async def list_threads( 

111 self, limit: int = 50, offset: int = 0 

112 ) -> list[CheckpointMetadata]: 

113 await self._ensure_pool() 

114 async with self._pool.acquire() as conn: 

115 rows = await conn.fetch( 

116 """SELECT DISTINCT ON (thread_id) * 

117 FROM checkpoints 

118 ORDER BY thread_id, step DESC 

119 LIMIT $1 OFFSET $2""", 

120 limit, offset, 

121 ) 

122 return [self._row_to_metadata(r) for r in rows] 

123 

124 async def list_checkpoints( 

125 self, thread_id: str, limit: int = 100, offset: int = 0 

126 ) -> list[CheckpointMetadata]: 

127 await self._ensure_pool() 

128 async with self._pool.acquire() as conn: 

129 rows = await conn.fetch( 

130 "SELECT * FROM checkpoints WHERE thread_id = $1 ORDER BY step DESC LIMIT $2 OFFSET $3", 

131 thread_id, limit, offset, 

132 ) 

133 return [self._row_to_metadata(r) for r in rows] 

134 

135 async def delete_thread(self, thread_id: str) -> int: 

136 await self._ensure_pool() 

137 async with self._pool.acquire() as conn: 

138 result = await conn.execute( 

139 "DELETE FROM checkpoints WHERE thread_id = $1", thread_id 

140 ) 

141 return int(result.split()[-1]) if result else 0 

142 

143 async def delete_before(self, thread_id: str, before_step: int) -> int: 

144 await self._ensure_pool() 

145 async with self._pool.acquire() as conn: 

146 result = await conn.execute( 

147 "DELETE FROM checkpoints WHERE thread_id = $1 AND step < $2", 

148 thread_id, before_step, 

149 ) 

150 return int(result.split()[-1]) if result else 0 

151 

152 async def close(self) -> None: 

153 if self._pool: 

154 await self._pool.close() 

155 self._pool = None 

156 

157 @staticmethod 

158 def _row_to_metadata(row: Any) -> CheckpointMetadata: 

159 return CheckpointMetadata( 

160 thread_id=row["thread_id"], 

161 checkpoint_id=row["checkpoint_id"], 

162 parent_checkpoint_id=row["parent_id"], 

163 step=row["step"], 

164 created_at=str(row["created_at"]), 

165 tags=row["tags"] if isinstance(row["tags"], list) else json.loads(row["tags"]), 

166 summary=row["summary"], 

167 ) 

168 

169 @staticmethod 

170 def _row_to_checkpoint(row: Any) -> Checkpoint: 

171 messages = row["messages_blob"] if isinstance(row["messages_blob"], list) else json.loads(row["messages_blob"]) 

172 state = row["state_blob"] if isinstance(row["state_blob"], dict) else json.loads(row["state_blob"]) 

173 tools = row["tools_blob"] if isinstance(row["tools_blob"], dict) else json.loads(row["tools_blob"]) 

174 return Checkpoint( 

175 metadata=PostgresCheckpointer._row_to_metadata(row), 

176 messages=messages, 

177 state=state, 

178 tools_result=tools, 

179 next_node=row["next_node"], 

180 )