Coverage for src / kemi / reranker.py: 86%

134 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""Cross-encoder re-ranking for precise semantic result ordering. 

2 

3In a two-stage retrieval pipeline: 

4 Stage 1 (Bi-Encoder): Fast ANN/BM25 retrieval → top N candidates (high recall) 

5 Stage 2 (Cross-Encoder): Re-rank top N → final top-K (high precision) 

6 

7Unlike bi-encoders which encode query and document independently, 

8cross-encoders process (query, document) pairs jointly, enabling 

9deeper semantic understanding of relevance. 

10 

11This module supports: 

12- Local cross-encoder models (sentence-transformers) 

13- OpenAI API-based reranking 

14- A lightweight fallback scoring when no cross-encoder is available 

15""" 

16 

17from __future__ import annotations 

18 

19import math 

20from dataclasses import dataclass 

21from typing import TYPE_CHECKING 

22 

23if TYPE_CHECKING: 

24 from kemi.models import MemoryObject 

25 

26__all__ = [ 

27 "RerankerConfig", 

28 "RerankerResult", 

29 "CrossEncoderReranker", 

30 "NomicReranker", 

31 "rerank_results", 

32 "FallbackReranker", 

33] 

34 

35 

36# --------------------------------------------------------------------------- 

37# Config and result types 

38# --------------------------------------------------------------------------- 

39 

40_DEFAULT_RRF_K = 60 

41 

42 

43@dataclass 

44class RerankerConfig: 

45 """Configuration for a cross-encoder reranker.""" 

46 

47 provider: str = "fallback" # "fallback" | "sentence-transformers" | "openai" | "nomic" 

48 model: str | None = None # Model name (e.g., "BAAI/bge-reranker-base") 

49 device: str = "cpu" # "cpu" | "cuda" 

50 batch_size: int = 8 # Number of (query, doc) pairs to score per batch 

51 score_threshold: float = 0.0 # Drop results below this score 

52 

53 

54@dataclass 

55class RerankerResult: 

56 """A re-ranked memory with cross-encoder score and metadata.""" 

57 

58 memory: MemoryObject 

59 cross_encoder_score: float # joint query-doc score from cross-encoder 

60 bi_encoder_rank: int # original position in bi-encoder result list 

61 cross_encoder_rank: int # new position after reranking 

62 

63 

64# --------------------------------------------------------------------------- 

65# Core reranking logic 

66# --------------------------------------------------------------------------- 

67 

68def rerank_results( 

69 results: list[MemoryObject], 

70 query: str, 

71 config: RerankerConfig, 

72 embed_fn=None, 

73) -> list[MemoryObject]: 

74 """Re-rank a list of MemoryObjects using a cross-encoder. 

75 

76 Uses the configured provider to score (query, document) pairs jointly, 

77 then sorts results by cross-encoder score descending. 

78 

79 Falls back to a lightweight scoring method if no cross-encoder is 

80 configured or available. 

81 

82 Args: 

83 results: Initial retrieval results (from bi-encoder / BM25). 

84 query: The search query string. 

85 config: RerankerConfig specifying provider and model. 

86 embed_fn: Optional embed function for fallback scoring. 

87 

88 Returns: 

89 Re-ranked list of MemoryObjects (same set, different order). 

90 Results below score_threshold (if set) are dropped. 

91 """ 

92 if not results: 

93 return [] 

94 

95 # Assign original bi-encoder ranks 

96 for idx, mem in enumerate(results): 

97 mem.bi_encoder_rank = idx # type: ignore[attr-defined] 

98 

99 if config.provider == "fallback": 

100 reranker = FallbackReranker(embed_fn=embed_fn) 

101 elif config.provider == "nomic": 

102 reranker = NomicReranker(model=config.model) 

103 else: 

104 # Unknown provider — use fallback 

105 reranker = FallbackReranker(embed_fn=embed_fn) 

106 

107 scored = reranker.score(query, results) 

108 

109 if config.score_threshold > 0.0: 

110 scored = [s for s in scored if s.cross_encoder_score >= config.score_threshold] 

111 

112 # Sort by cross-encoder score descending 

113 scored.sort(key=lambda x: x.cross_encoder_score, reverse=True) 

114 

115 # Assign new cross-encoder ranks 

116 for idx, s in enumerate(scored): 

117 s.cross_encoder_rank = idx 

118 

119 return [s.memory for s in scored] 

120 

121 

122# --------------------------------------------------------------------------- 

123# Fallback reranker (no external model required) 

124# --------------------------------------------------------------------------- 

125 

126class FallbackReranker: 

127 """Lightweight re-ranker using keyword overlap + embedding similarity. 

128 

129 Used when no cross-encoder model is available. Provides a meaningful 

130 re-ranking signal using: 

