Coverage for src / kemi / scoring.py: 92%

181 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1import math 

2from collections.abc import Callable 

3from datetime import datetime, timezone 

4 

5try: # pragma: no cover 

6 import numpy as np 

7 

8 _NUMPY_AVAILABLE = True 

9except ImportError: 

10 _NUMPY_AVAILABLE = False # pragma: no cover 

11 

12from kemi.models import MemoryObject 

13 

14 

15def bm25_score(query: str, document: str) -> float: 

16 """Compute simple BM25-style keyword score. 

17 

18 Uses term frequency approach without external libraries. 

19 Normalizes query and document to lowercase. 

20 Returns score between 0.0 and 1.0. 

21 

22 Args: 

23 query: Search query string. 

24 document: Document to score against. 

25 

26 Returns: 

27 BM25 score normalized to [0.0, 1.0] range. 

28 """ 

29 if not query or not query.strip(): 

30 return 0.0 

31 

32 if not document or not document.strip(): 

33 return 0.0 

34 

35 query_terms = query.lower().split() 

36 doc_terms = document.lower().split() 

37 

38 if not query_terms or not doc_terms: # pragma: no cover (unreachable) 

39 return 0.0 

40 

41 doc_length = len(doc_terms) 

42 if doc_length == 0: # pragma: no cover (unreachable) 

43 return 0.0 

44 

45 avg_doc_length = max(doc_length, 1) 

46 

47 k1 = 1.5 

48 b = 0.75 

49 

50 term_freqs: dict[str, int] = {} 

51 for term in doc_terms: 

52 term_freqs[term] = term_freqs.get(term, 0) + 1 

53 

54 score = 0.0 

55 for query_term in query_terms: 

56 if query_term in term_freqs: 

57 tf = term_freqs[query_term] 

58 numerator = tf * (k1 + 1) 

59 denominator = tf + k1 * (1 - b + b * doc_length / avg_doc_length) 

60 score += numerator / denominator 

61 

62 max_score = len(query_terms) * (k1 + 1) / k1 

63 if max_score > 0: 

64 score = min(1.0, score / max_score) 

65 

66 return score 

67 

68 

69def bm25_score_corpus( 

70 query: str, 

71 document: str, 

72 corpus: list[str], 

73 k1: float = 1.5, 

74 b: float = 0.75, 

75) -> float: 

76 """Compute BM25 score with IDF from a corpus. 

77 

78 Uses Inverse Document Frequency to weight terms based on how rare they are 

79 across the corpus. 

80 

81 Args: 

82 query: Search query string. 

83 document: Document to score against. 

84 corpus: List of document strings to compute IDF from. 

85 k1: Term frequency saturation parameter. 

86 b: Document length normalization parameter. 

87 

88 Returns: 

89 BM25 score as float. 

90 """ 

91 if not query or not query.strip(): 

92 return 0.0 

93 

94 if not document or not document.strip(): 

95 return 0.0 

96 

97 if not corpus: 

98 return bm25_score(query, document) 

99 

100 query_terms = query.lower().split() 

101 doc_terms = document.lower().split() 

102 

103 if not query_terms or not doc_terms: # pragma: no cover (unreachable) 

104 return 0.0 

105 

106 n_docs = len(corpus) 

107 if n_docs == 0: # pragma: no cover (unreachable) 

108 return 0.0 

109 

110 doc_length = len(doc_terms) 

111 avgdl: float = sum(len(d.lower().split()) for d in corpus) / n_docs 

112 

113 if avgdl == 0: 

114 avgdl = 1.0 

115 

116 df_counts: dict[str, int] = {} 

117 for doc in corpus: 

118 doc_words = set(doc.lower().split()) 

119 for term in query_terms: 

120 if term in doc_words: 

121 df_counts[term] = df_counts.get(term, 0) + 1 

122 

123 term_freqs: dict[str, int] = {} 

124 for term in doc_terms: 

125 term_freqs[term] = term_freqs.get(term, 0) + 1 

126 

127 score = 0.0 

128 for query_term in query_terms: 

129 df = df_counts.get(query_term, 0) 

130 

131 idf = math.log((n_docs - df + 0.5) / (df + 0.5) + 1) 

132 

133 tf = term_freqs.get(query_term, 0) 

134 

135 tf_norm = (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * doc_length / avgdl)) 

136 

137 score += idf * tf_norm 

138 

139 return score 

140 

141 

142def cosine_similarity(a: list[float] | None, b: list[float] | None) -> float: 

