Coverage for src / dataknobs_data / vector / stores / base.py: 15%

82 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 15:45 -0700

1"""Base class for specialized vector stores.""" 

2 

3from __future__ import annotations 

4 

5from abc import ABC, abstractmethod 

6from typing import TYPE_CHECKING, Any 

7 

8from ...fields import VectorField 

9from ...records import Record 

10from ..types import VectorSearchResult 

11from .common import VectorStoreBase 

12 

13if TYPE_CHECKING: 

14 import numpy as np 

15 from collections.abc import Callable 

16 

17 

18class VectorStore(ABC, VectorStoreBase): 

19 """Abstract base class for specialized vector stores. 

20  

21 This provides a dedicated vector storage backend that can be used 

22 independently or alongside traditional databases. It inherits from 

23 VectorStoreBase which provides common configuration parsing and 

24 utility methods. 

25 """ 

26 

27 @abstractmethod 

28 async def initialize(self) -> None: 

29 """Initialize the vector store (create index, connect, etc.).""" 

30 pass 

31 

32 @abstractmethod 

33 async def close(self) -> None: 

34 """Close connections and clean up resources.""" 

35 pass 

36 

37 @abstractmethod 

38 async def add_vectors( 

39 self, 

40 vectors: np.ndarray | list[np.ndarray], 

41 ids: list[str] | None = None, 

42 metadata: list[dict[str, Any]] | None = None, 

43 ) -> list[str]: 

44 """Add vectors to the store. 

45  

46 Args: 

47 vectors: Vector(s) to add 

48 ids: Optional IDs for vectors (generated if not provided) 

49 metadata: Optional metadata for each vector 

50  

51 Returns: 

52 List of IDs for the added vectors 

53 """ 

54 pass 

55 

56 @abstractmethod 

57 async def get_vectors( 

58 self, 

59 ids: list[str], 

60 include_metadata: bool = True, 

61 ) -> list[tuple[np.ndarray, dict[str, Any] | None]]: 

62 """Retrieve vectors by ID. 

63  

64 Args: 

65 ids: Vector IDs to retrieve 

66 include_metadata: Whether to include metadata 

67  

68 Returns: 

69 List of (vector, metadata) tuples 

70 """ 

71 pass 

72 

73 @abstractmethod 

74 async def delete_vectors(self, ids: list[str]) -> int: 

75 """Delete vectors by ID. 

76  

77 Args: 

78 ids: Vector IDs to delete 

79  

80 Returns: 

81 Number of vectors deleted 

82 """ 

83 pass 

84 

85 @abstractmethod 

86 async def search( 

87 self, 

88 query_vector: np.ndarray, 

89 k: int = 10, 

90 filter: dict[str, Any] | None = None, 

91 include_metadata: bool = True, 

92 ) -> list[tuple[str, float, dict[str, Any] | None]]: 

93 """Search for similar vectors. 

94  

95 Args: 

96 query_vector: Query vector 

97 k: Number of results 

98 filter: Optional metadata filter 

99 include_metadata: Whether to include metadata 

100  

101 Returns: 

102 List of (id, score, metadata) tuples 

103 """ 

104 pass 

105 

106 @abstractmethod 

107 async def update_metadata( 

108 self, 

109 ids: list[str], 

110 metadata: list[dict[str, Any]], 

111 ) -> int: 

112 """Update metadata for existing vectors. 

113  

114 Args: 

115 ids: Vector IDs to update 

116 metadata: New metadata for each vector 

117  

118 Returns: 

119 Number of vectors updated 

120 """ 

121 pass 

122 

123 @abstractmethod 

124 async def count(self, filter: dict[str, Any] | None = None) -> int: 

125 """Count vectors in the store. 

126  

127 Args: 

128 filter: Optional metadata filter 

129  

130 Returns: 

131 Number of vectors matching filter 

132 """ 

133 pass 

134 

135 @abstractmethod 

136 async def clear(self) -> None: 

137 """Clear all vectors from the store.""" 

138 pass 

139 

140 async def update_vectors( 

141 self, 

142 vectors: np.ndarray | list[np.ndarray], 

143 ids: list[str], 

144 metadata: list[dict[str, Any]] | None = None, 

145 ) -> list[str]: 

146 """Update existing vectors by ID. 

147  

148 This is a convenience method that deletes and re-adds vectors. 

149 Some vector stores may override this with a more efficient implementation. 

150  

151 Args: 

152 vectors: New vector values 

153 ids: IDs of vectors to update 

154 metadata: Optional new metadata 

155  

156 Returns: 

157 List of updated IDs 

158 """ 

159 # Delete existing vectors 

160 await self.delete_vectors(ids) 

161 

162 # Add new vectors with same IDs 

163 return await self.add_vectors(vectors, ids, metadata) 

164 

165 # Higher-level convenience methods 

166 

167 async def add_records( 

168 self, 

169 records: list[Record], 

170 vector_field: str = "embedding", 

171 include_fields: list[str] | None = None, 

172 ) -> list[str]: 

173 """Add records with vector fields to the store. 

174  

175 Args: 

176 records: Records containing vector fields 

177 vector_field: Name of the vector field 

178 include_fields: Fields to include in metadata 

179  

180 Returns: 

181 List of IDs for added vectors 

182 """ 

