Coverage for src/dataknobs_data/vector/types.py: 61%

85 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-29 14:14 -0600

1"""Core types and data structures for vector operations.""" 

2 

3from __future__ import annotations 

4 

5from dataclasses import dataclass, field 

6from enum import Enum 

7from typing import Any, TYPE_CHECKING 

8 

9if TYPE_CHECKING: 

10 from ..records import Record 

11 

12 

13class DistanceMetric(Enum): 

14 """Enumeration of supported vector distance metrics.""" 

15 

16 COSINE = "cosine" 

17 EUCLIDEAN = "euclidean" 

18 DOT_PRODUCT = "dot_product" 

19 INNER_PRODUCT = "inner_product" # Alias for dot_product 

20 L2 = "l2" # Alias for euclidean 

21 L1 = "l1" # Manhattan distance 

22 

23 def get_aliases(self) -> list[str]: 

24 """Get alternative names for this metric.""" 

25 aliases: dict[DistanceMetric, list[str]] = { 

26 DistanceMetric.COSINE: ["cosine_similarity", "cos"], 

27 DistanceMetric.EUCLIDEAN: ["l2", "euclidean_distance"], 

28 DistanceMetric.DOT_PRODUCT: ["inner_product", "ip"], 

29 DistanceMetric.L1: ["manhattan", "l1_distance"], 

30 } 

31 return aliases.get(self, []) 

32 

33 

34@dataclass 

35class VectorSearchResult: 

36 """Result from a vector similarity search operation.""" 

37 

38 record: Record 

39 score: float 

40 source_text: str | None = None 

41 vector_field: str | None = None 

42 metadata: dict[str, Any] = field(default_factory=dict) 

43 

44 def __lt__(self, other: VectorSearchResult) -> bool: 

45 """Enable sorting by score.""" 

46 return self.score < other.score 

47 

48 def __repr__(self) -> str: 

49 """String representation of the result.""" 

50 return ( 

51 f"VectorSearchResult(score={self.score:.4f}, " 

52 f"record_id={self.record.id}, " 

53 f"vector_field={self.vector_field})" 

54 ) 

55 

56 

57@dataclass 

58class VectorConfig: 

59 """Configuration for vector operations.""" 

60 

61 dimensions: int 

62 metric: DistanceMetric = DistanceMetric.COSINE 

63 normalize: bool = False 

64 source_field: str | None = None 

65 model_name: str | None = None 

66 model_version: str | None = None 

67 

68 def validate(self) -> None: 

69 """Validate configuration parameters.""" 

70 if self.dimensions <= 0: 

71 raise ValueError(f"Dimensions must be positive, got {self.dimensions}") 

72 

73 if self.dimensions > 65536: # Common maximum for vector databases 

74 raise ValueError( 

75 f"Dimensions {self.dimensions} exceeds maximum supported (65536)" 

76 ) 

77 

78 

79@dataclass 

80class VectorIndexConfig: 

81 """Configuration for vector index creation.""" 

82 

83 index_type: str = "auto" # auto, flat, ivfflat, hnsw 

84 lists: int | None = None # For IVFFlat 

85 m: int | None = None # For HNSW 

86 ef_construction: int | None = None # For HNSW 

87 ef_search: int | None = None # For HNSW search 

88 probes: int | None = None # For IVFFlat search 

89 quantization: str | None = None # none, scalar, product 

90 

91 def get_optimal_params(self, num_vectors: int) -> dict[str, Any]: 

92 """Get optimal index parameters based on dataset size.""" 

93 params = {} 

94 

95 if self.index_type == "auto": 

96 # Auto-select based on dataset size 

97 if num_vectors < 10_000: 

98 params["type"] = "flat" 

99 elif num_vectors < 1_000_000: 

100 params["type"] = "ivfflat" 

101 params["lists"] = self.lists or max(num_vectors // 1000, 100) 

102 params["probes"] = self.probes or 10 

103 else: 

104 params["type"] = "hnsw" 

105 params["m"] = self.m or 16 

106 params["ef_construction"] = self.ef_construction or 200 

107 params["ef_search"] = self.ef_search or 64 

108 else: 

109 params["type"] = self.index_type 

110 if self.index_type == "ivfflat": 

111 params["lists"] = self.lists or 100 

112 params["probes"] = self.probes or 10 

113 elif self.index_type == "hnsw": 

114 params["m"] = self.m or 16 

115 params["ef_construction"] = self.ef_construction or 200 

116 params["ef_search"] = self.ef_search or 64 

117 

118 if self.quantization: 

119 params["quantization"] = self.quantization 

120 

121 return params 

122 

123 

124@dataclass 

125class VectorMetadata: 

126 """Metadata associated with vector fields.""" 

127 

128 dimensions: int 

129 source_field: str | None = None 

130 model_name: str | None = None 

131 model_version: str | None = None 

132 created_at: str | None = None 

133 updated_at: str | None = None 

134 index_type: str | None = None 

135 metric: str | None = None 

136 

137 def to_dict(self) -> dict[str, Any]: 

138 """Convert to dictionary representation.""" 

139 return { 

140 "dimensions": self.dimensions, 

141 "source_field": self.source_field, 

142 "model": { 

143 "name": self.model_name, 

144 "version": self.model_version, 

145 } if self.model_name else None, 

146 "created_at": self.created_at, 

147 "updated_at": self.updated_at, 

148 "index_type": self.index_type, 

149 "metric": self.metric, 

150 } 

151 

152 @classmethod 

153 def from_dict(cls, data: dict[str, Any]) -> VectorMetadata: 

154 """Create from dictionary representation.""" 

155 model_info = data.get("model", {}) 

156 return cls( 

157 dimensions=data["dimensions"], 

158 source_field=data.get("source_field"), 

159 model_name=model_info.get("name") if model_info else None, 

160 model_version=model_info.get("version") if model_info else None, 

161 created_at=data.get("created_at"), 

162 updated_at=data.get("updated_at"), 

163 index_type=data.get("index_type"), 

164 metric=data.get("metric"), 

165 )