Coverage for agentos/memory/persistence.py: 19%

190 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 16:01 +0800

1""" 

2AgentOS v1.14.9 — Memory Persistence Manager. 

3 

4Unified save/load for all 12 memory subsystems, bridging the gap between 

5the existing in-memory pyramid and crash-safe disk persistence. 

6 

7All writes are atomic (write to temp file, then rename). JSON format with 

8gzip compression for production efficiency; plain JSON for debug. 

9 

10Usage: 

11 mgr = MemoryPersistenceManager(base_dir="~/.agentos/memory") 

12 

13 # Save everything 

14 await mgr.save_all( 

15 pyramid=pyramid, 

16 working=working, 

17 conversation=conv, 

18 long_term=lterm, 

19 reflection_engine=reflection, 

20 consolidation_pipeline=pipeline, 

21 retriever_index=retriever_index, 

22 ) 

23 

24 # Restore everything 

25 state = await mgr.load_all() 

26""" 

27 

28from __future__ import annotations 

29 

30import gzip 

31import json 

32import os 

33import tempfile 

34import time 

35from dataclasses import dataclass, field 

36from pathlib import Path 

37from typing import Any, Optional 

38 

39 

40# ── Snapshot Data Models ────────────────────── 

41 

42 

43@dataclass 

44class MemorySnapshot: 

45 """Complete state of all memory subsystems at a point in time.""" 

46 

47 version: str = "1.14.9" 

48 created_at: float = field(default_factory=time.time) 

49 # Per-subsystem state dicts (optional — only non-empty ones are saved) 

50 pyramid_state: dict[str, Any] = field(default_factory=dict) 

51 working_state: dict[str, Any] = field(default_factory=dict) 

52 conversation_state: dict[str, Any] = field(default_factory=dict) 

53 long_term_state: dict[str, Any] = field(default_factory=dict) 

54 reflection_state: dict[str, Any] = field(default_factory=dict) 

55 consolidation_state: dict[str, Any] = field(default_factory=dict) 

56 retriever_index_state: dict[str, Any] = field(default_factory=dict) 

57 

58 def to_dict(self) -> dict[str, Any]: 

59 result: dict[str, Any] = { 

60 "version": self.version, 

61 "created_at": self.created_at, 

62 } 

63 for field_name in [ 

64 "pyramid_state", "working_state", "conversation_state", 

65 "long_term_state", "reflection_state", "consolidation_state", 

66 "retriever_index_state", 

67 ]: 

68 val = getattr(self, field_name) 

69 if val: 

70 result[field_name] = val 

71 return result 

72 

73 @classmethod 

74 def from_dict(cls, d: dict[str, Any]) -> "MemorySnapshot": 

75 return cls( 

76 version=d.get("version", "1.14.9"), 

77 created_at=d.get("created_at", time.time()), 

78 pyramid_state=d.get("pyramid_state", {}), 

79 working_state=d.get("working_state", {}), 

80 conversation_state=d.get("conversation_state", {}), 

81 long_term_state=d.get("long_term_state", {}), 

82 reflection_state=d.get("reflection_state", {}), 

83 consolidation_state=d.get("consolidation_state", {}), 

84 retriever_index_state=d.get("retriever_index_state", {}), 

85 ) 

86 

87 

88# ── Persistence Manager ────────────────────── 

89 

90 

91class MemoryPersistenceManager: 

92 """Centralized save/load manager for all memory subsystems. 

93 

94 Writes snapshots as compressed JSON files under base_dir. 

95 Supports atomic writes (temp file + rename) and optional gzip compression. 

96 """ 

97 

98 def __init__( 

99 self, 

100 base_dir: str = "", 

101 compress: bool = True, 

102 ): 

103 base = Path(base_dir) if base_dir else Path.home() / ".agentos" / "memory" 

104 base.mkdir(parents=True, exist_ok=True) 

105 self.base_dir: Path = base 

106 self.compress = compress 

107 self._snapshot_path: Path = base / ("snapshot.json.gz" if compress else "snapshot.json") 

108 self._max_backups: int = 3 

109 

110 # ── Save ──────────────────────────────── 

111 

112 async def save_all( 

113 self, 

114 pyramid: Any = None, 

115 working: Any = None, 

116 conversation: Any = None, 

117 long_term: Any = None, 

118 reflection_engine: Any = None, 

119 consolidation_pipeline: Any = None, 

120 retriever_index: dict[str, Any] | None = None, 

121 ) -> str: 

122 """Save all memory subsystems to a single snapshot file. 

123 

124 Each subsystem provides a get_state() / dump_state() method; 

125 we probe for supported interfaces and extract what we can. 

126 

127 Returns the snapshot file path. 

128 """ 

