Coverage for src / kemi / decomposer.py: 98%

116 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""Query decomposition and result fusion for improved recall. 

2 

3Breaks complex multi-aspect queries into simpler sub-queries, executes each 

4sub-query independently against the memory store, then fuses the ranked 

5results using Reciprocal Rank Fusion (RRF). 

6 

7RRF is parameter-free and works with any retrieval method (vector, keyword, 

8or hybrid). It uses only ranks (not scores) so it is robust to score scale 

9differences between retrieval methods. 

10 

11Usage: 

12 sub_queries = decompose_query( 

13 "What did I eat for breakfast and dinner yesterday?", 

14 strategy="simple" 

15 ) 

16 # ["What did I eat for breakfast yesterday?", 

17 # "What did I eat for dinner yesterday?"] 

18 

19 results = fused_recall(memory, user_id, sub_queries, top_k=5) 

20""" 

21 

22from __future__ import annotations 

23 

24import re 

25from dataclasses import dataclass 

26from typing import TYPE_CHECKING, Any 

27 

28from kemi.models import MemoryObject 

29 

30if TYPE_CHECKING: 

31 from kemi import Memory 

32 from kemi.adapters.base import EmbeddingAdapter 

33 

34__all__ = [ 

35 "decompose_query", 

36 "fused_recall", 

37 "QueryDecompositionStrategy", 

38 "DecomposedQuery", 

39 "FusionResult", 

40] 

41 

42 

43# --------------------------------------------------------------------------- 

44# Query decomposition strategies 

45# --------------------------------------------------------------------------- 

46 

47 

48class QueryDecompositionStrategy: 

49 """Base class for query decomposition strategies.""" 

50 

51 def decompose(self, query: str) -> list[str]: 

52 raise NotImplementedError 

53 

54 

55class SimpleDecomposition(QueryDecompositionStrategy): 

56 """Split on conjunctions (and, or, but) and question words. 

57 

58 Handles queries like: 

59 - "What did I eat for breakfast and dinner?" 

60 → ["What did I eat for breakfast?", "What did I eat for dinner?"] 

61 - "Tell me about my work meetings and personal tasks" 

62 → ["Tell me about my work meetings", "Tell me about my personal tasks"] 

63 """ 

64 

65 CONJUNCTION_PATTERN = re.compile( 

66 r"\b(?:and|or|but|however|additionally|also|plus|while)\b", 

67 re.IGNORECASE, 

68 ) 

69 QUESTION_STARTS = {"what", "when", "where", "who", "whom", "whose", "why", "how", "which"} 

70 

71 def decompose(self, query: str) -> list[str]: 

72 if not query or not query.strip(): 

73 return [] 

74 

75 # Single-sentence, no conjunction — return as-is 

76 if not self.CONJUNCTION_PATTERN.search(query): 

77 return [query.strip()] 

78 

79 # Split on conjunctions 

80 parts = self.CONJUNCTION_PATTERN.split(query) 

81 if len(parts) <= 1: 

82 return [query.strip()] 

83 

84 sub_queries: list[str] = [] 

85 for part in parts: 

86 cleaned = part.strip() 

87 if not cleaned: 

88 continue 

89 

90 if self._starts_with_question_word(cleaned): 

91 sub_queries.append(cleaned) 

92 else: 

93 reconstructed = self._reconstruct_query(cleaned) 

94 sub_queries.append(reconstructed) 

95 

96 return [q for q in sub_queries if len(q.split()) >= 2] 

97 

98 def _starts_with_question_word(self, text: str) -> bool: 

99 first_word = text.split()[0].lower().rstrip("?") if text.split() else "" 

100 return first_word in self.QUESTION_STARTS 

101 

102 def _reconstruct_query(self, part: str) -> str: 

103 """Try to build a self-standing query from a clause.""" 

104 first_word = part.split()[0].lower() if part.split() else "" 

105 if first_word and first_word not in {"i", "my", "me", "the", "a", "an", "that", "this"}: 

106 return f"Tell me about {part.strip().rstrip('?')}" 

107 return part.strip().rstrip("?") 

108 

109 

110class SubqueryExpansion(QueryDecompositionStrategy): 

111 """Expand a query with synonyms and related terms to improve recall coverage. 

112 

113 Generates multiple variants using synonym substitution (no external library required). 

