Coverage for src / dataknobs_data / vector / hybrid.py: 100%

69 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 15:45 -0700

1"""Hybrid search types and fusion algorithms. 

2 

3This module provides types and utilities for combining keyword (text) search 

4with vector (semantic) search for improved retrieval quality. 

5""" 

6 

7from __future__ import annotations 

8 

9from dataclasses import dataclass, field 

10from enum import Enum 

11from typing import TYPE_CHECKING, Any 

12 

13if TYPE_CHECKING: 

14 from ..records import Record 

15 

16 

17class FusionStrategy(Enum): 

18 """Strategy for combining text and vector search results.""" 

19 

20 RRF = "rrf" # Reciprocal Rank Fusion 

21 WEIGHTED_SUM = "weighted_sum" # Weighted score combination 

22 NATIVE = "native" # Use backend's native hybrid implementation 

23 

24 

25@dataclass 

26class HybridSearchConfig: 

27 """Configuration for hybrid search operations. 

28 

29 Attributes: 

30 text_weight: Weight for text search scores (0.0 to 1.0) 

31 vector_weight: Weight for vector search scores (0.0 to 1.0) 

32 fusion_strategy: Strategy for combining results 

33 rrf_k: Constant k for RRF formula (default 60) 

34 text_fields: Fields to search for text matching (None = all text fields) 

35 """ 

36 

37 text_weight: float = 0.5 

38 vector_weight: float = 0.5 

39 fusion_strategy: FusionStrategy = FusionStrategy.RRF 

40 rrf_k: int = 60 

41 text_fields: list[str] | None = None 

42 

43 def __post_init__(self) -> None: 

44 """Validate configuration parameters.""" 

45 if not 0.0 <= self.text_weight <= 1.0: 

46 raise ValueError(f"text_weight must be between 0 and 1, got {self.text_weight}") 

47 if not 0.0 <= self.vector_weight <= 1.0: 

48 raise ValueError(f"vector_weight must be between 0 and 1, got {self.vector_weight}") 

49 if self.rrf_k <= 0: 

50 raise ValueError(f"rrf_k must be positive, got {self.rrf_k}") 

51 

52 def normalize_weights(self) -> tuple[float, float]: 

53 """Get normalized weights that sum to 1.0. 

54 

55 Returns: 

56 Tuple of (normalized_text_weight, normalized_vector_weight) 

57 """ 

58 total = self.text_weight + self.vector_weight 

59 if total == 0: 

60 return 0.5, 0.5 

61 return self.text_weight / total, self.vector_weight / total 

62 

63 

64@dataclass 

65class HybridSearchResult: 

66 """Result from a hybrid search operation. 

67 

68 Attributes: 

69 record: The matched record 

70 combined_score: Final combined score after fusion 

71 text_score: Score from text search (None if not matched by text) 

72 vector_score: Score from vector search (None if not matched by vector) 

73 text_rank: Rank in text search results (None if not matched) 

74 vector_rank: Rank in vector search results (None if not matched) 

75 metadata: Additional result metadata 

76 """ 

77 

78 record: Record 

79 combined_score: float 

80 text_score: float | None = None 

81 vector_score: float | None = None 

82 text_rank: int | None = None 

83 vector_rank: int | None = None 

84 metadata: dict[str, Any] = field(default_factory=dict) 

85 

86 def __lt__(self, other: HybridSearchResult) -> bool: 

87 """Enable sorting by combined score (descending).""" 

88 return self.combined_score > other.combined_score 

89 

90 def __repr__(self) -> str: 

91 """String representation of the result.""" 

92 text_str = f"{self.text_score:.4f}" if self.text_score is not None else "N/A" 

93 vector_str = f"{self.vector_score:.4f}" if self.vector_score is not None else "N/A" 

94 return ( 

95 f"HybridSearchResult(score={self.combined_score:.4f}, " 

96 f"text={text_str}, " 

97 f"vector={vector_str}, " 

98 f"record_id={self.record.id})" 

99 ) 

100 

101 