129 snapshot = MemorySnapshot() 

130 

131 if pyramid is not None: 

132 try: 

133 snapshot.pyramid_state = pyramid.get_state() 

134 except AttributeError: 

135 pass 

136 

137 if working is not None: 

138 try: 

139 snapshot.working_state = working.get_state() 

140 except AttributeError: 

141 pass 

142 

143 if conversation is not None: 

144 try: 

145 snapshot.conversation_state = conversation.get_state() 

146 except AttributeError: 

147 pass 

148 

149 if long_term is not None: 

150 try: 

151 snapshot.long_term_state = long_term.get_state() 

152 except AttributeError: 

153 pass 

154 

155 if reflection_engine is not None: 

156 try: 

157 snapshot.reflection_state = reflection_engine.get_state() 

158 except AttributeError: 

159 pass 

160 

161 if consolidation_pipeline is not None: 

162 try: 

163 snapshot.consolidation_state = consolidation_pipeline.get_state() 

164 except AttributeError: 

165 pass 

166 

167 if retriever_index is not None: 

168 snapshot.retriever_index_state = retriever_index 

169 

170 return self._atomic_write(snapshot) 

171 

172 def save_sync( 

173 self, 

174 pyramid: Any = None, 

175 working: Any = None, 

176 conversation: Any = None, 

177 long_term: Any = None, 

178 reflection_engine: Any = None, 

179 consolidation_pipeline: Any = None, 

180 retriever_index: dict[str, Any] | None = None, 

181 ) -> str: 

182 """Synchronous save — for use in signal handlers / atexit hooks.""" 

183 snapshot = MemorySnapshot() 

184 

185 for obj, attr in [ 

186 (pyramid, "pyramid_state"), 

187 (working, "working_state"), 

188 (conversation, "conversation_state"), 

189 (long_term, "long_term_state"), 

190 (reflection_engine, "reflection_state"), 

191 (consolidation_pipeline, "consolidation_state"), 

192 ]: 

193 if obj is not None: 

194 try: 

195 setattr(snapshot, attr, obj.get_state()) 

196 except AttributeError: 

197 pass 

198 

199 if retriever_index is not None: 

200 snapshot.retriever_index_state = retriever_index 

201 

202 return self._atomic_write(snapshot) 

203 

204 # ── Load ──────────────────────────────── 

205 

206 async def load_all(self) -> MemorySnapshot: 

207 """Load the latest memory snapshot from disk. 

208 

209 Returns a MemorySnapshot; empty fields mean no saved state for that subsystem. 

210 """ 

211 if not self._snapshot_path.exists(): 

212 return MemorySnapshot() 

213 

214 data = self._read_snapshot_file() 

215 if data is None: 

216 return MemorySnapshot() 

217 

218 return MemorySnapshot.from_dict(data) 

219 

220 def load_sync(self) -> MemorySnapshot: 

221 """Synchronous load.""" 

222 if not self._snapshot_path.exists(): 

223 return MemorySnapshot() 

224 

225 data = self._read_snapshot_file() 

226 if data is None: 

227 return MemorySnapshot() 

228 

229 return MemorySnapshot.from_dict(data) 

230 

231 async def restore_all( 

232 self, 

233 pyramid: Any = None, 

234 working: Any = None, 

235 conversation: Any = None, 

236 long_term: Any = None, 

237 reflection_engine: Any = None, 

238 consolidation_pipeline: Any = None, 

239 retriever_index_target: dict[str, Any] | None = None, 

240 ) -> int: 

241 """Load snapshot from disk and restore into live objects. 

242 

243 Each target object must have a restore_state(state_dict) method. 

244 Returns count of subsystems restored. 

245 """ 

246 snapshot = await self.load_all() 

247 restored = 0 

248 

249 for obj, state_attr in [ 

250 (pyramid, "pyramid_state"), 

251 (working, "working_state"), 

252 (conversation, "conversation_state"), 

253 (long_term, "long_term_state"), 

254 (reflection_engine, "reflection_state"), 

255 (consolidation_pipeline, "consolidation_state"), 

256 ]: 

257 state = getattr(snapshot, state_attr, {}) 

258 if obj is not None and state: 

259 try: 

260 obj.restore_state(state) 

261 restored += 1 

262 except AttributeError: 

263 pass 

264 

265 if retriever_index_target is not None and snapshot.retriever_index_state: 

266 retriever_index_target.clear() 

267 retriever_index_target.update(snapshot.retriever_index_state) 

268 restored += 1 

269 

270 return restored 

271 

