Coverage for agentos/rag/hybrid_search.py: 24%

271 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Hybrid Search + Re-Ranking for RAG (v1.9.0) 

3 

4Production-grade hybrid search combining: 

5 - Dense (semantic) retrieval via embeddings 

6 - Sparse (keyword) retrieval via BM25 

7 - Cross-encoder re-ranking for precision 

8 - Citation tracking with source provenance 

9 - Multi-modal: text, code, markdown, tables 

10 - Fusion algorithms: RRF, weighted sum, cascade 

11 

12Compatible with existing ChromaStore + RAGPipeline. 

13""" 

14 

15from __future__ import annotations 

16 

17import hashlib 

18import json 

19import math 

20import re 

21import time 

22from collections import Counter, defaultdict 

23from dataclasses import dataclass, field 

24from pathlib import Path 

25from typing import Any, Optional, Callable 

26 

27 

28# ── Types ─────────────────────────────────────────────────────────── 

29 

30@dataclass 

31class SearchResult: 

32 """A single search result with metadata.""" 

33 doc_id: str 

34 content: str 

35 source: str = "" # File path, URL, or source identifier 

36 title: str = "" 

37 score: float = 0.0 

38 dense_score: float = 0.0 

39 sparse_score: float = 0.0 

40 rerank_score: float = 0.0 

41 chunk_index: int = 0 

42 metadata: dict[str, Any] = field(default_factory=dict) 

43 citations: list[str] = field(default_factory=list) # Specific sentences/quotes 

44 

45 

46@dataclass 

47class Citation: 

48 """A citation from source material.""" 

49 text: str 

50 source: str 

51 doc_id: str = "" 

52 chunk_index: int = 0 

53 start_pos: int = 0 

54 end_pos: int = 0 

55 confidence: float = 1.0 

56 

57 

58# ── BM25 Sparse Retriever ─────────────────────────────────────────── 

59 

60class BM25Retriever: 

61 """Pure Python BM25 implementation for keyword search. 

62 

63 No external dependencies. Tokenizes, builds inverted index, 

64 and scores documents using Okapi BM25. 

65 """ 

66 

67 def __init__(self, k1: float = 1.5, b: float = 0.75): 

68 self.k1 = k1 

69 self.b = b 

70 self._docs: list[str] = [] 

71 self._doc_ids: list[str] = [] 

72 self._doc_lengths: list[int] = [] 

73 self._avg_dl: float = 0.0 

74 self._inverted_index: dict[str, dict[int, int]] = defaultdict(dict) 

75 self._idf: dict[str, float] = {} 

76 self._N: int = 0 

77 

78 def index(self, documents: list[dict[str, str]]): 

79 """Build BM25 index from documents. 

80 

81 Args: 

82 documents: List of {id, content} dicts. 

83 """ 

84 self._docs = [doc.get("content", "") for doc in documents] 

85 self._doc_ids = [doc.get("id", f"doc_{i}") for i, doc in enumerate(documents)] 

86 self._doc_lengths = [len(self._tokenize(doc)) for doc in self._docs] 

87 self._N = len(self._docs) 

88 self._avg_dl = sum(self._doc_lengths) / max(self._N, 1) 

89 

90 # Build inverted index 

91 self._inverted_index.clear() 

92 doc_freq: dict[str, int] = defaultdict(int) 

93 

94 for doc_id, doc in enumerate(self._docs): 

95 tokens = self._tokenize(doc) 

96 token_counts = Counter(tokens) 

97 for token, count in token_counts.items(): 

98 self._inverted_index[token][doc_id] = count 

99 doc_freq[token] += 1 

100 

101 # Compute IDF 

102 self._idf = { 

103 token: math.log(1 + (self._N - freq + 0.5) / (freq + 0.5)) 

104 for token, freq in doc_freq.items() 

105 } 

106 

107 def search(self, query: str, top_k: int = 10) -> list[SearchResult]: 

108 """BM25 keyword search.""" 

109 if not self._docs: 

110 return [] 

111 

112 query_tokens = self._tokenize(query) 

113 scores: list[float] = [0.0] * self._N 

114 

115 for token in query_tokens: 

116 if token not in self._inverted_index: 

117 continue 

118 idf = self._idf.get(token, 0) 

119 for doc_id, tf in self._inverted_index[token].items(): 

120 dl = self._doc_lengths[doc_id] 

121 numerator = tf * (self.k1 + 1) 

122 denominator = tf + self.k1 * (1 - self.b + self.b * dl / max(self._avg_dl, 1)) 

123 scores[doc_id] += idf * numerator / max(denominator, 1e-9) 

124 

125 # Rank and return 

126 ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) 

127 max_score = ranked[0][1] if ranked else 1.0 

128 

129 return [ 

130 SearchResult( 

131 doc_id=self._doc_ids[doc_id], 

132 content=self._docs[doc_id][:500], 

133 sparse_score=score / max(max_score, 1e-9), 

134 score=score / max(max_score, 1e-9), 

135 metadata={"method": "bm25"}, 

136 ) 

137 for doc_id, score in ranked[:top_k] if score > 0 

138 ] 

139 

140 def _tokenize(self, text: str) -> list[str]: 

141 """Simple tokenization: lowercase, split on non-alphanumeric, filter short tokens.""" 

142 tokens = re.findall(r'[\w\u4e00-\u9fff]+', text.lower()) 

143 return [t for t in tokens if len(t) > 1] 

144 

145 

146# ── Dense Retriever ───────────────────────────────────────────────── 

147 

148class DenseRetriever: 

149 """Semantic search via embeddings. 

