Coverage for agentos/rag/hybrid.py: 0%
100 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""Hybrid search (dense + sparse) for RAG pipeline.
3Combines dense vector search with BM25 sparse retrieval
4using reciprocal rank fusion (RRF) or weighted score fusion.
5"""
7from __future__ import annotations
9import math
10from dataclasses import dataclass, field
11from typing import Any, Dict, List, Optional, Tuple
14@dataclass
15class HybridConfig:
16 """Configuration for hybrid search."""
17 dense_weight: float = 0.6 # weight for dense scores
18 sparse_weight: float = 0.4 # weight for BM25 scores
19 fusion_method: str = "weighted" # "weighted" | "rrf"
20 rrf_k: int = 60 # RRF constant
21 bm25_k1: float = 1.5 # BM25 term frequency saturation
22 bm25_b: float = 0.75 # BM25 document length normalization
23 top_k_per_source: int = 20 # candidates from each retriever before fusion
26class BM25Retriever:
27 """BM25 sparse retrieval with Okapi BM25 scoring.
29 Works with pre-chunked documents, builds an in-memory inverted index.
30 """
32 def __init__(
33 self,
34 k1: float = 1.5,
35 b: float = 0.75,
36 stop_words: Optional[List[str]] = None,
37 ):
38 self.k1 = k1
39 self.b = b
40 self.stop_words = set(stop_words or _DEFAULT_STOP_WORDS)
41 self._docs: List[str] = []
42 self._doc_lens: List[int] = []
43 self._avgdl: float = 0.0
44 self._df: Dict[str, int] = {} # term -> document frequency
45 self._term_freqs: List[Dict[str, int]] = [] # per-doc term freqs
46 self._built = False
48 def _tokenize(self, text: str) -> List[str]:
49 """Simple whitespace + punctuation tokenization."""
50 import re
51 tokens = re.findall(r'\w+', text.lower())
52 return [t for t in tokens if t not in self.stop_words and len(t) > 1]
54 def index(self, documents: List[str]):
55 """Build BM25 index from documents."""
56 self._docs = documents
57 self._doc_lens = [len(d) for d in documents]
58 self._avgdl = sum(self._doc_lens) / max(len(documents), 1)
59 self._df = {}
60 self._term_freqs = []
62 for doc in documents:
63 tokens = self._tokenize(doc)
64 tf = {}
65 for t in tokens:
66 tf[t] = tf.get(t, 0) + 1
67 self._term_freqs.append(tf)
68 for t in tf:
69 self._df[t] = self._df.get(t, 0) + 1
71 self._built = True
73 def search(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]:
74 """Search and return (doc_index, score) sorted by score descending."""
75 if not self._built:
76 return []
78 query_tokens = self._tokenize(query)
79 idf_cache = {
80 t: math.log(1 + (len(self._docs) - freq + 0.5) / (freq + 0.5))
81 for t, freq in self._df.items()
82 if t in query_tokens
83 }
85 scores = []
86 for i, tf in enumerate(self._term_freqs):
87 score = 0.0
88 for t in query_tokens:
89 if t not in tf:
90 continue
91 idf = idf_cache.get(t, 0.0)
92 f = tf[t]
93 dl = self._doc_lens[i]
94 numerator = f * (self.k1 + 1)
95 denominator = f + self.k1 * (1 - self.b + self.b * dl / self._avgdl)
96 score += idf * numerator / denominator
97 if score > 0:
98 scores.append((i, score))
100 scores.sort(key=lambda x: x[1], reverse=True)
101 return scores[:top_k]
104class HybridRetriever:
105 """Combined dense + sparse retrieval with score fusion.
107 Usage:
108 retriever = HybridRetriever(
109 dense_fn=your_dense_search_fn,
110 bm25=bm25_retriever,
111 )
112 results = await retriever.search(query="how to train a model", top_k=5)
113 """
115 def __init__(
116 self,
117 dense_fn,
118 bm25: Optional[BM25Retriever] = None,
119 config: Optional[HybridConfig] = None,
120 ):
121 self.dense_fn = dense_fn
122 self.bm25 = bm25 or BM25Retriever()
123 self.config = config or HybridConfig()
125 def index_documents(self, documents: List[str]):
126 """Index documents for BM25 sparse retrieval."""
127 self.bm25.index(documents)
129 async def search(
130 self,
131 query: str,
132 top_k: int = 5,
133 ) -> List[Dict[str, Any]]:
134 """Hybrid search combining dense and sparse scores.
136 Returns list of dicts with 'text', 'score', 'dense_score',
137 'sparse_score', 'index'.
138 """
139 # Dense retrieval
140 dense_results = await self.dense_fn(query, self.config.top_k_per_source)
142 # BM25 retrieval
143 bm25_pairs = self.bm25.search(query, self.config.top_k_per_source)
145 # Normalize and fuse scores
146 fused = self._fuse_scores(dense_results, bm25_pairs)
147 fused.sort(key=lambda x: x["score"], reverse=True)
149 return fused[:top_k]
151 def _fuse_scores(
152 self,
153 dense_results: List[Dict[str, Any]],
154 bm25_pairs: List[Tuple[int, float]],
155 ) -> List[Dict[str, Any]]:
156 """Fuse dense and sparse scores using configured method."""
157 # Build lookup: doc_index -> result
158 index_map: Dict[int, Dict[str, Any]] = {}
159 for i, r in enumerate(dense_results):
160 idx = r.get("index", i)
161 index_map[idx] = {
162 "text": r.get("text", ""),
163 "dense_score": r.get("score", 0.0),
164 "sparse_score": 0.0,
165 "index": idx,
166 "metadata": r.get("metadata", {}),
167 "dense_rank": i + 1, # 1-based rank
168 "sparse_rank": 0,
169 }
171 for rank, (idx, bm25_score) in enumerate(bm25_pairs):
172 rank_p1 = rank + 1
173 if idx in index_map:
174 index_map[idx]["sparse_score"] = bm25_score
175 index_map[idx]["sparse_rank"] = rank_p1
176 else:
177 index_map[idx] = {
178 "text": "",
179 "dense_score": 0.0,
180 "sparse_score": bm25_score,
181 "index": idx,
182 "metadata": {},
183 "dense_rank": 0,
184 "sparse_rank": rank_p1,
185 }
187 # Compute fused score
188 for idx, entry in index_map.items():
189 if self.config.fusion_method == "rrf":
190 dr = entry["dense_rank"] or (self.config.top_k_per_source + 1)
191 sr = entry["sparse_rank"] or (self.config.top_k_per_source + 1)
192 entry["score"] = 1.0 / (self.config.rrf_k + dr) + 1.0 / (self.config.rrf_k + sr)
193 else:
194 entry["score"] = (
195 self.config.dense_weight * entry["dense_score"]
196 + self.config.sparse_weight * self._normalize_bm25(entry["sparse_score"])
197 )
199 return list(index_map.values())
201 def _normalize_bm25(self, score: float) -> float:
202 """Simple sigmoid normalization for BM25 scores."""
203 if score <= 0:
204 return 0.0
205 return 2.0 / (1.0 + math.exp(-score / 3.0)) - 1.0
208_DEFAULT_STOP_WORDS = {
209 "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for",
210 "of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
211 "being", "have", "has", "had", "do", "does", "did", "will", "would",
212 "could", "should", "may", "might", "can", "shall", "it", "its", "this",
213 "that", "these", "those", "i", "you", "he", "she", "they", "we", "my",
214 "your", "his", "her", "our", "their", "not", "no", "if", "so", "as",
215 "than", "then", "just", "about", "also", "very", "too", "into", "over",
216 "after", "before", "between", "under", "more", "up", "out", "some",
217 "such", "only", "other", "each", "all", "both", "few", "most",
218}