183 vectors = [] 

184 ids = [] 

185 metadatas = [] 

186 

187 for record in records: 

188 # Extract vector 

189 if vector_field not in record.fields: 

190 continue 

191 

192 vector_obj = record.fields[vector_field] 

193 if not isinstance(vector_obj, VectorField): 

194 continue 

195 

196 # Skip records without IDs 

197 if record.id is None: 

198 continue 

199 

200 vectors.append(vector_obj.value) 

201 ids.append(record.id) 

202 

203 # Build metadata 

204 metadata = {"record_id": record.id} 

205 

206 # Add source field if present 

207 if vector_obj.source_field: 

208 metadata["source_field"] = vector_obj.source_field 

209 # Include source text if available 

210 if vector_obj.source_field in record.fields: 

211 metadata["source_text"] = record.get_value(vector_obj.source_field) 

212 

213 # Add model info if present 

214 if vector_obj.model_name: 

215 metadata["model_name"] = vector_obj.model_name 

216 if vector_obj.model_version: 

217 metadata["model_version"] = vector_obj.model_version 

218 

219 # Add requested fields 

220 if include_fields: 

221 for field_name in include_fields: 

222 if field_name in record.fields: 

223 metadata[field_name] = record.get_value(field_name) 

224 

225 metadatas.append(metadata) 

226 

227 if vectors: 

228 return await self.add_vectors(vectors, ids=ids, metadata=metadatas) 

229 return [] 

230 

231 async def search_similar_records( 

232 self, 

233 query_vector: np.ndarray, 

234 k: int = 10, 

235 filter: dict[str, Any] | None = None, 

236 fetch_records: Callable[[list[str]], list[Record]] | None = None, 

237 ) -> list[VectorSearchResult]: 

238 """Search and return results as VectorSearchResult objects. 

239  

240 Args: 

241 query_vector: Query vector 

242 k: Number of results 

243 filter: Optional metadata filter 

244 fetch_records: Optional function to fetch full records 

245  

246 Returns: 

247 List of VectorSearchResult objects 

248 """ 

249 results = await self.search( 

250 query_vector, k=k, filter=filter, include_metadata=True 

251 ) 

252 

253 search_results = [] 

254 record_ids = [] 

255 

256 for vector_id, _score, metadata in results: 

257 record_id = metadata.get("record_id", vector_id) if metadata else vector_id 

258 record_ids.append(record_id) 

259 

260 # Fetch full records if function provided 

261 records_map = {} 

262 if fetch_records and record_ids: 

263 records = fetch_records(record_ids) 

264 records_map = {r.id: r for r in records} 

265 

266 for vector_id, score, metadata in results: 

267 record_id = metadata.get("record_id", vector_id) if metadata else vector_id 

268 

269 # Get or create record 

270 if record_id in records_map: 

271 record = records_map[record_id] 

272 else: 

273 # Create minimal record with metadata 

274 record = Record({"id": record_id}) 

275 if metadata: 

276 for key, value in metadata.items(): 

277 if key not in ["record_id", "source_text", "source_field"]: 

278 record.fields[key] = value 

279 

280 # Extract source text 

281 source_text = None 

282 if metadata: 

283 source_text = metadata.get("source_text") 

284 

285 search_results.append( 

286 VectorSearchResult( 

287 record=record, 

288 score=score, 

289 source_text=source_text, 

290 vector_field=metadata.get("source_field") if metadata else None, 

291 metadata=metadata or {}, 

292 ) 

293 ) 

294 

295 return search_results 

296 

297 async def bulk_embed_and_store( 

298 self, 

299 texts: list[str], 

300 embedding_fn: Callable[[list[str]], np.ndarray], 

301 ids: list[str] | None = None, 

302 metadata: list[dict[str, Any]] | None = None, 

303 batch_size: int | None = None, 

304 ) -> list[str]: 

305 """Embed texts and store vectors. 

306  

307 Args: 

308 texts: Texts to embed 

309 embedding_fn: Function to generate embeddings 

310 ids: Optional IDs for vectors 

311 metadata: Optional metadata for each vector 

312 batch_size: Batch size for embedding 

313  

314 Returns: 

315 List of IDs for added vectors 

316 """ 

317 batch_size = batch_size or self.batch_size 

318 all_ids = [] 

319 

320 for i in range(0, len(texts), batch_size): 

321 batch_texts = texts[i:i + batch_size] 

322 batch_ids = ids[i:i + batch_size] if ids else None 

323 batch_metadata = metadata[i:i + batch_size] if metadata else None 

324 

325 # Generate embeddings 

326 embeddings = embedding_fn(batch_texts) 

327 

328 # Add source text to metadata 

329 if batch_metadata is None: 

330 batch_metadata = [{} for _ in batch_texts] 

331 

332 for j, text in enumerate(batch_texts): 

333 batch_metadata[j]["source_text"] = text 

334 

335 # Store vectors 

336 stored_ids = await self.add_vectors( 

337 embeddings, ids=batch_ids, metadata=batch_metadata 

338 ) 

339 all_ids.extend(stored_ids) 

340 

341 return all_ids