Coverage for src / dataknobs_data / vector / hybrid.py: 100%
69 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-26 15:45 -0700
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-26 15:45 -0700
1"""Hybrid search types and fusion algorithms.
3This module provides types and utilities for combining keyword (text) search
4with vector (semantic) search for improved retrieval quality.
5"""
7from __future__ import annotations
9from dataclasses import dataclass, field
10from enum import Enum
11from typing import TYPE_CHECKING, Any
13if TYPE_CHECKING:
14 from ..records import Record
17class FusionStrategy(Enum):
18 """Strategy for combining text and vector search results."""
20 RRF = "rrf" # Reciprocal Rank Fusion
21 WEIGHTED_SUM = "weighted_sum" # Weighted score combination
22 NATIVE = "native" # Use backend's native hybrid implementation
25@dataclass
26class HybridSearchConfig:
27 """Configuration for hybrid search operations.
29 Attributes:
30 text_weight: Weight for text search scores (0.0 to 1.0)
31 vector_weight: Weight for vector search scores (0.0 to 1.0)
32 fusion_strategy: Strategy for combining results
33 rrf_k: Constant k for RRF formula (default 60)
34 text_fields: Fields to search for text matching (None = all text fields)
35 """
37 text_weight: float = 0.5
38 vector_weight: float = 0.5
39 fusion_strategy: FusionStrategy = FusionStrategy.RRF
40 rrf_k: int = 60
41 text_fields: list[str] | None = None
43 def __post_init__(self) -> None:
44 """Validate configuration parameters."""
45 if not 0.0 <= self.text_weight <= 1.0:
46 raise ValueError(f"text_weight must be between 0 and 1, got {self.text_weight}")
47 if not 0.0 <= self.vector_weight <= 1.0:
48 raise ValueError(f"vector_weight must be between 0 and 1, got {self.vector_weight}")
49 if self.rrf_k <= 0:
50 raise ValueError(f"rrf_k must be positive, got {self.rrf_k}")
52 def normalize_weights(self) -> tuple[float, float]:
53 """Get normalized weights that sum to 1.0.
55 Returns:
56 Tuple of (normalized_text_weight, normalized_vector_weight)
57 """
58 total = self.text_weight + self.vector_weight
59 if total == 0:
60 return 0.5, 0.5
61 return self.text_weight / total, self.vector_weight / total
64@dataclass
65class HybridSearchResult:
66 """Result from a hybrid search operation.
68 Attributes:
69 record: The matched record
70 combined_score: Final combined score after fusion
71 text_score: Score from text search (None if not matched by text)
72 vector_score: Score from vector search (None if not matched by vector)
73 text_rank: Rank in text search results (None if not matched)
74 vector_rank: Rank in vector search results (None if not matched)
75 metadata: Additional result metadata
76 """
78 record: Record
79 combined_score: float
80 text_score: float | None = None
81 vector_score: float | None = None
82 text_rank: int | None = None
83 vector_rank: int | None = None
84 metadata: dict[str, Any] = field(default_factory=dict)
86 def __lt__(self, other: HybridSearchResult) -> bool:
87 """Enable sorting by combined score (descending)."""
88 return self.combined_score > other.combined_score
90 def __repr__(self) -> str:
91 """String representation of the result."""
92 text_str = f"{self.text_score:.4f}" if self.text_score is not None else "N/A"
93 vector_str = f"{self.vector_score:.4f}" if self.vector_score is not None else "N/A"
94 return (
95 f"HybridSearchResult(score={self.combined_score:.4f}, "
96 f"text={text_str}, "
97 f"vector={vector_str}, "
98 f"record_id={self.record.id})"
99 )
102def reciprocal_rank_fusion(
103 text_results: list[tuple[str, float]],
104 vector_results: list[tuple[str, float]],
105 k: int = 60,
106 text_weight: float = 1.0,
107 vector_weight: float = 1.0,
108) -> list[tuple[str, float]]:
109 """Combine ranked results using Reciprocal Rank Fusion.
111 RRF score = sum(weight / (k + rank)) for each ranking where the item appears.
113 This is a robust fusion method that doesn't require score normalization
114 and handles different score distributions well.
116 Args:
117 text_results: List of (record_id, score) from text search, ordered by score desc
118 vector_results: List of (record_id, score) from vector search, ordered by score desc
119 k: Ranking constant (default 60). Higher k reduces the impact of high ranks.
120 text_weight: Weight multiplier for text search ranks
121 vector_weight: Weight multiplier for vector search ranks
123 Returns:
124 List of (record_id, rrf_score) ordered by RRF score descending
125 """
126 scores: dict[str, float] = {}
128 # Add text search contributions
129 for rank, (record_id, _score) in enumerate(text_results, start=1):
130 rrf_contribution = text_weight / (k + rank)
131 scores[record_id] = scores.get(record_id, 0.0) + rrf_contribution
133 # Add vector search contributions
134 for rank, (record_id, _score) in enumerate(vector_results, start=1):
135 rrf_contribution = vector_weight / (k + rank)
136 scores[record_id] = scores.get(record_id, 0.0) + rrf_contribution
138 # Sort by RRF score descending
139 return sorted(scores.items(), key=lambda x: x[1], reverse=True)
142def weighted_score_fusion(
143 text_results: list[tuple[str, float]],
144 vector_results: list[tuple[str, float]],
145 text_weight: float = 0.5,
146 vector_weight: float = 0.5,
147 normalize_scores: bool = True,
148) -> list[tuple[str, float]]:
149 """Combine results using weighted score sum.
151 Combined score = text_weight * text_score + vector_weight * vector_score
153 Args:
154 text_results: List of (record_id, score) from text search
155 vector_results: List of (record_id, score) from vector search
156 text_weight: Weight for text scores (0.0 to 1.0)
157 vector_weight: Weight for vector scores (0.0 to 1.0)
158 normalize_scores: Whether to normalize scores to 0-1 range before combining
160 Returns:
161 List of (record_id, combined_score) ordered by score descending
162 """
163 # Build score maps
164 text_scores = dict(text_results)
165 vector_scores = dict(vector_results)
167 # Optionally normalize scores
168 if normalize_scores:
169 text_scores = _normalize_scores(text_scores)
170 vector_scores = _normalize_scores(vector_scores)
172 # Get all unique record IDs
173 all_ids = set(text_scores.keys()) | set(vector_scores.keys())
175 # Compute combined scores
176 combined: dict[str, float] = {}
177 for record_id in all_ids:
178 text_score = text_scores.get(record_id, 0.0)
179 vector_score = vector_scores.get(record_id, 0.0)
180 combined[record_id] = text_weight * text_score + vector_weight * vector_score
182 # Sort by combined score descending
183 return sorted(combined.items(), key=lambda x: x[1], reverse=True)
186def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
187 """Normalize scores to 0-1 range using min-max normalization.
189 Args:
190 scores: Dictionary mapping record IDs to scores
192 Returns:
193 Dictionary with normalized scores
194 """
195 if not scores:
196 return {}
198 values = list(scores.values())
199 min_score = min(values)
200 max_score = max(values)
202 if max_score == min_score:
203 # All scores are the same
204 return {k: 1.0 for k in scores}
206 return {
207 k: (v - min_score) / (max_score - min_score)
208 for k, v in scores.items()
209 }