143 """Compute cosine similarity between two vectors. 

144 

145 Handles dimension mismatches by computing over the minimum dimension. 

146 Returns 0.0 if either vector is None or empty to avoid division by zero. 

147 Never returns NaN. 

148 """ 

149 if a is None or b is None or not a or not b: 

150 return 0.0 

151 

152 # Handle dimension mismatch gracefully — truncate to min dim so numpy 

153 # doesn't raise ValueError on mismatched vector lengths. 

154 min_dim = min(len(a), len(b)) 

155 

156 if _NUMPY_AVAILABLE: # pragma: no cover 

157 a_arr = np.array(a[:min_dim]) 

158 b_arr = np.array(b[:min_dim]) 

159 norm = np.linalg.norm(a_arr) * np.linalg.norm(b_arr) 

160 return float(np.dot(a_arr, b_arr) / norm) if norm != 0 else 0.0 

161 

162 dot_product = 0.0 

163 norm_a = 0.0 

164 norm_b = 0.0 

165 

166 for i in range(min_dim): 

167 dot_product += a[i] * b[i] 

168 norm_a += a[i] * a[i] 

169 norm_b += b[i] * b[i] 

170 

171 norm_a = norm_a**0.5 

172 norm_b = norm_b**0.5 

173 

174 if norm_a == 0.0 or norm_b == 0.0: 

175 return 0.0 

176 

177 return dot_product / (norm_a * norm_b) # type: ignore[no-any-return] 

178 

179 

180def temporal_recency(last_accessed: datetime, half_life_hours: float = 168.0) -> float: 

181 """Compute temporal recency score using exponential decay. 

182 

183 A memory accessed now scores 1.0. 

184 A memory accessed half_life_hours ago scores 0.5. 

185 A memory accessed 2x half_life_hours ago scores 0.25. 

186 

187 Default half_life is 168 hours (7 days). 

188 """ 

189 now = datetime.now(timezone.utc) 

190 hours_elapsed = (now - last_accessed).total_seconds() / 3600.0 

191 

192 if hours_elapsed <= 0: 

193 return 1.0 

194 

195 return 2.0 ** (-hours_elapsed / half_life_hours) # type: ignore[no-any-return] 

196 

197 

198def jaccard_similarity(a: set[str], b: set[str]) -> float: 

199 """Compute Jaccard similarity between two sets of strings. 

200 

201 Returns 0.0 if either set is empty. 

202 """ 

203 if not a or not b: 

204 return 0.0 

205 intersection = len(a & b) 

206 union = len(a | b) 

207 return intersection / union if union > 0 else 0.0 

208 

209 

210def score_memory( 

211 memory: MemoryObject, 

212 query_embedding: list[float], 

213 query: str | None = None, 

214 hybrid_search: bool = True, 

215 corpus: list[str] | None = None, 

216 weight_semantic: float = 0.6, 

217 weight_recency: float = 0.25, 

218 weight_bm25: float = 0.15, 

219 weight_semantic_no_embed: float = 0.5, 

220 weight_recency_no_embed: float = 0.3, 

221 weight_importance: float = 0.2, 

222 query_entities: set[str] | None = None, 

223 memory_entities: set[str] | None = None, 

224 weight_entity: float = 0.1, 

225) -> float: 

226 """Compute final relevance score for a memory. 

227 

228 When hybrid_search=True and query is provided: 

229 Formula: (semantic × weight_semantic) + (recency × weight_recency) + (bm25 × weight_bm25) 

230 

231 When hybrid_search=False or no query: 

232 Formula: (semantic × weight_semantic_no_embed) 

233 + (recency × weight_recency_no_embed) 

234 + (importance × weight_importance) 

235 

236 If entity parameters are provided, an additional boost is applied: 

237 + (jaccard(query_entities, memory_entities) × weight_entity) 

238 

239 If memory.embedding is None or query_embedding is empty, semantic contribution is 0.0. 

240 

241 Args: 

242 memory: The memory object to score. 

243 query_embedding: Embedding vector for semantic search. 

244 query: Optional query string for keyword search. 

245 hybrid_search: Use hybrid scoring (default True). 

246 corpus: List of document strings to compute IDF from for BM25. 

247 weight_semantic: Weight for semantic similarity in hybrid mode (default 0.6). 

248 weight_recency: Weight for recency in hybrid mode (default 0.25). 

249 weight_bm25: Weight for BM25 keyword match in hybrid mode (default 0.15). 

250 weight_semantic_no_embed: Weight for semantic when no embedding (default 0.5). 

251 weight_recency_no_embed: Weight for recency when no embedding (default 0.3). 

252 weight_importance: Weight for importance when no embedding (default 0.2). 

253 query_entities: Optional set of entities extracted from the query. 

254 memory_entities: Optional set of entities extracted from the memory content. 

255 weight_entity: Weight for entity overlap boost (default 0.1). 

256 """ 

