Coverage for src / kemi / adaptive.py: 99%

153 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""Adaptive retrieval for kemi memory. 

2 

3Auto-tunes hybrid search weights based on query characteristics. 

4Provides query analysis, classification, and dynamic weight adjustment. 

5 

6Features: 

7- Query classification (factual, conversational, procedural, keyword-dense) 

8- Dynamic weight adjustment for semantic vs BM25 vs recency 

9- Query length impact assessment 

10- Feedback-driven continuous improvement 

11- Query specificity scoring 

12 

13Usage: 

14 from kemi.adaptive import AdaptiveRetriever 

15 

16 retriever = AdaptiveRetriever() 

17 weights = retriever.analyze_query("What are my food preferences?") 

18 # weights = {"weight_semantic": 0.65, "weight_recency": 0.20, "weight_bm25": 0.15} 

19""" 

20 

21import re 

22from dataclasses import dataclass, field 

23from enum import Enum 

24from typing import Any 

25 

26 

27class QueryType(Enum): 

28 """Classification of query types for adaptive retrieval.""" 

29 

30 FACTUAL = "factual" # "What is X?", "Who is Y?" 

31 CONVERSATIONAL = "conversational" # "How are you?", "Tell me about..." 

32 PROCEDURAL = "procedural" # "How do I...", "Steps to..." 

33 KEYWORD_DENSE = "keyword_dense" # "dark mode preference vegetarian food" 

34 TEMPORAL = "temporal" # "What did I do yesterday?", "Last week's..." 

35 COMPARATIVE = "comparative" # "X vs Y", "better option" 

36 AMBIGUOUS = "ambiguous" # Unclear query intent 

37 

38 

39# Keyword patterns for query classification 

40_FACTUAL_PATTERNS = [ 

41 r"\bwhat (is|are|was|were)\b", 

42 r"\bwho (is|are|was|were)\b", 

43 r"\bwhen (is|was|did)\b", 

44 r"\bwhere (is|are|was|were)\b", 

45 r"\bwhich (is|are|was|were)\b", 

46 r"\bdefine\b", 

47 r"\bdefinition\b", 

48 r"\bmeaning of\b", 

49] 

50 

51_CONVERSATIONAL_PATTERNS = [ 

52 r"\bhow are you\b", 

53 r"\btell me about\b", 

54 r"\bcan you\b", 

55 r"\bplease\b", 

56 r"\bthanks?\b", 

57 r"\bhelp me\b", 

58 r"\bexplain\b", 

59 r"\bdescribe\b", 

60] 

61 

62_PROCEDURAL_PATTERNS = [ 

63 r"\bhow (do|can|would|should|to)\b", 

64 r"\bsteps?\b", 

65 r"\bguide\b", 

66 r"\btutorial\b", 

67 r"\bprocess\b", 

68 r"\binstruction\b", 

69 r"\bwalkthrough\b", 

70] 

71 

72_TEMPORAL_PATTERNS = [ 

73 r"\b(yesterday|today|tomorrow)\b", 

74 r"\blast (week|month|year|night|time)\b", 

75 r"\bthis (week|month|year)\b", 

76 r"\b(ago|recently|lately|earlier)\b", 

77 r"\bwhen (did|was|were)\b", 

78 r"\bwhat (happened|occurred)\b", 

79] 

80 

81_COMPARATIVE_PATTERNS = [ 

82 r"\b(vs|versus|compared)\b", 

83 r"\b(better|worse|best|worst)\b", 

84 r"\b(difference|similar)\b", 

85 r"\b(option|choice|alternative)\b", 

86 r"\b(prefer|rather)\b", 

87] 

88 

89 

90@dataclass 

91class QueryProfile: 

92 """Analysis result for a query.""" 

93 

94 query: str 

95 query_type: QueryType 

96 word_count: int 

97 keyword_density: float = 0.0 # Ratio of content words to total words 

98 specificity: float = 0.0 # 0.0 = vague, 1.0 = highly specific 

99 has_question_mark: bool = False 

100 has_named_entity_hint: bool = False # Has capitalized words or numbers 

101 recommended_weights: dict[str, float] = field(default_factory=dict) 

102 confidence: float = 0.5 # Confidence in the classification 

103 

104 

105@dataclass 

106class AdaptiveWeights: 

107 """Dynamically computed retrieval weights.""" 

108 

109 weight_semantic: float = 0.6 

110 weight_recency: float = 0.25 

111 weight_bm25: float = 0.15 

112 weight_semantic_no_embed: float = 0.5 

113 weight_recency_no_embed: float = 0.3 

114 weight_importance: float = 0.2 

115 query_type: QueryType = QueryType.AMBIGUOUS 

116 analysis_confidence: float = 0.5 

117 

118 

119class AdaptiveRetriever: 

120 """Auto-tunes retrieval weights based on query characteristics. 

