Coverage for agentos/swarm/agent_memory.py: 23%

316 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2v1.9.7: Agent Memory System — layered memory with context window management. 

3 

4Three-tier memory architecture: 

5- WorkingMemory: current task context, small capacity, fast access 

6- ShortTermMemory: recent N conversation rounds, sliding window with summarization 

7- LongTermMemory: vector-based semantic retrieval for historical knowledge 

8 

9ContextWindowManager: auto-trim/compress context to fit token budgets. 

10""" 

11 

12from __future__ import annotations 

13 

14import heapq 

15import json 

16import time 

17import uuid 

18from collections import OrderedDict, deque 

19from dataclasses import dataclass, field 

20from typing import Any, Optional 

21 

22 

23@dataclass 

24class MemoryEntry: 

25 """A single memory entry.""" 

26 

27 id: str = field(default_factory=lambda: uuid.uuid4().hex[:8]) 

28 content: str = "" 

29 role: str = "system" # system, user, assistant, tool 

30 timestamp: float = field(default_factory=time.time) 

31 importance: float = 0.5 # 0.0-1.0 

32 ttl: float = 0.0 # Time-to-live in seconds, 0 = never expire 

33 metadata: dict[str, Any] = field(default_factory=dict) 

34 embedding: list[float] | None = None # For long-term vector search 

35 summary: str = "" # Compressed version for context window 

36 

37 

38# ── Working Memory ──────────────────────────────────────────────── 

39 

40class WorkingMemory: 

41 """Ultra-fast, small-capacity memory for the current task. 

42 

43 Holds task description, active goals, intermediate results. 

44 Max entries enforced — oldest evicted on overflow. 

45 """ 

46 

47 def __init__(self, max_entries: int = 20): 

48 self.max_entries = max_entries 

49 self._entries: OrderedDict[str, MemoryEntry] = OrderedDict() 

50 self.task_goal: str = "" 

51 self.active_subtask: str = "" 

52 self.scratchpad: dict[str, Any] = {} 

53 

54 def add(self, entry: MemoryEntry) -> None: 

55 self._entries[entry.id] = entry 

56 while len(self._entries) > self.max_entries: 

57 self._entries.popitem(last=False) 

58 

59 def set_task(self, goal: str, subtask: str = "") -> None: 

60 self.task_goal = goal 

61 self.active_subtask = subtask or goal 

62 

63 def get_all(self) -> list[MemoryEntry]: 

64 return list(self._entries.values()) 

65 

66 def get_last(self, n: int = 5) -> list[MemoryEntry]: 

67 return list(self._entries.values())[-n:] 

68 

69 def clear(self) -> None: 

70 self._entries.clear() 

71 self.scratchpad.clear() 

72 

73 def to_context(self, max_tokens: int = 500) -> str: 

74 """Serialize working memory as context string for LLM.""" 

75 parts = [] 

76 if self.task_goal: 

77 parts.append(f"[Task] {self.task_goal}") 

78 if self.active_subtask and self.active_subtask != self.task_goal: 

79 parts.append(f"[SubTask] {self.active_subtask}") 

80 for entry in list(self._entries.values())[-5:]: 

81 content = entry.summary or entry.content 

82 if len(content) > 200: 

83 content = content[:197] + "..." 

84 parts.append(f"[{entry.role}] {content}") 

85 result = "\n".join(parts) 

86 if self._estimate_tokens(result) > max_tokens: 

87 # Truncate from front 

88 lines = result.split("\n") 

89 while lines and self._estimate_tokens("\n".join(lines)) > max_tokens: 

90 lines.pop(0) 

91 result = "\n".join(lines) 

92 return result 

93 

94 def _estimate_tokens(self, text: str) -> int: 

95 """Rough token estimation: ~4 chars per token.""" 

96 return max(1, len(text) // 4) 

97 

98 

99# ── Short-Term Memory ───────────────────────────────────────────── 

100 

101class ShortTermMemory: 

102 """Sliding window of recent conversation rounds. 

103 

104 Auto-summarizes old rounds to maintain a compact window. 

105 Supports importance-based retention. 