257 semantic_score = 0.0 

258 if memory.embedding is not None and query_embedding is not None: 

259 similarity = cosine_similarity(memory.embedding, query_embedding) 

260 semantic_score = (similarity + 1.0) / 2.0 

261 

262 recency_score = temporal_recency(memory.last_accessed_at) 

263 

264 if hybrid_search and query: 

265 if corpus and len(corpus) > 1: 

266 bm25_keyword_score = bm25_score_corpus(query, memory.content, corpus) 

267 else: 

268 bm25_keyword_score = bm25_score(query, memory.content) 

269 

270 final_score = ( 

271 semantic_score * weight_semantic 

272 + recency_score * weight_recency 

273 + bm25_keyword_score * weight_bm25 

274 ) 

275 else: 

276 importance_score = max(0.0, min(1.0, memory.importance)) 

277 final_score = ( 

278 semantic_score * weight_semantic_no_embed 

279 + recency_score * weight_recency_no_embed 

280 + importance_score * weight_importance 

281 ) 

282 

283 # Entity-aware boost 

284 if query_entities is not None and memory_entities is not None: 

285 entity_score = jaccard_similarity(query_entities, memory_entities) 

286 final_score += entity_score * weight_entity 

287 

288 return final_score 

289 

290 

291def rank_memories( 

292 memories: list[MemoryObject], 

293 query_embedding: list[float], 

294 query: str | None = None, 

295 hybrid_search: bool = True, 

296 weight_semantic: float = 0.6, 

297 weight_recency: float = 0.25, 

298 weight_bm25: float = 0.15, 

299 weight_semantic_no_embed: float = 0.5, 

300 weight_recency_no_embed: float = 0.3, 

301 weight_importance: float = 0.2, 

302 query_entities: set[str] | None = None, 

303 memory_entities_map: dict[str, set[str]] | None = None, 

304 weight_entity: float = 0.1, 

305) -> list[MemoryObject]: 

306 """Rank memories by computed score, highest first. 

307 

308 Mutates the score field on each MemoryObject in place. 

309 Returns the sorted list. 

310 

311 Args: 

312 memories: List of MemoryObjects to rank. 

313 query_embedding: Embedding vector for semantic search. 

314 query: Optional query string for keyword search. 

315 hybrid_search: Use hybrid scoring (default True). 

316 weight_semantic: Weight for semantic similarity in hybrid mode. 

317 weight_recency: Weight for recency in hybrid mode. 

318 weight_bm25: Weight for BM25 keyword match in hybrid mode. 

319 weight_semantic_no_embed: Weight for semantic when no embedding. 

320 weight_recency_no_embed: Weight for recency when no embedding. 

321 weight_importance: Weight for importance when no embedding. 

322 query_entities: Optional set of entities extracted from the query. 

323 memory_entities_map: Optional dict mapping memory_id -> set of entities. 

324 weight_entity: Weight for entity overlap boost. 

325 """ 

326 corpus = [m.content for m in memories] if len(memories) > 1 else None 

327 

328 for memory in memories: 

329 mem_entities = None 

330 if memory_entities_map is not None: 

331 mem_entities = memory_entities_map.get(memory.memory_id) 

332 memory.score = score_memory( 

333 memory, 

334 query_embedding, 

335 query, 

336 hybrid_search, 

337 corpus, 

338 weight_semantic, 

339 weight_recency, 

340 weight_bm25, 

341 weight_semantic_no_embed, 

342 weight_recency_no_embed, 

343 weight_importance, 

344 query_entities, 

345 mem_entities, 

346 weight_entity, 

347 ) 

348 

349 return sorted(memories, key=lambda m: m.score, reverse=True) 

350 

351 

352def mmr_rerank( 

353 memories: list[MemoryObject], 

354 query_embedding: list[float], 

355 top_k: int, 

356 lambda_param: float = 0.7, 

357) -> list[MemoryObject]: 

