Coverage for agentos/rag/hybrid_search.py: 24%
271 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Hybrid Search + Re-Ranking for RAG (v1.9.0)
4Production-grade hybrid search combining:
5 - Dense (semantic) retrieval via embeddings
6 - Sparse (keyword) retrieval via BM25
7 - Cross-encoder re-ranking for precision
8 - Citation tracking with source provenance
9 - Multi-modal: text, code, markdown, tables
10 - Fusion algorithms: RRF, weighted sum, cascade
12Compatible with existing ChromaStore + RAGPipeline.
13"""
15from __future__ import annotations
17import hashlib
18import json
19import math
20import re
21import time
22from collections import Counter, defaultdict
23from dataclasses import dataclass, field
24from pathlib import Path
25from typing import Any, Optional, Callable
28# ── Types ───────────────────────────────────────────────────────────
30@dataclass
31class SearchResult:
32 """A single search result with metadata."""
33 doc_id: str
34 content: str
35 source: str = "" # File path, URL, or source identifier
36 title: str = ""
37 score: float = 0.0
38 dense_score: float = 0.0
39 sparse_score: float = 0.0
40 rerank_score: float = 0.0
41 chunk_index: int = 0
42 metadata: dict[str, Any] = field(default_factory=dict)
43 citations: list[str] = field(default_factory=list) # Specific sentences/quotes
46@dataclass
47class Citation:
48 """A citation from source material."""
49 text: str
50 source: str
51 doc_id: str = ""
52 chunk_index: int = 0
53 start_pos: int = 0
54 end_pos: int = 0
55 confidence: float = 1.0
58# ── BM25 Sparse Retriever ───────────────────────────────────────────
60class BM25Retriever:
61 """Pure Python BM25 implementation for keyword search.
63 No external dependencies. Tokenizes, builds inverted index,
64 and scores documents using Okapi BM25.
65 """
67 def __init__(self, k1: float = 1.5, b: float = 0.75):
68 self.k1 = k1
69 self.b = b
70 self._docs: list[str] = []
71 self._doc_ids: list[str] = []
72 self._doc_lengths: list[int] = []
73 self._avg_dl: float = 0.0
74 self._inverted_index: dict[str, dict[int, int]] = defaultdict(dict)
75 self._idf: dict[str, float] = {}
76 self._N: int = 0
78 def index(self, documents: list[dict[str, str]]):
79 """Build BM25 index from documents.
81 Args:
82 documents: List of {id, content} dicts.
83 """
84 self._docs = [doc.get("content", "") for doc in documents]
85 self._doc_ids = [doc.get("id", f"doc_{i}") for i, doc in enumerate(documents)]
86 self._doc_lengths = [len(self._tokenize(doc)) for doc in self._docs]
87 self._N = len(self._docs)
88 self._avg_dl = sum(self._doc_lengths) / max(self._N, 1)
90 # Build inverted index
91 self._inverted_index.clear()
92 doc_freq: dict[str, int] = defaultdict(int)
94 for doc_id, doc in enumerate(self._docs):
95 tokens = self._tokenize(doc)
96 token_counts = Counter(tokens)
97 for token, count in token_counts.items():
98 self._inverted_index[token][doc_id] = count
99 doc_freq[token] += 1
101 # Compute IDF
102 self._idf = {
103 token: math.log(1 + (self._N - freq + 0.5) / (freq + 0.5))
104 for token, freq in doc_freq.items()
105 }
107 def search(self, query: str, top_k: int = 10) -> list[SearchResult]:
108 """BM25 keyword search."""
109 if not self._docs:
110 return []
112 query_tokens = self._tokenize(query)
113 scores: list[float] = [0.0] * self._N
115 for token in query_tokens:
116 if token not in self._inverted_index:
117 continue
118 idf = self._idf.get(token, 0)
119 for doc_id, tf in self._inverted_index[token].items():
120 dl = self._doc_lengths[doc_id]
121 numerator = tf * (self.k1 + 1)
122 denominator = tf + self.k1 * (1 - self.b + self.b * dl / max(self._avg_dl, 1))
123 scores[doc_id] += idf * numerator / max(denominator, 1e-9)
125 # Rank and return
126 ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
127 max_score = ranked[0][1] if ranked else 1.0
129 return [
130 SearchResult(
131 doc_id=self._doc_ids[doc_id],
132 content=self._docs[doc_id][:500],
133 sparse_score=score / max(max_score, 1e-9),
134 score=score / max(max_score, 1e-9),
135 metadata={"method": "bm25"},
136 )
137 for doc_id, score in ranked[:top_k] if score > 0
138 ]
140 def _tokenize(self, text: str) -> list[str]:
141 """Simple tokenization: lowercase, split on non-alphanumeric, filter short tokens."""
142 tokens = re.findall(r'[\w\u4e00-\u9fff]+', text.lower())
143 return [t for t in tokens if len(t) > 1]
146# ── Dense Retriever ─────────────────────────────────────────────────
148class DenseRetriever:
149 """Semantic search via embeddings.
151 Wraps an embedding function (e.g., OpenAI embeddings, sentence-transformers)
152 and a vector store (ChromaDB or similar).
153 """
155 def __init__(
156 self,
157 vector_store=None,
158 embed_fn: Optional[Callable[[str], list[float]]] = None,
159 ):
160 self._store = vector_store
161 self._embed = embed_fn
163 async def search(self, query: str, top_k: int = 10) -> list[SearchResult]:
164 """Dense vector search."""
165 if not self._store:
166 return []
168 try:
169 results = await self._store.search(query, top_k=top_k)
171 max_score = results[0].get("score", 1.0) if results else 1.0
173 return [
174 SearchResult(
175 doc_id=result.get("id", ""),
176 content=result.get("content", "")[:500],
177 dense_score=result.get("score", 0) / max(max_score, 1e-9),
178 score=result.get("score", 0) / max(max_score, 1e-9),
179 metadata=result.get("metadata", {}),
180 )
181 for result in results
182 ]
183 except Exception:
184 return []
187# ── Cross-Encoder Re-Ranker ─────────────────────────────────────────
189class CrossEncoderReranker:
190 """Re-rank search results with a cross-encoder model.
192 Instead of embedding query and documents independently (bi-encoder),
193 a cross-encoder processes (query, document) pairs together for higher
194 accuracy — at the cost of more computation.
196 Supports:
197 - HuggingFace cross-encoder models (e.g., ms-marco-MiniLM)
198 - Custom scoring functions
199 - LLM-based re-ranking (use an LLM to judge relevance)
200 """
202 def __init__(
203 self,
204 model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
205 use_llm: bool = False,
206 llm_client=None,
207 ):
208 self._model_name = model_name
209 self._model = None
210 self._use_llm = use_llm
211 self._llm = llm_client
213 async def rerank(
214 self,
215 query: str,
216 candidates: list[SearchResult],
217 top_k: int = 5,
218 ) -> list[SearchResult]:
219 """Re-rank candidates by relevance to query.
221 Args:
222 query: Original search query
223 candidates: Initial retrieval results
224 top_k: Number of results to return after re-ranking
226 Returns:
227 Re-ranked candidates with updated rerank_score.
228 """
229 if not candidates:
230 return []
232 if self._use_llm and self._llm:
233 return await self._llm_rerank(query, candidates, top_k)
234 else:
235 return await self._cross_encoder_rerank(query, candidates, top_k)
237 async def _cross_encoder_rerank(
238 self,
239 query: str,
240 candidates: list[SearchResult],
241 top_k: int,
242 ) -> list[SearchResult]:
243 """Re-rank using HuggingFace cross-encoder."""
244 try:
245 from sentence_transformers import CrossEncoder
246 if self._model is None:
247 self._model = CrossEncoder(self._model_name)
249 pairs = [(query, c.content[:1000]) for c in candidates]
250 scores = self._model.predict(pairs)
252 for candidate, score in zip(candidates, scores):
253 candidate.rerank_score = float(score)
254 # Weighted fusion
255 candidate.score = (
256 candidate.dense_score * 0.3 +
257 candidate.sparse_score * 0.2 +
258 float(score) * 0.5
259 )
261 candidates.sort(key=lambda x: x.rerank_score, reverse=True)
262 return candidates[:top_k]
264 except ImportError:
265 return candidates[:top_k] # Fallback: no re-ranking
267 async def _llm_rerank(
268 self,
269 query: str,
270 candidates: list[SearchResult],
271 top_k: int,
272 ) -> list[SearchResult]:
273 """Re-rank using LLM relevance judgment."""
274 if not self._llm:
275 return candidates[:top_k]
277 prompt = f"Query: {query}\n\nRate each document's relevance on a scale of 0-10:\n\n"
278 for i, c in enumerate(candidates[:20]):
279 prompt += f"[{i}] {c.content[:300]}\n\n"
280 prompt += "Output format: [doc_id] score"
282 try:
283 response = await self._llm.complete(prompt)
284 # Parse scores
285 scores: dict[int, float] = {}
286 for line in response.split("\n"):
287 match = re.match(r'\[(\d+)\]\s*(\d+(?:\.\d+)?)', line.strip())
288 if match:
289 idx = int(match.group(1))
290 score = float(match.group(2)) / 10.0
291 if idx < len(candidates):
292 scores[idx] = score
294 for i, candidate in enumerate(candidates):
295 candidate.rerank_score = scores.get(i, 0.5)
296 candidate.score = (
297 candidate.dense_score * 0.25 +
298 candidate.sparse_score * 0.15 +
299 candidate.rerank_score * 0.6
300 )
302 candidates.sort(key=lambda x: x.rerank_score, reverse=True)
303 return candidates[:top_k]
305 except Exception:
306 return candidates[:top_k]
309# ── Fusion Algorithms ───────────────────────────────────────────────
311class FusionMethod:
312 """Collection of rank fusion algorithms."""
314 @staticmethod
315 def reciprocal_rank_fusion(
316 dense_results: list[SearchResult],
317 sparse_results: list[SearchResult],
318 k: int = 60,
319 ) -> list[SearchResult]:
320 """RRF: Reciprocal Rank Fusion.
322 RRF_score(d) = sum_{ranker} 1 / (k + rank(d))
323 """
324 scores: dict[str, float] = {}
325 docs: dict[str, SearchResult] = {}
327 for rank, result in enumerate(dense_results):
328 scores[result.doc_id] = 1.0 / (k + rank + 1)
329 docs[result.doc_id] = result
331 for rank, result in enumerate(sparse_results):
332 if result.doc_id in scores:
333 scores[result.doc_id] += 1.0 / (k + rank + 1)
334 else:
335 scores[result.doc_id] = 1.0 / (k + rank + 1)
336 docs[result.doc_id] = result
338 fused = sorted(scores.items(), key=lambda x: x[1], reverse=True)
339 results = []
340 for doc_id, score in fused:
341 doc = docs[doc_id]
342 doc.score = score
343 results.append(doc)
345 return results
347 @staticmethod
348 def weighted_sum(
349 dense_results: list[SearchResult],
350 sparse_results: list[SearchResult],
351 dense_weight: float = 0.6,
352 sparse_weight: float = 0.4,
353 ) -> list[SearchResult]:
354 """Weighted score summation."""
355 scores: dict[str, list[float]] = defaultdict(list)
356 docs: dict[str, SearchResult] = {}
358 for result in dense_results:
359 scores[result.doc_id].append(result.dense_score * dense_weight)
360 docs[result.doc_id] = result
362 for result in sparse_results:
363 scores[result.doc_id].append(result.sparse_score * sparse_weight)
364 if result.doc_id not in docs:
365 docs[result.doc_id] = result
367 fused = []
368 for doc_id, wscores in scores.items():
369 doc = docs[doc_id]
370 doc.score = sum(wscores)
371 fused.append(doc)
373 fused.sort(key=lambda x: x.score, reverse=True)
374 return fused
376 @staticmethod
377 def cascade(
378 dense_results: list[SearchResult],
379 sparse_results: list[SearchResult],
380 ) -> list[SearchResult]:
381 """Cascade: dense first, then sparse fills gaps."""
382 seen: set[str] = set()
383 results: list[SearchResult] = []
385 for r in dense_results:
386 results.append(r)
387 seen.add(r.doc_id)
389 for r in sparse_results:
390 if r.doc_id not in seen:
391 results.append(r)
392 seen.add(r.doc_id)
394 return results
397# ── Citation Tracker ────────────────────────────────────────────────
399class CitationTracker:
400 """Track and verify citations from source documents.
402 Key features:
403 - Extract citations from generated text
404 - Verify against source documents
405 - Mark unverifiable (potential hallucination)
406 - Track citation usage statistics
407 """
409 def __init__(self):
410 self._citations: list[Citation] = []
411 self._source_index: dict[str, dict] = {} # doc_id → metadata
413 def add_source(self, doc_id: str, content: str, metadata: dict[str, Any] | None = None):
414 """Register a source document."""
415 self._source_index[doc_id] = {
416 "content": content,
417 "metadata": metadata or {},
418 }
420 def extract_citations(self, text: str, sources: list[SearchResult]) -> list[Citation]:
421 """Extract and verify citations from generated text.
423 Args:
424 text: Generated response text
425 sources: Source documents used for generation
427 Returns:
428 List of verified Citation objects.
429 """
430 citations: list[Citation] = []
432 for source in sources:
433 # Find substrings of generated text that appear in source
434 source_content = source.content.lower()
435 text_lower = text.lower()
437 # Extract sentences from generated text
438 sentences = re.split(r'[.!?]+', text)
439 for sent in sentences:
440 sent = sent.strip()
441 if len(sent) < 15:
442 continue
444 # Check if this sentence appears in source (with fuzzy matching)
445 if self._is_from_source(sent.lower(), source_content):
446 citations.append(Citation(
447 text=sent,
448 source=source.source or source.doc_id,
449 doc_id=source.doc_id,
450 chunk_index=source.chunk_index,
451 confidence=0.9,
452 ))
454 # Deduplicate
455 seen: set[str] = set()
456 unique = []
457 for c in citations:
458 key = c.text[:50]
459 if key not in seen:
460 seen.add(key)
461 unique.append(c)
463 self._citations.extend(unique)
464 return unique
466 def verify(self, text: str, sources: list[SearchResult]) -> dict[str, Any]:
467 """Verify all claims in text against source documents.
469 Returns:
470 Dict with verified/unverified segments and hallucination score.
471 """
472 citations = self.extract_citations(text, sources)
474 sentences = re.split(r'[.!?]+', text)
475 total_sentences = len(sentences)
476 cited_sentences = sum(
477 1 for s in sentences
478 if any(c.text[:30].lower() in s.strip().lower() for c in citations)
479 )
481 uncited = total_sentences - cited_sentences
482 hallucination_risk = uncited / max(total_sentences, 1)
484 return {
485 "total_sentences": total_sentences,
486 "cited_sentences": cited_sentences,
487 "uncited_sentences": uncited,
488 "hallucination_risk": round(hallucination_risk, 3),
489 "citations": [
490 {"text": c.text[:100], "source": c.source, "confidence": c.confidence}
491 for c in citations[:10]
492 ],
493 "status": "clean" if hallucination_risk < 0.3 else "medium_risk" if hallucination_risk < 0.6 else "high_risk",
494 }
496 def _is_from_source(self, text: str, source: str, threshold: float = 0.6) -> bool:
497 """Check if text originated from source using substring and word overlap."""
498 if text in source:
499 return True
501 text_words = set(text.split())
502 source_words = set(source.split())
503 if not text_words:
504 return False
506 overlap = len(text_words & source_words) / len(text_words)
507 return overlap >= threshold
509 def get_stats(self) -> dict[str, Any]:
510 """Get citation statistics."""
511 return {
512 "total_citations": len(self._citations),
513 "by_source": Counter(c.source for c in self._citations),
514 "avg_confidence": (
515 sum(c.confidence for c in self._citations) / len(self._citations)
516 if self._citations else 0
517 ),
518 "sources_indexed": len(self._source_index),
519 }
522# ── Hybrid Search Engine ────────────────────────────────────────────
524class HybridSearchEngine:
525 """Unified hybrid search engine.
527 Combines dense + sparse retrieval with fusion and re-ranking.
529 Usage:
530 engine = HybridSearchEngine(
531 dense_retriever=DenseRetriever(vector_store=chroma_store),
532 sparse_retriever=BM25Retriever(),
533 )
535 # Index documents
536 engine.index_sparse(documents)
538 # Hybrid search
539 results = await engine.search("How to implement retry logic?")
540 for r in results:
541 print(f"{r.score:.3f} | {r.content[:100]}")
542 """
544 def __init__(
545 self,
546 dense_retriever: Optional[DenseRetriever] = None,
547 sparse_retriever: Optional[BM25Retriever] = None,
548 reranker: Optional[CrossEncoderReranker] = None,
549 citation_tracker: Optional[CitationTracker] = None,
550 fusion_method: str = "rrf",
551 dense_weight: float = 0.6,
552 ):
553 self.dense = dense_retriever or DenseRetriever()
554 self.sparse = sparse_retriever or BM25Retriever()
555 self.reranker = reranker or CrossEncoderReranker()
556 self.citations = citation_tracker or CitationTracker()
558 self.fusion_method = fusion_method
559 self.dense_weight = dense_weight
561 def index_sparse(self, documents: list[dict[str, str]]):
562 """Build sparse index from documents."""
563 self.sparse.index(documents)
564 for doc in documents:
565 self.citations.add_source(
566 doc_id=doc.get("id", ""),
567 content=doc.get("content", ""),
568 metadata=doc.get("metadata"),
569 )
571 async def search(
572 self,
573 query: str,
574 top_k: int = 10,
575 rerank: bool = True,
576 return_citations: bool = False,
577 ) -> list[SearchResult]:
578 """Hybrid search: dense + sparse → fusion → rerank.
580 Args:
581 query: Search query
582 top_k: Number of results
583 rerank: Whether to apply re-ranking
584 return_citations: Whether to attach citation info
586 Returns:
587 Ranked SearchResults.
588 """
589 # Step 1: Parallel retrieval
590 dense_results = await self.dense.search(query, top_k=top_k * 2)
591 sparse_results = self.sparse.search(query, top_k=top_k * 2)
593 # Step 2: Fusion
594 if self.fusion_method == "rrf":
595 fused = FusionMethod.reciprocal_rank_fusion(dense_results, sparse_results)
596 elif self.fusion_method == "cascade":
597 fused = FusionMethod.cascade(dense_results, sparse_results)
598 else: # weighted_sum
599 fused = FusionMethod.weighted_sum(
600 dense_results, sparse_results,
601 dense_weight=self.dense_weight,
602 sparse_weight=1.0 - self.dense_weight,
603 )
605 # Step 3: Re-rank (optional)
606 if rerank and len(fused) > top_k:
607 fused = await self.reranker.rerank(query, fused, top_k=top_k)
608 else:
609 fused = fused[:top_k]
611 # Step 4: Attach citations (optional)
612 if return_citations and fused:
613 for result in fused:
614 result.citations = [
615 c.text for c in self.citations.extract_citations(
616 result.content, [result]
617 )
618 ]
620 return fused
622 async def search_with_citations(
623 self,
624 query: str,
625 top_k: int = 10,
626 ) -> dict[str, Any]:
627 """Search and return both results and verified citations."""
628 results = await self.search(query, top_k=top_k, return_citations=True)
630 # Build combined text from top results
631 combined = "\n\n".join(r.content for r in results)
633 # Verify citations
634 verification = self.citations.verify(combined, results)
636 return {
637 "results": results,
638 "verification": verification,
639 "top_result": results[0] if results else None,
640 "citation_stats": self.citations.get_stats(),
641 }
643 def get_stats(self) -> dict[str, Any]:
644 """Get search engine statistics."""
645 return {
646 "bm25_documents": self.sparse._N if self.sparse else 0,
647 "bm25_vocabulary": len(self.sparse._idf) if self.sparse else 0,
648 "citation_stats": self.citations.get_stats(),
649 "fusion_method": self.fusion_method,
650 }