Coverage for agentos/storage/base.py: 55%

29 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AgentOS v0.20 持久化存储层。 

3Base + SQLite实现,支持Checkpoint持久化。 

4""" 

5 

6from __future__ import annotations 

7 

8import json 

9import sqlite3 

10import time 

11from abc import ABC, abstractmethod 

12from dataclasses import dataclass 

13from typing import Any 

14 

15 

16# ── 抽象基类 ──────────────────────────────────── 

17 

18class CheckpointStore(ABC): 

19 

20 """检查点存储基类。""" 

21 

22 @abstractmethod 

23 async def save(self, session_id: str, snapshot: dict): ... 

24 @abstractmethod 

25 async def load(self, session_id: str) -> dict | None: ... 

26 @abstractmethod 

27 async def delete(self, session_id: str): ... 

28 @abstractmethod 

29 async def list_sessions(self, limit: int = 50) -> list[str]: ... 

30 

31 

32@dataclass 

33class SqliteStore(CheckpointStore): 

34 """SQLite 持久化存储。""" 

35 

36 path: str = ":memory:" 

37 

38 def __post_init__(self): 

39 self._conn = sqlite3.connect(self.path, check_same_thread=False) 

40 self._conn.execute( 

41 """CREATE TABLE IF NOT EXISTS checkpoints ( 

42 session_id TEXT PRIMARY KEY, 

43 snapshot TEXT NOT NULL, 

44 created_at REAL NOT NULL, 

45 updated_at REAL NOT NULL 

46 )""" 

47 ) 

48 self._conn.execute("CREATE INDEX IF NOT EXISTS idx_updated ON checkpoints(updated_at DESC)") 

49 self._conn.commit() 

50 

51 async def save(self, session_id: str, snapshot: dict): 

52 now = time.time() 

53 self._conn.execute( 

54 """INSERT INTO checkpoints(session_id, snapshot, created_at, updated_at) 

55 VALUES(?, ?, ?, ?) 

56 ON CONFLICT(session_id) DO UPDATE SET 

57 snapshot=excluded.snapshot, updated_at=excluded.updated_at""", 

58 (session_id, json.dumps(snapshot, default=str), now, now), 

59 ) 

60 self._conn.commit() 

61 

62 async def load(self, session_id: str) -> dict | None: 

63 row = self._conn.execute( 

64 "SELECT snapshot FROM checkpoints WHERE session_id=?", (session_id,) 

65 ).fetchone() 

66 return json.loads(row[0]) if row else None 

67 

68 async def delete(self, session_id: str): 

69 self._conn.execute("DELETE FROM checkpoints WHERE session_id=?", (session_id,)) 

70 self._conn.commit() 

71 

72 async def list_sessions(self, limit: int = 50) -> list[str]: 

73 rows = self._conn.execute( 

74 "SELECT session_id FROM checkpoints ORDER BY updated_at DESC LIMIT ?", (limit,) 

75 ).fetchall() 

76 return [r[0] for r in rows]