150 

151 Wraps an embedding function (e.g., OpenAI embeddings, sentence-transformers) 

152 and a vector store (ChromaDB or similar). 

153 """ 

154 

155 def __init__( 

156 self, 

157 vector_store=None, 

158 embed_fn: Optional[Callable[[str], list[float]]] = None, 

159 ): 

160 self._store = vector_store 

161 self._embed = embed_fn 

162 

163 async def search(self, query: str, top_k: int = 10) -> list[SearchResult]: 

164 """Dense vector search.""" 

165 if not self._store: 

166 return [] 

167 

168 try: 

169 results = await self._store.search(query, top_k=top_k) 

170 

171 max_score = results[0].get("score", 1.0) if results else 1.0 

172 

173 return [ 

174 SearchResult( 

175 doc_id=result.get("id", ""), 

176 content=result.get("content", "")[:500], 

177 dense_score=result.get("score", 0) / max(max_score, 1e-9), 

178 score=result.get("score", 0) / max(max_score, 1e-9), 

179 metadata=result.get("metadata", {}), 

180 ) 

181 for result in results 

182 ] 

183 except Exception: 

184 return [] 

185 

186 

187# ── Cross-Encoder Re-Ranker ───────────────────────────────────────── 

188 

189class CrossEncoderReranker: 

190 """Re-rank search results with a cross-encoder model. 

191 

192 Instead of embedding query and documents independently (bi-encoder), 

193 a cross-encoder processes (query, document) pairs together for higher 

194 accuracy — at the cost of more computation. 

195 

196 Supports: 

197 - HuggingFace cross-encoder models (e.g., ms-marco-MiniLM) 

198 - Custom scoring functions 

199 - LLM-based re-ranking (use an LLM to judge relevance) 

200 """ 

201 

202 def __init__( 

203 self, 

204 model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", 

205 use_llm: bool = False, 

206 llm_client=None, 

207 ): 

208 self._model_name = model_name 

209 self._model = None 

210 self._use_llm = use_llm 

211 self._llm = llm_client 

212 

213 async def rerank( 

214 self, 

215 query: str, 

216 candidates: list[SearchResult], 

217 top_k: int = 5, 

218 ) -> list[SearchResult]: 

219 """Re-rank candidates by relevance to query. 

220 

221 Args: 

222 query: Original search query 

223 candidates: Initial retrieval results 

224 top_k: Number of results to return after re-ranking 

225 

226 Returns: 

227 Re-ranked candidates with updated rerank_score. 