121 

122 Uses heuristic analysis of the query text to determine the best 

123 hybrid search weight configuration. No ML models required. 

124 

125 Limitations: 

126 - Classification is based on keyword/regex pattern matching and 

127 may misclassify unusual or ambiguous queries. 

128 - When confidence is low, weights fall back to defaults. 

129 - For production use with very diverse query types, consider 

130 training a small classifier or using LLM-based classification. 

131 """ 

132 

133 # Stop words to filter out for keyword density calculation 

134 _STOP_WORDS: set[str] = { 

135 "a", 

136 "an", 

137 "the", 

138 "is", 

139 "are", 

140 "was", 

141 "were", 

142 "be", 

143 "been", 

144 "being", 

145 "have", 

146 "has", 

147 "had", 

148 "do", 

149 "does", 

150 "did", 

151 "will", 

152 "would", 

153 "could", 

154 "should", 

155 "may", 

156 "might", 

157 "can", 

158 "shall", 

159 "to", 

160 "of", 

161 "in", 

162 "for", 

163 "on", 

164 "with", 

165 "at", 

166 "by", 

167 "from", 

168 "as", 

169 "into", 

170 "through", 

171 "during", 

172 "before", 

173 "after", 

174 "above", 

175 "below", 

176 "between", 

177 "and", 

178 "but", 

179 "or", 

180 "nor", 

181 "not", 

182 "so", 

183 "yet", 

184 "both", 

185 "either", 

186 "neither", 

187 "each", 

188 "every", 

189 "all", 

190 "any", 

191 "few", 

192 "more", 

193 "most", 

194 "other", 

195 "some", 

196 "such", 

197 "no", 

198 "only", 

199 "own", 

200 "same", 

201 "than", 

202 "too", 

203 "very", 

204 "just", 

205 "about", 

206 "how", 

207 "what", 

208 "which", 

209 "who", 

210 "whom", 

211 "whose", 

212 "why", 

213 "when", 

214 "where", 

215 "if", 

216 "then", 

217 "else", 

218 "that", 

219 "this", 

220 "these", 

221 "those", 

222 "it", 

223 "its", 

224 "he", 

225 "she", 

226 "they", 

227 "them", 

228 "their", 

229 "we", 

230 "you", 

231 "me", 

232 "my", 

233 "your", 

234 "our", 

235 "i", 

236 "him", 

237 "her", 

238 "us", 

239 } 

240 

241 # Base weight configurations for each query type 

242 _TYPE_WEIGHTS: dict[QueryType, dict[str, float]] = { 

243 QueryType.FACTUAL: { 

244 "weight_semantic": 0.55, 

245 "weight_recency": 0.20, 

246 "weight_bm25": 0.25, 

247 "weight_semantic_no_embed": 0.45, 

248 "weight_recency_no_embed": 0.25, 

249 "weight_importance": 0.30, 

250 }, 

251 QueryType.CONVERSATIONAL: { 

252 "weight_semantic": 0.70, 

253 "weight_recency": 0.20, 

254 "weight_bm25": 0.10, 

255 "weight_semantic_no_embed": 0.60, 

256 "weight_recency_no_embed": 0.25, 

257 "weight_importance": 0.15, 

258 }, 

259 QueryType.PROCEDURAL: { 

260 "weight_semantic": 0.50, 

261 "weight_recency": 0.15, 

262 "weight_bm25": 0.35, 

263 "weight_semantic_no_embed": 0.40, 

264 "weight_recency_no_embed": 0.25, 

265 "weight_importance": 0.35, 

266 }, 

267 QueryType.KEYWORD_DENSE: { 

268 "weight_semantic": 0.40, 

269 "weight_recency": 0.15, 

270 "weight_bm25": 0.45, 

271 "weight_semantic_no_embed": 0.35, 

272 "weight_recency_no_embed": 0.20, 

273 "weight_importance": 0.45, 

274 }, 

275 QueryType.TEMPORAL: { 

276 "weight_semantic": 0.45, 

277 "weight_recency": 0.40, 

278 "weight_bm25": 0.15, 

279 "weight_semantic_no_embed": 0.35, 

280 "weight_recency_no_embed": 0.45, 

281 "weight_importance": 0.20, 

282 }, 

283 QueryType.COMPARATIVE: { 

284 "weight_semantic": 0.60, 

285 "weight_recency": 0.15, 

286 "weight_bm25": 0.25, 

287 "weight_semantic_no_embed": 0.50, 

288 "weight_recency_no_embed": 0.20, 

289 "weight_importance": 0.30, 

290 }, 

291 QueryType.AMBIGUOUS: { 

292 "weight_semantic": 0.60, 

293 "weight_recency": 0.25, 

294 "weight_bm25": 0.15, 

295 "weight_semantic_no_embed": 0.50, 

296 "weight_recency_no_embed": 0.30, 

297 "weight_importance": 0.20, 

298 }, 

299 } 

300 

301 def __init__( 

302 self, 

303 enable_adaptation: bool = True, 

304 feedback_weight: float = 0.1, 

305 ) -> None: 

306 """Initialize adaptive retriever. 

