Coverage for src / kemi / pipeline / retrieval.py: 95%
121 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""Retrieval pipeline: turn a ``(user_id, query, ...)`` request into ranked results.
3Extracted from ``kemi._memory_impl``. The pipeline is a stateful
4object (``RetrievalPipeline``) that holds a :class:`RetrievalContext`
5with all its dependencies. It does not reference the ``Memory``
6class — everything it needs comes through the context, which keeps
7the pipeline independently testable.
9The pipeline owns the full recall flow: default resolution, query
10embedding, cache check, hook firing, storage search, metadata
11filtering, embedding-dimension check, entity extraction, scoring,
12MMR reranking, token truncation, lifecycle updates, metric
13increments, cache write, and adaptive retrieval feedback. The
14caller (:meth:`kemi._memory_impl.Memory.recall`) is responsible for
15input validation and latency tracking.
16"""
18from __future__ import annotations
20import logging
21from collections.abc import Callable
22from dataclasses import dataclass
23from datetime import datetime, timezone
24from typing import TYPE_CHECKING, Any
26from kemi import lifecycle, scoring
27from kemi.models import LifecycleState, MemoryConfig, MemoryObject
29if TYPE_CHECKING:
30 from kemi.adapters.base import EmbeddingAdapter, StorageAdapter
31 from kemi.entities import EntityLinker
33logger = logging.getLogger(__name__)
36@dataclass
37class RetrievalContext:
38 """Dependencies required to run a single retrieval.
40 The context is intentionally explicit — no global state, no
41 hidden coupling to the ``Memory`` class. Side-effect callbacks
42 (``run_hooks``, ``track_operation``) are passed as callables
43 that close over the orchestrator's state, which keeps the
44 pipeline testable in isolation.
45 """
47 store: "StorageAdapter"
48 embed: "EmbeddingAdapter"
49 config: MemoryConfig
50 entity_linker: "EntityLinker"
51 query_cache: Any | None
52 metrics: Any | None
53 adaptive_retriever: Any | None
55 # Side-effect callbacks. The ``Memory`` orchestrator wires these
56 # to the implementations in ``kemi.operations._ops_*``.
57 run_hooks: Callable[..., None] = lambda *args, **kwargs: None
58 track_operation: Callable[..., None] = lambda *args, **kwargs: None
61class RetrievalPipeline:
62 """Encapsulates the recall flow.
64 Public entry point is :meth:`retrieve` — the rest are private
65 helpers that split the flow into testable steps.
66 """
68 # Magic numbers, kept as class constants for clarity.
69 _METADATA_FETCH_MULTIPLIER = 10
70 _DEFAULT_FETCH_MULTIPLIER = 3
71 _MMR_LAMBDA = 0.7
73 def __init__(self, ctx: RetrievalContext) -> None:
74 self._ctx = ctx
76 # ------------------------------------------------------------------
77 # Public entry point
78 # ------------------------------------------------------------------
80 def retrieve(
81 self,
82 user_id: str,
83 query: str,
84 top_k: int = 5,
85 max_tokens: int | None = None,
86 lifecycle_filter: list[LifecycleState] | None = None,
87 hybrid_search: bool | None = None,
88 namespace: str = "default",
89 session_id: str | None = None,
90 metadata_filter: dict[str, Any] | None = None,
91 ) -> list[MemoryObject]:
92 """Run the recall flow and return the top-k ranked results."""
93 if hybrid_search is None:
94 hybrid_search = self._ctx.config.hybrid_search
96 query_embedding = self._embed_query(query)
98 if lifecycle_filter is None:
99 lifecycle_filter = lifecycle.get_recall_filter()
101 cached = self._check_cache(
102 user_id,
103 query,
104 top_k,
105 max_tokens,
106 lifecycle_filter,
107 hybrid_search,
108 namespace,
109 session_id,
110 metadata_filter,
111 )
112 if cached is not None:
113 return cached
115 self._ctx.run_hooks(
116 "pre", "recall", user_id=user_id, query=query, namespace=namespace
117 )
119 # When metadata_filter is active we may need more than top_k
120 # results from storage because filtering is applied post-hoc.
121 # Use a larger multiplier to increase the chance of returning
122 # top_k results after filtering.
123 fetch_multiplier = (
124 self._METADATA_FETCH_MULTIPLIER
125 if metadata_filter is not None
126 else self._DEFAULT_FETCH_MULTIPLIER
127 )
128 search_results = self._search_storage(
129 user_id=user_id,
130 query_embedding=query_embedding,
131 top_k=top_k,
132 lifecycle_filter=lifecycle_filter,
133 namespace=namespace,
134 session_id=session_id,
135 fetch_multiplier=fetch_multiplier,
136 metadata_filter=metadata_filter,
137 )
139 self._validate_embedding_dim(search_results)
141 query_entities, memory_entities_map = self._build_entity_maps(query, search_results)
143 ranked = self._rank(
144 search_results=search_results,
145 query_embedding=query_embedding,
146 query=query,
147 hybrid_search=hybrid_search,
148 query_entities=query_entities,
149 memory_entities_map=memory_entities_map,
150 )
152 ranked = self._mmr_rerank(ranked, query_embedding, top_k)
154 effective_max_tokens = (
155 max_tokens if max_tokens is not None else self._ctx.config.max_tokens_default
156 )
157 ranked = self._truncate(ranked, effective_max_tokens)
159 final_results = ranked[:top_k]
161 self._update_lifecycle(final_results)
163 if self._ctx.metrics is not None:
164 self._ctx.metrics.total_memories.set(self._ctx.store.count(user_id))
166 self._cache_results(
167 user_id=user_id,
168 query=query,
169 top_k=top_k,
170 max_tokens=max_tokens,
171 lifecycle_filter=lifecycle_filter,
172 hybrid_search=hybrid_search,
173 namespace=namespace,
174 session_id=session_id,
175 metadata_filter=metadata_filter,
176 results=final_results,
177 )
179 self._ctx.run_hooks(
180 "post",
181 "recall",
182 user_id=user_id,
183 query=query,
184 results=final_results,
185 namespace=namespace,
186 )
187 self._ctx.track_operation(
188 "recall",
189 user_id,
190 {"query": query, "results_count": len(final_results), "cache_hit": False},
191 namespace=namespace,
192 )
193 self._adaptive_feedback(user_id, query)
195 return final_results
197 # ------------------------------------------------------------------
198 # Pipeline steps
199 # ------------------------------------------------------------------
201 def _embed_query(self, query: str) -> list[float]:
202 return self._ctx.embed.embed_single(query)
204 def _check_cache(
205 self,
206 user_id: str,
207 query: str,
208 top_k: int,
209 max_tokens: int | None,
210 lifecycle_filter: list[LifecycleState],
211 hybrid_search: bool,
212 namespace: str,
213 session_id: str | None,
214 metadata_filter: dict[str, Any] | None,
215 ) -> list[MemoryObject] | None:
216 """Return cached results on hit, else None. Records the cache hit."""
217 if self._ctx.query_cache is None:
218 return None
219 cache_key = self._ctx.query_cache._make_key(
220 user_id,
221 query,
222 top_k,
223 max_tokens,
224 lifecycle_filter,
225 hybrid_search,
226 namespace,
227 session_id,
228 metadata_filter,
229 )
230 cached = self._ctx.query_cache.get(cache_key)
231 if cached is None:
232 return None
233 self._ctx.track_operation(
234 "recall",
235 user_id,
236 {"query": query, "results_count": len(cached), "cache_hit": True},
237 namespace=namespace,
238 )
239 return cached
241 def _search_storage(
242 self,
243 user_id: str,
244 query_embedding: list[float],
245 top_k: int,
246 lifecycle_filter: list[LifecycleState],
247 namespace: str,
248 session_id: str | None,
249 fetch_multiplier: int,
250 metadata_filter: dict[str, Any] | None,
251 ) -> list[MemoryObject]:
252 """Run the storage search and apply the metadata filter post-hoc."""
253 results = self._ctx.store.search(
254 user_id=user_id,
255 query_embedding=query_embedding,
256 top_k=top_k * fetch_multiplier,
257 lifecycle_filter=lifecycle_filter,
258 namespace=namespace,
259 session_id=session_id,
260 )
261 if metadata_filter is not None:
262 results = [
263 m
264 for m in results
265 if all(m.metadata.get(k) == v for k, v in metadata_filter.items())
266 ]
267 return results
269 def _validate_embedding_dim(self, search_results: list[MemoryObject]) -> None:
270 """Raise if stored memories have a different embedding dimension than the current adapter."""
271 if not search_results:
272 return
273 current_dim = self._ctx.embed.dimension()
274 stored_dim = search_results[0].embedding_dim
275 if stored_dim is not None and stored_dim != current_dim:
276 raise ValueError(
277 "Embedding dimension mismatch: stored memories use "
278 f"{stored_dim} dimensions but current adapter produces "
279 f"{current_dim} dimensions. Run memory.migrate(user_id, "
280 "new_adapter) to re-embed your memories."
281 )
283 def _build_entity_maps(
284 self,
285 query: str,
286 search_results: list[MemoryObject],
287 ) -> tuple[set[str] | None, dict[str, set[str]] | None]:
288 """Extract query entities and a per-memory entity map (if entity boost is enabled)."""
289 if not self._ctx.config.enable_entity_boost:
290 return None, None
291 query_entities = self._ctx.entity_linker.extract(query)
292 memory_entities_map: dict[str, set[str]] = {}
293 for m in search_results:
294 cached = m.metadata.get("extracted_entities")
295 if cached is not None:
296 memory_entities_map[m.memory_id] = set(cached)
297 else:
298 memory_entities_map[m.memory_id] = self._ctx.entity_linker.extract(m.content)
299 return query_entities, memory_entities_map
301 def _rank(
302 self,
303 search_results: list[MemoryObject],
304 query_embedding: list[float],
305 query: str,
306 hybrid_search: bool,
307 query_entities: set[str] | None,
308 memory_entities_map: dict[str, set[str]] | None,
309 ) -> list[MemoryObject]:
310 return scoring.rank_memories(
311 search_results,
312 query_embedding,
313 query,
314 hybrid_search,
315 weight_semantic=self._ctx.config.weight_semantic,
316 weight_recency=self._ctx.config.weight_recency,
317 weight_bm25=self._ctx.config.weight_bm25,
318 weight_semantic_no_embed=self._ctx.config.weight_semantic_no_embed,
319 weight_recency_no_embed=self._ctx.config.weight_recency_no_embed,
320 weight_importance=self._ctx.config.weight_importance,
321 query_entities=query_entities,
322 memory_entities_map=memory_entities_map,
323 weight_entity=self._ctx.config.entity_boost_weight,
324 )
326 def _mmr_rerank(
327 self,
328 ranked: list[MemoryObject],
329 query_embedding: list[float],
330 top_k: int,
331 ) -> list[MemoryObject]:
332 if len(ranked) <= top_k or top_k <= 1:
333 return ranked
334 return scoring.mmr_rerank(ranked, query_embedding, top_k, lambda_param=self._MMR_LAMBDA)
336 def _truncate(
337 self,
338 ranked: list[MemoryObject],
339 max_tokens: int | None,
340 ) -> list[MemoryObject]:
341 if max_tokens is None:
342 return ranked
343 return scoring.truncate_by_tokens(ranked, max_tokens)
345 def _update_lifecycle(self, results: list[MemoryObject]) -> None:
346 """Bump ``last_accessed_at`` and apply lifecycle transitions."""
347 threshold = self._ctx.config.decay_threshold_hours
348 for mem in results:
349 mem.last_accessed_at = datetime.now(timezone.utc)
350 new_state = lifecycle.evaluate_lifecycle(mem, threshold)
351 if new_state != mem.lifecycle_state:
352 updated = lifecycle.transition(mem, new_state)
353 self._ctx.store.update(updated)
354 if self._ctx.metrics is not None:
355 self._ctx.metrics.lifecycle_transitions.inc(1)
357 def _cache_results(
358 self,
359 user_id: str,
360 query: str,
361 top_k: int,
362 max_tokens: int | None,
363 lifecycle_filter: list[LifecycleState],
364 hybrid_search: bool,
365 namespace: str,
366 session_id: str | None,
367 metadata_filter: dict[str, Any] | None,
368 results: list[MemoryObject],
369 ) -> None:
370 if self._ctx.query_cache is None:
371 return
372 cache_key = self._ctx.query_cache._make_key(
373 user_id,
374 query,
375 top_k,
376 max_tokens,
377 lifecycle_filter,
378 hybrid_search,
379 namespace,
380 session_id,
381 metadata_filter,
382 )
383 self._ctx.query_cache.put(cache_key, results)
385 def _adaptive_feedback(self, user_id: str, query: str) -> None:
386 if self._ctx.adaptive_retriever is None:
387 return
388 try:
389 profile = self._ctx.adaptive_retriever.analyze_query(query)
390 self._ctx.adaptive_retriever.record_feedback(user_id, query, profile)
391 except Exception:
392 # Adaptive retrieval is best-effort. Any failure here must
393 # not break the recall response that was already returned.
394 logger.debug("Adaptive retrieval analysis failed", exc_info=True)
397__all__ = ["RetrievalContext", "RetrievalPipeline"]