Coverage for agentos/rag/hybrid.py: 0%

100 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1"""Hybrid search (dense + sparse) for RAG pipeline. 

2 

3Combines dense vector search with BM25 sparse retrieval 

4using reciprocal rank fusion (RRF) or weighted score fusion. 

5""" 

6 

7from __future__ import annotations 

8 

9import math 

10from dataclasses import dataclass, field 

11from typing import Any, Dict, List, Optional, Tuple 

12 

13 

14@dataclass 

15class HybridConfig: 

16 """Configuration for hybrid search.""" 

17 dense_weight: float = 0.6 # weight for dense scores 

18 sparse_weight: float = 0.4 # weight for BM25 scores 

19 fusion_method: str = "weighted" # "weighted" | "rrf" 

20 rrf_k: int = 60 # RRF constant 

21 bm25_k1: float = 1.5 # BM25 term frequency saturation 

22 bm25_b: float = 0.75 # BM25 document length normalization 

23 top_k_per_source: int = 20 # candidates from each retriever before fusion 

24 

25 

26class BM25Retriever: 

27 """BM25 sparse retrieval with Okapi BM25 scoring. 

28 

29 Works with pre-chunked documents, builds an in-memory inverted index. 

30 """ 

31 

32 def __init__( 

33 self, 

34 k1: float = 1.5, 

35 b: float = 0.75, 

36 stop_words: Optional[List[str]] = None, 

37 ): 

38 self.k1 = k1 

39 self.b = b 

40 self.stop_words = set(stop_words or _DEFAULT_STOP_WORDS) 

41 self._docs: List[str] = [] 

42 self._doc_lens: List[int] = [] 

43 self._avgdl: float = 0.0 

44 self._df: Dict[str, int] = {} # term -> document frequency 

45 self._term_freqs: List[Dict[str, int]] = [] # per-doc term freqs 

46 self._built = False 

47 

48 def _tokenize(self, text: str) -> List[str]: 

49 """Simple whitespace + punctuation tokenization.""" 

50 import re 

51 tokens = re.findall(r'\w+', text.lower()) 

52 return [t for t in tokens if t not in self.stop_words and len(t) > 1] 

53 

54 def index(self, documents: List[str]): 

55 """Build BM25 index from documents.""" 

56 self._docs = documents 

57 self._doc_lens = [len(d) for d in documents] 

58 self._avgdl = sum(self._doc_lens) / max(len(documents), 1) 

59 self._df = {} 

60 self._term_freqs = [] 

61 

62 for doc in documents: 

63 tokens = self._tokenize(doc) 

64 tf = {} 

65 for t in tokens: 

66 tf[t] = tf.get(t, 0) + 1 

67 self._term_freqs.append(tf) 

68 for t in tf: 

69 self._df[t] = self._df.get(t, 0) + 1 

70 

71 self._built = True 

72 

73 def search(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]: 

74 """Search and return (doc_index, score) sorted by score descending.""" 

75 if not self._built: 

76 return [] 

77 

78 query_tokens = self._tokenize(query) 

79 idf_cache = { 

80 t: math.log(1 + (len(self._docs) - freq + 0.5) / (freq + 0.5)) 

81 for t, freq in self._df.items() 

82 if t in query_tokens 

83 } 

84 

85 scores = [] 

86 for i, tf in enumerate(self._term_freqs): 

87 score = 0.0 

88 for t in query_tokens: 

89 if t not in tf: 

90 continue 

91 idf = idf_cache.get(t, 0.0) 

92 f = tf[t] 

93 dl = self._doc_lens[i] 

94 numerator = f * (self.k1 + 1) 

95 denominator = f + self.k1 * (1 - self.b + self.b * dl / self._avgdl) 

96 score += idf * numerator / denominator 

97 if score > 0: 

98 scores.append((i, score)) 

99 

100 scores.sort(key=lambda x: x[1], reverse=True) 

101 return scores[:top_k] 

102 

103 

104class HybridRetriever: 

105 """Combined dense + sparse retrieval with score fusion. 

106 

107 Usage: 

108 retriever = HybridRetriever( 

109 dense_fn=your_dense_search_fn, 

110 bm25=bm25_retriever, 

111 ) 

112 results = await retriever.search(query="how to train a model", top_k=5) 

113 """ 

114 