106 """ 

107 

108 def __init__( 

109 self, 

110 max_rounds: int = 50, 

111 auto_summarize: bool = True, 

112 summarize_threshold: int = 20, # Summarize when rounds > threshold 

113 keep_recent: int = 10, # Keep N most recent rounds raw 

114 ): 

115 self.max_rounds = max_rounds 

116 self.auto_summarize = auto_summarize 

117 self.summarize_threshold = summarize_threshold 

118 self.keep_recent = keep_recent 

119 

120 self._rounds: deque[list[MemoryEntry]] = deque() 

121 self._summaries: list[str] = [] # Compressed old rounds 

122 self.total_rounds = 0 

123 

124 def add_round(self, entries: list[MemoryEntry]) -> None: 

125 """Add a full conversation round.""" 

126 self._rounds.append(entries) 

127 self.total_rounds += 1 

128 

129 # Enforce max rounds 

130 while len(self._rounds) > self.max_rounds: 

131 evicted = self._rounds.popleft() 

132 if self.auto_summarize: 

133 summary = self._summarize_round(evicted) 

134 if summary: 

135 self._summaries.append(summary) 

136 

137 # Auto-summarize middle rounds when over threshold 

138 if self.auto_summarize and len(self._rounds) > self.summarize_threshold: 

139 self._compress_middle() 

140 

141 def _compress_middle(self) -> None: 

142 """Compress rounds between recent keepers and front.""" 

143 keep_count = min(self.keep_recent, len(self._rounds)) 

144 recent = list(self._rounds)[-keep_count:] 

145 middle = list(self._rounds)[:-keep_count] if keep_count < len(self._rounds) else [] 

146 

147 if not middle: 

148 return 

149 

150 # Summarize middle rounds 

151 for entries in middle: 

152 summary = self._summarize_round(entries) 

153 if summary: 

154 self._summaries.append(summary) 

155 

156 # Replace deque with only recent rounds 

157 self._rounds = deque(recent) 

158 

159 def _summarize_round(self, entries: list[MemoryEntry]) -> str: 

160 """Create a compressed summary of a round.""" 

161 if not entries: 

162 return "" 

163 

164 # Collect key content 

165 parts = [] 

166 for entry in entries: 

167 content = entry.content 

168 if len(content) > 100: 

169 content = content[:97] + "..." 

170 parts.append(f"{entry.role}: {content}") 

171 

172 if not parts: 

173 return "" 

174 

175 timestamp = entries[0].timestamp if entries else time.time() 

176 return f"[Round@{timestamp:.0f}] " + " | ".join(parts) 

177 

178 def get_context( 

179 self, 

180 include_summaries: bool = True, 

181 max_rounds: int = 15, 

182 ) -> list[MemoryEntry]: 

183 """Get flattened context entries.""" 

184 flat: list[MemoryEntry] = [] 

185 

186 # Add summaries as system entries 

187 if include_summaries: 

188 for summary in self._summaries[-3:]: # Keep last 3 summaries 

189 flat.append(MemoryEntry( 

190 content=f"[History Summary] {summary}", 

191 role="system", 

192 importance=0.3, 

193 )) 

194 

195 # Add recent rounds 

196 recent_rounds = list(self._rounds)[-max_rounds:] 

197 for entries in recent_rounds: 

198 flat.extend(entries) 

199 

200 return flat 

201 

202 def clear(self) -> None: 

203 self._rounds.clear() 

204 self._summaries.clear() 

205 self.total_rounds = 0 

206 

207 

208# ── Long-Term Memory ────────────────────────────────────────────── 

209 

210class LongTermMemory: 

211 """Vector-based semantic memory for historical knowledge retrieval. 

212 

213 Stores important memories with embeddings. Supports cosine-similarity search. 

214 Falls back to keyword search when no embeddings available. 

215 """ 

216 

217 def __init__( 

218 self, 

219 max_entries: int = 10000, 

220 importance_threshold: float = 0.4, # Only store entries above this importance 

221 persist_path: str = "", 

222 ): 

223 self.max_entries = max_entries 

224 self.importance_threshold = importance_threshold 

225 self.persist_path = persist_path 

226 

227 self._entries: dict[str, MemoryEntry] = {} 

228 self._embeddings: dict[str, list[float]] = {} # entry_id → embedding 

229 

230 self._embedder: Any = None # Lazy-loaded embedder 

231 

232 def add(self, entry: MemoryEntry) -> None: 

233 """Store a memory entry. Only stores if importance >= threshold.""" 

234 if entry.importance < self.importance_threshold: 

235 return 

236 

237 self._entries[entry.id] = entry 

238 if entry.embedding: 

239 self._embeddings[entry.id] = entry.embedding 

240 

241 # Evict oldest if over capacity 

242 while len(self._entries) > self.max_entries: 

243 oldest_id = min( 

244 self._entries.keys(), 

245 key=lambda k: self._entries[k].timestamp, 

246 ) 

247 del self._entries[oldest_id] 

248 self._embeddings.pop(oldest_id, None) 

249 

250 def search( 

251 self, 

252 query: str, 

253 top_k: int = 5, 

254 query_embedding: list[float] | None = None, 

255 ) -> list[MemoryEntry]: 

256 """Semantic search over stored memories. 