114 """ 

115 

116 SYNONYMS: dict[str, list[str]] = { 

117 "eat": ["consume", "have", "dined", "food"], 

118 "breakfast": ["morning meal", "breakfast"], 

119 "dinner": ["evening meal", "supper", "dinner"], 

120 "lunch": ["midday meal", "lunch"], 

121 "work": ["job", "profession", "career", "task"], 

122 "meeting": ["discussion", "standup", "sync", "call"], 

123 "exercise": ["workout", "gym", "run", "fitness"], 

124 "travel": ["trip", "visit", "journey", "flight"], 

125 "buy": ["purchase", "shop", "acquire"], 

126 "learn": ["study", "understand", "discover", "explore"], 

127 "remember": ["recall", "note", "record"], 

128 "important": ["significant", "crucial", "priority"], 

129 "happy": ["glad", "pleased", "delighted", "joyful"], 

130 "sad": ["unhappy", "upset", "depressed", "melancholy"], 

131 } 

132 

133 def decompose(self, query: str) -> list[str]: 

134 if not query or not query.strip(): 

135 return [] 

136 

137 results = [query.strip()] 

138 

139 for term, synonyms in self.SYNONYMS.items(): 

140 if term.lower() in query.lower(): 

141 for syn in synonyms[:2]: 

142 variant = re.sub( 

143 re.compile(r"\b" + re.escape(term) + r"\b", re.IGNORECASE), 

144 syn, 

145 query, 

146 count=1, 

147 ) 

148 if variant != query and variant.strip() not in results: 

149 results.append(variant.strip()) 

150 

151 return results[:5] 

152 

153 

154# --------------------------------------------------------------------------- 

155# Public API 

156# --------------------------------------------------------------------------- 

157 

158 

159@dataclass 

160class DecomposedQuery: 

161 """Result of query decomposition.""" 

162 strategy: str 

163 sub_queries: list[str] 

164 original_query: str 

165 

166 

167@dataclass 

168class FusionResult: 

169 """A single fused result with its RRF score and source rankings.""" 

170 memory: MemoryObject 

171 rrf_score: float 

172 source_ranks: dict[str, int] # sub_query → rank in that result set 

173 

174 

175def decompose_query( 

176 query: str, 

177 strategy: str = "simple", 

178) -> DecomposedQuery: 

179 """Decompose a complex query into simpler sub-queries. 

180 

181 Args: 

182 query: The original search query (may contain multiple aspects/conjunctions). 

183 strategy: Decomposition strategy. Options: 

184 - "simple": Split on conjunctions (and, or, but) and question words. 

185 - "expand": Generate synonym-expanded variants. 

186 - "both": Run both strategies and combine (deduplicated). 

187 - "none": Return the original query unchanged. 

188 

189 Returns: 

190 A DecomposedQuery with the strategy used and the list of sub-queries. 

191 """ 

192 if strategy == "none": 

193 return DecomposedQuery(strategy="none", sub_queries=[query.strip()], original_query=query) 

194 

195 if strategy == "simple": 

196 strat = SimpleDecomposition() 

197 elif strategy == "expand": 

198 strat = SubqueryExpansion() 

199 elif strategy == "both": 

200 simple_strat = SimpleDecomposition() 

201 expand_strat = SubqueryExpansion() 

202 simple_queries = simple_strat.decompose(query) 

203 expand_queries = expand_strat.decompose(query) 

204 seen = set() 

205 combined: list[str] = [] 

206 for q in simple_queries + expand_queries: 

207 normalized = q.lower().strip() 

208 if normalized not in seen: 

209 seen.add(normalized) 

210 combined.append(q) 

211 return DecomposedQuery( 

212 strategy="both", 

213 sub_queries=combined[:5], 

214 original_query=query, 

215 ) 

216 else: 

217 raise ValueError( 

218 f"Unknown decomposition strategy: {strategy!r}. " 

219 "Options: 'simple', 'expand', 'both', 'none'." 

220 ) 

221 

222 return DecomposedQuery( 

223 strategy=strategy, 

224 sub_queries=strat.decompose(query), 

225 original_query=query, 

226 ) 

227 

228 

229def fused_recall( 

230 memory: Memory, 

231 user_id: str, 

232 sub_queries: list[str], 

233 *, 

234 top_k: int = 5, 

235 rrf_k: int = 60, 

236 namespace: str = "default", 

237 session_id: str | None = None, 

238 lifecycle_filter: list | None = None, 

239 metadata_filter: dict[str, Any] | None = None, 

240) -> list[FusionResult]: 

241 """Execute multiple sub-queries and fuse results using RRF. 

242 