307 

308 Args: 

309 enable_adaptation: If False, always returns default weights. 

310 feedback_weight: How much to adjust weights from feedback (0.0-1.0). 

311 """ 

312 self._enable_adaptation = enable_adaptation 

313 self._feedback_weight = max(0.0, min(1.0, feedback_weight)) 

314 # Track per-user query type distribution for better adaptation 

315 self._user_query_history: dict[str, dict[str, int]] = {} 

316 

317 def analyze_query(self, query: str) -> QueryProfile: 

318 """Analyze a query and return its profile. 

319 

320 Args: 

321 query: The search query string. 

322 

323 Returns: 

324 QueryProfile with classification and recommended weights. 

325 """ 

326 if not query or not query.strip(): 

327 return QueryProfile( 

328 query="", 

329 query_type=QueryType.AMBIGUOUS, 

330 word_count=0, 

331 recommended_weights=self._TYPE_WEIGHTS[QueryType.AMBIGUOUS], 

332 ) 

333 

334 words = query.strip().split() 

335 word_count = len(words) 

336 

337 # Classify query type 

338 query_lower = query.lower() 

339 query_type, confidence = self._classify_query(query_lower) 

340 

341 # Calculate keyword density 

342 keyword_density = self._compute_keyword_density(words) 

343 

344 # Calculate specificity 

345 specificity = self._compute_specificity(query, words) 

346 

347 # Check for named entity hints 

348 has_named_entity_hint = bool( 

349 re.search(r"[A-Z][a-z]{2,}", query) or re.search(r"\d+", query) 

350 ) 

351 

352 # Get base weights for this query type 

353 base_weights = dict(self._TYPE_WEIGHTS[query_type]) 

354 

355 # Adjust weights based on query characteristics 

356 adjusted_weights = self._adjust_weights( 

357 base_weights, 

358 keyword_density, 

359 specificity, 

360 word_count, 

361 ) 

362 

363 return QueryProfile( 

364 query=query, 

365 query_type=query_type, 

366 word_count=word_count, 

367 keyword_density=keyword_density, 

368 specificity=specificity, 

369 has_question_mark=query.rstrip().endswith("?"), 

370 has_named_entity_hint=has_named_entity_hint, 

371 recommended_weights=adjusted_weights, 

372 confidence=confidence, 

373 ) 

374 

375 def get_weights(self, query: str) -> AdaptiveWeights: 

376 """Get adaptive retrieval weights for a query. 

377 

378 This is the main entry point for integration with the recall pipeline. 

379 

380 Args: 

381 query: The search query string. 

382 

383 Returns: 

384 AdaptiveWeights with the recommended weight configuration. 

385 """ 

386 if not self._enable_adaptation: 

387 return AdaptiveWeights() 

388 

389 profile = self.analyze_query(query) 

390 

391 return AdaptiveWeights( 

392 weight_semantic=profile.recommended_weights["weight_semantic"], 

393 weight_recency=profile.recommended_weights["weight_recency"], 

394 weight_bm25=profile.recommended_weights["weight_bm25"], 

395 weight_semantic_no_embed=profile.recommended_weights["weight_semantic_no_embed"], 

396 weight_recency_no_embed=profile.recommended_weights["weight_recency_no_embed"], 

397 weight_importance=profile.recommended_weights["weight_importance"], 

398 query_type=profile.query_type, 

399 analysis_confidence=profile.confidence, 

400 ) 

401 

402 def record_feedback( 

403 self, 

404 user_id: str, 

405 query: str, 

406 profile: QueryProfile, 

407 ) -> None: 

408 """Record query type for this user to improve future adaptation. 