272 # ── Atomic write ──────────────────────── 

273 

274 def _atomic_write(self, snapshot: MemorySnapshot) -> str: 

275 """Write snapshot atomically: temp file → rename.""" 

276 data = snapshot.to_dict() 

277 json_bytes = json.dumps(data, ensure_ascii=False, indent=2, default=str).encode("utf-8") 

278 

279 if self.compress: 

280 json_bytes = gzip.compress(json_bytes, compresslevel=6) 

281 

282 # Write to temp file, then rename 

283 fd, tmp_path = tempfile.mkstemp( 

284 dir=str(self.base_dir), 

285 prefix=".snapshot-tmp-", 

286 suffix=".json.gz" if self.compress else ".json", 

287 ) 

288 try: 

289 with os.fdopen(fd, "wb") as f: 

290 f.write(json_bytes) 

291 

292 # Rotate old backups 

293 self._rotate_backups() 

294 

295 os.replace(tmp_path, str(self._snapshot_path)) 

296 except Exception: 

297 try: 

298 os.unlink(tmp_path) 

299 except OSError: 

300 pass 

301 raise 

302 

303 return str(self._snapshot_path) 

304 

305 # ── Read snapshot ──────────────────────── 

306 

307 def _read_snapshot_file(self) -> dict[str, Any] | None: 

308 """Read and parse snapshot file. Returns None on failure.""" 

309 try: 

310 with open(self._snapshot_path, "rb") as f: 

311 raw = f.read() 

312 

313 if self.compress: 

314 raw = gzip.decompress(raw) 

315 

316 return json.loads(raw.decode("utf-8")) 

317 except (OSError, json.JSONDecodeError, gzip.BadGzipFile): 

318 return None 

319 

320 # ── Backup rotation ────────────────────── 

321 

322 def _rotate_backups(self) -> None: 

323 """Rotate old snapshot backups, keeping self._max_backups.""" 

324 for i in range(self._max_backups - 1, 0, -1): 

325 old_path = self.base_dir / f"snapshot.{i}.json.gz" 

326 new_path = self.base_dir / f"snapshot.{i + 1}.json.gz" 

327 if old_path.exists(): 

328 try: 

329 os.replace(str(old_path), str(new_path)) 

330 except OSError: 

331 pass 

332 

333 # Rotate current into .1 

334 if self._snapshot_path.exists(): 

335 backup_path = self.base_dir / "snapshot.1.json.gz" 

336 try: 

337 os.replace(str(self._snapshot_path), str(backup_path)) 

338 except OSError: 

339 pass 

340 

341 # ── Query ──────────────────────────────── 

342 

343 def snapshot_info(self) -> dict[str, Any]: 

344 """Return metadata about the current snapshot.""" 

345 if not self._snapshot_path.exists(): 

346 return {"exists": False} 

347 

348 try: 

349 stat = self._snapshot_path.stat() 

350 snapshot = self.load_sync() 

351 

352 subsystems_saved = sum( 

353 1 for v in [ 

354 snapshot.pyramid_state, 

355 snapshot.working_state, 

356 snapshot.conversation_state, 

357 snapshot.long_term_state, 

358 snapshot.reflection_state, 

359 snapshot.consolidation_state, 

360 snapshot.retriever_index_state, 

361 ] if v 

362 ) 

363 

364 return { 

365 "exists": True, 

366 "path": str(self._snapshot_path), 

367 "size_bytes": stat.st_size, 

368 "created_at": snapshot.created_at, 

369 "version": snapshot.version, 

370 "subsystems_saved": subsystems_saved, 

371 "compressed": self.compress, 

372 } 

373 except Exception: 

374 return {"exists": True, "error": "unreadable"} 

375 

376 def delete_snapshot(self) -> bool: 

377 """Delete the current snapshot file(s).""" 

378 deleted = False 

379 for path in self.base_dir.glob("snapshot*.json.gz"): 

380 try: 

381 path.unlink() 

382 deleted = True 

383 except OSError: 

384 pass 

385 for path in self.base_dir.glob("snapshot*.json"): 

386 try: 

387 path.unlink() 

388 deleted = True 

389 except OSError: 

390 pass 

391 return deleted 

392 

393 def list_backups(self) -> list[dict[str, Any]]: 

394 """List all available snapshot backups.""" 

395 results = [] 

396 for path in sorted(self.base_dir.glob("snapshot*.json*")): 

397 try: 

398 stat = path.stat() 

399 results.append({ 

400 "name": path.name, 

401 "size_bytes": stat.st_size, 

402 "mtime": stat.st_mtime, 

403 }) 

404 except OSError: 

405 continue 

406 return results