Coverage for src / kemi / reranker.py: 86%
134 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""Cross-encoder re-ranking for precise semantic result ordering.
3In a two-stage retrieval pipeline:
4 Stage 1 (Bi-Encoder): Fast ANN/BM25 retrieval → top N candidates (high recall)
5 Stage 2 (Cross-Encoder): Re-rank top N → final top-K (high precision)
7Unlike bi-encoders which encode query and document independently,
8cross-encoders process (query, document) pairs jointly, enabling
9deeper semantic understanding of relevance.
11This module supports:
12- Local cross-encoder models (sentence-transformers)
13- OpenAI API-based reranking
14- A lightweight fallback scoring when no cross-encoder is available
15"""
17from __future__ import annotations
19import math
20from dataclasses import dataclass
21from typing import TYPE_CHECKING
23if TYPE_CHECKING:
24 from kemi.models import MemoryObject
26__all__ = [
27 "RerankerConfig",
28 "RerankerResult",
29 "CrossEncoderReranker",
30 "NomicReranker",
31 "rerank_results",
32 "FallbackReranker",
33]
36# ---------------------------------------------------------------------------
37# Config and result types
38# ---------------------------------------------------------------------------
40_DEFAULT_RRF_K = 60
43@dataclass
44class RerankerConfig:
45 """Configuration for a cross-encoder reranker."""
47 provider: str = "fallback" # "fallback" | "sentence-transformers" | "openai" | "nomic"
48 model: str | None = None # Model name (e.g., "BAAI/bge-reranker-base")
49 device: str = "cpu" # "cpu" | "cuda"
50 batch_size: int = 8 # Number of (query, doc) pairs to score per batch
51 score_threshold: float = 0.0 # Drop results below this score
54@dataclass
55class RerankerResult:
56 """A re-ranked memory with cross-encoder score and metadata."""
58 memory: MemoryObject
59 cross_encoder_score: float # joint query-doc score from cross-encoder
60 bi_encoder_rank: int # original position in bi-encoder result list
61 cross_encoder_rank: int # new position after reranking
64# ---------------------------------------------------------------------------
65# Core reranking logic
66# ---------------------------------------------------------------------------
68def rerank_results(
69 results: list[MemoryObject],
70 query: str,
71 config: RerankerConfig,
72 embed_fn=None,
73) -> list[MemoryObject]:
74 """Re-rank a list of MemoryObjects using a cross-encoder.
76 Uses the configured provider to score (query, document) pairs jointly,
77 then sorts results by cross-encoder score descending.
79 Falls back to a lightweight scoring method if no cross-encoder is
80 configured or available.
82 Args:
83 results: Initial retrieval results (from bi-encoder / BM25).
84 query: The search query string.
85 config: RerankerConfig specifying provider and model.
86 embed_fn: Optional embed function for fallback scoring.
88 Returns:
89 Re-ranked list of MemoryObjects (same set, different order).
90 Results below score_threshold (if set) are dropped.
91 """
92 if not results:
93 return []
95 # Assign original bi-encoder ranks
96 for idx, mem in enumerate(results):
97 mem.bi_encoder_rank = idx # type: ignore[attr-defined]
99 if config.provider == "fallback":
100 reranker = FallbackReranker(embed_fn=embed_fn)
101 elif config.provider == "nomic":
102 reranker = NomicReranker(model=config.model)
103 else:
104 # Unknown provider — use fallback
105 reranker = FallbackReranker(embed_fn=embed_fn)
107 scored = reranker.score(query, results)
109 if config.score_threshold > 0.0:
110 scored = [s for s in scored if s.cross_encoder_score >= config.score_threshold]
112 # Sort by cross-encoder score descending
113 scored.sort(key=lambda x: x.cross_encoder_score, reverse=True)
115 # Assign new cross-encoder ranks
116 for idx, s in enumerate(scored):
117 s.cross_encoder_rank = idx
119 return [s.memory for s in scored]
122# ---------------------------------------------------------------------------
123# Fallback reranker (no external model required)
124# ---------------------------------------------------------------------------
126class FallbackReranker:
127 """Lightweight re-ranker using keyword overlap + embedding similarity.
129 Used when no cross-encoder model is available. Provides a meaningful
130 re-ranking signal using:
131 - Exact term overlap (Query term in document → +1)
132 - Partial term overlap (stemmed/lemmatized → +0.5)
133 - Query-document embedding similarity (from bi-encoder embeddings)
134 - Position bonus (earlier mentions → higher score)
135 """
137 STEM_SUFFIXES = ("ing", "ed", "es", "er", "ly", "tion", "ness", "ment")
139 def __init__(self, embed_fn=None) -> None:
140 self._embed_fn = embed_fn
142 def score(
143 self, query: str, results: list[MemoryObject]
144 ) -> list[RerankerResult]:
145 """Score each result using keyword matching + similarity."""
147 query_terms = set(self._normalize_terms(query))
149 scored_results: list[RerankerResult] = []
151 for mem in results:
152 content_terms = set(self._normalize_terms(mem.content))
154 # Exact term match score
155 exact_overlap = len(query_terms & content_terms)
156 exact_score = exact_overlap / max(len(query_terms), 1)
158 # Partial/stemmed match score
159 partial_score = self._stemmed_overlap(query_terms, content_terms)
161 # Position bonus: query terms appearing in first 50 chars → higher
162 position_score = self._position_bonus(query_terms, mem.content)
164 # Embedding similarity if available
165 embed_score = 0.0
166 if self._embed_fn is not None and mem.embedding is not None:
167 doc_emb = mem.embedding
168 # Embed the query with same adapter
169 try:
170 query_emb = self._embed_fn.embed_single(query)
171 embed_score = self._cosine_sim(query_emb, doc_emb)
172 except Exception:
173 embed_score = 0.0
175 # Combine scores: keyword 40%, position 10%, embed 50%
176 combined = (
177 exact_score * 0.25
178 + partial_score * 0.15
179 + position_score * 0.10
180 + ((embed_score + 1.0) / 2.0) * 0.50
181 )
183 scored_results.append(
184 RerankerResult(
185 memory=mem,
186 cross_encoder_score=round(combined, 4),
187 bi_encoder_rank=getattr(mem, "bi_encoder_rank", 0),
188 cross_encoder_rank=0,
189 )
190 )
192 return scored_results
194 def _normalize_terms(self, text: str) -> list[str]:
195 """Lowercase and split text into terms."""
196 return text.lower().split()
198 def _stemmed_overlap(self, query_terms: set[str], doc_terms: set[str]) -> float:
199 """Compute partial overlap using simple suffix stripping."""
200 stemmed_query = {self._strip_suffix(t) for t in query_terms if len(t) > 4}
201 stemmed_doc = {self._strip_suffix(t) for t in doc_terms if len(t) > 4}
202 overlap = len(stemmed_query & stemmed_doc)
203 return overlap / max(len(stemmed_query), 1)
205 def _strip_suffix(self, word: str) -> str:
206 """Simple suffix stripper (no external library)."""
207 for suffix in self.STEM_SUFFIXES:
208 if word.endswith(suffix):
209 return word[: -len(suffix)]
210 return word
212 def _position_bonus(self, query_terms: set[str], content: str) -> float:
213 """Award points for query terms appearing early in content."""
214 content_lower = content.lower()
215 first_100 = content_lower[:100]
216 bonus = 0.0
217 for term in query_terms:
218 if term in first_100:
219 bonus += 0.1
220 return min(bonus, 0.5) # Cap at 0.5
222 def _cosine_sim(self, a: list[float], b: list[float]) -> float:
223 """Compute cosine similarity between two vectors."""
224 dot = sum(x * y for x, y in zip(a, b))
225 norm_a = math.sqrt(sum(x * x for x in a))
226 norm_b = math.sqrt(sum(y * y for y in b))
227 if norm_a == 0.0 or norm_b == 0.0:
228 return 0.0
229 return dot / (norm_a * norm_b)
232# ---------------------------------------------------------------------------
233# Nomic Vision reranker (local cross-encoder via nomic embed)
234# ---------------------------------------------------------------------------
236class CrossEncoderReranker:
237 """Stub cross-encoder reranker for future sentence-transformers / OpenAI integration.
239 Placeholder for when a full cross-encoder model is implemented. Currently
240 falls back to FallbackReranker. Install sentence-transformers to enable:
241 ``pip install sentence-transformers``
242 """
244 def __init__(
245 self,
246 model: str | None = None,
247 device: str = "cpu",
248 batch_size: int = 8,
249 ) -> None:
250 self._model = model
251 self._device = device
252 self._batch_size = batch_size
253 self._reranker = FallbackReranker()
255 def score(
256 self, query: str, results: list[MemoryObject]
257 ) -> list[RerankerResult]:
258 """Score (query, document) pairs using a cross-encoder model.
260 Currently falls back to FallbackReranker until a real cross-encoder
261 model is integrated.
262 """
263 return self._reranker.score(query, results)
266class NomicReranker:
267 """Cross-encoder reranker using Nomic's local embed model.
269 Uses nomic-embed-text-v1.5 or similar for cross-style scoring:
270 pairs query and document as a single input with a separator,
271 then uses the classification-style embedding for scoring.
273 Falls back to FallbackReranker if nomic is not available.
274 """
276 def __init__(self, model: str | None = None) -> None:
277 self._model = model or "nomic-embed-text-v1.5"
278 self._reranker: FallbackReranker | None = None
279 try:
280 import requests # noqa: F401
281 except ImportError:
282 pass
284 def score(
285 self, query: str, results: list[MemoryObject]
286 ) -> list[RerankerResult]:
287 """Score using Nomic's cross-encoder-style embedding."""
288 try:
289 import requests
291 scored = []
292 for mem in results:
293 # Format as cross-encoder pair: query [SEP] document
294 pair_input = f"query: {query}\ndocument: {mem.content}"
295 response = requests.post(
296 "http://localhost:11434/api/embeddings",
297 json={"model": self._model, "prompt": pair_input},
298 timeout=10,
299 )
300 if response.status_code == 200:
301 emb = response.json().get("embedding", [])
302 # For cross-encoder pairs, the embedding magnitude ≈ relevance
303 # Use mean of embedding dims as proxy score (higher = more relevant)
304 score = sum(emb) / max(len(emb), 1) if emb else 0.0
305 # Normalize to roughly 0–1
306 norm_score = 1.0 / (1.0 + math.exp(-score))
307 else:
308 score = 0.0
309 norm_score = 0.0
311 scored.append(
312 RerankerResult(
313 memory=mem,
314 cross_encoder_score=norm_score,
315 bi_encoder_rank=getattr(mem, "bi_encoder_rank", 0),
316 cross_encoder_rank=0,
317 )
318 )
319 return scored
321 except Exception:
322 # Fall back if Nomic is not available
323 if self._reranker is None:
324 self._reranker = FallbackReranker()
325 return self._reranker.score(query, results)
328# ---------------------------------------------------------------------------
329# Convenience function: combine RRF fusion + cross-encoder reranking
330# ---------------------------------------------------------------------------
332def fuse_and_rerank(
333 fusion_results, # list of FusionResult from decompose.py
334 query: str,
335 config: RerankerConfig,
336 embed_fn=None,
337) -> list[RerankerResult]:
338 """Run RRF fusion results through cross-encoder reranking.
340 Args:
341 fusion_results: List of FusionResult objects from fused_recall().
342 query: Original query string.
343 config: RerankerConfig.
344 embed_fn: Optional embed function for fallback scoring.
346 Returns:
347 List of RerankerResult objects, sorted by cross-encoder score descending.
348 """
349 memories = [fr.memory for fr in fusion_results]
350 reranked = rerank_results(memories, query, config, embed_fn)
352 # Build result with RRF context
353 reranked_results: list[RerankerResult] = []
354 for mem in reranked:
355 # Find the original FusionResult to get RRF score and source_ranks
356 fr = next((f for f in fusion_results if f.memory.memory_id == mem.memory_id), None)
357 rrf_score = fr.rrf_score if fr else 0.0
359 reranked_results.append(
360 RerankerResult(
361 memory=mem,
362 cross_encoder_score=(
363 0.4 * getattr(mem, "cross_encoder_score", 0.0)
364 + 0.6 * rrf_score
365 ),
366 bi_encoder_rank=getattr(mem, "bi_encoder_rank", 0),
367 cross_encoder_rank=0,
368 )
369 )
371 reranked_results.sort(key=lambda x: x.cross_encoder_score, reverse=True)
372 for idx, r in enumerate(reranked_results):
373 r.cross_encoder_rank = idx
375 return reranked_results