358 """Rerank memories using Maximal Marginal Relevance. 

359 

360 Balances relevance (similarity to query) with diversity 

361 (dissimilarity to already selected memories). 

362 

363 lambda_param controls the tradeoff: 

364 1.0 = pure relevance (same as no MMR) 

365 0.0 = pure diversity 

366 0.7 = default, slightly favors relevance 

367 

368 Algorithm: 

369 - Start with empty selected list 

370 - At each step, pick the candidate that maximizes: 

371 lambda * relevance_score - (1 - lambda) * max_similarity_to_selected 

372 where relevance_score = memory.score (already computed) 

373 and max_similarity_to_selected = max cosine_similarity between 

374 candidate embedding and each already-selected memory embedding 

375 - Skip candidates with no embedding (embedding is None) 

376 by treating their relevance as memory.score and similarity as 0.0 

377 - Stop when top_k memories are selected or candidates exhausted 

378 - Return selected list in order selected 

379 """ 

380 if top_k <= 0 or not memories: 

381 return [] 

382 

383 candidates = list(memories) 

384 selected: list[MemoryObject] = [] 

385 

386 while len(selected) < top_k and candidates: 

387 best_idx = -1 

388 best_mmr = float("-inf") 

389 

390 for i, candidate in enumerate(candidates): 

391 relevance = candidate.score 

392 

393 if candidate.embedding is not None and query_embedding: 

394 max_sim_to_selected = 0.0 

395 for sel in selected: 

396 if sel.embedding is not None: 

397 sim = cosine_similarity(candidate.embedding, sel.embedding) 

398 max_sim_to_selected = max(max_sim_to_selected, sim) 

399 else: 

400 max_sim_to_selected = 0.0 

401 

402 mmr_score = lambda_param * relevance - (1 - lambda_param) * max_sim_to_selected 

403 

404 if mmr_score > best_mmr: 

405 best_mmr = mmr_score 

406 best_idx = i 

407 

408 if best_idx == -1: # pragma: no cover (unreachable) 

409 break 

410 

411 selected.append(candidates.pop(best_idx)) 

412 

413 return selected 

414 

415 

416def _default_token_counter(text: str) -> int: 

417 """Default token counter: rough estimate = word_count * 1.3""" 

418 result: float = len(text.split()) * 1.3 

419 return int(result) 

420 

421 

422def mmr_rerank_stream( 

423 memories: list[MemoryObject], 

424 query_embedding: list[float], 

425 top_k: int, 

426 lambda_param: float = 0.7, 

427): 

428 """Yield memories one at a time as MMR selects them. 

429 

430 Same algorithm as :func:`mmr_rerank` but yields each selected memory 

431 immediately rather than collecting them into a list. 

432 

433 Yields: 

434 MemoryObject, each selected by the MMR criterion. 

435 """ 

436 if top_k <= 0 or not memories: 

437 return 

438 

439 candidates = list(memories) 

440 selected: list[MemoryObject] = [] 

441 

442 while len(selected) < top_k and candidates: 

443 best_idx = -1 

444 best_mmr = float("-inf") 

445 

446 for i, candidate in enumerate(candidates): 

447 relevance = candidate.score 

448 

449 if candidate.embedding is not None and query_embedding: 

450 max_sim_to_selected = 0.0 

451 for sel in selected: 

452 if sel.embedding is not None: 

453 sim = cosine_similarity(candidate.embedding, sel.embedding) 

454 max_sim_to_selected = max(max_sim_to_selected, sim) 

455 else: 

456 max_sim_to_selected = 0.0 

457 

458 mmr_score = lambda_param * relevance - (1 - lambda_param) * max_sim_to_selected 

459 

460 if mmr_score > best_mmr: 

461 best_mmr = mmr_score 

462 best_idx = i 

463 

464 if best_idx == -1: 

465 return 

466 

467 memory = candidates.pop(best_idx) 

468 selected.append(memory) 

469 yield memory 

470 

471 

472def truncate_by_tokens( 

473 memories: list[MemoryObject], 

474 max_tokens: int | None, 

475 token_counter: Callable[[str], int] | None = None, 

476) -> list[MemoryObject]: 

477 """Truncate memories by token budget. 

478 

479 Walks ranked list, sums token counts, stops when budget reached. 

480 If max_tokens is None, returns all memories. 

481 If a single memory exceeds budget, includes it anyway. 

482 Never returns an empty list (if any input, returns at least one). 

483 """ 

484 if max_tokens is None: 

485 return memories 

486 

487 if not memories: 

488 return memories 

489 

490 counter = token_counter or _default_token_counter 

491 result: list[MemoryObject] = [] 

492 total_tokens = 0 

493 

494 for memory in memories: 

495 memory_tokens = counter(memory.content) 

496 

497 if result and total_tokens + memory_tokens > max_tokens: 

498 break 

499 

500 result.append(memory) 

501 total_tokens += memory_tokens 

502 

503 if not result and memories: # pragma: no cover (unreachable) 

504 result = [memories[0]] 

505 

506 return result