243 Each sub-query is executed via :meth:`Memory.recall`. Results are ranked 

244 using Reciprocal Rank Fusion: 

245 

246 RRF_score(d) = Σ 1 / (k + rank(d)_i) 

247 

248 where k is a constant (default 60, as recommended by literature) and 

249 rank(d)_i is the rank of document d in the i-th result list. 

250 

251 Args: 

252 memory: A Memory instance. 

253 user_id: User ID to recall memories for. 

254 sub_queries: List of sub-queries to execute. 

255 top_k: Number of results to retrieve per sub-query (before fusion). 

256 rrf_k: RRF constant; higher = more weight to lower-ranked results (default 60). 

257 namespace: Memory namespace. 

258 session_id: Optional session ID filter. 

259 lifecycle_filter: Optional lifecycle state filter. 

260 metadata_filter: Optional metadata filter. 

261 

262 Returns: 

263 List of FusionResult objects, sorted by RRF score descending. 

264 Each contains the memory, its RRF score, and which ranks it appeared at 

265 in each sub-query result. 

266 """ 

267 if not sub_queries: 

268 return [] 

269 

270 if len(sub_queries) == 1: 

271 results = memory.recall( 

272 user_id, 

273 sub_queries[0], 

274 top_k=top_k, 

275 namespace=namespace, 

276 session_id=session_id, 

277 lifecycle_filter=lifecycle_filter, 

278 metadata_filter=metadata_filter, 

279 ) 

280 return [ 

281 FusionResult( 

282 memory=r, 

283 rrf_score=1.0, 

284 source_ranks={sub_queries[0]: 0}, 

285 ) 

286 for r in results 

287 ] 

288 

289 per_query_results: list[list[MemoryObject]] = [] 

290 for sq in sub_queries: 

291 hits = memory.recall( 

292 user_id, 

293 sq, 

294 top_k=top_k, 

295 namespace=namespace, 

296 session_id=session_id, 

297 lifecycle_filter=lifecycle_filter, 

298 metadata_filter=metadata_filter, 

299 ) 

300 per_query_results.append(hits) 

301 

302 memory_ranks: dict[str, dict[int, int]] = {} 

303 memory_objects: dict[str, MemoryObject] = {} 

304 

305 for sq_idx, results in enumerate(per_query_results): 

306 for rank, mem in enumerate(results): 

307 mem_id = mem.memory_id 

308 memory_objects[mem_id] = mem 

309 if mem_id not in memory_ranks: 

310 memory_ranks[mem_id] = {} 

311 memory_ranks[mem_id][sq_idx] = rank 

312 

313 rrf_scores: dict[str, float] = {} 

314 for mem_id, ranks in memory_ranks.items(): 

315 score = sum(1.0 / (rrf_k + rank) for rank in ranks.values()) 

316 rrf_scores[mem_id] = score 

317 

318 sorted_mem_ids = sorted(rrf_scores, key=lambda mid: rrf_scores[mid], reverse=True) 

319 

320 fusion_results: list[FusionResult] = [] 

321 for mem_id in sorted_mem_ids: 

322 mem = memory_objects[mem_id] 

323 ranks = memory_ranks[mem_id] 

324 source_ranks = {sub_queries[sq_idx]: rank for sq_idx, rank in ranks.items()} 

325 fusion_results.append( 

326 FusionResult( 

327 memory=mem, 

328 rrf_score=round(rrf_scores[mem_id], 4), 

329 source_ranks=source_ranks, 

330 ) 

331 ) 

332 

333 return fusion_results 

334 

335 

336def rerank_with_reranker( 

337 memory: Memory, 

338 user_id: str, 

339 query: str, 

340 results: list[MemoryObject], 

341 *, 

342 provider: str = "cross-encoder", 

343 model: str | None = None, 

344) -> list[MemoryObject]: 

345 """Placeholder for future cross-encoder reranking. 

346 

347 When a cross-encoder model is configured, this re-orders the results by 

348 scoring (query, document) pairs jointly rather than independently. 

349 

350 Args: 

351 memory: Memory instance (used to access embed adapter). 

352 user_id: User ID (for context). 

353 query: The original query string. 

354 results: List of MemoryObjects from initial retrieval. 

355 provider: Cross-encoder provider (future: "cross-encoder", "bge-reranker"). 

356 model: Model name override. 

357 

358 Returns: 

359 Re-ranked list of MemoryObjects (same set, different order). 

360 """ 

361 _ = memory, user_id, query, provider, model 

362 return results