Coverage for src / kemi / scoring.py: 92%
181 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1import math
2from collections.abc import Callable
3from datetime import datetime, timezone
5try: # pragma: no cover
6 import numpy as np
8 _NUMPY_AVAILABLE = True
9except ImportError:
10 _NUMPY_AVAILABLE = False # pragma: no cover
12from kemi.models import MemoryObject
15def bm25_score(query: str, document: str) -> float:
16 """Compute simple BM25-style keyword score.
18 Uses term frequency approach without external libraries.
19 Normalizes query and document to lowercase.
20 Returns score between 0.0 and 1.0.
22 Args:
23 query: Search query string.
24 document: Document to score against.
26 Returns:
27 BM25 score normalized to [0.0, 1.0] range.
28 """
29 if not query or not query.strip():
30 return 0.0
32 if not document or not document.strip():
33 return 0.0
35 query_terms = query.lower().split()
36 doc_terms = document.lower().split()
38 if not query_terms or not doc_terms: # pragma: no cover (unreachable)
39 return 0.0
41 doc_length = len(doc_terms)
42 if doc_length == 0: # pragma: no cover (unreachable)
43 return 0.0
45 avg_doc_length = max(doc_length, 1)
47 k1 = 1.5
48 b = 0.75
50 term_freqs: dict[str, int] = {}
51 for term in doc_terms:
52 term_freqs[term] = term_freqs.get(term, 0) + 1
54 score = 0.0
55 for query_term in query_terms:
56 if query_term in term_freqs:
57 tf = term_freqs[query_term]
58 numerator = tf * (k1 + 1)
59 denominator = tf + k1 * (1 - b + b * doc_length / avg_doc_length)
60 score += numerator / denominator
62 max_score = len(query_terms) * (k1 + 1) / k1
63 if max_score > 0:
64 score = min(1.0, score / max_score)
66 return score
69def bm25_score_corpus(
70 query: str,
71 document: str,
72 corpus: list[str],
73 k1: float = 1.5,
74 b: float = 0.75,
75) -> float:
76 """Compute BM25 score with IDF from a corpus.
78 Uses Inverse Document Frequency to weight terms based on how rare they are
79 across the corpus.
81 Args:
82 query: Search query string.
83 document: Document to score against.
84 corpus: List of document strings to compute IDF from.
85 k1: Term frequency saturation parameter.
86 b: Document length normalization parameter.
88 Returns:
89 BM25 score as float.
90 """
91 if not query or not query.strip():
92 return 0.0
94 if not document or not document.strip():
95 return 0.0
97 if not corpus:
98 return bm25_score(query, document)
100 query_terms = query.lower().split()
101 doc_terms = document.lower().split()
103 if not query_terms or not doc_terms: # pragma: no cover (unreachable)
104 return 0.0
106 n_docs = len(corpus)
107 if n_docs == 0: # pragma: no cover (unreachable)
108 return 0.0
110 doc_length = len(doc_terms)
111 avgdl: float = sum(len(d.lower().split()) for d in corpus) / n_docs
113 if avgdl == 0:
114 avgdl = 1.0
116 df_counts: dict[str, int] = {}
117 for doc in corpus:
118 doc_words = set(doc.lower().split())
119 for term in query_terms:
120 if term in doc_words:
121 df_counts[term] = df_counts.get(term, 0) + 1
123 term_freqs: dict[str, int] = {}
124 for term in doc_terms:
125 term_freqs[term] = term_freqs.get(term, 0) + 1
127 score = 0.0
128 for query_term in query_terms:
129 df = df_counts.get(query_term, 0)
131 idf = math.log((n_docs - df + 0.5) / (df + 0.5) + 1)
133 tf = term_freqs.get(query_term, 0)
135 tf_norm = (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * doc_length / avgdl))
137 score += idf * tf_norm
139 return score
142def cosine_similarity(a: list[float] | None, b: list[float] | None) -> float:
143 """Compute cosine similarity between two vectors.
145 Handles dimension mismatches by computing over the minimum dimension.
146 Returns 0.0 if either vector is None or empty to avoid division by zero.
147 Never returns NaN.
148 """
149 if a is None or b is None or not a or not b:
150 return 0.0
152 # Handle dimension mismatch gracefully — truncate to min dim so numpy
153 # doesn't raise ValueError on mismatched vector lengths.
154 min_dim = min(len(a), len(b))
156 if _NUMPY_AVAILABLE: # pragma: no cover
157 a_arr = np.array(a[:min_dim])
158 b_arr = np.array(b[:min_dim])
159 norm = np.linalg.norm(a_arr) * np.linalg.norm(b_arr)
160 return float(np.dot(a_arr, b_arr) / norm) if norm != 0 else 0.0
162 dot_product = 0.0
163 norm_a = 0.0
164 norm_b = 0.0
166 for i in range(min_dim):
167 dot_product += a[i] * b[i]
168 norm_a += a[i] * a[i]
169 norm_b += b[i] * b[i]
171 norm_a = norm_a**0.5
172 norm_b = norm_b**0.5
174 if norm_a == 0.0 or norm_b == 0.0:
175 return 0.0
177 return dot_product / (norm_a * norm_b) # type: ignore[no-any-return]
180def temporal_recency(last_accessed: datetime, half_life_hours: float = 168.0) -> float:
181 """Compute temporal recency score using exponential decay.
183 A memory accessed now scores 1.0.
184 A memory accessed half_life_hours ago scores 0.5.
185 A memory accessed 2x half_life_hours ago scores 0.25.
187 Default half_life is 168 hours (7 days).
188 """
189 now = datetime.now(timezone.utc)
190 hours_elapsed = (now - last_accessed).total_seconds() / 3600.0
192 if hours_elapsed <= 0:
193 return 1.0
195 return 2.0 ** (-hours_elapsed / half_life_hours) # type: ignore[no-any-return]
198def jaccard_similarity(a: set[str], b: set[str]) -> float:
199 """Compute Jaccard similarity between two sets of strings.
201 Returns 0.0 if either set is empty.
202 """
203 if not a or not b:
204 return 0.0
205 intersection = len(a & b)
206 union = len(a | b)
207 return intersection / union if union > 0 else 0.0
210def score_memory(
211 memory: MemoryObject,
212 query_embedding: list[float],
213 query: str | None = None,
214 hybrid_search: bool = True,
215 corpus: list[str] | None = None,
216 weight_semantic: float = 0.6,
217 weight_recency: float = 0.25,
218 weight_bm25: float = 0.15,
219 weight_semantic_no_embed: float = 0.5,
220 weight_recency_no_embed: float = 0.3,
221 weight_importance: float = 0.2,
222 query_entities: set[str] | None = None,
223 memory_entities: set[str] | None = None,
224 weight_entity: float = 0.1,
225) -> float:
226 """Compute final relevance score for a memory.
228 When hybrid_search=True and query is provided:
229 Formula: (semantic × weight_semantic) + (recency × weight_recency) + (bm25 × weight_bm25)
231 When hybrid_search=False or no query:
232 Formula: (semantic × weight_semantic_no_embed)
233 + (recency × weight_recency_no_embed)
234 + (importance × weight_importance)
236 If entity parameters are provided, an additional boost is applied:
237 + (jaccard(query_entities, memory_entities) × weight_entity)
239 If memory.embedding is None or query_embedding is empty, semantic contribution is 0.0.
241 Args:
242 memory: The memory object to score.
243 query_embedding: Embedding vector for semantic search.
244 query: Optional query string for keyword search.
245 hybrid_search: Use hybrid scoring (default True).
246 corpus: List of document strings to compute IDF from for BM25.
247 weight_semantic: Weight for semantic similarity in hybrid mode (default 0.6).
248 weight_recency: Weight for recency in hybrid mode (default 0.25).
249 weight_bm25: Weight for BM25 keyword match in hybrid mode (default 0.15).
250 weight_semantic_no_embed: Weight for semantic when no embedding (default 0.5).
251 weight_recency_no_embed: Weight for recency when no embedding (default 0.3).
252 weight_importance: Weight for importance when no embedding (default 0.2).
253 query_entities: Optional set of entities extracted from the query.
254 memory_entities: Optional set of entities extracted from the memory content.
255 weight_entity: Weight for entity overlap boost (default 0.1).
256 """
257 semantic_score = 0.0
258 if memory.embedding is not None and query_embedding is not None:
259 similarity = cosine_similarity(memory.embedding, query_embedding)
260 semantic_score = (similarity + 1.0) / 2.0
262 recency_score = temporal_recency(memory.last_accessed_at)
264 if hybrid_search and query:
265 if corpus and len(corpus) > 1:
266 bm25_keyword_score = bm25_score_corpus(query, memory.content, corpus)
267 else:
268 bm25_keyword_score = bm25_score(query, memory.content)
270 final_score = (
271 semantic_score * weight_semantic
272 + recency_score * weight_recency
273 + bm25_keyword_score * weight_bm25
274 )
275 else:
276 importance_score = max(0.0, min(1.0, memory.importance))
277 final_score = (
278 semantic_score * weight_semantic_no_embed
279 + recency_score * weight_recency_no_embed
280 + importance_score * weight_importance
281 )
283 # Entity-aware boost
284 if query_entities is not None and memory_entities is not None:
285 entity_score = jaccard_similarity(query_entities, memory_entities)
286 final_score += entity_score * weight_entity
288 return final_score
291def rank_memories(
292 memories: list[MemoryObject],
293 query_embedding: list[float],
294 query: str | None = None,
295 hybrid_search: bool = True,
296 weight_semantic: float = 0.6,
297 weight_recency: float = 0.25,
298 weight_bm25: float = 0.15,
299 weight_semantic_no_embed: float = 0.5,
300 weight_recency_no_embed: float = 0.3,
301 weight_importance: float = 0.2,
302 query_entities: set[str] | None = None,
303 memory_entities_map: dict[str, set[str]] | None = None,
304 weight_entity: float = 0.1,
305) -> list[MemoryObject]:
306 """Rank memories by computed score, highest first.
308 Mutates the score field on each MemoryObject in place.
309 Returns the sorted list.
311 Args:
312 memories: List of MemoryObjects to rank.
313 query_embedding: Embedding vector for semantic search.
314 query: Optional query string for keyword search.
315 hybrid_search: Use hybrid scoring (default True).
316 weight_semantic: Weight for semantic similarity in hybrid mode.
317 weight_recency: Weight for recency in hybrid mode.
318 weight_bm25: Weight for BM25 keyword match in hybrid mode.
319 weight_semantic_no_embed: Weight for semantic when no embedding.
320 weight_recency_no_embed: Weight for recency when no embedding.
321 weight_importance: Weight for importance when no embedding.
322 query_entities: Optional set of entities extracted from the query.
323 memory_entities_map: Optional dict mapping memory_id -> set of entities.
324 weight_entity: Weight for entity overlap boost.
325 """
326 corpus = [m.content for m in memories] if len(memories) > 1 else None
328 for memory in memories:
329 mem_entities = None
330 if memory_entities_map is not None:
331 mem_entities = memory_entities_map.get(memory.memory_id)
332 memory.score = score_memory(
333 memory,
334 query_embedding,
335 query,
336 hybrid_search,
337 corpus,
338 weight_semantic,
339 weight_recency,
340 weight_bm25,
341 weight_semantic_no_embed,
342 weight_recency_no_embed,
343 weight_importance,
344 query_entities,
345 mem_entities,
346 weight_entity,
347 )
349 return sorted(memories, key=lambda m: m.score, reverse=True)
352def mmr_rerank(
353 memories: list[MemoryObject],
354 query_embedding: list[float],
355 top_k: int,
356 lambda_param: float = 0.7,
357) -> list[MemoryObject]:
358 """Rerank memories using Maximal Marginal Relevance.
360 Balances relevance (similarity to query) with diversity
361 (dissimilarity to already selected memories).
363 lambda_param controls the tradeoff:
364 1.0 = pure relevance (same as no MMR)
365 0.0 = pure diversity
366 0.7 = default, slightly favors relevance
368 Algorithm:
369 - Start with empty selected list
370 - At each step, pick the candidate that maximizes:
371 lambda * relevance_score - (1 - lambda) * max_similarity_to_selected
372 where relevance_score = memory.score (already computed)
373 and max_similarity_to_selected = max cosine_similarity between
374 candidate embedding and each already-selected memory embedding
375 - Skip candidates with no embedding (embedding is None)
376 by treating their relevance as memory.score and similarity as 0.0
377 - Stop when top_k memories are selected or candidates exhausted
378 - Return selected list in order selected
379 """
380 if top_k <= 0 or not memories:
381 return []
383 candidates = list(memories)
384 selected: list[MemoryObject] = []
386 while len(selected) < top_k and candidates:
387 best_idx = -1
388 best_mmr = float("-inf")
390 for i, candidate in enumerate(candidates):
391 relevance = candidate.score
393 if candidate.embedding is not None and query_embedding:
394 max_sim_to_selected = 0.0
395 for sel in selected:
396 if sel.embedding is not None:
397 sim = cosine_similarity(candidate.embedding, sel.embedding)
398 max_sim_to_selected = max(max_sim_to_selected, sim)
399 else:
400 max_sim_to_selected = 0.0
402 mmr_score = lambda_param * relevance - (1 - lambda_param) * max_sim_to_selected
404 if mmr_score > best_mmr:
405 best_mmr = mmr_score
406 best_idx = i
408 if best_idx == -1: # pragma: no cover (unreachable)
409 break
411 selected.append(candidates.pop(best_idx))
413 return selected
416def _default_token_counter(text: str) -> int:
417 """Default token counter: rough estimate = word_count * 1.3"""
418 result: float = len(text.split()) * 1.3
419 return int(result)
422def mmr_rerank_stream(
423 memories: list[MemoryObject],
424 query_embedding: list[float],
425 top_k: int,
426 lambda_param: float = 0.7,
427):
428 """Yield memories one at a time as MMR selects them.
430 Same algorithm as :func:`mmr_rerank` but yields each selected memory
431 immediately rather than collecting them into a list.
433 Yields:
434 MemoryObject, each selected by the MMR criterion.
435 """
436 if top_k <= 0 or not memories:
437 return
439 candidates = list(memories)
440 selected: list[MemoryObject] = []
442 while len(selected) < top_k and candidates:
443 best_idx = -1
444 best_mmr = float("-inf")
446 for i, candidate in enumerate(candidates):
447 relevance = candidate.score
449 if candidate.embedding is not None and query_embedding:
450 max_sim_to_selected = 0.0
451 for sel in selected:
452 if sel.embedding is not None:
453 sim = cosine_similarity(candidate.embedding, sel.embedding)
454 max_sim_to_selected = max(max_sim_to_selected, sim)
455 else:
456 max_sim_to_selected = 0.0
458 mmr_score = lambda_param * relevance - (1 - lambda_param) * max_sim_to_selected
460 if mmr_score > best_mmr:
461 best_mmr = mmr_score
462 best_idx = i
464 if best_idx == -1:
465 return
467 memory = candidates.pop(best_idx)
468 selected.append(memory)
469 yield memory
472def truncate_by_tokens(
473 memories: list[MemoryObject],
474 max_tokens: int | None,
475 token_counter: Callable[[str], int] | None = None,
476) -> list[MemoryObject]:
477 """Truncate memories by token budget.
479 Walks ranked list, sums token counts, stops when budget reached.
480 If max_tokens is None, returns all memories.
481 If a single memory exceeds budget, includes it anyway.
482 Never returns an empty list (if any input, returns at least one).
483 """
484 if max_tokens is None:
485 return memories
487 if not memories:
488 return memories
490 counter = token_counter or _default_token_counter
491 result: list[MemoryObject] = []
492 total_tokens = 0
494 for memory in memories:
495 memory_tokens = counter(memory.content)
497 if result and total_tokens + memory_tokens > max_tokens:
498 break
500 result.append(memory)
501 total_tokens += memory_tokens
503 if not result and memories: # pragma: no cover (unreachable)
504 result = [memories[0]]
506 return result