115 def __init__( 

116 self, 

117 dense_fn, 

118 bm25: Optional[BM25Retriever] = None, 

119 config: Optional[HybridConfig] = None, 

120 ): 

121 self.dense_fn = dense_fn 

122 self.bm25 = bm25 or BM25Retriever() 

123 self.config = config or HybridConfig() 

124 

125 def index_documents(self, documents: List[str]): 

126 """Index documents for BM25 sparse retrieval.""" 

127 self.bm25.index(documents) 

128 

129 async def search( 

130 self, 

131 query: str, 

132 top_k: int = 5, 

133 ) -> List[Dict[str, Any]]: 

134 """Hybrid search combining dense and sparse scores. 

135 

136 Returns list of dicts with 'text', 'score', 'dense_score', 

137 'sparse_score', 'index'. 

138 """ 

139 # Dense retrieval 

140 dense_results = await self.dense_fn(query, self.config.top_k_per_source) 

141 

142 # BM25 retrieval 

143 bm25_pairs = self.bm25.search(query, self.config.top_k_per_source) 

144 

145 # Normalize and fuse scores 

146 fused = self._fuse_scores(dense_results, bm25_pairs) 

147 fused.sort(key=lambda x: x["score"], reverse=True) 

148 

149 return fused[:top_k] 

150 

151 def _fuse_scores( 

152 self, 

153 dense_results: List[Dict[str, Any]], 

154 bm25_pairs: List[Tuple[int, float]], 

155 ) -> List[Dict[str, Any]]: 

156 """Fuse dense and sparse scores using configured method.""" 

157 # Build lookup: doc_index -> result 

158 index_map: Dict[int, Dict[str, Any]] = {} 

159 for i, r in enumerate(dense_results): 

160 idx = r.get("index", i) 

161 index_map[idx] = { 

162 "text": r.get("text", ""), 

163 "dense_score": r.get("score", 0.0), 

164 "sparse_score": 0.0, 

165 "index": idx, 

166 "metadata": r.get("metadata", {}), 

167 "dense_rank": i + 1, # 1-based rank 

168 "sparse_rank": 0, 

169 } 

170 

171 for rank, (idx, bm25_score) in enumerate(bm25_pairs): 

172 rank_p1 = rank + 1 

173 if idx in index_map: 

174 index_map[idx]["sparse_score"] = bm25_score 

175 index_map[idx]["sparse_rank"] = rank_p1 

176 else: 

177 index_map[idx] = { 

178 "text": "", 

179 "dense_score": 0.0, 

180 "sparse_score": bm25_score, 

181 "index": idx, 

182 "metadata": {}, 

183 "dense_rank": 0, 

184 "sparse_rank": rank_p1, 

185 } 

186 

187 # Compute fused score 

188 for idx, entry in index_map.items(): 

189 if self.config.fusion_method == "rrf": 

190 dr = entry["dense_rank"] or (self.config.top_k_per_source + 1) 

191 sr = entry["sparse_rank"] or (self.config.top_k_per_source + 1) 

192 entry["score"] = 1.0 / (self.config.rrf_k + dr) + 1.0 / (self.config.rrf_k + sr) 

193 else: 

194 entry["score"] = ( 

195 self.config.dense_weight * entry["dense_score"] 

196 + self.config.sparse_weight * self._normalize_bm25(entry["sparse_score"]) 

197 ) 

198 

199 return list(index_map.values()) 

200 

201 def _normalize_bm25(self, score: float) -> float: 

202 """Simple sigmoid normalization for BM25 scores.""" 

203 if score <= 0: 

204 return 0.0 

205 return 2.0 / (1.0 + math.exp(-score / 3.0)) - 1.0 

206 

207 

208_DEFAULT_STOP_WORDS = { 

209 "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", 

210 "of", "with", "by", "from", "is", "are", "was", "were", "be", "been", 

211 "being", "have", "has", "had", "do", "does", "did", "will", "would", 

212 "could", "should", "may", "might", "can", "shall", "it", "its", "this", 

213 "that", "these", "those", "i", "you", "he", "she", "they", "we", "my", 

214 "your", "his", "her", "our", "their", "not", "no", "if", "so", "as", 

215 "than", "then", "just", "about", "also", "very", "too", "into", "over", 

216 "after", "before", "between", "under", "more", "up", "out", "some", 

217 "such", "only", "other", "each", "all", "both", "few", "most", 

218}