Coverage for src/dataknobs_data/vector/mixins.py: 35%

52 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:23 -0700

1"""Mixins and protocols for vector-capable databases.""" 

2 

3from __future__ import annotations 

4 

5from abc import ABC, abstractmethod 

6from typing import TYPE_CHECKING, Any, Protocol 

7 

8from ..fields import FieldType 

9from .types import DistanceMetric, VectorSearchResult 

10 

11if TYPE_CHECKING: 

12 import numpy as np 

13 from collections.abc import Callable 

14 from ..query import Query 

15 from ..records import Record 

16 

17 

18class VectorCapable(Protocol): 

19 """Protocol for backends that can handle vector operations.""" 

20 

21 async def has_vector_support(self) -> bool: 

22 """Check if backend has vector support available. 

23 

24 Returns: 

25 True if vector operations are supported 

26 """ 

27 ... 

28 

29 async def enable_vector_support(self) -> bool: 

30 """Enable vector support (install extensions, configure indices, etc.). 

31 

32 Returns: 

33 True if vector support was successfully enabled 

34 """ 

35 ... 

36 

37 async def detect_vector_fields(self, record: Record) -> list[str]: 

38 """Detect vector fields in a record. 

39 

40 Args: 

41 record: Record to examine 

42 

43 Returns: 

44 List of field names that contain vectors 

45 """ 

46 return [ 

47 field_name 

48 for field_name, field_obj in record.fields.items() 

49 if field_obj.type in (FieldType.VECTOR, FieldType.SPARSE_VECTOR) 

50 ] 

51 

52 def get_vector_config(self) -> dict[str, Any]: 

53 """Get vector-specific configuration for this backend. 

54 

55 Returns: 

56 Dictionary of vector configuration options 

57 """ 

58 return {} 

59 

60 

61class VectorOperationsMixin(ABC): 

62 """Mixin providing vector operations for databases. 

63 

64 This mixin should be added to database backend classes that support 

65 vector operations. It provides abstract methods that must be implemented 

66 by the concrete backend class. 

67 """ 

68 

69 @abstractmethod 

70 async def vector_search( 

71 self, 

72 query_vector: np.ndarray | list[float], 

73 vector_field: str = "embedding", 

74 k: int = 10, 

75 metric: DistanceMetric = DistanceMetric.COSINE, 

76 filter: Query | None = None, 

77 include_source: bool = True, 

78 score_threshold: float | None = None, 

79 ) -> list[VectorSearchResult]: 

80 """Search for similar vectors. 

81 

82 Args: 

83 query_vector: The vector to search for 

84 vector_field: Name of the vector field to search 

85 k: Number of results to return 

86 metric: Distance metric to use 

87 filter: Optional query filter to apply before vector search 

88 include_source: Whether to include source text in results 

89 score_threshold: Optional minimum similarity score 

90 

91 Returns: 

92 List of search results ordered by similarity 

93 """ 

94 pass 

95 

96 @abstractmethod 

97 async def bulk_embed_and_store( 

98 self, 

99 records: list[Record], 

100 text_field: str | list[str], 

101 vector_field: str = "embedding", 

102 embedding_fn: Callable[[list[str]], np.ndarray] | None = None, 

103 batch_size: int = 100, 

104 model_name: str | None = None, 

105 model_version: str | None = None, 

106 ) -> list[str]: 

107 """Embed text fields and store vectors with records. 

108 

109 Args: 

110 records: Records to process 

111 text_field: Field name(s) containing text to embed 

112 vector_field: Field name to store vectors in 

113 embedding_fn: Function to generate embeddings 

114 batch_size: Number of records to process at once 

115 model_name: Name of the embedding model 

116 model_version: Version of the embedding model 

117 

118 Returns: 

119 List of record IDs that were processed 

120 """ 

121 pass 

122 

123 async def update_vector( 

124 self, 

125 record_id: str, 

126 vector_field: str, 

127 vector: np.ndarray | list[float], 

128 metadata: dict[str, Any] | None = None, 

129 ) -> bool: 

130 """Update a vector field for a specific record. 

131 

132 Args: 

133 record_id: ID of the record to update 

134 vector_field: Name of the vector field 

135 vector: New vector value 

136 metadata: Optional metadata to attach 

137 

138 Returns: 

139 True if update was successful 

140 """ 

141 # Default implementation using standard update 