228 """ 

229 if not candidates: 

230 return [] 

231 

232 if self._use_llm and self._llm: 

233 return await self._llm_rerank(query, candidates, top_k) 

234 else: 

235 return await self._cross_encoder_rerank(query, candidates, top_k) 

236 

237 async def _cross_encoder_rerank( 

238 self, 

239 query: str, 

240 candidates: list[SearchResult], 

241 top_k: int, 

242 ) -> list[SearchResult]: 

243 """Re-rank using HuggingFace cross-encoder.""" 

244 try: 

245 from sentence_transformers import CrossEncoder 

246 if self._model is None: 

247 self._model = CrossEncoder(self._model_name) 

248 

249 pairs = [(query, c.content[:1000]) for c in candidates] 

250 scores = self._model.predict(pairs) 

251 

252 for candidate, score in zip(candidates, scores): 

253 candidate.rerank_score = float(score) 

254 # Weighted fusion 

255 candidate.score = ( 

256 candidate.dense_score * 0.3 + 

257 candidate.sparse_score * 0.2 + 

258 float(score) * 0.5 

259 ) 

260 

261 candidates.sort(key=lambda x: x.rerank_score, reverse=True) 

262 return candidates[:top_k] 

263 

264 except ImportError: 

265 return candidates[:top_k] # Fallback: no re-ranking 

266 

267 async def _llm_rerank( 

268 self, 

269 query: str, 

270 candidates: list[SearchResult], 

271 top_k: int, 

272 ) -> list[SearchResult]: 

273 """Re-rank using LLM relevance judgment.""" 

274 if not self._llm: 

275 return candidates[:top_k] 

276 

277 prompt = f"Query: {query}\n\nRate each document's relevance on a scale of 0-10:\n\n" 

278 for i, c in enumerate(candidates[:20]): 

279 prompt += f"[{i}] {c.content[:300]}\n\n" 

280 prompt += "Output format: [doc_id] score" 

281 

282 try: 

283 response = await self._llm.complete(prompt) 

284 # Parse scores 

285 scores: dict[int, float] = {} 

286 for line in response.split("\n"): 

287 match = re.match(r'\[(\d+)\]\s*(\d+(?:\.\d+)?)', line.strip()) 

288 if match: 

289 idx = int(match.group(1)) 

290 score = float(match.group(2)) / 10.0 

291 if idx < len(candidates): 

292 scores[idx] = score 

293 

294 for i, candidate in enumerate(candidates): 

295 candidate.rerank_score = scores.get(i, 0.5) 

296 candidate.score = ( 

297 candidate.dense_score * 0.25 + 

298 candidate.sparse_score * 0.15 + 

299 candidate.rerank_score * 0.6 

300 ) 

301 

302 candidates.sort(key=lambda x: x.rerank_score, reverse=True) 

303 return candidates[:top_k] 

304 

305 except Exception: 

306 return candidates[:top_k] 

307 

308 

309# ── Fusion Algorithms ─────────────────────────────────────────────── 

310 

311class FusionMethod: 

312 """Collection of rank fusion algorithms.""" 

313 

314 @staticmethod 

315 def reciprocal_rank_fusion( 

316 dense_results: list[SearchResult], 

317 sparse_results: list[SearchResult], 

318 k: int = 60, 

319 ) -> list[SearchResult]: 

320 """RRF: Reciprocal Rank Fusion. 

321 

322 RRF_score(d) = sum_{ranker} 1 / (k + rank(d)) 

323 """ 

324 scores: dict[str, float] = {} 

325 docs: dict[str, SearchResult] = {} 

326 

327 for rank, result in enumerate(dense_results): 

328 scores[result.doc_id] = 1.0 / (k + rank + 1) 

329 docs[result.doc_id] = result 

330 

331 for rank, result in enumerate(sparse_results): 

332 if result.doc_id in scores: 

333 scores[result.doc_id] += 1.0 / (k + rank + 1) 

334 else: 

335 scores[result.doc_id] = 1.0 / (k + rank + 1) 

336 docs[result.doc_id] = result 

337 

338 fused = sorted(scores.items(), key=lambda x: x[1], reverse=True) 

339 results = [] 

340 for doc_id, score in fused: 

341 doc = docs[doc_id] 

342 doc.score = score 

343 results.append(doc) 

344 

345 return results 

346 

347 @staticmethod 

348 def weighted_sum( 

349 dense_results: list[SearchResult], 

350 sparse_results: list[SearchResult], 

351 dense_weight: float = 0.6, 

352 sparse_weight: float = 0.4, 

353 ) -> list[SearchResult]: 

354 """Weighted score summation.""" 

355 scores: dict[str, list[float]] = defaultdict(list) 

356 docs: dict[str, SearchResult] = {} 

357 

358 for result in dense_results: 

359 scores[result.doc_id].append(result.dense_score * dense_weight) 

360 docs[result.doc_id] = result 

361 

362 for result in sparse_results: 

363 scores[result.doc_id].append(result.sparse_score * sparse_weight) 

364 if result.doc_id not in docs: 

365 docs[result.doc_id] = result 

366 

367 fused = [] 

368 for doc_id, wscores in scores.items(): 

369 doc = docs[doc_id] 

370 doc.score = sum(wscores) 

371 fused.append(doc) 

372 

373 fused.sort(key=lambda x: x.score, reverse=True) 

374 return fused 

375 

376 @staticmethod 

377 def cascade( 

378 dense_results: list[SearchResult], 

379 sparse_results: list[SearchResult], 

380 ) -> list[SearchResult]: 

381 """Cascade: dense first, then sparse fills gaps.""" 

382 seen: set[str] = set() 

383 results: list[SearchResult] = [] 

384 

385 for r in dense_results: 

386 results.append(r) 

387 seen.add(r.doc_id) 

388 

389 for r in sparse_results: 

390 if r.doc_id not in seen: 

391 results.append(r) 

392 seen.add(r.doc_id) 

393 

394 return results 

395 

396 

397# ── Citation Tracker ──────────────────────────────────────────────── 

398 

399class CitationTracker: 

400 """Track and verify citations from source documents. 