131 - Exact term overlap (Query term in document → +1) 

132 - Partial term overlap (stemmed/lemmatized → +0.5) 

133 - Query-document embedding similarity (from bi-encoder embeddings) 

134 - Position bonus (earlier mentions → higher score) 

135 """ 

136 

137 STEM_SUFFIXES = ("ing", "ed", "es", "er", "ly", "tion", "ness", "ment") 

138 

139 def __init__(self, embed_fn=None) -> None: 

140 self._embed_fn = embed_fn 

141 

142 def score( 

143 self, query: str, results: list[MemoryObject] 

144 ) -> list[RerankerResult]: 

145 """Score each result using keyword matching + similarity.""" 

146 

147 query_terms = set(self._normalize_terms(query)) 

148 

149 scored_results: list[RerankerResult] = [] 

150 

151 for mem in results: 

152 content_terms = set(self._normalize_terms(mem.content)) 

153 

154 # Exact term match score 

155 exact_overlap = len(query_terms & content_terms) 

156 exact_score = exact_overlap / max(len(query_terms), 1) 

157 

158 # Partial/stemmed match score 

159 partial_score = self._stemmed_overlap(query_terms, content_terms) 

160 

161 # Position bonus: query terms appearing in first 50 chars → higher 

162 position_score = self._position_bonus(query_terms, mem.content) 

163 

164 # Embedding similarity if available 

165 embed_score = 0.0 

166 if self._embed_fn is not None and mem.embedding is not None: 

167 doc_emb = mem.embedding 

168 # Embed the query with same adapter 

169 try: 

170 query_emb = self._embed_fn.embed_single(query) 

171 embed_score = self._cosine_sim(query_emb, doc_emb) 

172 except Exception: 

173 embed_score = 0.0 

174 

175 # Combine scores: keyword 40%, position 10%, embed 50% 

176 combined = ( 

177 exact_score * 0.25 

178 + partial_score * 0.15 

179 + position_score * 0.10 

180 + ((embed_score + 1.0) / 2.0) * 0.50 

181 ) 

182 

183 scored_results.append( 

184 RerankerResult( 

185 memory=mem, 

186 cross_encoder_score=round(combined, 4), 

187 bi_encoder_rank=getattr(mem, "bi_encoder_rank", 0), 

188 cross_encoder_rank=0, 

189 ) 

190 ) 

191 

192 return scored_results 

193 

194 def _normalize_terms(self, text: str) -> list[str]: 

195 """Lowercase and split text into terms.""" 

196 return text.lower().split() 

197 

198 def _stemmed_overlap(self, query_terms: set[str], doc_terms: set[str]) -> float: 

199 """Compute partial overlap using simple suffix stripping.""" 

200 stemmed_query = {self._strip_suffix(t) for t in query_terms if len(t) > 4} 

201 stemmed_doc = {self._strip_suffix(t) for t in doc_terms if len(t) > 4} 

202 overlap = len(stemmed_query & stemmed_doc) 

203 return overlap / max(len(stemmed_query), 1) 

204 

205 def _strip_suffix(self, word: str) -> str: 

206 """Simple suffix stripper (no external library).""" 

207 for suffix in self.STEM_SUFFIXES: 

208 if word.endswith(suffix): 

209 return word[: -len(suffix)] 

210 return word 

211 

212 def _position_bonus(self, query_terms: set[str], content: str) -> float: 

213 """Award points for query terms appearing early in content.""" 

214 content_lower = content.lower() 

215 first_100 = content_lower[:100] 

216 bonus = 0.0 

217 for term in query_terms: 

218 if term in first_100: 

219 bonus += 0.1 

220 return min(bonus, 0.5) # Cap at 0.5 

221 

222 def _cosine_sim(self, a: list[float], b: list[float]) -> float: 

223 """Compute cosine similarity between two vectors.""" 

224 dot = sum(x * y for x, y in zip(a, b)) 

225 norm_a = math.sqrt(sum(x * x for x in a)) 

226 norm_b = math.sqrt(sum(y * y for y in b)) 

227 if norm_a == 0.0 or norm_b == 0.0: 

228 return 0.0 

229 return dot / (norm_a * norm_b) 

230 

231 

232# --------------------------------------------------------------------------- 

233# Nomic Vision reranker (local cross-encoder via nomic embed) 

234# --------------------------------------------------------------------------- 

235 

236class CrossEncoderReranker: 

237 """Stub cross-encoder reranker for future sentence-transformers / OpenAI integration. 

238 

239 Placeholder for when a full cross-encoder model is implemented. Currently 

240 falls back to FallbackReranker. Install sentence-transformers to enable: 

241 ``pip install sentence-transformers`` 

