Coverage for agentos/checkpoint/sqlite.py: 39%
59 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
1"""
2SQLite Checkpointer — 零依赖本地持久化。
4适用场景: 单机部署、开发调试、POC。
5生产多机部署请使用 PostgresCheckpointer。
6"""
8from __future__ import annotations
10import asyncio
11import json
12import os
13import sqlite3
14from datetime import datetime, timezone
15from typing import Any
17from agentos.checkpoint.base import (
18 Checkpoint,
19 CheckpointBackend,
20 CheckpointMetadata,
21)
23__all__ = ["SQLiteCheckpointer"]
25_SCHEMA = """
26CREATE TABLE IF NOT EXISTS checkpoints (
27 id INTEGER PRIMARY KEY AUTOINCREMENT,
28 thread_id TEXT NOT NULL,
29 checkpoint_id TEXT NOT NULL UNIQUE,
30 parent_id TEXT,
31 step INTEGER NOT NULL,
32 created_at TEXT NOT NULL,
33 tags TEXT NOT NULL DEFAULT '[]',
34 summary TEXT NOT NULL DEFAULT '',
35 messages_blob TEXT NOT NULL DEFAULT '[]',
36 state_blob TEXT NOT NULL DEFAULT '{}',
37 tools_blob TEXT NOT NULL DEFAULT '{}',
38 next_node TEXT NOT NULL DEFAULT ''
39);
41CREATE INDEX IF NOT EXISTS idx_thread_step ON checkpoints(thread_id, step DESC);
42CREATE INDEX IF NOT EXISTS idx_checkpoint_id ON checkpoints(checkpoint_id);
43CREATE INDEX IF NOT EXISTS idx_parent ON checkpoints(parent_id);
44"""
47class SQLiteCheckpointer(CheckpointBackend):
48 """SQLite 后端 Checkpointer。
50 用法:
51 cp = SQLiteCheckpointer(db_path="data/checkpoints.db")
52 await cp.put(checkpoint)
53 latest = await cp.get_latest("thread_abc")
54 """
56 def __init__(self, db_path: str = "checkpoints.db"):
57 self._db_path = db_path
58 os.makedirs(os.path.dirname(db_path) or ".", exist_ok=True)
59 self._init_db()
61 def _init_db(self) -> None:
62 with sqlite3.connect(self._db_path) as conn:
63 conn.executescript(_SCHEMA)
64 conn.commit()
66 def _get_conn(self) -> sqlite3.Connection:
67 conn = sqlite3.connect(self._db_path)
68 conn.row_factory = sqlite3.Row
69 return conn
71 def _row_to_metadata(self, row: sqlite3.Row) -> CheckpointMetadata:
72 return CheckpointMetadata(
73 thread_id=row["thread_id"],
74 checkpoint_id=row["checkpoint_id"],
75 parent_checkpoint_id=row["parent_id"],
76 step=row["step"],
77 created_at=row["created_at"],
78 tags=json.loads(row["tags"]),
79 summary=row["summary"],
80 )
82 def _row_to_checkpoint(self, row: sqlite3.Row) -> Checkpoint:
83 return Checkpoint(
84 metadata=self._row_to_metadata(row),
85 messages=json.loads(row["messages_blob"]),
86 state=json.loads(row["state_blob"]),
87 tools_result=json.loads(row["tools_blob"]),
88 next_node=row["next_node"],
89 )
91 async def put(self, checkpoint: Checkpoint) -> str:
92 meta = checkpoint.metadata
93 with self._get_conn() as conn:
94 conn.execute(
95 """INSERT OR REPLACE INTO checkpoints
96 (thread_id, checkpoint_id, parent_id, step, created_at, tags, summary,
97 messages_blob, state_blob, tools_blob, next_node)
98 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
99 (
100 meta.thread_id,
101 meta.checkpoint_id,
102 meta.parent_checkpoint_id,
103 meta.step,
104 meta.created_at,
105 json.dumps(meta.tags),
106 meta.summary,
107 json.dumps(checkpoint.messages, ensure_ascii=False),
108 json.dumps(checkpoint.state, ensure_ascii=False),
109 json.dumps(checkpoint.tools_result, ensure_ascii=False),
110 checkpoint.next_node,
111 ),
112 )
113 conn.commit()
114 return meta.checkpoint_id
116 async def get(self, checkpoint_id: str) -> Checkpoint | None:
117 with self._get_conn() as conn:
118 row = conn.execute(
119 "SELECT * FROM checkpoints WHERE checkpoint_id = ?", (checkpoint_id,)
120 ).fetchone()
121 return self._row_to_checkpoint(row) if row else None
123 async def get_latest(self, thread_id: str) -> Checkpoint | None:
124 with self._get_conn() as conn:
125 row = conn.execute(
126 "SELECT * FROM checkpoints WHERE thread_id = ? ORDER BY step DESC LIMIT 1",
127 (thread_id,),
128 ).fetchone()
129 return self._row_to_checkpoint(row) if row else None
131 async def list_threads(
132 self, limit: int = 50, offset: int = 0
133 ) -> list[CheckpointMetadata]:
134 with self._get_conn() as conn:
135 rows = conn.execute(
136 """SELECT * FROM checkpoints
137 WHERE checkpoint_id IN (
138 SELECT checkpoint_id FROM checkpoints
139 GROUP BY thread_id HAVING step = MAX(step)
140 )
141 ORDER BY created_at DESC LIMIT ? OFFSET ?""",
142 (limit, offset),
143 ).fetchall()
144 return [self._row_to_metadata(r) for r in rows]
146 async def list_checkpoints(
147 self, thread_id: str, limit: int = 100, offset: int = 0
148 ) -> list[CheckpointMetadata]:
149 with self._get_conn() as conn:
150 rows = conn.execute(
151 "SELECT * FROM checkpoints WHERE thread_id = ? ORDER BY step DESC LIMIT ? OFFSET ?",
152 (thread_id, limit, offset),
153 ).fetchall()
154 return [self._row_to_metadata(r) for r in rows]
156 async def delete_thread(self, thread_id: str) -> int:
157 with self._get_conn() as conn:
158 cur = conn.execute("DELETE FROM checkpoints WHERE thread_id = ?", (thread_id,))
159 conn.commit()
160 return cur.rowcount
162 async def delete_before(self, thread_id: str, before_step: int) -> int:
163 with self._get_conn() as conn:
164 cur = conn.execute(
165 "DELETE FROM checkpoints WHERE thread_id = ? AND step < ?",
166 (thread_id, before_step),
167 )
168 conn.commit()
169 return cur.rowcount