401 

402 Key features: 

403 - Extract citations from generated text 

404 - Verify against source documents 

405 - Mark unverifiable (potential hallucination) 

406 - Track citation usage statistics 

407 """ 

408 

409 def __init__(self): 

410 self._citations: list[Citation] = [] 

411 self._source_index: dict[str, dict] = {} # doc_id → metadata 

412 

413 def add_source(self, doc_id: str, content: str, metadata: dict[str, Any] | None = None): 

414 """Register a source document.""" 

415 self._source_index[doc_id] = { 

416 "content": content, 

417 "metadata": metadata or {}, 

418 } 

419 

420 def extract_citations(self, text: str, sources: list[SearchResult]) -> list[Citation]: 

421 """Extract and verify citations from generated text. 

422 

423 Args: 

424 text: Generated response text 

425 sources: Source documents used for generation 

426 

427 Returns: 

428 List of verified Citation objects. 

429 """ 

430 citations: list[Citation] = [] 

431 

432 for source in sources: 

433 # Find substrings of generated text that appear in source 

434 source_content = source.content.lower() 

435 text_lower = text.lower() 

436 

437 # Extract sentences from generated text 

438 sentences = re.split(r'[.!?]+', text) 

439 for sent in sentences: 

440 sent = sent.strip() 

441 if len(sent) < 15: 

442 continue 

443 

444 # Check if this sentence appears in source (with fuzzy matching) 

445 if self._is_from_source(sent.lower(), source_content): 

446 citations.append(Citation( 

447 text=sent, 

448 source=source.source or source.doc_id, 

449 doc_id=source.doc_id, 

450 chunk_index=source.chunk_index, 

451 confidence=0.9, 

452 )) 

453 

454 # Deduplicate 

455 seen: set[str] = set() 

456 unique = [] 

457 for c in citations: 

458 key = c.text[:50] 

459 if key not in seen: 

460 seen.add(key) 

461 unique.append(c) 

462 

463 self._citations.extend(unique) 

464 return unique 

465 

466 def verify(self, text: str, sources: list[SearchResult]) -> dict[str, Any]: 

467 """Verify all claims in text against source documents. 

468 

469 Returns: 

470 Dict with verified/unverified segments and hallucination score. 

471 """ 

472 citations = self.extract_citations(text, sources) 

473 

474 sentences = re.split(r'[.!?]+', text) 

475 total_sentences = len(sentences) 

476 cited_sentences = sum( 

477 1 for s in sentences 

478 if any(c.text[:30].lower() in s.strip().lower() for c in citations) 

479 ) 

480 

481 uncited = total_sentences - cited_sentences 

482 hallucination_risk = uncited / max(total_sentences, 1) 

483 

484 return { 

485 "total_sentences": total_sentences, 

486 "cited_sentences": cited_sentences, 

487 "uncited_sentences": uncited, 

488 "hallucination_risk": round(hallucination_risk, 3), 

489 "citations": [ 

490 {"text": c.text[:100], "source": c.source, "confidence": c.confidence} 

491 for c in citations[:10] 

492 ], 

493 "status": "clean" if hallucination_risk < 0.3 else "medium_risk" if hallucination_risk < 0.6 else "high_risk", 

494 } 

495 

496 def _is_from_source(self, text: str, source: str, threshold: float = 0.6) -> bool: 

497 """Check if text originated from source using substring and word overlap.""" 

498 if text in source: 

499 return True 

500 

501 text_words = set(text.split()) 

502 source_words = set(source.split()) 

503 if not text_words: 

504 return False 

505 

506 overlap = len(text_words & source_words) / len(text_words) 

507 return overlap >= threshold 

508 

509 def get_stats(self) -> dict[str, Any]: 

510 """Get citation statistics.""" 

511 return { 

512 "total_citations": len(self._citations), 

513 "by_source": Counter(c.source for c in self._citations), 

514 "avg_confidence": ( 

515 sum(c.confidence for c in self._citations) / len(self._citations) 

516 if self._citations else 0 

517 ), 

518 "sources_indexed": len(self._source_index), 

519 } 

520 

521 

522# ── Hybrid Search Engine ──────────────────────────────────────────── 

523 

524class HybridSearchEngine: 

525 """Unified hybrid search engine. 