242 """ 

243 

244 def __init__( 

245 self, 

246 model: str | None = None, 

247 device: str = "cpu", 

248 batch_size: int = 8, 

249 ) -> None: 

250 self._model = model 

251 self._device = device 

252 self._batch_size = batch_size 

253 self._reranker = FallbackReranker() 

254 

255 def score( 

256 self, query: str, results: list[MemoryObject] 

257 ) -> list[RerankerResult]: 

258 """Score (query, document) pairs using a cross-encoder model. 

259 

260 Currently falls back to FallbackReranker until a real cross-encoder 

261 model is integrated. 

262 """ 

263 return self._reranker.score(query, results) 

264 

265 

266class NomicReranker: 

267 """Cross-encoder reranker using Nomic's local embed model. 

268 

269 Uses nomic-embed-text-v1.5 or similar for cross-style scoring: 

270 pairs query and document as a single input with a separator, 

271 then uses the classification-style embedding for scoring. 

272 

273 Falls back to FallbackReranker if nomic is not available. 

274 """ 

275 

276 def __init__(self, model: str | None = None) -> None: 

277 self._model = model or "nomic-embed-text-v1.5" 

278 self._reranker: FallbackReranker | None = None 

279 try: 

280 import requests # noqa: F401 

281 except ImportError: 

282 pass 

283 

284 def score( 

285 self, query: str, results: list[MemoryObject] 

286 ) -> list[RerankerResult]: 

287 """Score using Nomic's cross-encoder-style embedding.""" 

288 try: 

289 import requests 

290 

291 scored = [] 

292 for mem in results: 

293 # Format as cross-encoder pair: query [SEP] document 

294 pair_input = f"query: {query}\ndocument: {mem.content}" 

295 response = requests.post( 

296 "http://localhost:11434/api/embeddings", 

297 json={"model": self._model, "prompt": pair_input}, 

298 timeout=10, 

299 ) 

300 if response.status_code == 200: 

301 emb = response.json().get("embedding", []) 

302 # For cross-encoder pairs, the embedding magnitude ≈ relevance 

303 # Use mean of embedding dims as proxy score (higher = more relevant) 

304 score = sum(emb) / max(len(emb), 1) if emb else 0.0 

305 # Normalize to roughly 0–1 

306 norm_score = 1.0 / (1.0 + math.exp(-score)) 

307 else: 

308 score = 0.0 

309 norm_score = 0.0 

310 

311 scored.append( 

312 RerankerResult( 

313 memory=mem, 

314 cross_encoder_score=norm_score, 

315 bi_encoder_rank=getattr(mem, "bi_encoder_rank", 0), 

316 cross_encoder_rank=0, 

317 ) 

318 ) 

319 return scored 

320 

321 except Exception: 

322 # Fall back if Nomic is not available 

323 if self._reranker is None: 

324 self._reranker = FallbackReranker() 

325 return self._reranker.score(query, results) 

326 

327 

328# --------------------------------------------------------------------------- 

329# Convenience function: combine RRF fusion + cross-encoder reranking 

330# --------------------------------------------------------------------------- 

331 

332def fuse_and_rerank( 

333 fusion_results, # list of FusionResult from decompose.py 

334 query: str, 

335 config: RerankerConfig, 

336 embed_fn=None, 

337) -> list[RerankerResult]: 

338 """Run RRF fusion results through cross-encoder reranking. 

339 

340 Args: 

341 fusion_results: List of FusionResult objects from fused_recall(). 

342 query: Original query string. 

343 config: RerankerConfig. 

344 embed_fn: Optional embed function for fallback scoring. 

345 

346 Returns: 

347 List of RerankerResult objects, sorted by cross-encoder score descending. 

348 """ 

349 memories = [fr.memory for fr in fusion_results] 

350 reranked = rerank_results(memories, query, config, embed_fn) 

351 

352 # Build result with RRF context 

353 reranked_results: list[RerankerResult] = [] 

354 for mem in reranked: 

355 # Find the original FusionResult to get RRF score and source_ranks 

356 fr = next((f for f in fusion_results if f.memory.memory_id == mem.memory_id), None) 

357 rrf_score = fr.rrf_score if fr else 0.0 

358 

359 reranked_results.append( 

360 RerankerResult( 

361 memory=mem, 

362 cross_encoder_score=( 

363 0.4 * getattr(mem, "cross_encoder_score", 0.0) 

364 + 0.6 * rrf_score 

365 ), 

366 bi_encoder_rank=getattr(mem, "bi_encoder_rank", 0), 

367 cross_encoder_rank=0, 

368 ) 

369 ) 

370 

371 reranked_results.sort(key=lambda x: x.cross_encoder_score, reverse=True) 

372 for idx, r in enumerate(reranked_results): 

373 r.cross_encoder_rank = idx 

374 

375 return reranked_results