Coverage for agentos/memory/retriever.py: 29%
179 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Semantic Memory Retriever — Embedding-based memory retrieval with hybrid search.
4Supports semantic (embedding), keyword (BM25), and hybrid search across
5conversation memory, long-term memory, and working memory. Aligns with
6ConversationMemory window strategies and LongTermMemory persistence.
7"""
9from __future__ import annotations
11import json
12import math
13from collections import Counter
14from dataclasses import dataclass, field
15from enum import Enum
16from typing import Any, Callable, Optional
18import numpy as np
21class RetrievalStrategy(Enum):
23 """检索策略枚举。"""
25 SEMANTIC = "semantic"
26 KEYWORD = "keyword"
27 HYBRID = "hybrid"
28 RECENT = "recent"
31@dataclass
32class MemoryEntry:
33 """A single memory entry with content and metadata."""
35 id: str
36 content: str
37 metadata: dict[str, Any] = field(default_factory=dict)
38 embedding: Optional[list[float]] = None
39 timestamp: Optional[float] = None
40 importance: float = 0.5
41 source: str = "conversation" # conversation / long_term / working
44@dataclass
45class RetrievalResult:
46 """A single retrieval result with relevance score."""
48 entry: MemoryEntry
49 score: float
50 strategy: RetrievalStrategy
53@dataclass
54class RetrievalStats:
55 """Statistics for a retrieval operation."""
57 total_entries: int
58 retrieved: int
59 strategies_used: list[RetrievalStrategy] = field(default_factory=list)
60 latency_ms: float = 0.0
63class SemanticMemoryRetriever:
64 """
65 Semantic retrieval engine for AgentOS memory systems.
67 Supports three retrieval strategies:
68 - **semantic**: Cosine similarity over embeddings (requires embedder)
69 - **keyword**: BM25-style TF-IDF keyword matching (no embedder needed)
70 - **hybrid**: Weighted combination of semantic + keyword scores
72 Example::
74 retriever = SemanticMemoryRetriever(embedder=my_embedder)
75 results = retriever.retrieve(
76 "What did we discuss about deployment?",
77 top_k=5,
78 strategy=RetrievalStrategy.HYBRID,
79 )
80 for r in results:
81 print(f"[{r.score:.2f}] {r.entry.content[:80]}...")
82 """
84 def __init__(
85 self,
86 embedder: Optional[Callable[[str], list[float]]] = None,
87 hybrid_weight: float = 0.7,
88 min_keyword_score: float = 0.01,
89 default_top_k: int = 10,
90 ):
91 """
92 Args:
93 embedder: Callable that takes text and returns embedding vector.
94 hybrid_weight: Weight for semantic score in hybrid mode (0-1).
95 Remaining weight goes to keyword score.
96 min_keyword_score: Minimum BM25 score to include in results.
97 default_top_k: Default number of results to return.
98 """
99 self._embedder = embedder
100 self._hybrid_weight = hybrid_weight
101 self._min_keyword_score = min_keyword_score
102 self._default_top_k = default_top_k
103 self._entries: dict[str, MemoryEntry] = {}
104 self._idf_cache: dict[str, float] = {}
105 self._doc_freqs: Counter[str, int] = Counter()
106 self._total_docs: int = 0
108 def index(self, entries: list[MemoryEntry]) -> None:
109 """Add entries to the search index."""
110 for entry in entries:
111 self._entries[entry.id] = entry
112 if entry.embedding and self._embedder:
113 # Already has embedding, no need to re-embed
114 pass
115 elif self._embedder:
116 entry.embedding = self._embedder(entry.content)
118 # Update keyword index
119 tokens = self._tokenize(entry.content)
120 unique_tokens = set(tokens)
121 self._doc_freqs.update(unique_tokens)
122 self._total_docs += 1
124 def remove(self, entry_ids: list[str]) -> None:
125 """Remove entries from the index."""
126 for eid in entry_ids:
127 if eid in self._entries:
128 entry = self._entries.pop(eid)
129 unique_tokens = set(self._tokenize(entry.content))
130 for token in unique_tokens:
131 self._doc_freqs[token] = max(0, self._doc_freqs[token] - 1)
132 self._total_docs = max(0, self._total_docs - 1)
133 self._idf_cache.clear()
135 def retrieve(
136 self,
137 query: str,
138 top_k: Optional[int] = None,
139 strategy: RetrievalStrategy = RetrievalStrategy.HYBRID,
140 filter_source: Optional[str] = None,
141 min_importance: float = 0.0,
142 ) -> list[RetrievalResult]:
143 """
144 Retrieve the most relevant memories for a query.
146 Args:
147 query: Search query.
148 top_k: Number of results to return.
149 strategy: Retrieval strategy.
150 filter_source: Only return entries from this source.
151 min_importance: Minimum importance score filter.
153 Returns:
154 List of RetrievalResult sorted by relevance.
155 """
156 import time
157 start = time.perf_counter()
158 top_k = top_k or self._default_top_k
160 # Filter entries
161 candidates = [
162 e for e in self._entries.values()
163 if (filter_source is None or e.source == filter_source)
164 and e.importance >= min_importance
165 ]
167 if not candidates:
168 return []
170 if strategy == RetrievalStrategy.RECENT:
171 results = self._retrieve_recent(candidates, top_k)
172 elif strategy == RetrievalStrategy.KEYWORD:
173 results = self._retrieve_keyword(query, candidates, top_k)
174 elif strategy == RetrievalStrategy.SEMANTIC:
175 results = self._retrieve_semantic(query, candidates, top_k)
176 else: # HYBRID
177 results = self._retrieve_hybrid(query, candidates, top_k)
179 elapsed = (time.perf_counter() - start) * 1000
180 # Attach stats to results via a common approach
181 return results
183 def _retrieve_recent(
184 self, candidates: list[MemoryEntry], top_k: int,
185 ) -> list[RetrievalResult]:
186 """Return most recent entries sorted by timestamp."""
187 sorted_entries = sorted(
188 candidates,
189 key=lambda e: e.timestamp or 0,
190 reverse=True,
191 )
192 return [
193 RetrievalResult(
194 entry=e, score=1.0, strategy=RetrievalStrategy.RECENT,
195 )
196 for e in sorted_entries[:top_k]
197 ]
199 def _retrieve_keyword(
200 self, query: str, candidates: list[MemoryEntry], top_k: int,
201 ) -> list[RetrievalResult]:
202 """BM25-style keyword search."""
203 query_tokens = self._tokenize(query)
204 if not query_tokens:
205 return []
207 scores = []
208 for entry in candidates:
209 score = self._bm25_score(query_tokens, entry.content)
210 if score >= self._min_keyword_score:
211 scores.append((score, entry))
213 scores.sort(key=lambda x: x[0], reverse=True)
214 return [
215 RetrievalResult(
216 entry=e, score=s, strategy=RetrievalStrategy.KEYWORD,
217 )
218 for s, e in scores[:top_k]
219 ]
221 def _retrieve_semantic(
222 self, query: str, candidates: list[MemoryEntry], top_k: int,
223 ) -> list[RetrievalResult]:
224 """Cosine similarity semantic search."""
225 if not self._embedder:
226 return self._retrieve_keyword(query, candidates, top_k)
228 query_embedding = np.array(self._embedder(query))
229 scores = []
230 for entry in candidates:
231 if entry.embedding is None:
232 entry.embedding = self._embedder(entry.content)
233 entry_embedding = np.array(entry.embedding)
234 similarity = self._cosine_sim(query_embedding, entry_embedding)
235 scores.append((similarity, entry))
237 scores.sort(key=lambda x: x[0], reverse=True)
238 return [
239 RetrievalResult(
240 entry=e, score=float(s), strategy=RetrievalStrategy.SEMANTIC,
241 )
242 for s, e in scores[:top_k]
243 ]
245 def _retrieve_hybrid(
246 self, query: str, candidates: list[MemoryEntry], top_k: int,
247 ) -> list[RetrievalResult]:
248 """Weighted combination of semantic and keyword scores."""
249 query_tokens = self._tokenize(query)
250 has_embedder = self._embedder is not None
252 if has_embedder:
253 query_embedding = np.array(self._embedder(query))
255 scores = []
256 for entry in candidates:
257 kw_score = self._bm25_score(query_tokens, entry.content)
259 if has_embedder:
260 if entry.embedding is None:
261 entry.embedding = self._embedder(entry.content)
262 entry_embedding = np.array(entry.embedding)
263 sem_score = self._cosine_sim(query_embedding, entry_embedding)
264 combined = (
265 self._hybrid_weight * sem_score
266 + (1 - self._hybrid_weight) * kw_score
267 )
268 else:
269 combined = kw_score
271 if combined > 0:
272 scores.append((combined, entry))
274 scores.sort(key=lambda x: x[0], reverse=True)
275 return [
276 RetrievalResult(
277 entry=e, score=float(s), strategy=RetrievalStrategy.HYBRID,
278 )
279 for s, e in scores[:top_k]
280 ]
282 # --- BM25 implementation ---
284 @staticmethod
285 def _tokenize(text: str) -> list[str]:
286 """Simple word tokenizer."""
287 text = text.lower()
288 # Split on non-alphanumeric, keep sequences of 2+ chars
289 tokens = []
290 current = []
291 for ch in text:
292 if ch.isalnum():
293 current.append(ch)
294 else:
295 if len(current) >= 2:
296 tokens.append("".join(current))
297 current = []
298 if len(current) >= 2:
299 tokens.append("".join(current))
300 return tokens
302 def _idf(self, term: str) -> float:
303 """Inverse document frequency."""
304 if term not in self._idf_cache:
305 df = self._doc_freqs.get(term, 0)
306 if df == 0 or self._total_docs == 0:
307 self._idf_cache[term] = 0.0
308 else:
309 self._idf_cache[term] = math.log(
310 (self._total_docs - df + 0.5) / (df + 0.5) + 1.0
311 )
312 return self._idf_cache[term]
314 def _bm25_score(
315 self, query_tokens: list[str], document: str,
316 k1: float = 1.2, b: float = 0.75,
317 ) -> float:
318 """BM25 score for a document given query tokens."""
319 doc_tokens = self._tokenize(document)
320 doc_len = len(doc_tokens)
321 avg_doc_len = max(1, self._total_docs)
323 term_freqs = Counter(doc_tokens)
324 score = 0.0
326 for token in query_tokens:
327 tf = term_freqs.get(token, 0)
328 if tf == 0:
329 continue
330 idf = self._idf(token)
331 numerator = tf * (k1 + 1)
332 denominator = tf + k1 * (1 - b + b * doc_len / avg_doc_len)
333 score += idf * numerator / denominator
335 return round(score, 6)
337 # --- Utilities ---
339 @staticmethod
340 def _cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
341 """Cosine similarity between two vectors."""
342 dot = np.dot(a, b)
343 norm_a = np.linalg.norm(a)
344 norm_b = np.linalg.norm(b)
345 if norm_a == 0 or norm_b == 0:
346 return 0.0
347 return float(dot / (norm_a * norm_b))
349 @property
350 def entry_count(self) -> int:
351 return len(self._entries)
353 def clear(self) -> None:
354 self._entries.clear()
355 self._idf_cache.clear()
356 self._doc_freqs.clear()
357 self._total_docs = 0
359 def get_stats(self) -> dict[str, Any]:
360 """Return index statistics."""
361 return {
362 "total_entries": len(self._entries),
363 "total_docs": self._total_docs,
364 "unique_terms": len(self._doc_freqs),
365 "has_embedder": self._embedder is not None,
366 "hybrid_weight": self._hybrid_weight,
367 }