Coverage for src / kemi / decomposer.py: 98%
116 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""Query decomposition and result fusion for improved recall.
3Breaks complex multi-aspect queries into simpler sub-queries, executes each
4sub-query independently against the memory store, then fuses the ranked
5results using Reciprocal Rank Fusion (RRF).
7RRF is parameter-free and works with any retrieval method (vector, keyword,
8or hybrid). It uses only ranks (not scores) so it is robust to score scale
9differences between retrieval methods.
11Usage:
12 sub_queries = decompose_query(
13 "What did I eat for breakfast and dinner yesterday?",
14 strategy="simple"
15 )
16 # ["What did I eat for breakfast yesterday?",
17 # "What did I eat for dinner yesterday?"]
19 results = fused_recall(memory, user_id, sub_queries, top_k=5)
20"""
22from __future__ import annotations
24import re
25from dataclasses import dataclass
26from typing import TYPE_CHECKING, Any
28from kemi.models import MemoryObject
30if TYPE_CHECKING:
31 from kemi import Memory
32 from kemi.adapters.base import EmbeddingAdapter
34__all__ = [
35 "decompose_query",
36 "fused_recall",
37 "QueryDecompositionStrategy",
38 "DecomposedQuery",
39 "FusionResult",
40]
43# ---------------------------------------------------------------------------
44# Query decomposition strategies
45# ---------------------------------------------------------------------------
48class QueryDecompositionStrategy:
49 """Base class for query decomposition strategies."""
51 def decompose(self, query: str) -> list[str]:
52 raise NotImplementedError
55class SimpleDecomposition(QueryDecompositionStrategy):
56 """Split on conjunctions (and, or, but) and question words.
58 Handles queries like:
59 - "What did I eat for breakfast and dinner?"
60 → ["What did I eat for breakfast?", "What did I eat for dinner?"]
61 - "Tell me about my work meetings and personal tasks"
62 → ["Tell me about my work meetings", "Tell me about my personal tasks"]
63 """
65 CONJUNCTION_PATTERN = re.compile(
66 r"\b(?:and|or|but|however|additionally|also|plus|while)\b",
67 re.IGNORECASE,
68 )
69 QUESTION_STARTS = {"what", "when", "where", "who", "whom", "whose", "why", "how", "which"}
71 def decompose(self, query: str) -> list[str]:
72 if not query or not query.strip():
73 return []
75 # Single-sentence, no conjunction — return as-is
76 if not self.CONJUNCTION_PATTERN.search(query):
77 return [query.strip()]
79 # Split on conjunctions
80 parts = self.CONJUNCTION_PATTERN.split(query)
81 if len(parts) <= 1:
82 return [query.strip()]
84 sub_queries: list[str] = []
85 for part in parts:
86 cleaned = part.strip()
87 if not cleaned:
88 continue
90 if self._starts_with_question_word(cleaned):
91 sub_queries.append(cleaned)
92 else:
93 reconstructed = self._reconstruct_query(cleaned)
94 sub_queries.append(reconstructed)
96 return [q for q in sub_queries if len(q.split()) >= 2]
98 def _starts_with_question_word(self, text: str) -> bool:
99 first_word = text.split()[0].lower().rstrip("?") if text.split() else ""
100 return first_word in self.QUESTION_STARTS
102 def _reconstruct_query(self, part: str) -> str:
103 """Try to build a self-standing query from a clause."""
104 first_word = part.split()[0].lower() if part.split() else ""
105 if first_word and first_word not in {"i", "my", "me", "the", "a", "an", "that", "this"}:
106 return f"Tell me about {part.strip().rstrip('?')}"
107 return part.strip().rstrip("?")
110class SubqueryExpansion(QueryDecompositionStrategy):
111 """Expand a query with synonyms and related terms to improve recall coverage.
113 Generates multiple variants using synonym substitution (no external library required).
114 """
116 SYNONYMS: dict[str, list[str]] = {
117 "eat": ["consume", "have", "dined", "food"],
118 "breakfast": ["morning meal", "breakfast"],
119 "dinner": ["evening meal", "supper", "dinner"],
120 "lunch": ["midday meal", "lunch"],
121 "work": ["job", "profession", "career", "task"],
122 "meeting": ["discussion", "standup", "sync", "call"],
123 "exercise": ["workout", "gym", "run", "fitness"],
124 "travel": ["trip", "visit", "journey", "flight"],
125 "buy": ["purchase", "shop", "acquire"],
126 "learn": ["study", "understand", "discover", "explore"],
127 "remember": ["recall", "note", "record"],
128 "important": ["significant", "crucial", "priority"],
129 "happy": ["glad", "pleased", "delighted", "joyful"],
130 "sad": ["unhappy", "upset", "depressed", "melancholy"],
131 }
133 def decompose(self, query: str) -> list[str]:
134 if not query or not query.strip():
135 return []
137 results = [query.strip()]
139 for term, synonyms in self.SYNONYMS.items():
140 if term.lower() in query.lower():
141 for syn in synonyms[:2]:
142 variant = re.sub(
143 re.compile(r"\b" + re.escape(term) + r"\b", re.IGNORECASE),
144 syn,
145 query,
146 count=1,
147 )
148 if variant != query and variant.strip() not in results:
149 results.append(variant.strip())
151 return results[:5]
154# ---------------------------------------------------------------------------
155# Public API
156# ---------------------------------------------------------------------------
159@dataclass
160class DecomposedQuery:
161 """Result of query decomposition."""
162 strategy: str
163 sub_queries: list[str]
164 original_query: str
167@dataclass
168class FusionResult:
169 """A single fused result with its RRF score and source rankings."""
170 memory: MemoryObject
171 rrf_score: float
172 source_ranks: dict[str, int] # sub_query → rank in that result set
175def decompose_query(
176 query: str,
177 strategy: str = "simple",
178) -> DecomposedQuery:
179 """Decompose a complex query into simpler sub-queries.
181 Args:
182 query: The original search query (may contain multiple aspects/conjunctions).
183 strategy: Decomposition strategy. Options:
184 - "simple": Split on conjunctions (and, or, but) and question words.
185 - "expand": Generate synonym-expanded variants.
186 - "both": Run both strategies and combine (deduplicated).
187 - "none": Return the original query unchanged.
189 Returns:
190 A DecomposedQuery with the strategy used and the list of sub-queries.
191 """
192 if strategy == "none":
193 return DecomposedQuery(strategy="none", sub_queries=[query.strip()], original_query=query)
195 if strategy == "simple":
196 strat = SimpleDecomposition()
197 elif strategy == "expand":
198 strat = SubqueryExpansion()
199 elif strategy == "both":
200 simple_strat = SimpleDecomposition()
201 expand_strat = SubqueryExpansion()
202 simple_queries = simple_strat.decompose(query)
203 expand_queries = expand_strat.decompose(query)
204 seen = set()
205 combined: list[str] = []
206 for q in simple_queries + expand_queries:
207 normalized = q.lower().strip()
208 if normalized not in seen:
209 seen.add(normalized)
210 combined.append(q)
211 return DecomposedQuery(
212 strategy="both",
213 sub_queries=combined[:5],
214 original_query=query,
215 )
216 else:
217 raise ValueError(
218 f"Unknown decomposition strategy: {strategy!r}. "
219 "Options: 'simple', 'expand', 'both', 'none'."
220 )
222 return DecomposedQuery(
223 strategy=strategy,
224 sub_queries=strat.decompose(query),
225 original_query=query,
226 )
229def fused_recall(
230 memory: Memory,
231 user_id: str,
232 sub_queries: list[str],
233 *,
234 top_k: int = 5,
235 rrf_k: int = 60,
236 namespace: str = "default",
237 session_id: str | None = None,
238 lifecycle_filter: list | None = None,
239 metadata_filter: dict[str, Any] | None = None,
240) -> list[FusionResult]:
241 """Execute multiple sub-queries and fuse results using RRF.
243 Each sub-query is executed via :meth:`Memory.recall`. Results are ranked
244 using Reciprocal Rank Fusion:
246 RRF_score(d) = Σ 1 / (k + rank(d)_i)
248 where k is a constant (default 60, as recommended by literature) and
249 rank(d)_i is the rank of document d in the i-th result list.
251 Args:
252 memory: A Memory instance.
253 user_id: User ID to recall memories for.
254 sub_queries: List of sub-queries to execute.
255 top_k: Number of results to retrieve per sub-query (before fusion).
256 rrf_k: RRF constant; higher = more weight to lower-ranked results (default 60).
257 namespace: Memory namespace.
258 session_id: Optional session ID filter.
259 lifecycle_filter: Optional lifecycle state filter.
260 metadata_filter: Optional metadata filter.
262 Returns:
263 List of FusionResult objects, sorted by RRF score descending.
264 Each contains the memory, its RRF score, and which ranks it appeared at
265 in each sub-query result.
266 """
267 if not sub_queries:
268 return []
270 if len(sub_queries) == 1:
271 results = memory.recall(
272 user_id,
273 sub_queries[0],
274 top_k=top_k,
275 namespace=namespace,
276 session_id=session_id,
277 lifecycle_filter=lifecycle_filter,
278 metadata_filter=metadata_filter,
279 )
280 return [
281 FusionResult(
282 memory=r,
283 rrf_score=1.0,
284 source_ranks={sub_queries[0]: 0},
285 )
286 for r in results
287 ]
289 per_query_results: list[list[MemoryObject]] = []
290 for sq in sub_queries:
291 hits = memory.recall(
292 user_id,
293 sq,
294 top_k=top_k,
295 namespace=namespace,
296 session_id=session_id,
297 lifecycle_filter=lifecycle_filter,
298 metadata_filter=metadata_filter,
299 )
300 per_query_results.append(hits)
302 memory_ranks: dict[str, dict[int, int]] = {}
303 memory_objects: dict[str, MemoryObject] = {}
305 for sq_idx, results in enumerate(per_query_results):
306 for rank, mem in enumerate(results):
307 mem_id = mem.memory_id
308 memory_objects[mem_id] = mem
309 if mem_id not in memory_ranks:
310 memory_ranks[mem_id] = {}
311 memory_ranks[mem_id][sq_idx] = rank
313 rrf_scores: dict[str, float] = {}
314 for mem_id, ranks in memory_ranks.items():
315 score = sum(1.0 / (rrf_k + rank) for rank in ranks.values())
316 rrf_scores[mem_id] = score
318 sorted_mem_ids = sorted(rrf_scores, key=lambda mid: rrf_scores[mid], reverse=True)
320 fusion_results: list[FusionResult] = []
321 for mem_id in sorted_mem_ids:
322 mem = memory_objects[mem_id]
323 ranks = memory_ranks[mem_id]
324 source_ranks = {sub_queries[sq_idx]: rank for sq_idx, rank in ranks.items()}
325 fusion_results.append(
326 FusionResult(
327 memory=mem,
328 rrf_score=round(rrf_scores[mem_id], 4),
329 source_ranks=source_ranks,
330 )
331 )
333 return fusion_results
336def rerank_with_reranker(
337 memory: Memory,
338 user_id: str,
339 query: str,
340 results: list[MemoryObject],
341 *,
342 provider: str = "cross-encoder",
343 model: str | None = None,
344) -> list[MemoryObject]:
345 """Placeholder for future cross-encoder reranking.
347 When a cross-encoder model is configured, this re-orders the results by
348 scoring (query, document) pairs jointly rather than independently.
350 Args:
351 memory: Memory instance (used to access embed adapter).
352 user_id: User ID (for context).
353 query: The original query string.
354 results: List of MemoryObjects from initial retrieval.
355 provider: Cross-encoder provider (future: "cross-encoder", "bge-reranker").
356 model: Model name override.
358 Returns:
359 Re-ranked list of MemoryObjects (same set, different order).
360 """
361 _ = memory, user_id, query, provider, model
362 return results