Coverage for agentos/checkpoint/postgres.py: 31%
68 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Postgres Checkpointer — 生产级持久化后端。
4需安装: pip install asyncpg
6参考 LangGraph PostgresSaver 的 schema 设计。
7"""
9from __future__ import annotations
11import json
12from typing import Any
14from agentos.checkpoint.base import (
15 Checkpoint,
16 CheckpointBackend,
17 CheckpointMetadata,
18)
20__all__ = ["PostgresCheckpointer"]
22_SCHEMA = """
23CREATE TABLE IF NOT EXISTS checkpoints (
24 id BIGSERIAL PRIMARY KEY,
25 thread_id TEXT NOT NULL,
26 checkpoint_id TEXT NOT NULL UNIQUE,
27 parent_id TEXT,
28 step INTEGER NOT NULL,
29 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
30 tags JSONB NOT NULL DEFAULT '[]',
31 summary TEXT NOT NULL DEFAULT '',
32 messages_blob JSONB NOT NULL DEFAULT '[]',
33 state_blob JSONB NOT NULL DEFAULT '{}',
34 tools_blob JSONB NOT NULL DEFAULT '{}',
35 next_node TEXT NOT NULL DEFAULT ''
36);
38CREATE INDEX IF NOT EXISTS idx_thread_step ON checkpoints(thread_id, step DESC);
39CREATE INDEX IF NOT EXISTS idx_checkpoint_id ON checkpoints(checkpoint_id);
40CREATE INDEX IF NOT EXISTS idx_parent ON checkpoints(parent_id);
41CREATE INDEX IF NOT EXISTS idx_created_at ON checkpoints(created_at DESC);
42"""
45class PostgresCheckpointer(CheckpointBackend):
46 """Postgres 后端 Checkpointer — 生产环境推荐。
48 用法:
49 cp = PostgresCheckpointer(dsn="postgresql://user:pass@localhost:5432/agentos")
50 await cp.put(checkpoint)
51 latest = await cp.get_latest("thread_abc")
52 """
54 def __init__(self, dsn: str = "", **kwargs: Any):
55 self._dsn = dsn or "postgresql://localhost:5432/agentos"
56 self._kwargs = kwargs
57 self._pool: Any = None
58 self._initialized = False
60 async def _ensure_pool(self):
61 if self._pool is not None:
62 return
63 import asyncpg
64 self._pool = await asyncpg.create_pool(dsn=self._dsn, **self._kwargs)
65 async with self._pool.acquire() as conn:
66 await conn.execute(_SCHEMA)
68 async def put(self, checkpoint: Checkpoint) -> str:
69 await self._ensure_pool()
70 meta = checkpoint.metadata
71 async with self._pool.acquire() as conn:
72 await conn.execute(
73 """INSERT INTO checkpoints
74 (thread_id, checkpoint_id, parent_id, step, created_at, tags, summary,
75 messages_blob, state_blob, tools_blob, next_node)
76 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
77 ON CONFLICT (checkpoint_id) DO UPDATE SET
78 step=$4, messages_blob=$8, state_blob=$9, tools_blob=$10, next_node=$11""",
79 meta.thread_id,
80 meta.checkpoint_id,
81 meta.parent_checkpoint_id,
82 meta.step,
83 meta.created_at,
84 json.dumps(meta.tags),
85 meta.summary,
86 json.dumps(checkpoint.messages, ensure_ascii=False),
87 json.dumps(checkpoint.state, ensure_ascii=False),
88 json.dumps(checkpoint.tools_result, ensure_ascii=False),
89 checkpoint.next_node,
90 )
91 return meta.checkpoint_id
93 async def get(self, checkpoint_id: str) -> Checkpoint | None:
94 await self._ensure_pool()
95 async with self._pool.acquire() as conn:
96 row = await conn.fetchrow(
97 "SELECT * FROM checkpoints WHERE checkpoint_id = $1", checkpoint_id
98 )
99 return self._row_to_checkpoint(row) if row else None
101 async def get_latest(self, thread_id: str) -> Checkpoint | None:
102 await self._ensure_pool()
103 async with self._pool.acquire() as conn:
104 row = await conn.fetchrow(
105 "SELECT * FROM checkpoints WHERE thread_id = $1 ORDER BY step DESC LIMIT 1",
106 thread_id,
107 )
108 return self._row_to_checkpoint(row) if row else None
110 async def list_threads(
111 self, limit: int = 50, offset: int = 0
112 ) -> list[CheckpointMetadata]:
113 await self._ensure_pool()
114 async with self._pool.acquire() as conn:
115 rows = await conn.fetch(
116 """SELECT DISTINCT ON (thread_id) *
117 FROM checkpoints
118 ORDER BY thread_id, step DESC
119 LIMIT $1 OFFSET $2""",
120 limit, offset,
121 )
122 return [self._row_to_metadata(r) for r in rows]
124 async def list_checkpoints(
125 self, thread_id: str, limit: int = 100, offset: int = 0
126 ) -> list[CheckpointMetadata]:
127 await self._ensure_pool()
128 async with self._pool.acquire() as conn:
129 rows = await conn.fetch(
130 "SELECT * FROM checkpoints WHERE thread_id = $1 ORDER BY step DESC LIMIT $2 OFFSET $3",
131 thread_id, limit, offset,
132 )
133 return [self._row_to_metadata(r) for r in rows]
135 async def delete_thread(self, thread_id: str) -> int:
136 await self._ensure_pool()
137 async with self._pool.acquire() as conn:
138 result = await conn.execute(
139 "DELETE FROM checkpoints WHERE thread_id = $1", thread_id
140 )
141 return int(result.split()[-1]) if result else 0
143 async def delete_before(self, thread_id: str, before_step: int) -> int:
144 await self._ensure_pool()
145 async with self._pool.acquire() as conn:
146 result = await conn.execute(
147 "DELETE FROM checkpoints WHERE thread_id = $1 AND step < $2",
148 thread_id, before_step,
149 )
150 return int(result.split()[-1]) if result else 0
152 async def close(self) -> None:
153 if self._pool:
154 await self._pool.close()
155 self._pool = None
157 @staticmethod
158 def _row_to_metadata(row: Any) -> CheckpointMetadata:
159 return CheckpointMetadata(
160 thread_id=row["thread_id"],
161 checkpoint_id=row["checkpoint_id"],
162 parent_checkpoint_id=row["parent_id"],
163 step=row["step"],
164 created_at=str(row["created_at"]),
165 tags=row["tags"] if isinstance(row["tags"], list) else json.loads(row["tags"]),
166 summary=row["summary"],
167 )
169 @staticmethod
170 def _row_to_checkpoint(row: Any) -> Checkpoint:
171 messages = row["messages_blob"] if isinstance(row["messages_blob"], list) else json.loads(row["messages_blob"])
172 state = row["state_blob"] if isinstance(row["state_blob"], dict) else json.loads(row["state_blob"])
173 tools = row["tools_blob"] if isinstance(row["tools_blob"], dict) else json.loads(row["tools_blob"])
174 return Checkpoint(
175 metadata=PostgresCheckpointer._row_to_metadata(row),
176 messages=messages,
177 state=state,
178 tools_result=tools,
179 next_node=row["next_node"],
180 )