Coverage for agentos/swarm/agent_memory.py: 23%
316 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2v1.9.7: Agent Memory System — layered memory with context window management.
4Three-tier memory architecture:
5- WorkingMemory: current task context, small capacity, fast access
6- ShortTermMemory: recent N conversation rounds, sliding window with summarization
7- LongTermMemory: vector-based semantic retrieval for historical knowledge
9ContextWindowManager: auto-trim/compress context to fit token budgets.
10"""
12from __future__ import annotations
14import heapq
15import json
16import time
17import uuid
18from collections import OrderedDict, deque
19from dataclasses import dataclass, field
20from typing import Any, Optional
23@dataclass
24class MemoryEntry:
25 """A single memory entry."""
27 id: str = field(default_factory=lambda: uuid.uuid4().hex[:8])
28 content: str = ""
29 role: str = "system" # system, user, assistant, tool
30 timestamp: float = field(default_factory=time.time)
31 importance: float = 0.5 # 0.0-1.0
32 ttl: float = 0.0 # Time-to-live in seconds, 0 = never expire
33 metadata: dict[str, Any] = field(default_factory=dict)
34 embedding: list[float] | None = None # For long-term vector search
35 summary: str = "" # Compressed version for context window
38# ── Working Memory ────────────────────────────────────────────────
40class WorkingMemory:
41 """Ultra-fast, small-capacity memory for the current task.
43 Holds task description, active goals, intermediate results.
44 Max entries enforced — oldest evicted on overflow.
45 """
47 def __init__(self, max_entries: int = 20):
48 self.max_entries = max_entries
49 self._entries: OrderedDict[str, MemoryEntry] = OrderedDict()
50 self.task_goal: str = ""
51 self.active_subtask: str = ""
52 self.scratchpad: dict[str, Any] = {}
54 def add(self, entry: MemoryEntry) -> None:
55 self._entries[entry.id] = entry
56 while len(self._entries) > self.max_entries:
57 self._entries.popitem(last=False)
59 def set_task(self, goal: str, subtask: str = "") -> None:
60 self.task_goal = goal
61 self.active_subtask = subtask or goal
63 def get_all(self) -> list[MemoryEntry]:
64 return list(self._entries.values())
66 def get_last(self, n: int = 5) -> list[MemoryEntry]:
67 return list(self._entries.values())[-n:]
69 def clear(self) -> None:
70 self._entries.clear()
71 self.scratchpad.clear()
73 def to_context(self, max_tokens: int = 500) -> str:
74 """Serialize working memory as context string for LLM."""
75 parts = []
76 if self.task_goal:
77 parts.append(f"[Task] {self.task_goal}")
78 if self.active_subtask and self.active_subtask != self.task_goal:
79 parts.append(f"[SubTask] {self.active_subtask}")
80 for entry in list(self._entries.values())[-5:]:
81 content = entry.summary or entry.content
82 if len(content) > 200:
83 content = content[:197] + "..."
84 parts.append(f"[{entry.role}] {content}")
85 result = "\n".join(parts)
86 if self._estimate_tokens(result) > max_tokens:
87 # Truncate from front
88 lines = result.split("\n")
89 while lines and self._estimate_tokens("\n".join(lines)) > max_tokens:
90 lines.pop(0)
91 result = "\n".join(lines)
92 return result
94 def _estimate_tokens(self, text: str) -> int:
95 """Rough token estimation: ~4 chars per token."""
96 return max(1, len(text) // 4)
99# ── Short-Term Memory ─────────────────────────────────────────────
101class ShortTermMemory:
102 """Sliding window of recent conversation rounds.
104 Auto-summarizes old rounds to maintain a compact window.
105 Supports importance-based retention.
106 """
108 def __init__(
109 self,
110 max_rounds: int = 50,
111 auto_summarize: bool = True,
112 summarize_threshold: int = 20, # Summarize when rounds > threshold
113 keep_recent: int = 10, # Keep N most recent rounds raw
114 ):
115 self.max_rounds = max_rounds
116 self.auto_summarize = auto_summarize
117 self.summarize_threshold = summarize_threshold
118 self.keep_recent = keep_recent
120 self._rounds: deque[list[MemoryEntry]] = deque()
121 self._summaries: list[str] = [] # Compressed old rounds
122 self.total_rounds = 0
124 def add_round(self, entries: list[MemoryEntry]) -> None:
125 """Add a full conversation round."""
126 self._rounds.append(entries)
127 self.total_rounds += 1
129 # Enforce max rounds
130 while len(self._rounds) > self.max_rounds:
131 evicted = self._rounds.popleft()
132 if self.auto_summarize:
133 summary = self._summarize_round(evicted)
134 if summary:
135 self._summaries.append(summary)
137 # Auto-summarize middle rounds when over threshold
138 if self.auto_summarize and len(self._rounds) > self.summarize_threshold:
139 self._compress_middle()
141 def _compress_middle(self) -> None:
142 """Compress rounds between recent keepers and front."""
143 keep_count = min(self.keep_recent, len(self._rounds))
144 recent = list(self._rounds)[-keep_count:]
145 middle = list(self._rounds)[:-keep_count] if keep_count < len(self._rounds) else []
147 if not middle:
148 return
150 # Summarize middle rounds
151 for entries in middle:
152 summary = self._summarize_round(entries)
153 if summary:
154 self._summaries.append(summary)
156 # Replace deque with only recent rounds
157 self._rounds = deque(recent)
159 def _summarize_round(self, entries: list[MemoryEntry]) -> str:
160 """Create a compressed summary of a round."""
161 if not entries:
162 return ""
164 # Collect key content
165 parts = []
166 for entry in entries:
167 content = entry.content
168 if len(content) > 100:
169 content = content[:97] + "..."
170 parts.append(f"{entry.role}: {content}")
172 if not parts:
173 return ""
175 timestamp = entries[0].timestamp if entries else time.time()
176 return f"[Round@{timestamp:.0f}] " + " | ".join(parts)
178 def get_context(
179 self,
180 include_summaries: bool = True,
181 max_rounds: int = 15,
182 ) -> list[MemoryEntry]:
183 """Get flattened context entries."""
184 flat: list[MemoryEntry] = []
186 # Add summaries as system entries
187 if include_summaries:
188 for summary in self._summaries[-3:]: # Keep last 3 summaries
189 flat.append(MemoryEntry(
190 content=f"[History Summary] {summary}",
191 role="system",
192 importance=0.3,
193 ))
195 # Add recent rounds
196 recent_rounds = list(self._rounds)[-max_rounds:]
197 for entries in recent_rounds:
198 flat.extend(entries)
200 return flat
202 def clear(self) -> None:
203 self._rounds.clear()
204 self._summaries.clear()
205 self.total_rounds = 0
208# ── Long-Term Memory ──────────────────────────────────────────────
210class LongTermMemory:
211 """Vector-based semantic memory for historical knowledge retrieval.
213 Stores important memories with embeddings. Supports cosine-similarity search.
214 Falls back to keyword search when no embeddings available.
215 """
217 def __init__(
218 self,
219 max_entries: int = 10000,
220 importance_threshold: float = 0.4, # Only store entries above this importance
221 persist_path: str = "",
222 ):
223 self.max_entries = max_entries
224 self.importance_threshold = importance_threshold
225 self.persist_path = persist_path
227 self._entries: dict[str, MemoryEntry] = {}
228 self._embeddings: dict[str, list[float]] = {} # entry_id → embedding
230 self._embedder: Any = None # Lazy-loaded embedder
232 def add(self, entry: MemoryEntry) -> None:
233 """Store a memory entry. Only stores if importance >= threshold."""
234 if entry.importance < self.importance_threshold:
235 return
237 self._entries[entry.id] = entry
238 if entry.embedding:
239 self._embeddings[entry.id] = entry.embedding
241 # Evict oldest if over capacity
242 while len(self._entries) > self.max_entries:
243 oldest_id = min(
244 self._entries.keys(),
245 key=lambda k: self._entries[k].timestamp,
246 )
247 del self._entries[oldest_id]
248 self._embeddings.pop(oldest_id, None)
250 def search(
251 self,
252 query: str,
253 top_k: int = 5,
254 query_embedding: list[float] | None = None,
255 ) -> list[MemoryEntry]:
256 """Semantic search over stored memories.
258 Uses cosine similarity if embeddings available, else keyword overlap.
259 """
260 if self._embeddings and query_embedding:
261 return self._vector_search(query_embedding, top_k)
262 return self._keyword_search(query, top_k)
264 def search_keywords(
265 self,
266 keywords: list[str],
267 top_k: int = 5,
268 ) -> list[MemoryEntry]:
269 """Search memories by keyword overlap."""
270 results = []
271 for entry in self._entries.values():
272 content_lower = entry.content.lower()
273 score = sum(1 for kw in keywords if kw.lower() in content_lower)
274 if score > 0:
275 results.append((score, entry))
276 results.sort(key=lambda x: (-x[0], -x[1].importance))
277 return [entry for _, entry in results[:top_k]]
279 def search_by_timerange(
280 self,
281 start: float,
282 end: float | None = None,
283 top_k: int = 10,
284 ) -> list[MemoryEntry]:
285 """Search memories by time range."""
286 end = end or time.time()
287 results = [
288 entry for entry in self._entries.values()
289 if start <= entry.timestamp <= end
290 ]
291 results.sort(key=lambda e: e.timestamp, reverse=True)
292 return results[:top_k]
294 def _vector_search(
295 self,
296 query_emb: list[float],
297 top_k: int,
298 ) -> list[MemoryEntry]:
299 """Cosine similarity search."""
300 scores = []
301 for eid, emb in self._embeddings.items():
302 sim = self._cosine_similarity(query_emb, emb)
303 scores.append((sim, eid))
304 scores.sort(reverse=True)
305 return [self._entries[eid] for _, eid in scores[:top_k] if eid in self._entries]
307 def _keyword_search(self, query: str, top_k: int) -> list[MemoryEntry]:
308 """Fallback keyword overlap search."""
309 query_words = set(query.lower().split())
310 if not query_words:
311 return []
312 return self.search_keywords(list(query_words), top_k)
314 def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
315 """Cosine similarity between two vectors."""
316 if len(a) != len(b):
317 return 0.0
318 dot = sum(x * y for x, y in zip(a, b))
319 norm_a = sum(x * x for x in a) ** 0.5
320 norm_b = sum(y * y for y in b) ** 0.5
321 if norm_a == 0 or norm_b == 0:
322 return 0.0
323 return dot / (norm_a * norm_b)
325 def export_important(self, top_k: int = 20) -> list[MemoryEntry]:
326 """Export most important entries."""
327 entries = sorted(
328 self._entries.values(),
329 key=lambda e: (e.importance, e.timestamp),
330 reverse=True,
331 )
332 return entries[:top_k]
334 def clear(self) -> None:
335 self._entries.clear()
336 self._embeddings.clear()
338 def save(self, path: str = "") -> None:
339 """Persist to disk (without embeddings)."""
340 save_path = path or self.persist_path
341 if not save_path:
342 return
344 data = []
345 for entry in self._entries.values():
346 data.append({
347 "id": entry.id,
348 "content": entry.content,
349 "role": entry.role,
350 "timestamp": entry.timestamp,
351 "importance": entry.importance,
352 "metadata": entry.metadata,
353 })
355 with open(save_path, "w") as f:
356 json.dump(data, f, ensure_ascii=False, indent=2)
358 def load(self, path: str = "") -> int:
359 """Load from disk."""
360 load_path = path or self.persist_path
361 if not load_path:
362 return 0
364 try:
365 with open(load_path) as f:
366 data = json.load(f)
367 except (FileNotFoundError, json.JSONDecodeError):
368 return 0
370 count = 0
371 for item in data:
372 entry = MemoryEntry(
373 id=item.get("id", uuid.uuid4().hex[:8]),
374 content=item.get("content", ""),
375 role=item.get("role", "system"),
376 timestamp=item.get("timestamp", time.time()),
377 importance=item.get("importance", 0.5),
378 metadata=item.get("metadata", {}),
379 )
380 self._entries[entry.id] = entry
381 count += 1
383 return count
386# ── Context Window Manager ────────────────────────────────────────
388@dataclass
389class ContextBudget:
390 """Token budget for context window management."""
392 total_tokens: int = 4096 # Max total tokens
393 system_reserved: int = 512 # Reserved for system prompt
394 working_memory_budget: int = 640 # Budget for working memory
395 short_term_budget: int = 1536 # Budget for short-term memory
396 long_term_budget: int = 512 # Budget for injected long-term memories
397 query_budget: int = 896 # Budget for current query
398 safety_margin: int = 128 # Safety margin
401class ContextWindowManager:
402 """Manages token budgets and assembles context windows.
404 Automatically trims/compresses content to fit within token budgets.
405 Handles the three-tier memory system's context assembly.
406 """
408 def __init__(self, budget: ContextBudget | None = None):
409 self.budget = budget or ContextBudget()
411 def assemble(
412 self,
413 working: WorkingMemory,
414 short_term: ShortTermMemory,
415 long_term: LongTermMemory,
416 current_query: str = "",
417 retrieval_query: str = "",
418 ) -> str:
419 """Assemble a full context window from all memory tiers.
421 Returns a string ready to prepend to the LLM prompt.
422 """
423 sections = []
425 # 1. Working memory context
426 wm_ctx = working.to_context(max_tokens=self.budget.working_memory_budget)
427 if wm_ctx:
428 sections.append(("Working Memory", wm_ctx, self.budget.working_memory_budget))
430 # 2. Short-term memory context
431 st_entries = short_term.get_context(include_summaries=True, max_rounds=15)
432 st_ctx = self._entries_to_context(st_entries, self.budget.short_term_budget)
433 if st_ctx:
434 sections.append(("Recent History", st_ctx, self.budget.short_term_budget))
436 # 3. Long-term memory (semantic retrieval)
437 lt_entries = []
438 if retrieval_query:
439 lt_entries = long_term.search(retrieval_query, top_k=5)
440 else:
441 lt_entries = long_term.export_important(top_k=5)
443 if lt_entries:
444 lt_ctx = self._entries_to_context(lt_entries, self.budget.long_term_budget)
445 if lt_ctx:
446 sections.append(("Relevant Memories", lt_ctx, self.budget.long_term_budget))
448 # 4. Current query
449 query_section = current_query
450 if query_section:
451 est_tokens = self._estimate_tokens(query_section)
452 if est_tokens > self.budget.query_budget:
453 query_section = self._truncate_text(query_section, self.budget.query_budget)
454 sections.append(("Current Task", query_section, self.budget.query_budget))
456 # Assemble final context
457 final_parts = []
458 for name, content, _ in sections:
459 final_parts.append(f"--- {name} ---\n{content}")
461 return "\n\n".join(final_parts)
463 def fit_to_budget(self, text: str, max_tokens: int) -> str:
464 """Trim text to fit within token budget."""
465 if self._estimate_tokens(text) <= max_tokens:
466 return text
467 return self._truncate_text(text, max_tokens)
469 def _entries_to_context(self, entries: list[MemoryEntry], max_tokens: int) -> str:
470 """Convert memory entries to context string, fitting budget."""
471 if not entries:
472 return ""
474 lines = []
475 token_count = 0
477 for entry in entries:
478 content = entry.summary or entry.content
479 line = f"[{entry.role}] {content}"
480 line_tokens = self._estimate_tokens(line)
482 if token_count + line_tokens > max_tokens:
483 # Try truncated version
484 available = max_tokens - token_count - 10
485 if available > 20:
486 truncated = content[:available * 4]
487 line = f"[{entry.role}] {truncated}..."
488 token_count += self._estimate_tokens(line)
489 break
491 lines.append(line)
492 token_count += line_tokens
494 return "\n".join(lines)
496 def _truncate_text(self, text: str, max_tokens: int) -> str:
497 """Truncate text from the beginning to fit token budget."""
498 # Estimate char budget: ~4 chars per token
499 char_budget = max_tokens * 4
500 if len(text) <= char_budget:
501 return text
503 # Keep last char_budget characters for relevance
504 return "...(truncated) " + text[-char_budget:]
506 def _estimate_tokens(self, text: str) -> int:
507 """Rough token estimation."""
508 return max(1, len(text) // 4)
511# ── Unified Agent Memory ──────────────────────────────────────────
513class AgentMemory:
514 """Unified memory system combining all three tiers + context management.
516 High-level API for agent memory operations:
517 - Remember conversation rounds
518 - Retrieve relevant history
519 - Assemble context window
521 Usage:
522 memory = AgentMemory()
523 memory.add_round([user_msg, assistant_msg])
524 context = memory.get_context(query="What files did I create yesterday?")
525 """
527 def __init__(
528 self,
529 working_max: int = 20,
530 short_term_max_rounds: int = 50,
531 long_term_max: int = 10000,
532 budget: ContextBudget | None = None,
533 ):
534 self.working = WorkingMemory(max_entries=working_max)
535 self.short_term = ShortTermMemory(max_rounds=short_term_max_rounds)
536 self.long_term = LongTermMemory(max_entries=long_term_max)
537 self.window_manager = ContextWindowManager(budget=budget)
539 def add_round(
540 self,
541 entries: list[MemoryEntry],
542 importance: float = 0.5,
543 ) -> None:
544 """Add a full conversation round to memory."""
545 self.short_term.add_round(entries)
547 # Store important entries to long-term
548 for entry in entries:
549 if entry.importance >= 0.4:
550 self.long_term.add(entry)
552 def set_task(self, goal: str, subtask: str = "") -> None:
553 """Set current task context in working memory."""
554 self.working.set_task(goal, subtask)
556 def remember(
557 self,
558 content: str,
559 role: str = "system",
560 importance: float = 0.5,
561 ttl: float = 0.0,
562 metadata: dict | None = None,
563 ) -> MemoryEntry:
564 """Store a single memory entry."""
565 entry = MemoryEntry(
566 content=content,
567 role=role,
568 importance=importance,
569 ttl=ttl,
570 metadata=metadata or {},
571 )
572 self.working.add(entry)
573 if importance >= 0.6:
574 self.long_term.add(entry)
575 return entry
577 def recall(
578 self,
579 query: str,
580 top_k: int = 5,
581 include_short_term: bool = True,
582 include_long_term: bool = True,
583 ) -> list[MemoryEntry]:
584 """Search across memory tiers."""
585 results: list[MemoryEntry] = []
587 if include_short_term:
588 st_entries = self.short_term.get_context(include_summaries=True)
589 # Simple keyword filter on short-term
590 query_words = set(query.lower().split())
591 for entry in st_entries:
592 score = sum(1 for w in query_words if w in entry.content.lower())
593 if score > 0:
594 results.append(entry)
596 if include_long_term:
597 lt_results = self.long_term.search(query, top_k=top_k)
598 for entry in lt_results:
599 if entry not in results:
600 results.append(entry)
602 # Sort by importance then timestamp
603 results.sort(key=lambda e: (e.importance, e.timestamp), reverse=True)
604 return results[:top_k]
606 def get_context(self, query: str = "") -> str:
607 """Assemble full context window for the current task."""
608 return self.window_manager.assemble(
609 working=self.working,
610 short_term=self.short_term,
611 long_term=self.long_term,
612 current_query=query,
613 retrieval_query=query,
614 )
616 def clear_working(self) -> None:
617 self.working.clear()
619 def clear_all(self) -> None:
620 self.working.clear()
621 self.short_term.clear()
622 self.long_term.clear()