409 

410 Args: 

411 user_id: User who made the query. 

412 query: The original query. 

413 profile: The QueryProfile that was used. 

414 """ 

415 if user_id not in self._user_query_history: 

416 self._user_query_history[user_id] = {} 

417 

418 qtype = profile.query_type.value 

419 self._user_query_history[user_id][qtype] = ( 

420 self._user_query_history[user_id].get(qtype, 0) + 1 

421 ) 

422 

423 def get_user_profile(self, user_id: str) -> dict[str, Any]: 

424 """Get the query type distribution for a user. 

425 

426 Args: 

427 user_id: User to get profile for. 

428 

429 Returns: 

430 Dict with query type distribution and dominant type. 

431 """ 

432 history = self._user_query_history.get(user_id, {}) 

433 total = sum(history.values()) if history else 0 

434 

435 if total == 0: 

436 return { 

437 "user_id": user_id, 

438 "total_queries": 0, 

439 "distribution": {}, 

440 "dominant_type": None, 

441 } 

442 

443 distribution = {k: v / total for k, v in history.items()} 

444 dominant = max(history, key=history.get) 

445 

446 return { 

447 "user_id": user_id, 

448 "total_queries": total, 

449 "distribution": distribution, 

450 "dominant_type": dominant, 

451 } 

452 

453 def _classify_query(self, query_lower: str) -> tuple[QueryType, float]: 

454 """Classify query into a type using pattern matching. 

455 

456 Returns: 

457 Tuple of (QueryType, confidence). 

458 """ 

459 scores: dict[QueryType, int] = { 

460 QueryType.FACTUAL: 0, 

461 QueryType.CONVERSATIONAL: 0, 

462 QueryType.PROCEDURAL: 0, 

463 QueryType.TEMPORAL: 0, 

464 QueryType.COMPARATIVE: 0, 

465 } 

466 

467 for pattern in _FACTUAL_PATTERNS: 

468 if re.search(pattern, query_lower): 

469 scores[QueryType.FACTUAL] += 1 

470 for pattern in _CONVERSATIONAL_PATTERNS: 

471 if re.search(pattern, query_lower): 

472 scores[QueryType.CONVERSATIONAL] += 1 

473 for pattern in _PROCEDURAL_PATTERNS: 

474 if re.search(pattern, query_lower): 

475 scores[QueryType.PROCEDURAL] += 1 

476 for pattern in _TEMPORAL_PATTERNS: 

477 if re.search(pattern, query_lower): 

478 scores[QueryType.TEMPORAL] += 1 

479 for pattern in _COMPARATIVE_PATTERNS: 

480 if re.search(pattern, query_lower): 

481 scores[QueryType.COMPARATIVE] += 1 

482 

483 # Check for keyword-dense: no question structure, short, many nouns 

484 words = query_lower.split() 

485 has_question_word = any(w in words for w in ("what", "who", "when", "where", "why", "how")) 

486 if not has_question_word and len(words) <= 6: 

487 content_words = [w for w in words if w not in self._STOP_WORDS] 

488 if len(content_words) >= len(words) * 0.6: 

489 scores[QueryType.KEYWORD_DENSE] = 3 if len(content_words) >= 2 else 1 

490 

491 # Find the highest scoring type 

492 if not scores or max(scores.values()) == 0: 

493 return QueryType.AMBIGUOUS, 0.3 

494 

495 best_type = max(scores, key=lambda k: scores[k]) # type: ignore[arg-type] 

496 max_score = scores[best_type] 

497 total_score = sum(scores.values()) if scores else 1 

498 confidence = max_score / max(total_score, 1) 

499 

500 return best_type, min(confidence, 0.95) 

501 

502 def _compute_keyword_density(self, words: list[str]) -> float: 

503 """Compute ratio of content words to total words.""" 

504 if not words: 

505 return 0.0 

506 content_words = [w for w in words if w.lower() not in self._STOP_WORDS] 

507 return len(content_words) / len(words) 

508 

509 def _compute_specificity(self, query: str, words: list[str]) -> float: 

510 """Estimate query specificity (0.0 = vague, 1.0 = highly specific). 

