Coverage for agentos/memory/persistence.py: 19%
190 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:01 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:01 +0800
1"""
2AgentOS v1.14.9 — Memory Persistence Manager.
4Unified save/load for all 12 memory subsystems, bridging the gap between
5the existing in-memory pyramid and crash-safe disk persistence.
7All writes are atomic (write to temp file, then rename). JSON format with
8gzip compression for production efficiency; plain JSON for debug.
10Usage:
11 mgr = MemoryPersistenceManager(base_dir="~/.agentos/memory")
13 # Save everything
14 await mgr.save_all(
15 pyramid=pyramid,
16 working=working,
17 conversation=conv,
18 long_term=lterm,
19 reflection_engine=reflection,
20 consolidation_pipeline=pipeline,
21 retriever_index=retriever_index,
22 )
24 # Restore everything
25 state = await mgr.load_all()
26"""
28from __future__ import annotations
30import gzip
31import json
32import os
33import tempfile
34import time
35from dataclasses import dataclass, field
36from pathlib import Path
37from typing import Any, Optional
40# ── Snapshot Data Models ──────────────────────
43@dataclass
44class MemorySnapshot:
45 """Complete state of all memory subsystems at a point in time."""
47 version: str = "1.14.9"
48 created_at: float = field(default_factory=time.time)
49 # Per-subsystem state dicts (optional — only non-empty ones are saved)
50 pyramid_state: dict[str, Any] = field(default_factory=dict)
51 working_state: dict[str, Any] = field(default_factory=dict)
52 conversation_state: dict[str, Any] = field(default_factory=dict)
53 long_term_state: dict[str, Any] = field(default_factory=dict)
54 reflection_state: dict[str, Any] = field(default_factory=dict)
55 consolidation_state: dict[str, Any] = field(default_factory=dict)
56 retriever_index_state: dict[str, Any] = field(default_factory=dict)
58 def to_dict(self) -> dict[str, Any]:
59 result: dict[str, Any] = {
60 "version": self.version,
61 "created_at": self.created_at,
62 }
63 for field_name in [
64 "pyramid_state", "working_state", "conversation_state",
65 "long_term_state", "reflection_state", "consolidation_state",
66 "retriever_index_state",
67 ]:
68 val = getattr(self, field_name)
69 if val:
70 result[field_name] = val
71 return result
73 @classmethod
74 def from_dict(cls, d: dict[str, Any]) -> "MemorySnapshot":
75 return cls(
76 version=d.get("version", "1.14.9"),
77 created_at=d.get("created_at", time.time()),
78 pyramid_state=d.get("pyramid_state", {}),
79 working_state=d.get("working_state", {}),
80 conversation_state=d.get("conversation_state", {}),
81 long_term_state=d.get("long_term_state", {}),
82 reflection_state=d.get("reflection_state", {}),
83 consolidation_state=d.get("consolidation_state", {}),
84 retriever_index_state=d.get("retriever_index_state", {}),
85 )
88# ── Persistence Manager ──────────────────────
91class MemoryPersistenceManager:
92 """Centralized save/load manager for all memory subsystems.
94 Writes snapshots as compressed JSON files under base_dir.
95 Supports atomic writes (temp file + rename) and optional gzip compression.
96 """
98 def __init__(
99 self,
100 base_dir: str = "",
101 compress: bool = True,
102 ):
103 base = Path(base_dir) if base_dir else Path.home() / ".agentos" / "memory"
104 base.mkdir(parents=True, exist_ok=True)
105 self.base_dir: Path = base
106 self.compress = compress
107 self._snapshot_path: Path = base / ("snapshot.json.gz" if compress else "snapshot.json")
108 self._max_backups: int = 3
110 # ── Save ────────────────────────────────
112 async def save_all(
113 self,
114 pyramid: Any = None,
115 working: Any = None,
116 conversation: Any = None,
117 long_term: Any = None,
118 reflection_engine: Any = None,
119 consolidation_pipeline: Any = None,
120 retriever_index: dict[str, Any] | None = None,
121 ) -> str:
122 """Save all memory subsystems to a single snapshot file.
124 Each subsystem provides a get_state() / dump_state() method;
125 we probe for supported interfaces and extract what we can.
127 Returns the snapshot file path.
128 """
129 snapshot = MemorySnapshot()
131 if pyramid is not None:
132 try:
133 snapshot.pyramid_state = pyramid.get_state()
134 except AttributeError:
135 pass
137 if working is not None:
138 try:
139 snapshot.working_state = working.get_state()
140 except AttributeError:
141 pass
143 if conversation is not None:
144 try:
145 snapshot.conversation_state = conversation.get_state()
146 except AttributeError:
147 pass
149 if long_term is not None:
150 try:
151 snapshot.long_term_state = long_term.get_state()
152 except AttributeError:
153 pass
155 if reflection_engine is not None:
156 try:
157 snapshot.reflection_state = reflection_engine.get_state()
158 except AttributeError:
159 pass
161 if consolidation_pipeline is not None:
162 try:
163 snapshot.consolidation_state = consolidation_pipeline.get_state()
164 except AttributeError:
165 pass
167 if retriever_index is not None:
168 snapshot.retriever_index_state = retriever_index
170 return self._atomic_write(snapshot)
172 def save_sync(
173 self,
174 pyramid: Any = None,
175 working: Any = None,
176 conversation: Any = None,
177 long_term: Any = None,
178 reflection_engine: Any = None,
179 consolidation_pipeline: Any = None,
180 retriever_index: dict[str, Any] | None = None,
181 ) -> str:
182 """Synchronous save — for use in signal handlers / atexit hooks."""
183 snapshot = MemorySnapshot()
185 for obj, attr in [
186 (pyramid, "pyramid_state"),
187 (working, "working_state"),
188 (conversation, "conversation_state"),
189 (long_term, "long_term_state"),
190 (reflection_engine, "reflection_state"),
191 (consolidation_pipeline, "consolidation_state"),
192 ]:
193 if obj is not None:
194 try:
195 setattr(snapshot, attr, obj.get_state())
196 except AttributeError:
197 pass
199 if retriever_index is not None:
200 snapshot.retriever_index_state = retriever_index
202 return self._atomic_write(snapshot)
204 # ── Load ────────────────────────────────
206 async def load_all(self) -> MemorySnapshot:
207 """Load the latest memory snapshot from disk.
209 Returns a MemorySnapshot; empty fields mean no saved state for that subsystem.
210 """
211 if not self._snapshot_path.exists():
212 return MemorySnapshot()
214 data = self._read_snapshot_file()
215 if data is None:
216 return MemorySnapshot()
218 return MemorySnapshot.from_dict(data)
220 def load_sync(self) -> MemorySnapshot:
221 """Synchronous load."""
222 if not self._snapshot_path.exists():
223 return MemorySnapshot()
225 data = self._read_snapshot_file()
226 if data is None:
227 return MemorySnapshot()
229 return MemorySnapshot.from_dict(data)
231 async def restore_all(
232 self,
233 pyramid: Any = None,
234 working: Any = None,
235 conversation: Any = None,
236 long_term: Any = None,
237 reflection_engine: Any = None,
238 consolidation_pipeline: Any = None,
239 retriever_index_target: dict[str, Any] | None = None,
240 ) -> int:
241 """Load snapshot from disk and restore into live objects.
243 Each target object must have a restore_state(state_dict) method.
244 Returns count of subsystems restored.
245 """
246 snapshot = await self.load_all()
247 restored = 0
249 for obj, state_attr in [
250 (pyramid, "pyramid_state"),
251 (working, "working_state"),
252 (conversation, "conversation_state"),
253 (long_term, "long_term_state"),
254 (reflection_engine, "reflection_state"),
255 (consolidation_pipeline, "consolidation_state"),
256 ]:
257 state = getattr(snapshot, state_attr, {})
258 if obj is not None and state:
259 try:
260 obj.restore_state(state)
261 restored += 1
262 except AttributeError:
263 pass
265 if retriever_index_target is not None and snapshot.retriever_index_state:
266 retriever_index_target.clear()
267 retriever_index_target.update(snapshot.retriever_index_state)
268 restored += 1
270 return restored
272 # ── Atomic write ────────────────────────
274 def _atomic_write(self, snapshot: MemorySnapshot) -> str:
275 """Write snapshot atomically: temp file → rename."""
276 data = snapshot.to_dict()
277 json_bytes = json.dumps(data, ensure_ascii=False, indent=2, default=str).encode("utf-8")
279 if self.compress:
280 json_bytes = gzip.compress(json_bytes, compresslevel=6)
282 # Write to temp file, then rename
283 fd, tmp_path = tempfile.mkstemp(
284 dir=str(self.base_dir),
285 prefix=".snapshot-tmp-",
286 suffix=".json.gz" if self.compress else ".json",
287 )
288 try:
289 with os.fdopen(fd, "wb") as f:
290 f.write(json_bytes)
292 # Rotate old backups
293 self._rotate_backups()
295 os.replace(tmp_path, str(self._snapshot_path))
296 except Exception:
297 try:
298 os.unlink(tmp_path)
299 except OSError:
300 pass
301 raise
303 return str(self._snapshot_path)
305 # ── Read snapshot ────────────────────────
307 def _read_snapshot_file(self) -> dict[str, Any] | None:
308 """Read and parse snapshot file. Returns None on failure."""
309 try:
310 with open(self._snapshot_path, "rb") as f:
311 raw = f.read()
313 if self.compress:
314 raw = gzip.decompress(raw)
316 return json.loads(raw.decode("utf-8"))
317 except (OSError, json.JSONDecodeError, gzip.BadGzipFile):
318 return None
320 # ── Backup rotation ──────────────────────
322 def _rotate_backups(self) -> None:
323 """Rotate old snapshot backups, keeping self._max_backups."""
324 for i in range(self._max_backups - 1, 0, -1):
325 old_path = self.base_dir / f"snapshot.{i}.json.gz"
326 new_path = self.base_dir / f"snapshot.{i + 1}.json.gz"
327 if old_path.exists():
328 try:
329 os.replace(str(old_path), str(new_path))
330 except OSError:
331 pass
333 # Rotate current into .1
334 if self._snapshot_path.exists():
335 backup_path = self.base_dir / "snapshot.1.json.gz"
336 try:
337 os.replace(str(self._snapshot_path), str(backup_path))
338 except OSError:
339 pass
341 # ── Query ────────────────────────────────
343 def snapshot_info(self) -> dict[str, Any]:
344 """Return metadata about the current snapshot."""
345 if not self._snapshot_path.exists():
346 return {"exists": False}
348 try:
349 stat = self._snapshot_path.stat()
350 snapshot = self.load_sync()
352 subsystems_saved = sum(
353 1 for v in [
354 snapshot.pyramid_state,
355 snapshot.working_state,
356 snapshot.conversation_state,
357 snapshot.long_term_state,
358 snapshot.reflection_state,
359 snapshot.consolidation_state,
360 snapshot.retriever_index_state,
361 ] if v
362 )
364 return {
365 "exists": True,
366 "path": str(self._snapshot_path),
367 "size_bytes": stat.st_size,
368 "created_at": snapshot.created_at,
369 "version": snapshot.version,
370 "subsystems_saved": subsystems_saved,
371 "compressed": self.compress,
372 }
373 except Exception:
374 return {"exists": True, "error": "unreadable"}
376 def delete_snapshot(self) -> bool:
377 """Delete the current snapshot file(s)."""
378 deleted = False
379 for path in self.base_dir.glob("snapshot*.json.gz"):
380 try:
381 path.unlink()
382 deleted = True
383 except OSError:
384 pass
385 for path in self.base_dir.glob("snapshot*.json"):
386 try:
387 path.unlink()
388 deleted = True
389 except OSError:
390 pass
391 return deleted
393 def list_backups(self) -> list[dict[str, Any]]:
394 """List all available snapshot backups."""
395 results = []
396 for path in sorted(self.base_dir.glob("snapshot*.json*")):
397 try:
398 stat = path.stat()
399 results.append({
400 "name": path.name,
401 "size_bytes": stat.st_size,
402 "mtime": stat.st_mtime,
403 })
404 except OSError:
405 continue
406 return results