Coverage for agentos/rag/reranker.py: 0%
143 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""Re-ranking for RAG pipeline.
3Cross-encoder and LLM-based reranking to refine retrieval results.
4Supports: cross-encoder (sentence-transformers), LLM reranking,
5and simple heuristic reranking (diversity, freshness).
6"""
8from __future__ import annotations
10from dataclasses import dataclass, field
11from typing import Any, Dict, List, Optional, Tuple
12import math
15@dataclass
16class RerankConfig:
17 """Configuration for reranking."""
18 method: str = "cross_encoder" # cross_encoder | llm | diversity | mmr
19 model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
20 top_n: int = 5 # number of results after reranking
21 diversity_lambda: float = 0.5 # MMR diversity weight
22 llm_prompt_template: str = "" # custom prompt for LLM reranker
23 batch_size: int = 8
26class Reranker:
27 """Re-rank retrieval results for improved relevance.
29 Methods:
30 - cross_encoder: Uses sentence-transformers cross-encoder for precision.
31 - mmr: Maximal Marginal Relevance for diversity.
32 - llm: Uses an LLM to score relevance of each passage.
33 """
35 def __init__(self, config: Optional[RerankConfig] = None):
36 self.config = config or RerankConfig()
37 self._cross_encoder = None
38 self._embed_fn = None # for MMR diversity
40 async def rerank(
41 self,
42 query: str,
43 passages: List[Dict[str, Any]],
44 ) -> List[Dict[str, Any]]:
45 """Re-rank passages by relevance to query.
47 Args:
48 query: Original search query.
49 passages: List of dicts with 'text' and 'score' keys.
51 Returns:
52 Re-ranked list with updated 'rerank_score' key.
53 """
54 if not passages:
55 return []
57 if self.config.method == "cross_encoder":
58 return await self._cross_encode_rerank(query, passages)
59 elif self.config.method == "mmr":
60 return self._mmr_rerank(query, passages)
61 elif self.config.method == "llm":
62 return await self._llm_rerank(query, passages)
63 else:
64 # diversity: sort by text length variability as proxy
65 return self._diversity_rerank(passages)
67 async def _cross_encode_rerank(
68 self,
69 query: str,
70 passages: List[Dict[str, Any]],
71 ) -> List[Dict[str, Any]]:
72 """Use cross-encoder model for relevance scoring."""
73 try:
74 from sentence_transformers import CrossEncoder
75 except ImportError:
76 # Fallback: use simple heuristic based on term overlap
77 return self._fallback_rerank(query, passages)
79 if self._cross_encoder is None:
80 self._cross_encoder = CrossEncoder(self.config.model)
82 pairs = [(query, p["text"]) for p in passages]
83 scores = self._cross_encoder.predict(pairs, batch_size=self.config.batch_size)
85 for p, s in zip(passages, scores):
86 p["rerank_score"] = float(s)
87 p["rerank_method"] = "cross_encoder"
89 passages.sort(key=lambda x: x.get("rerank_score", 0), reverse=True)
90 return passages[: self.config.top_n]
92 def _mmr_rerank(
93 self,
94 query: str,
95 passages: List[Dict[str, Any]],
96 ) -> List[Dict[str, Any]]:
97 """Maximal Marginal Relevance: balance relevance with diversity.
99 Without embeddings, uses Jaccard similarity on token sets as proxy.
100 """
101 if not passages:
102 return []
104 texts = [p["text"] for p in passages]
106 # Tokenize for diversity computation
107 token_sets = []
108 for t in texts:
109 import re
110 tokens = set(re.findall(r'\w+', t.lower()))
111 token_sets.append(tokens)
113 query_tokens = set(re.findall(r'\w+', query.lower())) if query else set()
115 def _jaccard_sim(a: set, b: set) -> float:
116 if not a or not b:
117 return 0.0
118 return len(a & b) / len(a | b)
120 # Initial relevance scores (original scores or query overlap)
121 relevance = []
122 for i, (p, ts) in enumerate(zip(passages, token_sets)):
123 if query_tokens:
124 rel = _jaccard_sim(query_tokens, ts)
125 else:
126 rel = p.get("score", 0.0)
127 relevance.append(rel)
129 selected = []
130 remaining = list(range(len(passages)))
132 while remaining and len(selected) < self.config.top_n:
133 best_idx = None
134 best_score = -float("inf")
136 for idx in remaining:
137 diversity = min(
138 (1.0 - _jaccard_sim(token_sets[idx], token_sets[s]))
139 for s in selected
140 ) if selected else 1.0
142 score = (
143 self.config.diversity_lambda * relevance[idx]
144 + (1 - self.config.diversity_lambda) * diversity
145 )
147 if score > best_score:
148 best_score = score
149 best_idx = idx
151 if best_idx is not None:
152 selected.append(best_idx)
153 remaining.remove(best_idx)
154 else:
155 break
157 result = [passages[i] for i in selected]
158 for i, p in enumerate(result):
159 p["rerank_score"] = relevance[selected[i]]
160 p["rerank_method"] = "mmr"
162 return result
164 async def _llm_rerank(
165 self,
166 query: str,
167 passages: List[Dict[str, Any]],
168 ) -> List[Dict[str, Any]]:
169 """LLM-based reranking: ask an LLM to score passage relevance.
171 Falls back to fallback heuristic if no LLM is configured.
172 """
173 # This is a framework hook — the actual LLM call is done by the caller
174 # by injecting an llm_call function or using the default heuristic
175 return self._fallback_rerank(query, passages)
177 def _diversity_rerank(
178 self,
179 passages: List[Dict[str, Any]],
180 ) -> List[Dict[str, Any]]:
181 """Simple diversity reranking: penalize similar-length passages."""
182 import re
184 texts = [p["text"] for p in passages]
185 token_sets = []
186 for t in texts:
187 token_sets.append(set(re.findall(r'\w+', t.lower())))
189 # Score: original score * diversity bonus (penalize similarity to higher-ranked)
190 scored = []
191 for i, p in enumerate(passages):
192 diversity_penalty = 0.0
193 for j in range(i):
194 if token_sets[i] and token_sets[j]:
195 overlap = len(token_sets[i] & token_sets[j]) / len(token_sets[i] | token_sets[j])
196 diversity_penalty += overlap * 0.1
197 p["rerank_score"] = p.get("score", 0.5) * (1.0 - min(diversity_penalty, 0.5))
198 p["rerank_method"] = "diversity"
200 passages.sort(key=lambda x: x.get("rerank_score", 0), reverse=True)
201 return passages[: self.config.top_n]
203 def _fallback_rerank(
204 self,
205 query: str,
206 passages: List[Dict[str, Any]],
207 ) -> List[Dict[str, Any]]:
208 """Fallback: simple term-overlap heuristic rerank."""
209 import re
211 query_tokens = set(re.findall(r'\w+', query.lower())) if query else set()
213 for p in passages:
214 text_tokens = set(re.findall(r'\w+', p["text"].lower()))
215 if query_tokens and text_tokens:
216 overlap = len(query_tokens & text_tokens) / max(len(query_tokens), 1)
217 p["rerank_score"] = p.get("score", 0.0) * 0.5 + overlap * 0.5
218 else:
219 p["rerank_score"] = p.get("score", 0.0)
220 p["rerank_method"] = "fallback_heuristic"
222 passages.sort(key=lambda x: x.get("rerank_score", 0), reverse=True)
223 return passages[: self.config.top_n]
226class DiversityRanker:
227 """Diversity-focused reranker for varied search results."""
229 def __init__(self, lambda_param: float = 0.6):
230 self.lambda_param = lambda_param
232 def rerank(
233 self,
234 passages: List[Dict[str, Any]],
235 top_n: int = 5,
236 ) -> List[Dict[str, Any]]:
237 """Maximize result diversity while keeping relevance high."""
238 if not passages:
239 return []
241 texts = [p["text"] for p in passages]
242 import re
244 token_sets = []
245 for t in texts:
246 token_sets.append(set(re.findall(r'\w+', t.lower())))
248 def sim(a: set, b: set) -> float:
249 if not a or not b:
250 return 0.0
251 return len(a & b) / len(a | b)
253 selected = [0]
254 remaining = set(range(1, len(passages)))
256 while len(selected) < min(top_n, len(passages)):
257 best_idx = -1
258 best_score = -float("inf")
259 for idx in remaining:
260 max_sim = max(sim(token_sets[idx], token_sets[s]) for s in selected)
261 score = (
262 self.lambda_param * passages[idx].get("score", 0.5)
263 - (1 - self.lambda_param) * max_sim
264 )
265 if score > best_score:
266 best_score = score
267 best_idx = idx
268 if best_idx < 0:
269 break
270 selected.append(best_idx)
271 remaining.remove(best_idx)
273 return [passages[i] for i in selected]