511 

512 Factors: 

513 - Query length (longer = more specific) 

514 - Unique words ratio 

515 - Presence of numbers, proper nouns, dates 

516 - Specificity modifiers ("exactly", "specifically", "precise") 

517 """ 

518 if not words: 

519 return 0.0 

520 

521 score = 0.0 

522 

523 # Length factor 

524 if len(words) <= 2: 

525 score += 0.1 

526 elif len(words) <= 4: 

527 score += 0.3 

528 elif len(words) <= 8: 

529 score += 0.5 

530 else: 

531 score += 0.7 

532 

533 # Unique words ratio 

534 unique_ratio = len(set(w.lower() for w in words)) / len(words) 

535 score += unique_ratio * 0.3 

536 

537 # Numbers and proper nouns 

538 if re.search(r"\d+", query): 

539 score += 0.15 

540 if re.search(r"[A-Z][a-z]{2,}", query): 

541 score += 0.15 

542 

543 # Specificity modifiers 

544 if re.search(r"\b(exactly|specifically|precisely|particular)\b", query.lower()): 

545 score += 0.1 

546 

547 return min(1.0, score) 

548 

549 def _adjust_weights( 

550 self, 

551 base_weights: dict[str, float], 

552 keyword_density: float, 

553 specificity: float, 

554 word_count: int, 

555 ) -> dict[str, float]: 

556 """Fine-tune weights based on query characteristics. 

557 

558 Rules: 

559 - Higher keyword density → boost BM25, reduce semantic 

560 - Higher specificity → boost semantic (more precise semantic match) 

561 - Longer queries → slight boost to BM25 (more keywords) 

562 - Very short queries → boost semantic (more likely conceptual) 

563 """ 

564 weights = dict(base_weights) 

565 

566 # Keyword density adjustment (max ±0.1) 

567 # High keyword density = good for BM25 keyword matching 

568 bm25_adjust = (keyword_density - 0.5) * 0.2 

569 weights["weight_bm25"] = max(0.05, min(0.55, weights["weight_bm25"] + bm25_adjust)) 

570 weights["weight_semantic"] = max( 

571 0.30, min(0.80, weights["weight_semantic"] - bm25_adjust * 0.5) 

572 ) 

573 

574 # Specificity adjustment 

575 # High specificity = better semantic matching possible 

576 sem_adjust = (specificity - 0.5) * 0.1 

577 weights["weight_semantic"] = max(0.30, min(0.80, weights["weight_semantic"] + sem_adjust)) 

578 

579 # Word count adjustment 

580 if word_count <= 2: 

581 # Very short: boost semantic slightly 

582 weights["weight_semantic"] = min(0.80, weights["weight_semantic"] + 0.05) 

583 weights["weight_bm25"] = max(0.05, weights["weight_bm25"] - 0.03) 

584 elif word_count >= 10: 

585 # Very long: boost BM25 

586 weights["weight_bm25"] = min(0.55, weights["weight_bm25"] + 0.05) 

587 

588 # Ensure all weights sum approximately to 1.0 

589 total = weights["weight_semantic"] + weights["weight_recency"] + weights["weight_bm25"] 

590 if total > 0: 

591 scale = 1.0 / total 

592 weights["weight_semantic"] = round(weights["weight_semantic"] * scale, 4) 

593 weights["weight_recency"] = round(weights["weight_recency"] * scale, 4) 

594 weights["weight_bm25"] = round(weights["weight_bm25"] * scale, 4) 

595 

596 # No-embed weights also adjust proportionally 

597 no_embed_total = ( 

598 weights["weight_semantic_no_embed"] 

599 + weights["weight_recency_no_embed"] 

600 + weights["weight_importance"] 

601 ) 

602 if no_embed_total > 0: 

603 scale_ne = 1.0 / no_embed_total 

604 weights["weight_semantic_no_embed"] = round( 

605 weights["weight_semantic_no_embed"] * scale_ne, 

606 4, 

607 ) 

608 weights["weight_recency_no_embed"] = round( 

609 weights["weight_recency_no_embed"] * scale_ne, 

610 4, 

611 ) 

612 weights["weight_importance"] = round( 

613 weights["weight_importance"] * scale_ne, 

614 4, 

615 ) 

616 

617 return weights