142 from ..fields import VectorField 

143 

144 record = await self.read(record_id) # type: ignore 

145 if not record: 

146 return False 

147 

148 # Update the vector field 

149 record.fields[vector_field] = VectorField( 

150 name=vector_field, 

151 value=vector, 

152 metadata=metadata, 

153 ) 

154 

155 return await self.update(record_id, record) is not None # type: ignore 

156 

157 async def delete_from_index( 

158 self, 

159 record_id: str, 

160 vector_field: str = "embedding", 

161 ) -> bool: 

162 """Remove a record from the vector index. 

163 

164 Args: 

165 record_id: ID of the record to remove 

166 vector_field: Name of the vector field 

167 

168 Returns: 

169 True if deletion was successful 

170 """ 

171 # Default implementation using standard delete 

172 return await self.delete(record_id) # type: ignore 

173 

174 async def create_vector_index( 

175 self, 

176 vector_field: str = "embedding", 

177 dimensions: int | None = None, 

178 metric: DistanceMetric = DistanceMetric.COSINE, 

179 index_type: str = "auto", 

180 **kwargs: Any, 

181 ) -> bool: 

182 """Create an index for vector similarity search. 

183 

184 Args: 

185 vector_field: Name of the vector field to index 

186 dimensions: Number of dimensions (if known) 

187 metric: Distance metric for the index 

188 index_type: Type of index to create 

189 **kwargs: Backend-specific index parameters 

190 

191 Returns: 

192 True if index was created successfully 

193 """ 

194 # Default no-op implementation 

195 return True 

196 

197 async def drop_vector_index( 

198 self, 

199 vector_field: str = "embedding", 

200 ) -> bool: 

201 """Drop a vector index. 

202 

203 Args: 

204 vector_field: Name of the vector field 

205 

206 Returns: 

207 True if index was dropped successfully 

208 """ 

209 # Default no-op implementation 

210 return True 

211 

212 async def get_vector_index_stats( 

213 self, 

214 vector_field: str = "embedding", 

215 ) -> dict[str, Any]: 

216 """Get statistics about a vector index. 

217 

218 Args: 

219 vector_field: Name of the vector field 

220 

221 Returns: 

222 Dictionary of index statistics 

223 """ 

224 return { 

225 "field": vector_field, 

226 "indexed": False, 

227 "vector_count": 0, 

228 } 

229 

230 

231class VectorSyncMixin: 

232 """Mixin for synchronizing vectors with source text.""" 

233 

234 async def sync_vectors_with_text( 

235 self, 

236 records: list[Record], 

237 text_fields: list[str], 

238 vector_field: str = "embedding", 

239 embedding_fn: Callable[[list[str]], np.ndarray] | None = None, 

240 force: bool = False, 

241 ) -> int: 

242 """Synchronize vector embeddings with text content. 

243 

244 Args: 

245 records: Records to synchronize 

246 text_fields: Text fields to generate vectors from 

247 vector_field: Vector field to update 

248 embedding_fn: Embedding function 

249 force: Force re-generation even if vectors exist 

250 

251 Returns: 

252 Number of records updated 

253 """ 

254 if not embedding_fn: 

255 raise ValueError("Embedding function is required for vector synchronization") 

256 

257 updated = 0 

258 for record in records: 

259 # Check if vector needs update 

260 needs_update = force or vector_field not in record.fields 

261 

262 if not needs_update: 

263 # Check if source fields changed 

264 vector_meta = record.fields[vector_field].metadata 

265 source_fields = vector_meta.get("source_field", "").split(",") 

266 needs_update = set(source_fields) != set(text_fields) 

267 

268 if needs_update: 

269 # Concatenate text fields 

270 text_content = " ".join([ 

271 str(record.get_value(field)) 

272 for field in text_fields 

273 if record.get_value(field) 

274 ]) 

275 

276 # Generate embedding 

277 if text_content: 

278 from ..fields import VectorField 

279 

280 result = embedding_fn([text_content]) 

281 # Handle both sync and async embedding functions 

282 if hasattr(result, '__await__'): 

283 embeddings = await result # type: ignore[misc] 

284 else: 

285 embeddings = result 

286 record.fields[vector_field] = VectorField( 

287 name=vector_field, 

288 value=embeddings[0], 

289 source_field=",".join(text_fields), 

290 ) 

291 updated += 1 

292 

293 return updated