Coverage for agentos/storage/base.py: 55%
29 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2AgentOS v0.20 持久化存储层。
3Base + SQLite实现,支持Checkpoint持久化。
4"""
6from __future__ import annotations
8import json
9import sqlite3
10import time
11from abc import ABC, abstractmethod
12from dataclasses import dataclass
13from typing import Any
16# ── 抽象基类 ────────────────────────────────────
18class CheckpointStore(ABC):
20 """检查点存储基类。"""
22 @abstractmethod
23 async def save(self, session_id: str, snapshot: dict): ...
24 @abstractmethod
25 async def load(self, session_id: str) -> dict | None: ...
26 @abstractmethod
27 async def delete(self, session_id: str): ...
28 @abstractmethod
29 async def list_sessions(self, limit: int = 50) -> list[str]: ...
32@dataclass
33class SqliteStore(CheckpointStore):
34 """SQLite 持久化存储。"""
36 path: str = ":memory:"
38 def __post_init__(self):
39 self._conn = sqlite3.connect(self.path, check_same_thread=False)
40 self._conn.execute(
41 """CREATE TABLE IF NOT EXISTS checkpoints (
42 session_id TEXT PRIMARY KEY,
43 snapshot TEXT NOT NULL,
44 created_at REAL NOT NULL,
45 updated_at REAL NOT NULL
46 )"""
47 )
48 self._conn.execute("CREATE INDEX IF NOT EXISTS idx_updated ON checkpoints(updated_at DESC)")
49 self._conn.commit()
51 async def save(self, session_id: str, snapshot: dict):
52 now = time.time()
53 self._conn.execute(
54 """INSERT INTO checkpoints(session_id, snapshot, created_at, updated_at)
55 VALUES(?, ?, ?, ?)
56 ON CONFLICT(session_id) DO UPDATE SET
57 snapshot=excluded.snapshot, updated_at=excluded.updated_at""",
58 (session_id, json.dumps(snapshot, default=str), now, now),
59 )
60 self._conn.commit()
62 async def load(self, session_id: str) -> dict | None:
63 row = self._conn.execute(
64 "SELECT snapshot FROM checkpoints WHERE session_id=?", (session_id,)
65 ).fetchone()
66 return json.loads(row[0]) if row else None
68 async def delete(self, session_id: str):
69 self._conn.execute("DELETE FROM checkpoints WHERE session_id=?", (session_id,))
70 self._conn.commit()
72 async def list_sessions(self, limit: int = 50) -> list[str]:
73 rows = self._conn.execute(
74 "SELECT session_id FROM checkpoints ORDER BY updated_at DESC LIMIT ?", (limit,)
75 ).fetchall()
76 return [r[0] for r in rows]