257 

258 Uses cosine similarity if embeddings available, else keyword overlap. 

259 """ 

260 if self._embeddings and query_embedding: 

261 return self._vector_search(query_embedding, top_k) 

262 return self._keyword_search(query, top_k) 

263 

264 def search_keywords( 

265 self, 

266 keywords: list[str], 

267 top_k: int = 5, 

268 ) -> list[MemoryEntry]: 

269 """Search memories by keyword overlap.""" 

270 results = [] 

271 for entry in self._entries.values(): 

272 content_lower = entry.content.lower() 

273 score = sum(1 for kw in keywords if kw.lower() in content_lower) 

274 if score > 0: 

275 results.append((score, entry)) 

276 results.sort(key=lambda x: (-x[0], -x[1].importance)) 

277 return [entry for _, entry in results[:top_k]] 

278 

279 def search_by_timerange( 

280 self, 

281 start: float, 

282 end: float | None = None, 

283 top_k: int = 10, 

284 ) -> list[MemoryEntry]: 

285 """Search memories by time range.""" 

286 end = end or time.time() 

287 results = [ 

288 entry for entry in self._entries.values() 

289 if start <= entry.timestamp <= end 

290 ] 

291 results.sort(key=lambda e: e.timestamp, reverse=True) 

292 return results[:top_k] 

293 

294 def _vector_search( 

295 self, 

296 query_emb: list[float], 

297 top_k: int, 

298 ) -> list[MemoryEntry]: 

299 """Cosine similarity search.""" 

300 scores = [] 

301 for eid, emb in self._embeddings.items(): 

302 sim = self._cosine_similarity(query_emb, emb) 

303 scores.append((sim, eid)) 

304 scores.sort(reverse=True) 

305 return [self._entries[eid] for _, eid in scores[:top_k] if eid in self._entries] 

306 

307 def _keyword_search(self, query: str, top_k: int) -> list[MemoryEntry]: 

308 """Fallback keyword overlap search.""" 

309 query_words = set(query.lower().split()) 

310 if not query_words: 

311 return [] 

312 return self.search_keywords(list(query_words), top_k) 

313 

314 def _cosine_similarity(self, a: list[float], b: list[float]) -> float: 

315 """Cosine similarity between two vectors.""" 

316 if len(a) != len(b): 

317 return 0.0 

318 dot = sum(x * y for x, y in zip(a, b)) 

319 norm_a = sum(x * x for x in a) ** 0.5 

320 norm_b = sum(y * y for y in b) ** 0.5 

321 if norm_a == 0 or norm_b == 0: 

322 return 0.0 

323 return dot / (norm_a * norm_b) 

324 

325 def export_important(self, top_k: int = 20) -> list[MemoryEntry]: 

326 """Export most important entries.""" 

327 entries = sorted( 

328 self._entries.values(), 

329 key=lambda e: (e.importance, e.timestamp), 

330 reverse=True, 

331 ) 

332 return entries[:top_k] 

333 

334 def clear(self) -> None: 

335 self._entries.clear() 

336 self._embeddings.clear() 

337 

338 def save(self, path: str = "") -> None: 

339 """Persist to disk (without embeddings).""" 

340 save_path = path or self.persist_path 

341 if not save_path: 

342 return 

343 

344 data = [] 

345 for entry in self._entries.values(): 

346 data.append({ 

347 "id": entry.id, 

348 "content": entry.content, 

349 "role": entry.role, 

350 "timestamp": entry.timestamp, 

351 "importance": entry.importance, 

352 "metadata": entry.metadata, 

353 }) 

354 

355 with open(save_path, "w") as f: 

356 json.dump(data, f, ensure_ascii=False, indent=2) 

357 

358 def load(self, path: str = "") -> int: 

359 """Load from disk.""" 

360 load_path = path or self.persist_path 

361 if not load_path: 

362 return 0 

363 

364 try: 

365 with open(load_path) as f: 

366 data = json.load(f) 

367 except (FileNotFoundError, json.JSONDecodeError): 

368 return 0 

369 

370 count = 0 

371 for item in data: 

372 entry = MemoryEntry( 

373 id=item.get("id", uuid.uuid4().hex[:8]), 

374 content=item.get("content", ""), 

375 role=item.get("role", "system"), 

376 timestamp=item.get("timestamp", time.time()), 

377 importance=item.get("importance", 0.5), 

378 metadata=item.get("metadata", {}), 

379 ) 

380 self._entries[entry.id] = entry 

381 count += 1 

382 

383 return count 

384 

385 

386# ── Context Window Manager ──────────────────────────────────────── 

387 

388@dataclass 

389class ContextBudget: 

390 """Token budget for context window management.""" 

391 

392 total_tokens: int = 4096 # Max total tokens 

393 system_reserved: int = 512 # Reserved for system prompt 

394 working_memory_budget: int = 640 # Budget for working memory 

395 short_term_budget: int = 1536 # Budget for short-term memory 

396 long_term_budget: int = 512 # Budget for injected long-term memories 

397 query_budget: int = 896 # Budget for current query 

398 safety_margin: int = 128 # Safety margin 

399 

400 

401class ContextWindowManager: 

402 """Manages token budgets and assembles context windows. 

