Coverage for agentos/rag/reranker.py: 0%

143 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1"""Re-ranking for RAG pipeline. 

2 

3Cross-encoder and LLM-based reranking to refine retrieval results. 

4Supports: cross-encoder (sentence-transformers), LLM reranking, 

5and simple heuristic reranking (diversity, freshness). 

6""" 

7 

8from __future__ import annotations 

9 

10from dataclasses import dataclass, field 

11from typing import Any, Dict, List, Optional, Tuple 

12import math 

13 

14 

15@dataclass 

16class RerankConfig: 

17 """Configuration for reranking.""" 

18 method: str = "cross_encoder" # cross_encoder | llm | diversity | mmr 

19 model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" 

20 top_n: int = 5 # number of results after reranking 

21 diversity_lambda: float = 0.5 # MMR diversity weight 

22 llm_prompt_template: str = "" # custom prompt for LLM reranker 

23 batch_size: int = 8 

24 

25 

26class Reranker: 

27 """Re-rank retrieval results for improved relevance. 

28 

29 Methods: 

30 - cross_encoder: Uses sentence-transformers cross-encoder for precision. 

31 - mmr: Maximal Marginal Relevance for diversity. 

32 - llm: Uses an LLM to score relevance of each passage. 

33 """ 

34 

35 def __init__(self, config: Optional[RerankConfig] = None): 

36 self.config = config or RerankConfig() 

37 self._cross_encoder = None 

38 self._embed_fn = None # for MMR diversity 

39 

40 async def rerank( 

41 self, 

42 query: str, 

43 passages: List[Dict[str, Any]], 

44 ) -> List[Dict[str, Any]]: 

45 """Re-rank passages by relevance to query. 

46 

47 Args: 

48 query: Original search query. 

49 passages: List of dicts with 'text' and 'score' keys. 

50 

51 Returns: 

52 Re-ranked list with updated 'rerank_score' key. 

53 """ 

54 if not passages: 

55 return [] 

56 

57 if self.config.method == "cross_encoder": 

58 return await self._cross_encode_rerank(query, passages) 

59 elif self.config.method == "mmr": 

60 return self._mmr_rerank(query, passages) 

61 elif self.config.method == "llm": 

62 return await self._llm_rerank(query, passages) 

63 else: 

64 # diversity: sort by text length variability as proxy 

65 return self._diversity_rerank(passages) 

66 

67 async def _cross_encode_rerank( 

68 self, 

69 query: str, 

70 passages: List[Dict[str, Any]], 

71 ) -> List[Dict[str, Any]]: 

72 """Use cross-encoder model for relevance scoring.""" 

73 try: 

74 from sentence_transformers import CrossEncoder 

75 except ImportError: 

76 # Fallback: use simple heuristic based on term overlap 

77 return self._fallback_rerank(query, passages) 

78 

79 if self._cross_encoder is None: 

80 self._cross_encoder = CrossEncoder(self.config.model) 

81 

82 pairs = [(query, p["text"]) for p in passages] 

83 scores = self._cross_encoder.predict(pairs, batch_size=self.config.batch_size) 

84 

85 for p, s in zip(passages, scores): 

86 p["rerank_score"] = float(s) 

87 p["rerank_method"] = "cross_encoder" 

88 

89 passages.sort(key=lambda x: x.get("rerank_score", 0), reverse=True) 

90 return passages[: self.config.top_n] 

91 

92 def _mmr_rerank( 

93 self, 

94 query: str, 

95 passages: List[Dict[str, Any]], 

96 ) -> List[Dict[str, Any]]: 

97 """Maximal Marginal Relevance: balance relevance with diversity. 

98 

99 Without embeddings, uses Jaccard similarity on token sets as proxy. 

100 """ 

101 if not passages: 

102 return [] 

103 

104 texts = [p["text"] for p in passages] 

105 

106 # Tokenize for diversity computation 

107 token_sets = [] 

108 for t in texts: 

109 import re 

110 tokens = set(re.findall(r'\w+', t.lower())) 

111 token_sets.append(tokens) 

112 

113 query_tokens = set(re.findall(r'\w+', query.lower())) if query else set() 

114 

115 def _jaccard_sim(a: set, b: set) -> float: 

116 if not a or not b: 

117 return 0.0 

118 return len(a & b) / len(a | b) 

119 

120 # Initial relevance scores (original scores or query overlap) 

121 relevance = [] 

122 for i, (p, ts) in enumerate(zip(passages, token_sets)): 

123 if query_tokens: 

124 rel = _jaccard_sim(query_tokens, ts) 

125 else: 

126 rel = p.get("score", 0.0) 

127 relevance.append(rel) 

128 

129 selected = [] 

130 remaining = list(range(len(passages))) 

131 

132 while remaining and len(selected) < self.config.top_n: 

133 best_idx = None 

134 best_score = -float("inf") 

135 

136 for idx in remaining: 