102def reciprocal_rank_fusion( 

103 text_results: list[tuple[str, float]], 

104 vector_results: list[tuple[str, float]], 

105 k: int = 60, 

106 text_weight: float = 1.0, 

107 vector_weight: float = 1.0, 

108) -> list[tuple[str, float]]: 

109 """Combine ranked results using Reciprocal Rank Fusion. 

110 

111 RRF score = sum(weight / (k + rank)) for each ranking where the item appears. 

112 

113 This is a robust fusion method that doesn't require score normalization 

114 and handles different score distributions well. 

115 

116 Args: 

117 text_results: List of (record_id, score) from text search, ordered by score desc 

118 vector_results: List of (record_id, score) from vector search, ordered by score desc 

119 k: Ranking constant (default 60). Higher k reduces the impact of high ranks. 

120 text_weight: Weight multiplier for text search ranks 

121 vector_weight: Weight multiplier for vector search ranks 

122 

123 Returns: 

124 List of (record_id, rrf_score) ordered by RRF score descending 

125 """ 

126 scores: dict[str, float] = {} 

127 

128 # Add text search contributions 

129 for rank, (record_id, _score) in enumerate(text_results, start=1): 

130 rrf_contribution = text_weight / (k + rank) 

131 scores[record_id] = scores.get(record_id, 0.0) + rrf_contribution 

132 

133 # Add vector search contributions 

134 for rank, (record_id, _score) in enumerate(vector_results, start=1): 

135 rrf_contribution = vector_weight / (k + rank) 

136 scores[record_id] = scores.get(record_id, 0.0) + rrf_contribution 

137 

138 # Sort by RRF score descending 

139 return sorted(scores.items(), key=lambda x: x[1], reverse=True) 

140 

141 

142def weighted_score_fusion( 

143 text_results: list[tuple[str, float]], 

144 vector_results: list[tuple[str, float]], 

145 text_weight: float = 0.5, 

146 vector_weight: float = 0.5, 

147 normalize_scores: bool = True, 

148) -> list[tuple[str, float]]: 

149 """Combine results using weighted score sum. 

150 

151 Combined score = text_weight * text_score + vector_weight * vector_score 

152 

153 Args: 

154 text_results: List of (record_id, score) from text search 

155 vector_results: List of (record_id, score) from vector search 

156 text_weight: Weight for text scores (0.0 to 1.0) 

157 vector_weight: Weight for vector scores (0.0 to 1.0) 

158 normalize_scores: Whether to normalize scores to 0-1 range before combining 

159 

160 Returns: 

161 List of (record_id, combined_score) ordered by score descending 

162 """ 

163 # Build score maps 

164 text_scores = dict(text_results) 

165 vector_scores = dict(vector_results) 

166 

167 # Optionally normalize scores 

168 if normalize_scores: 

169 text_scores = _normalize_scores(text_scores) 

170 vector_scores = _normalize_scores(vector_scores) 

171 

172 # Get all unique record IDs 

173 all_ids = set(text_scores.keys()) | set(vector_scores.keys()) 

174 

175 # Compute combined scores 

176 combined: dict[str, float] = {} 

177 for record_id in all_ids: 

178 text_score = text_scores.get(record_id, 0.0) 

179 vector_score = vector_scores.get(record_id, 0.0) 

180 combined[record_id] = text_weight * text_score + vector_weight * vector_score 

181 

182 # Sort by combined score descending 

183 return sorted(combined.items(), key=lambda x: x[1], reverse=True) 

184 

185 

186def _normalize_scores(scores: dict[str, float]) -> dict[str, float]: 

187 """Normalize scores to 0-1 range using min-max normalization. 

188 

189 Args: 

190 scores: Dictionary mapping record IDs to scores 

191 

192 Returns: 

193 Dictionary with normalized scores 

194 """ 

195 if not scores: 

196 return {} 

197 

198 values = list(scores.values()) 

199 min_score = min(values) 

200 max_score = max(values) 

201 

202 if max_score == min_score: 

203 # All scores are the same 

204 return {k: 1.0 for k in scores} 

205 

206 return { 

207 k: (v - min_score) / (max_score - min_score) 

208 for k, v in scores.items() 

209 }