Coverage for src / kemi / pipeline / retrieval.py: 95%

121 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""Retrieval pipeline: turn a ``(user_id, query, ...)`` request into ranked results. 

2 

3Extracted from ``kemi._memory_impl``. The pipeline is a stateful 

4object (``RetrievalPipeline``) that holds a :class:`RetrievalContext` 

5with all its dependencies. It does not reference the ``Memory`` 

6class — everything it needs comes through the context, which keeps 

7the pipeline independently testable. 

8 

9The pipeline owns the full recall flow: default resolution, query 

10embedding, cache check, hook firing, storage search, metadata 

11filtering, embedding-dimension check, entity extraction, scoring, 

12MMR reranking, token truncation, lifecycle updates, metric 

13increments, cache write, and adaptive retrieval feedback. The 

14caller (:meth:`kemi._memory_impl.Memory.recall`) is responsible for 

15input validation and latency tracking. 

16""" 

17 

18from __future__ import annotations 

19 

20import logging 

21from collections.abc import Callable 

22from dataclasses import dataclass 

23from datetime import datetime, timezone 

24from typing import TYPE_CHECKING, Any 

25 

26from kemi import lifecycle, scoring 

27from kemi.models import LifecycleState, MemoryConfig, MemoryObject 

28 

29if TYPE_CHECKING: 

30 from kemi.adapters.base import EmbeddingAdapter, StorageAdapter 

31 from kemi.entities import EntityLinker 

32 

33logger = logging.getLogger(__name__) 

34 

35 

36@dataclass 

37class RetrievalContext: 

38 """Dependencies required to run a single retrieval. 

39 

40 The context is intentionally explicit — no global state, no 

41 hidden coupling to the ``Memory`` class. Side-effect callbacks 

42 (``run_hooks``, ``track_operation``) are passed as callables 

43 that close over the orchestrator's state, which keeps the 

44 pipeline testable in isolation. 

45 """ 

46 

47 store: "StorageAdapter" 

48 embed: "EmbeddingAdapter" 

49 config: MemoryConfig 

50 entity_linker: "EntityLinker" 

51 query_cache: Any | None 

52 metrics: Any | None 

53 adaptive_retriever: Any | None 

54 

55 # Side-effect callbacks. The ``Memory`` orchestrator wires these 

56 # to the implementations in ``kemi.operations._ops_*``. 

57 run_hooks: Callable[..., None] = lambda *args, **kwargs: None 

58 track_operation: Callable[..., None] = lambda *args, **kwargs: None 

59 

60 

61class RetrievalPipeline: 

62 """Encapsulates the recall flow. 

63 

64 Public entry point is :meth:`retrieve` — the rest are private 

65 helpers that split the flow into testable steps. 