137 diversity = min( 

138 (1.0 - _jaccard_sim(token_sets[idx], token_sets[s])) 

139 for s in selected 

140 ) if selected else 1.0 

141 

142 score = ( 

143 self.config.diversity_lambda * relevance[idx] 

144 + (1 - self.config.diversity_lambda) * diversity 

145 ) 

146 

147 if score > best_score: 

148 best_score = score 

149 best_idx = idx 

150 

151 if best_idx is not None: 

152 selected.append(best_idx) 

153 remaining.remove(best_idx) 

154 else: 

155 break 

156 

157 result = [passages[i] for i in selected] 

158 for i, p in enumerate(result): 

159 p["rerank_score"] = relevance[selected[i]] 

160 p["rerank_method"] = "mmr" 

161 

162 return result 

163 

164 async def _llm_rerank( 

165 self, 

166 query: str, 

167 passages: List[Dict[str, Any]], 

168 ) -> List[Dict[str, Any]]: 

169 """LLM-based reranking: ask an LLM to score passage relevance. 

170 

171 Falls back to fallback heuristic if no LLM is configured. 

172 """ 

173 # This is a framework hook — the actual LLM call is done by the caller 

174 # by injecting an llm_call function or using the default heuristic 

175 return self._fallback_rerank(query, passages) 

176 

177 def _diversity_rerank( 

178 self, 

179 passages: List[Dict[str, Any]], 

180 ) -> List[Dict[str, Any]]: 

181 """Simple diversity reranking: penalize similar-length passages.""" 

182 import re 

183 

184 texts = [p["text"] for p in passages] 

185 token_sets = [] 

186 for t in texts: 

187 token_sets.append(set(re.findall(r'\w+', t.lower()))) 

188 

189 # Score: original score * diversity bonus (penalize similarity to higher-ranked) 

190 scored = [] 

191 for i, p in enumerate(passages): 

192 diversity_penalty = 0.0 

193 for j in range(i): 

194 if token_sets[i] and token_sets[j]: 

195 overlap = len(token_sets[i] & token_sets[j]) / len(token_sets[i] | token_sets[j]) 

196 diversity_penalty += overlap * 0.1 

197 p["rerank_score"] = p.get("score", 0.5) * (1.0 - min(diversity_penalty, 0.5)) 

198 p["rerank_method"] = "diversity" 

199 

200 passages.sort(key=lambda x: x.get("rerank_score", 0), reverse=True) 

201 return passages[: self.config.top_n] 

202 

203 def _fallback_rerank( 

204 self, 

205 query: str, 

206 passages: List[Dict[str, Any]], 

207 ) -> List[Dict[str, Any]]: 

208 """Fallback: simple term-overlap heuristic rerank.""" 

209 import re 

210 

211 query_tokens = set(re.findall(r'\w+', query.lower())) if query else set() 

212 

213 for p in passages: 

214 text_tokens = set(re.findall(r'\w+', p["text"].lower())) 

215 if query_tokens and text_tokens: 

216 overlap = len(query_tokens & text_tokens) / max(len(query_tokens), 1) 

217 p["rerank_score"] = p.get("score", 0.0) * 0.5 + overlap * 0.5 

218 else: 

219 p["rerank_score"] = p.get("score", 0.0) 

220 p["rerank_method"] = "fallback_heuristic" 

221 

222 passages.sort(key=lambda x: x.get("rerank_score", 0), reverse=True) 

223 return passages[: self.config.top_n] 

224 

225 

226class DiversityRanker: 

227 """Diversity-focused reranker for varied search results.""" 

228 

229 def __init__(self, lambda_param: float = 0.6): 

230 self.lambda_param = lambda_param 

231 

232 def rerank( 

233 self, 

234 passages: List[Dict[str, Any]], 

235 top_n: int = 5, 

236 ) -> List[Dict[str, Any]]: 

237 """Maximize result diversity while keeping relevance high.""" 

238 if not passages: 

239 return [] 

240 

241 texts = [p["text"] for p in passages] 

242 import re 

243 

244 token_sets = [] 

245 for t in texts: 

246 token_sets.append(set(re.findall(r'\w+', t.lower()))) 

247 

248 def sim(a: set, b: set) -> float: 

249 if not a or not b: 

250 return 0.0 

251 return len(a & b) / len(a | b) 

252 

253 selected = [0] 

254 remaining = set(range(1, len(passages))) 

255 

256 while len(selected) < min(top_n, len(passages)): 

257 best_idx = -1 

258 best_score = -float("inf") 

259 for idx in remaining: 

260 max_sim = max(sim(token_sets[idx], token_sets[s]) for s in selected) 

261 score = ( 

262 self.lambda_param * passages[idx].get("score", 0.5) 

263 - (1 - self.lambda_param) * max_sim 

264 ) 

265 if score > best_score: 

266 best_score = score 

267 best_idx = idx 

268 if best_idx < 0: 

269 break 

270 selected.append(best_idx) 

271 remaining.remove(best_idx) 

272 

273 return [passages[i] for i in selected]