526 

527 Combines dense + sparse retrieval with fusion and re-ranking. 

528 

529 Usage: 

530 engine = HybridSearchEngine( 

531 dense_retriever=DenseRetriever(vector_store=chroma_store), 

532 sparse_retriever=BM25Retriever(), 

533 ) 

534 

535 # Index documents 

536 engine.index_sparse(documents) 

537 

538 # Hybrid search 

539 results = await engine.search("How to implement retry logic?") 

540 for r in results: 

541 print(f"{r.score:.3f} | {r.content[:100]}") 

542 """ 

543 

544 def __init__( 

545 self, 

546 dense_retriever: Optional[DenseRetriever] = None, 

547 sparse_retriever: Optional[BM25Retriever] = None, 

548 reranker: Optional[CrossEncoderReranker] = None, 

549 citation_tracker: Optional[CitationTracker] = None, 

550 fusion_method: str = "rrf", 

551 dense_weight: float = 0.6, 

552 ): 

553 self.dense = dense_retriever or DenseRetriever() 

554 self.sparse = sparse_retriever or BM25Retriever() 

555 self.reranker = reranker or CrossEncoderReranker() 

556 self.citations = citation_tracker or CitationTracker() 

557 

558 self.fusion_method = fusion_method 

559 self.dense_weight = dense_weight 

560 

561 def index_sparse(self, documents: list[dict[str, str]]): 

562 """Build sparse index from documents.""" 

563 self.sparse.index(documents) 

564 for doc in documents: 

565 self.citations.add_source( 

566 doc_id=doc.get("id", ""), 

567 content=doc.get("content", ""), 

568 metadata=doc.get("metadata"), 

569 ) 

570 

571 async def search( 

572 self, 

573 query: str, 

574 top_k: int = 10, 

575 rerank: bool = True, 

576 return_citations: bool = False, 

577 ) -> list[SearchResult]: 

578 """Hybrid search: dense + sparse → fusion → rerank. 

579 

580 Args: 

581 query: Search query 

582 top_k: Number of results 

583 rerank: Whether to apply re-ranking 

584 return_citations: Whether to attach citation info 

585 

586 Returns: 

587 Ranked SearchResults. 

588 """ 

589 # Step 1: Parallel retrieval 

590 dense_results = await self.dense.search(query, top_k=top_k * 2) 

591 sparse_results = self.sparse.search(query, top_k=top_k * 2) 

592 

593 # Step 2: Fusion 

594 if self.fusion_method == "rrf": 

595 fused = FusionMethod.reciprocal_rank_fusion(dense_results, sparse_results) 

596 elif self.fusion_method == "cascade": 

597 fused = FusionMethod.cascade(dense_results, sparse_results) 

598 else: # weighted_sum 

599 fused = FusionMethod.weighted_sum( 

600 dense_results, sparse_results, 

601 dense_weight=self.dense_weight, 

602 sparse_weight=1.0 - self.dense_weight, 

603 ) 

604 

605 # Step 3: Re-rank (optional) 

606 if rerank and len(fused) > top_k: 

607 fused = await self.reranker.rerank(query, fused, top_k=top_k) 

608 else: 

609 fused = fused[:top_k] 

610 

611 # Step 4: Attach citations (optional) 

612 if return_citations and fused: 

613 for result in fused: 

614 result.citations = [ 

615 c.text for c in self.citations.extract_citations( 

616 result.content, [result] 

617 ) 

618 ] 

619 

620 return fused 

621 

622 async def search_with_citations( 

623 self, 

624 query: str, 

625 top_k: int = 10, 

626 ) -> dict[str, Any]: 

627 """Search and return both results and verified citations.""" 

628 results = await self.search(query, top_k=top_k, return_citations=True) 

629 

630 # Build combined text from top results 

631 combined = "\n\n".join(r.content for r in results) 

632 

633 # Verify citations 

634 verification = self.citations.verify(combined, results) 

635 

636 return { 

637 "results": results, 

638 "verification": verification, 

639 "top_result": results[0] if results else None, 

640 "citation_stats": self.citations.get_stats(), 

641 } 

642 

643 def get_stats(self) -> dict[str, Any]: 

644 """Get search engine statistics.""" 

645 return { 

646 "bm25_documents": self.sparse._N if self.sparse else 0, 

647 "bm25_vocabulary": len(self.sparse._idf) if self.sparse else 0, 

648 "citation_stats": self.citations.get_stats(), 

649 "fusion_method": self.fusion_method, 

650 }