403 

404 Automatically trims/compresses content to fit within token budgets. 

405 Handles the three-tier memory system's context assembly. 

406 """ 

407 

408 def __init__(self, budget: ContextBudget | None = None): 

409 self.budget = budget or ContextBudget() 

410 

411 def assemble( 

412 self, 

413 working: WorkingMemory, 

414 short_term: ShortTermMemory, 

415 long_term: LongTermMemory, 

416 current_query: str = "", 

417 retrieval_query: str = "", 

418 ) -> str: 

419 """Assemble a full context window from all memory tiers. 

420 

421 Returns a string ready to prepend to the LLM prompt. 

422 """ 

423 sections = [] 

424 

425 # 1. Working memory context 

426 wm_ctx = working.to_context(max_tokens=self.budget.working_memory_budget) 

427 if wm_ctx: 

428 sections.append(("Working Memory", wm_ctx, self.budget.working_memory_budget)) 

429 

430 # 2. Short-term memory context 

431 st_entries = short_term.get_context(include_summaries=True, max_rounds=15) 

432 st_ctx = self._entries_to_context(st_entries, self.budget.short_term_budget) 

433 if st_ctx: 

434 sections.append(("Recent History", st_ctx, self.budget.short_term_budget)) 

435 

436 # 3. Long-term memory (semantic retrieval) 

437 lt_entries = [] 

438 if retrieval_query: 

439 lt_entries = long_term.search(retrieval_query, top_k=5) 

440 else: 

441 lt_entries = long_term.export_important(top_k=5) 

442 

443 if lt_entries: 

444 lt_ctx = self._entries_to_context(lt_entries, self.budget.long_term_budget) 

445 if lt_ctx: 

446 sections.append(("Relevant Memories", lt_ctx, self.budget.long_term_budget)) 

447 

448 # 4. Current query 

449 query_section = current_query 

450 if query_section: 

451 est_tokens = self._estimate_tokens(query_section) 

452 if est_tokens > self.budget.query_budget: 

453 query_section = self._truncate_text(query_section, self.budget.query_budget) 

454 sections.append(("Current Task", query_section, self.budget.query_budget)) 

455 

456 # Assemble final context 

457 final_parts = [] 

458 for name, content, _ in sections: 

459 final_parts.append(f"--- {name} ---\n{content}") 

460 

461 return "\n\n".join(final_parts) 

462 

463 def fit_to_budget(self, text: str, max_tokens: int) -> str: 

464 """Trim text to fit within token budget.""" 

465 if self._estimate_tokens(text) <= max_tokens: 

466 return text 

467 return self._truncate_text(text, max_tokens) 

468 

469 def _entries_to_context(self, entries: list[MemoryEntry], max_tokens: int) -> str: 

470 """Convert memory entries to context string, fitting budget.""" 

471 if not entries: 

472 return "" 

473 

474 lines = [] 

475 token_count = 0 

476 

477 for entry in entries: 

478 content = entry.summary or entry.content 

479 line = f"[{entry.role}] {content}" 

480 line_tokens = self._estimate_tokens(line) 

481 

482 if token_count + line_tokens > max_tokens: 

483 # Try truncated version 

484 available = max_tokens - token_count - 10 

485 if available > 20: 

486 truncated = content[:available * 4] 

487 line = f"[{entry.role}] {truncated}..." 

488 token_count += self._estimate_tokens(line) 

489 break 

490 

491 lines.append(line) 

492 token_count += line_tokens 

493 

494 return "\n".join(lines) 

495 

496 def _truncate_text(self, text: str, max_tokens: int) -> str: 

497 """Truncate text from the beginning to fit token budget.""" 

498 # Estimate char budget: ~4 chars per token 

499 char_budget = max_tokens * 4 

500 if len(text) <= char_budget: 

501 return text 

502 

503 # Keep last char_budget characters for relevance 

504 return "...(truncated) " + text[-char_budget:] 

505 

506 def _estimate_tokens(self, text: str) -> int: 

507 """Rough token estimation.""" 

508 return max(1, len(text) // 4) 

509 

510 

511# ── Unified Agent Memory ────────────────────────────────────────── 

512 

513class AgentMemory: 

514 """Unified memory system combining all three tiers + context management. 

