Coverage for src / kemi / adaptive.py: 99%
153 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""Adaptive retrieval for kemi memory.
3Auto-tunes hybrid search weights based on query characteristics.
4Provides query analysis, classification, and dynamic weight adjustment.
6Features:
7- Query classification (factual, conversational, procedural, keyword-dense)
8- Dynamic weight adjustment for semantic vs BM25 vs recency
9- Query length impact assessment
10- Feedback-driven continuous improvement
11- Query specificity scoring
13Usage:
14 from kemi.adaptive import AdaptiveRetriever
16 retriever = AdaptiveRetriever()
17 weights = retriever.analyze_query("What are my food preferences?")
18 # weights = {"weight_semantic": 0.65, "weight_recency": 0.20, "weight_bm25": 0.15}
19"""
21import re
22from dataclasses import dataclass, field
23from enum import Enum
24from typing import Any
27class QueryType(Enum):
28 """Classification of query types for adaptive retrieval."""
30 FACTUAL = "factual" # "What is X?", "Who is Y?"
31 CONVERSATIONAL = "conversational" # "How are you?", "Tell me about..."
32 PROCEDURAL = "procedural" # "How do I...", "Steps to..."
33 KEYWORD_DENSE = "keyword_dense" # "dark mode preference vegetarian food"
34 TEMPORAL = "temporal" # "What did I do yesterday?", "Last week's..."
35 COMPARATIVE = "comparative" # "X vs Y", "better option"
36 AMBIGUOUS = "ambiguous" # Unclear query intent
39# Keyword patterns for query classification
40_FACTUAL_PATTERNS = [
41 r"\bwhat (is|are|was|were)\b",
42 r"\bwho (is|are|was|were)\b",
43 r"\bwhen (is|was|did)\b",
44 r"\bwhere (is|are|was|were)\b",
45 r"\bwhich (is|are|was|were)\b",
46 r"\bdefine\b",
47 r"\bdefinition\b",
48 r"\bmeaning of\b",
49]
51_CONVERSATIONAL_PATTERNS = [
52 r"\bhow are you\b",
53 r"\btell me about\b",
54 r"\bcan you\b",
55 r"\bplease\b",
56 r"\bthanks?\b",
57 r"\bhelp me\b",
58 r"\bexplain\b",
59 r"\bdescribe\b",
60]
62_PROCEDURAL_PATTERNS = [
63 r"\bhow (do|can|would|should|to)\b",
64 r"\bsteps?\b",
65 r"\bguide\b",
66 r"\btutorial\b",
67 r"\bprocess\b",
68 r"\binstruction\b",
69 r"\bwalkthrough\b",
70]
72_TEMPORAL_PATTERNS = [
73 r"\b(yesterday|today|tomorrow)\b",
74 r"\blast (week|month|year|night|time)\b",
75 r"\bthis (week|month|year)\b",
76 r"\b(ago|recently|lately|earlier)\b",
77 r"\bwhen (did|was|were)\b",
78 r"\bwhat (happened|occurred)\b",
79]
81_COMPARATIVE_PATTERNS = [
82 r"\b(vs|versus|compared)\b",
83 r"\b(better|worse|best|worst)\b",
84 r"\b(difference|similar)\b",
85 r"\b(option|choice|alternative)\b",
86 r"\b(prefer|rather)\b",
87]
90@dataclass
91class QueryProfile:
92 """Analysis result for a query."""
94 query: str
95 query_type: QueryType
96 word_count: int
97 keyword_density: float = 0.0 # Ratio of content words to total words
98 specificity: float = 0.0 # 0.0 = vague, 1.0 = highly specific
99 has_question_mark: bool = False
100 has_named_entity_hint: bool = False # Has capitalized words or numbers
101 recommended_weights: dict[str, float] = field(default_factory=dict)
102 confidence: float = 0.5 # Confidence in the classification
105@dataclass
106class AdaptiveWeights:
107 """Dynamically computed retrieval weights."""
109 weight_semantic: float = 0.6
110 weight_recency: float = 0.25
111 weight_bm25: float = 0.15
112 weight_semantic_no_embed: float = 0.5
113 weight_recency_no_embed: float = 0.3
114 weight_importance: float = 0.2
115 query_type: QueryType = QueryType.AMBIGUOUS
116 analysis_confidence: float = 0.5
119class AdaptiveRetriever:
120 """Auto-tunes retrieval weights based on query characteristics.
122 Uses heuristic analysis of the query text to determine the best
123 hybrid search weight configuration. No ML models required.
125 Limitations:
126 - Classification is based on keyword/regex pattern matching and
127 may misclassify unusual or ambiguous queries.
128 - When confidence is low, weights fall back to defaults.
129 - For production use with very diverse query types, consider
130 training a small classifier or using LLM-based classification.
131 """
133 # Stop words to filter out for keyword density calculation
134 _STOP_WORDS: set[str] = {
135 "a",
136 "an",
137 "the",
138 "is",
139 "are",
140 "was",
141 "were",
142 "be",
143 "been",
144 "being",
145 "have",
146 "has",
147 "had",
148 "do",
149 "does",
150 "did",
151 "will",
152 "would",
153 "could",
154 "should",
155 "may",
156 "might",
157 "can",
158 "shall",
159 "to",
160 "of",
161 "in",
162 "for",
163 "on",
164 "with",
165 "at",
166 "by",
167 "from",
168 "as",
169 "into",
170 "through",
171 "during",
172 "before",
173 "after",
174 "above",
175 "below",
176 "between",
177 "and",
178 "but",
179 "or",
180 "nor",
181 "not",
182 "so",
183 "yet",
184 "both",
185 "either",
186 "neither",
187 "each",
188 "every",
189 "all",
190 "any",
191 "few",
192 "more",
193 "most",
194 "other",
195 "some",
196 "such",
197 "no",
198 "only",
199 "own",
200 "same",
201 "than",
202 "too",
203 "very",
204 "just",
205 "about",
206 "how",
207 "what",
208 "which",
209 "who",
210 "whom",
211 "whose",
212 "why",
213 "when",
214 "where",
215 "if",
216 "then",
217 "else",
218 "that",
219 "this",
220 "these",
221 "those",
222 "it",
223 "its",
224 "he",
225 "she",
226 "they",
227 "them",
228 "their",
229 "we",
230 "you",
231 "me",
232 "my",
233 "your",
234 "our",
235 "i",
236 "him",
237 "her",
238 "us",
239 }
241 # Base weight configurations for each query type
242 _TYPE_WEIGHTS: dict[QueryType, dict[str, float]] = {
243 QueryType.FACTUAL: {
244 "weight_semantic": 0.55,
245 "weight_recency": 0.20,
246 "weight_bm25": 0.25,
247 "weight_semantic_no_embed": 0.45,
248 "weight_recency_no_embed": 0.25,
249 "weight_importance": 0.30,
250 },
251 QueryType.CONVERSATIONAL: {
252 "weight_semantic": 0.70,
253 "weight_recency": 0.20,
254 "weight_bm25": 0.10,
255 "weight_semantic_no_embed": 0.60,
256 "weight_recency_no_embed": 0.25,
257 "weight_importance": 0.15,
258 },
259 QueryType.PROCEDURAL: {
260 "weight_semantic": 0.50,
261 "weight_recency": 0.15,
262 "weight_bm25": 0.35,
263 "weight_semantic_no_embed": 0.40,
264 "weight_recency_no_embed": 0.25,
265 "weight_importance": 0.35,
266 },
267 QueryType.KEYWORD_DENSE: {
268 "weight_semantic": 0.40,
269 "weight_recency": 0.15,
270 "weight_bm25": 0.45,
271 "weight_semantic_no_embed": 0.35,
272 "weight_recency_no_embed": 0.20,
273 "weight_importance": 0.45,
274 },
275 QueryType.TEMPORAL: {
276 "weight_semantic": 0.45,
277 "weight_recency": 0.40,
278 "weight_bm25": 0.15,
279 "weight_semantic_no_embed": 0.35,
280 "weight_recency_no_embed": 0.45,
281 "weight_importance": 0.20,
282 },
283 QueryType.COMPARATIVE: {
284 "weight_semantic": 0.60,
285 "weight_recency": 0.15,
286 "weight_bm25": 0.25,
287 "weight_semantic_no_embed": 0.50,
288 "weight_recency_no_embed": 0.20,
289 "weight_importance": 0.30,
290 },
291 QueryType.AMBIGUOUS: {
292 "weight_semantic": 0.60,
293 "weight_recency": 0.25,
294 "weight_bm25": 0.15,
295 "weight_semantic_no_embed": 0.50,
296 "weight_recency_no_embed": 0.30,
297 "weight_importance": 0.20,
298 },
299 }
301 def __init__(
302 self,
303 enable_adaptation: bool = True,
304 feedback_weight: float = 0.1,
305 ) -> None:
306 """Initialize adaptive retriever.
308 Args:
309 enable_adaptation: If False, always returns default weights.
310 feedback_weight: How much to adjust weights from feedback (0.0-1.0).
311 """
312 self._enable_adaptation = enable_adaptation
313 self._feedback_weight = max(0.0, min(1.0, feedback_weight))
314 # Track per-user query type distribution for better adaptation
315 self._user_query_history: dict[str, dict[str, int]] = {}
317 def analyze_query(self, query: str) -> QueryProfile:
318 """Analyze a query and return its profile.
320 Args:
321 query: The search query string.
323 Returns:
324 QueryProfile with classification and recommended weights.
325 """
326 if not query or not query.strip():
327 return QueryProfile(
328 query="",
329 query_type=QueryType.AMBIGUOUS,
330 word_count=0,
331 recommended_weights=self._TYPE_WEIGHTS[QueryType.AMBIGUOUS],
332 )
334 words = query.strip().split()
335 word_count = len(words)
337 # Classify query type
338 query_lower = query.lower()
339 query_type, confidence = self._classify_query(query_lower)
341 # Calculate keyword density
342 keyword_density = self._compute_keyword_density(words)
344 # Calculate specificity
345 specificity = self._compute_specificity(query, words)
347 # Check for named entity hints
348 has_named_entity_hint = bool(
349 re.search(r"[A-Z][a-z]{2,}", query) or re.search(r"\d+", query)
350 )
352 # Get base weights for this query type
353 base_weights = dict(self._TYPE_WEIGHTS[query_type])
355 # Adjust weights based on query characteristics
356 adjusted_weights = self._adjust_weights(
357 base_weights,
358 keyword_density,
359 specificity,
360 word_count,
361 )
363 return QueryProfile(
364 query=query,
365 query_type=query_type,
366 word_count=word_count,
367 keyword_density=keyword_density,
368 specificity=specificity,
369 has_question_mark=query.rstrip().endswith("?"),
370 has_named_entity_hint=has_named_entity_hint,
371 recommended_weights=adjusted_weights,
372 confidence=confidence,
373 )
375 def get_weights(self, query: str) -> AdaptiveWeights:
376 """Get adaptive retrieval weights for a query.
378 This is the main entry point for integration with the recall pipeline.
380 Args:
381 query: The search query string.
383 Returns:
384 AdaptiveWeights with the recommended weight configuration.
385 """
386 if not self._enable_adaptation:
387 return AdaptiveWeights()
389 profile = self.analyze_query(query)
391 return AdaptiveWeights(
392 weight_semantic=profile.recommended_weights["weight_semantic"],
393 weight_recency=profile.recommended_weights["weight_recency"],
394 weight_bm25=profile.recommended_weights["weight_bm25"],
395 weight_semantic_no_embed=profile.recommended_weights["weight_semantic_no_embed"],
396 weight_recency_no_embed=profile.recommended_weights["weight_recency_no_embed"],
397 weight_importance=profile.recommended_weights["weight_importance"],
398 query_type=profile.query_type,
399 analysis_confidence=profile.confidence,
400 )
402 def record_feedback(
403 self,
404 user_id: str,
405 query: str,
406 profile: QueryProfile,
407 ) -> None:
408 """Record query type for this user to improve future adaptation.
410 Args:
411 user_id: User who made the query.
412 query: The original query.
413 profile: The QueryProfile that was used.
414 """
415 if user_id not in self._user_query_history:
416 self._user_query_history[user_id] = {}
418 qtype = profile.query_type.value
419 self._user_query_history[user_id][qtype] = (
420 self._user_query_history[user_id].get(qtype, 0) + 1
421 )
423 def get_user_profile(self, user_id: str) -> dict[str, Any]:
424 """Get the query type distribution for a user.
426 Args:
427 user_id: User to get profile for.
429 Returns:
430 Dict with query type distribution and dominant type.
431 """
432 history = self._user_query_history.get(user_id, {})
433 total = sum(history.values()) if history else 0
435 if total == 0:
436 return {
437 "user_id": user_id,
438 "total_queries": 0,
439 "distribution": {},
440 "dominant_type": None,
441 }
443 distribution = {k: v / total for k, v in history.items()}
444 dominant = max(history, key=history.get)
446 return {
447 "user_id": user_id,
448 "total_queries": total,
449 "distribution": distribution,
450 "dominant_type": dominant,
451 }
453 def _classify_query(self, query_lower: str) -> tuple[QueryType, float]:
454 """Classify query into a type using pattern matching.
456 Returns:
457 Tuple of (QueryType, confidence).
458 """
459 scores: dict[QueryType, int] = {
460 QueryType.FACTUAL: 0,
461 QueryType.CONVERSATIONAL: 0,
462 QueryType.PROCEDURAL: 0,
463 QueryType.TEMPORAL: 0,
464 QueryType.COMPARATIVE: 0,
465 }
467 for pattern in _FACTUAL_PATTERNS:
468 if re.search(pattern, query_lower):
469 scores[QueryType.FACTUAL] += 1
470 for pattern in _CONVERSATIONAL_PATTERNS:
471 if re.search(pattern, query_lower):
472 scores[QueryType.CONVERSATIONAL] += 1
473 for pattern in _PROCEDURAL_PATTERNS:
474 if re.search(pattern, query_lower):
475 scores[QueryType.PROCEDURAL] += 1
476 for pattern in _TEMPORAL_PATTERNS:
477 if re.search(pattern, query_lower):
478 scores[QueryType.TEMPORAL] += 1
479 for pattern in _COMPARATIVE_PATTERNS:
480 if re.search(pattern, query_lower):
481 scores[QueryType.COMPARATIVE] += 1
483 # Check for keyword-dense: no question structure, short, many nouns
484 words = query_lower.split()
485 has_question_word = any(w in words for w in ("what", "who", "when", "where", "why", "how"))
486 if not has_question_word and len(words) <= 6:
487 content_words = [w for w in words if w not in self._STOP_WORDS]
488 if len(content_words) >= len(words) * 0.6:
489 scores[QueryType.KEYWORD_DENSE] = 3 if len(content_words) >= 2 else 1
491 # Find the highest scoring type
492 if not scores or max(scores.values()) == 0:
493 return QueryType.AMBIGUOUS, 0.3
495 best_type = max(scores, key=lambda k: scores[k]) # type: ignore[arg-type]
496 max_score = scores[best_type]
497 total_score = sum(scores.values()) if scores else 1
498 confidence = max_score / max(total_score, 1)
500 return best_type, min(confidence, 0.95)
502 def _compute_keyword_density(self, words: list[str]) -> float:
503 """Compute ratio of content words to total words."""
504 if not words:
505 return 0.0
506 content_words = [w for w in words if w.lower() not in self._STOP_WORDS]
507 return len(content_words) / len(words)
509 def _compute_specificity(self, query: str, words: list[str]) -> float:
510 """Estimate query specificity (0.0 = vague, 1.0 = highly specific).
512 Factors:
513 - Query length (longer = more specific)
514 - Unique words ratio
515 - Presence of numbers, proper nouns, dates
516 - Specificity modifiers ("exactly", "specifically", "precise")
517 """
518 if not words:
519 return 0.0
521 score = 0.0
523 # Length factor
524 if len(words) <= 2:
525 score += 0.1
526 elif len(words) <= 4:
527 score += 0.3
528 elif len(words) <= 8:
529 score += 0.5
530 else:
531 score += 0.7
533 # Unique words ratio
534 unique_ratio = len(set(w.lower() for w in words)) / len(words)
535 score += unique_ratio * 0.3
537 # Numbers and proper nouns
538 if re.search(r"\d+", query):
539 score += 0.15
540 if re.search(r"[A-Z][a-z]{2,}", query):
541 score += 0.15
543 # Specificity modifiers
544 if re.search(r"\b(exactly|specifically|precisely|particular)\b", query.lower()):
545 score += 0.1
547 return min(1.0, score)
549 def _adjust_weights(
550 self,
551 base_weights: dict[str, float],
552 keyword_density: float,
553 specificity: float,
554 word_count: int,
555 ) -> dict[str, float]:
556 """Fine-tune weights based on query characteristics.
558 Rules:
559 - Higher keyword density → boost BM25, reduce semantic
560 - Higher specificity → boost semantic (more precise semantic match)
561 - Longer queries → slight boost to BM25 (more keywords)
562 - Very short queries → boost semantic (more likely conceptual)
563 """
564 weights = dict(base_weights)
566 # Keyword density adjustment (max ±0.1)
567 # High keyword density = good for BM25 keyword matching
568 bm25_adjust = (keyword_density - 0.5) * 0.2
569 weights["weight_bm25"] = max(0.05, min(0.55, weights["weight_bm25"] + bm25_adjust))
570 weights["weight_semantic"] = max(
571 0.30, min(0.80, weights["weight_semantic"] - bm25_adjust * 0.5)
572 )
574 # Specificity adjustment
575 # High specificity = better semantic matching possible
576 sem_adjust = (specificity - 0.5) * 0.1
577 weights["weight_semantic"] = max(0.30, min(0.80, weights["weight_semantic"] + sem_adjust))
579 # Word count adjustment
580 if word_count <= 2:
581 # Very short: boost semantic slightly
582 weights["weight_semantic"] = min(0.80, weights["weight_semantic"] + 0.05)
583 weights["weight_bm25"] = max(0.05, weights["weight_bm25"] - 0.03)
584 elif word_count >= 10:
585 # Very long: boost BM25
586 weights["weight_bm25"] = min(0.55, weights["weight_bm25"] + 0.05)
588 # Ensure all weights sum approximately to 1.0
589 total = weights["weight_semantic"] + weights["weight_recency"] + weights["weight_bm25"]
590 if total > 0:
591 scale = 1.0 / total
592 weights["weight_semantic"] = round(weights["weight_semantic"] * scale, 4)
593 weights["weight_recency"] = round(weights["weight_recency"] * scale, 4)
594 weights["weight_bm25"] = round(weights["weight_bm25"] * scale, 4)
596 # No-embed weights also adjust proportionally
597 no_embed_total = (
598 weights["weight_semantic_no_embed"]
599 + weights["weight_recency_no_embed"]
600 + weights["weight_importance"]
601 )
602 if no_embed_total > 0:
603 scale_ne = 1.0 / no_embed_total
604 weights["weight_semantic_no_embed"] = round(
605 weights["weight_semantic_no_embed"] * scale_ne,
606 4,
607 )
608 weights["weight_recency_no_embed"] = round(
609 weights["weight_recency_no_embed"] * scale_ne,
610 4,
611 )
612 weights["weight_importance"] = round(
613 weights["weight_importance"] * scale_ne,
614 4,
615 )
617 return weights