66 """ 

67 

68 # Magic numbers, kept as class constants for clarity. 

69 _METADATA_FETCH_MULTIPLIER = 10 

70 _DEFAULT_FETCH_MULTIPLIER = 3 

71 _MMR_LAMBDA = 0.7 

72 

73 def __init__(self, ctx: RetrievalContext) -> None: 

74 self._ctx = ctx 

75 

76 # ------------------------------------------------------------------ 

77 # Public entry point 

78 # ------------------------------------------------------------------ 

79 

80 def retrieve( 

81 self, 

82 user_id: str, 

83 query: str, 

84 top_k: int = 5, 

85 max_tokens: int | None = None, 

86 lifecycle_filter: list[LifecycleState] | None = None, 

87 hybrid_search: bool | None = None, 

88 namespace: str = "default", 

89 session_id: str | None = None, 

90 metadata_filter: dict[str, Any] | None = None, 

91 ) -> list[MemoryObject]: 

92 """Run the recall flow and return the top-k ranked results.""" 

93 if hybrid_search is None: 

94 hybrid_search = self._ctx.config.hybrid_search 

95 

96 query_embedding = self._embed_query(query) 

97 

98 if lifecycle_filter is None: 

99 lifecycle_filter = lifecycle.get_recall_filter() 

100 

101 cached = self._check_cache( 

102 user_id, 

103 query, 

104 top_k, 

105 max_tokens, 

106 lifecycle_filter, 

107 hybrid_search, 

108 namespace, 

109 session_id, 

110 metadata_filter, 

111 ) 

112 if cached is not None: 

113 return cached 

114 

115 self._ctx.run_hooks( 

116 "pre", "recall", user_id=user_id, query=query, namespace=namespace 

117 ) 

118 

119 # When metadata_filter is active we may need more than top_k 

120 # results from storage because filtering is applied post-hoc. 

121 # Use a larger multiplier to increase the chance of returning 

122 # top_k results after filtering. 

123 fetch_multiplier = ( 

124 self._METADATA_FETCH_MULTIPLIER 

125 if metadata_filter is not None 

126 else self._DEFAULT_FETCH_MULTIPLIER 

127 ) 

128 search_results = self._search_storage( 

129 user_id=user_id, 

130 query_embedding=query_embedding, 

131 top_k=top_k, 

132 lifecycle_filter=lifecycle_filter, 

133 namespace=namespace, 

134 session_id=session_id, 

135 fetch_multiplier=fetch_multiplier, 

136 metadata_filter=metadata_filter, 

137 ) 

138 

139 self._validate_embedding_dim(search_results) 

140 

141 query_entities, memory_entities_map = self._build_entity_maps(query, search_results) 

142 

143 ranked = self._rank( 

144 search_results=search_results, 

145 query_embedding=query_embedding, 

146 query=query, 

147 hybrid_search=hybrid_search, 

148 query_entities=query_entities, 

149 memory_entities_map=memory_entities_map, 

150 ) 

151 

152 ranked = self._mmr_rerank(ranked, query_embedding, top_k) 

153 

154 effective_max_tokens = ( 

155 max_tokens if max_tokens is not None else self._ctx.config.max_tokens_default 

156 ) 

157 ranked = self._truncate(ranked, effective_max_tokens) 

158 

159 final_results = ranked[:top_k] 

160 

161 self._update_lifecycle(final_results) 

162 

163 if self._ctx.metrics is not None: 

164 self._ctx.metrics.total_memories.set(self._ctx.store.count(user_id)) 

165 

166 self._cache_results( 

167 user_id=user_id, 

168 query=query, 

169 top_k=top_k, 

170 max_tokens=max_tokens, 

171 lifecycle_filter=lifecycle_filter, 

172 hybrid_search=hybrid_search, 

173 namespace=namespace, 

174 session_id=session_id, 

175 metadata_filter=metadata_filter, 

176 results=final_results, 

177 ) 

178 

179 self._ctx.run_hooks( 

180 "post", 

181 "recall", 

182 user_id=user_id, 

183 query=query, 

184 results=final_results, 

185 namespace=namespace, 

186 ) 

187 self._ctx.track_operation( 

188 "recall", 

189 user_id, 

190 {"query": query, "results_count": len(final_results), "cache_hit": False}, 

191 namespace=namespace, 

192 ) 

193 self._adaptive_feedback(user_id, query) 

194 

195 return final_results 

196 

197 # ------------------------------------------------------------------ 

198 # Pipeline steps 

199 # ------------------------------------------------------------------ 

200 

201 def _embed_query(self, query: str) -> list[float]: 

202 return self._ctx.embed.embed_single(query) 

203 

204 def _check_cache( 

205 self, 

206 user_id: str, 

207 query: str, 

208 top_k: int, 

209 max_tokens: int | None, 

210 lifecycle_filter: list[LifecycleState], 

211 hybrid_search: bool, 

212 namespace: str, 

213 session_id: str | None, 

214 metadata_filter: dict[str, Any] | None, 

215 ) -> list[MemoryObject] | None: 

216 """Return cached results on hit, else None. Records the cache hit.""" 

217 if self._ctx.query_cache is None: 

218 return None 

219 cache_key = self._ctx.query_cache._make_key( 

220 user_id, 

221 query, 

222 top_k, 

223 max_tokens, 

224 lifecycle_filter, 

225 hybrid_search, 

226 namespace, 

227 session_id, 

228 metadata_filter, 

229 ) 

230 cached = self._ctx.query_cache.get(cache_key) 

231 if cached is None: 

232 return None 

233 self._ctx.track_operation( 

234 "recall", 

235 user_id, 

236 {"query": query, "results_count": len(cached), "cache_hit": True}, 

237 namespace=namespace, 

238 ) 

239 return cached 

240 

241 def _search_storage( 

242 self, 

243 user_id: str, 

244 query_embedding: list[float], 

245 top_k: int, 

246 lifecycle_filter: list[LifecycleState], 

247 namespace: str, 

248 session_id: str | None, 

249 fetch_multiplier: int, 

250 metadata_filter: dict[str, Any] | None, 

251 ) -> list[MemoryObject]: 

252 """Run the storage search and apply the metadata filter post-hoc.""" 

253 results = self._ctx.store.search( 

254 user_id=user_id, 

255 query_embedding=query_embedding, 

256 top_k=top_k * fetch_multiplier, 

257 lifecycle_filter=lifecycle_filter, 

258 namespace=namespace, 

259 session_id=session_id, 

260 ) 

261 if metadata_filter is not None: 

262 results = [ 

263 m 

264 for m in results 

265 if all(m.metadata.get(k) == v for k, v in metadata_filter.items()) 

266 ] 

267 return results 

268 

269 def _validate_embedding_dim(self, search_results: list[MemoryObject]) -> None: 

270 """Raise if stored memories have a different embedding dimension than the current adapter.""" 

271 if not search_results: 

272 return 

273 current_dim = self._ctx.embed.dimension() 

274 stored_dim = search_results[0].embedding_dim 

275 if stored_dim is not None and stored_dim != current_dim: 

276 raise ValueError( 

277 "Embedding dimension mismatch: stored memories use " 

278 f"{stored_dim} dimensions but current adapter produces " 

279 f"{current_dim} dimensions. Run memory.migrate(user_id, " 

280 "new_adapter) to re-embed your memories." 

281 ) 

282 

283 def _build_entity_maps( 

284 self, 

285 query: str, 

286 search_results: list[MemoryObject], 

287 ) -> tuple[set[str] | None, dict[str, set[str]] | None]: 

288 """Extract query entities and a per-memory entity map (if entity boost is enabled).""" 

289 if not self._ctx.config.enable_entity_boost: 

290 return None, None 

291 query_entities = self._ctx.entity_linker.extract(query) 

292 memory_entities_map: dict[str, set[str]] = {} 

293 for m in search_results: 

294 cached = m.metadata.get("extracted_entities") 

295 if cached is not None: 

296 memory_entities_map[m.memory_id] = set(cached) 

297 else: 

298 memory_entities_map[m.memory_id] = self._ctx.entity_linker.extract(m.content) 

299 return query_entities, memory_entities_map 

300 

301 def _rank( 

302 self, 

303 search_results: list[MemoryObject], 

304 query_embedding: list[float], 

305 query: str, 

306 hybrid_search: bool, 

307 query_entities: set[str] | None, 

308 memory_entities_map: dict[str, set[str]] | None, 

309 ) -> list[MemoryObject]: 

310 return scoring.rank_memories( 

311 search_results, 

312 query_embedding, 

313 query, 

314 hybrid_search, 

315 weight_semantic=self._ctx.config.weight_semantic, 

316 weight_recency=self._ctx.config.weight_recency, 

317 weight_bm25=self._ctx.config.weight_bm25, 

318 weight_semantic_no_embed=self._ctx.config.weight_semantic_no_embed, 

319 weight_recency_no_embed=self._ctx.config.weight_recency_no_embed, 

320 weight_importance=self._ctx.config.weight_importance, 

321 query_entities=query_entities, 

322 memory_entities_map=memory_entities_map, 

323 weight_entity=self._ctx.config.entity_boost_weight, 

324 ) 

325 

326 def _mmr_rerank( 

327 self, 

328 ranked: list[MemoryObject], 

329 query_embedding: list[float], 

330 top_k: int, 

331 ) -> list[MemoryObject]: 

332 if len(ranked) <= top_k or top_k <= 1: 

333 return ranked 

334 return scoring.mmr_rerank(ranked, query_embedding, top_k, lambda_param=self._MMR_LAMBDA) 

335 

336 def _truncate( 

337 self, 

338 ranked: list[MemoryObject], 

339 max_tokens: int | None, 

340 ) -> list[MemoryObject]: 

341 if max_tokens is None: 

342 return ranked 

343 return scoring.truncate_by_tokens(ranked, max_tokens) 

344 

345 def _update_lifecycle(self, results: list[MemoryObject]) -> None: 

346 """Bump ``last_accessed_at`` and apply lifecycle transitions.""" 

347 threshold = self._ctx.config.decay_threshold_hours 

348 for mem in results: 

349 mem.last_accessed_at = datetime.now(timezone.utc) 

350 new_state = lifecycle.evaluate_lifecycle(mem, threshold) 

351 if new_state != mem.lifecycle_state: 

352 updated = lifecycle.transition(mem, new_state) 

353 self._ctx.store.update(updated) 

354 if self._ctx.metrics is not None: 

355 self._ctx.metrics.lifecycle_transitions.inc(1) 

356 

357 def _cache_results( 

358 self, 

359 user_id: str, 

360 query: str, 

361 top_k: int, 

362 max_tokens: int | None, 

363 lifecycle_filter: list[LifecycleState], 

364 hybrid_search: bool, 

365 namespace: str, 

366 session_id: str | None, 

367 metadata_filter: dict[str, Any] | None, 

368 results: list[MemoryObject], 

369 ) -> None: 

370 if self._ctx.query_cache is None: 

371 return 

372 cache_key = self._ctx.query_cache._make_key( 

373 user_id, 

374 query, 

375 top_k, 

376 max_tokens, 

377 lifecycle_filter, 

378 hybrid_search, 

379 namespace, 

380 session_id, 

381 metadata_filter, 

382 ) 

383 self._ctx.query_cache.put(cache_key, results) 

384 

385 def _adaptive_feedback(self, user_id: str, query: str) -> None: 

386 if self._ctx.adaptive_retriever is None: 

387 return 

388 try: 

389 profile = self._ctx.adaptive_retriever.analyze_query(query) 

390 self._ctx.adaptive_retriever.record_feedback(user_id, query, profile) 

391 except Exception: 

392 # Adaptive retrieval is best-effort. Any failure here must 

393 # not break the recall response that was already returned. 

394 logger.debug("Adaptive retrieval analysis failed", exc_info=True) 

395 

396 

397__all__ = ["RetrievalContext", "RetrievalPipeline"]