515 

516 High-level API for agent memory operations: 

517 - Remember conversation rounds 

518 - Retrieve relevant history 

519 - Assemble context window 

520 

521 Usage: 

522 memory = AgentMemory() 

523 memory.add_round([user_msg, assistant_msg]) 

524 context = memory.get_context(query="What files did I create yesterday?") 

525 """ 

526 

527 def __init__( 

528 self, 

529 working_max: int = 20, 

530 short_term_max_rounds: int = 50, 

531 long_term_max: int = 10000, 

532 budget: ContextBudget | None = None, 

533 ): 

534 self.working = WorkingMemory(max_entries=working_max) 

535 self.short_term = ShortTermMemory(max_rounds=short_term_max_rounds) 

536 self.long_term = LongTermMemory(max_entries=long_term_max) 

537 self.window_manager = ContextWindowManager(budget=budget) 

538 

539 def add_round( 

540 self, 

541 entries: list[MemoryEntry], 

542 importance: float = 0.5, 

543 ) -> None: 

544 """Add a full conversation round to memory.""" 

545 self.short_term.add_round(entries) 

546 

547 # Store important entries to long-term 

548 for entry in entries: 

549 if entry.importance >= 0.4: 

550 self.long_term.add(entry) 

551 

552 def set_task(self, goal: str, subtask: str = "") -> None: 

553 """Set current task context in working memory.""" 

554 self.working.set_task(goal, subtask) 

555 

556 def remember( 

557 self, 

558 content: str, 

559 role: str = "system", 

560 importance: float = 0.5, 

561 ttl: float = 0.0, 

562 metadata: dict | None = None, 

563 ) -> MemoryEntry: 

564 """Store a single memory entry.""" 

565 entry = MemoryEntry( 

566 content=content, 

567 role=role, 

568 importance=importance, 

569 ttl=ttl, 

570 metadata=metadata or {}, 

571 ) 

572 self.working.add(entry) 

573 if importance >= 0.6: 

574 self.long_term.add(entry) 

575 return entry 

576 

577 def recall( 

578 self, 

579 query: str, 

580 top_k: int = 5, 

581 include_short_term: bool = True, 

582 include_long_term: bool = True, 

583 ) -> list[MemoryEntry]: 

584 """Search across memory tiers.""" 

585 results: list[MemoryEntry] = [] 

586 

587 if include_short_term: 

588 st_entries = self.short_term.get_context(include_summaries=True) 

589 # Simple keyword filter on short-term 

590 query_words = set(query.lower().split()) 

591 for entry in st_entries: 

592 score = sum(1 for w in query_words if w in entry.content.lower()) 

593 if score > 0: 

594 results.append(entry) 

595 

596 if include_long_term: 

597 lt_results = self.long_term.search(query, top_k=top_k) 

598 for entry in lt_results: 

599 if entry not in results: 

600 results.append(entry) 

601 

602 # Sort by importance then timestamp 

603 results.sort(key=lambda e: (e.importance, e.timestamp), reverse=True) 

604 return results[:top_k] 

605 

606 def get_context(self, query: str = "") -> str: 

607 """Assemble full context window for the current task.""" 

608 return self.window_manager.assemble( 

609 working=self.working, 

610 short_term=self.short_term, 

611 long_term=self.long_term, 

612 current_query=query, 

613 retrieval_query=query, 

614 ) 

615 

616 def clear_working(self) -> None: 

617 self.working.clear() 

618 

619 def clear_all(self) -> None: 

620 self.working.clear() 

621 self.short_term.clear() 

622 self.long_term.clear()