Coverage for src/dataknobs_data/vector/stores/faiss.py: 12%

208 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-31 15:06 -0600

1"""Faiss vector store implementation.""" 

2 

3from __future__ import annotations 

4 

5import os 

6import pickle 

7from typing import TYPE_CHECKING, Any 

8from uuid import uuid4 

9 

10from ..types import DistanceMetric 

11from .base import VectorStore 

12 

13if TYPE_CHECKING: 

14 import numpy as np 

15 

16try: 

17 import faiss 

18 FAISS_AVAILABLE = True 

19except ImportError: 

20 FAISS_AVAILABLE = False 

21 

22 

23class FaissVectorStore(VectorStore): 

24 """Faiss-based vector store for efficient similarity search. 

25  

26 Faiss is a library for efficient similarity search and clustering of dense vectors. 

27 It provides various index types optimized for different use cases: 

28 - Flat: Exact search, best for small datasets 

29 - IVF: Inverted file index, good for medium datasets  

30 - HNSW: Hierarchical navigable small world, good for large datasets 

31 """ 

32 

33 def __init__(self, config: dict[str, Any] | None = None): 

34 """Initialize Faiss vector store.""" 

35 if not FAISS_AVAILABLE: 

36 raise ImportError( 

37 "Faiss is not installed. Install with: pip install faiss-cpu" 

38 ) 

39 

40 super().__init__(config) 

41 self.index = None 

42 self.id_map = {} # Map from our IDs to Faiss internal indices 

43 self.metadata_store = {} # Store metadata separately 

44 self.next_idx = 0 

45 

46 def _parse_backend_config(self) -> None: 

47 """Parse Faiss-specific configuration.""" 

48 # Determine index type 

49 self.index_type = self.index_params.get("type", "auto") 

50 if "index_type" in self.config: 

51 self.index_type = self.config["index_type"] 

52 

53 self.nlist = self.index_params.get("nlist", 100) # For IVF 

54 self.m = self.index_params.get("m", 32) # For HNSW 

55 self.ef_construction = self.index_params.get("ef_construction", 200) # For HNSW 

56 self.ef_search = self.index_params.get("ef_search", 50) # For HNSW search 

57 self.nprobe = self.search_params.get("nprobe", 10) # For IVF search 

58 

59 async def initialize(self) -> None: 

60 """Initialize Faiss index.""" 

61 if self._initialized: 

62 return 

63 

64 # Create index based on type and metric 

65 self.index = self._create_index() 

66 

67 # Load existing index if persist path exists 

68 if self.persist_path and os.path.exists(self.persist_path): 

69 await self.load() 

70 

71 self._initialized = True 

72 

73 def _create_index(self) -> Any: 

74 """Create Faiss index based on configuration.""" 

75 dimensions = self.dimensions 

76 

77 # Map distance metrics 

78 if self.metric == DistanceMetric.COSINE: 

79 # For cosine similarity, we'll normalize vectors and use inner product 

80 metric = faiss.METRIC_INNER_PRODUCT 

81 elif self.metric in (DistanceMetric.EUCLIDEAN, DistanceMetric.L2): 

82 metric = faiss.METRIC_L2 

83 elif self.metric in (DistanceMetric.DOT_PRODUCT, DistanceMetric.INNER_PRODUCT): 

84 metric = faiss.METRIC_INNER_PRODUCT 

85 else: 

86 metric = faiss.METRIC_L2 

87 

88 # Auto-select index type based on expected dataset size 

89 if self.index_type == "auto": 

90 # Use flat for small dimensions/datasets 

91 if dimensions < 100: 

92 self.index_type = "flat" 

93 else: 

94 self.index_type = "ivfflat" 

95 

96 # Create index 

97 if self.index_type == "flat": 

98 if metric == faiss.METRIC_INNER_PRODUCT: 

99 index = faiss.IndexFlatIP(dimensions) 

100 else: 

101 index = faiss.IndexFlatL2(dimensions) 

102 

103 elif self.index_type == "ivfflat": 

104 # Create quantizer 

105 quantizer = faiss.IndexFlatL2(dimensions) 

106 if metric == faiss.METRIC_INNER_PRODUCT: 

107 index = faiss.IndexIVFFlat(quantizer, dimensions, self.nlist, metric) 

108 else: 

109 index = faiss.IndexIVFFlat(quantizer, dimensions, self.nlist) 

110 

111 elif self.index_type == "hnsw": 

112 index = faiss.IndexHNSWFlat(dimensions, self.m, metric) 

113 index.hnsw.efConstruction = self.ef_construction 

114 index.hnsw.efSearch = self.ef_search 

115 

116 elif self.index_type == "ivfpq": 

117 # Product quantization for compression 

118 m = 8 # Number of subquantizers 

119 quantizer = faiss.IndexFlatL2(dimensions) 

120 index = faiss.IndexIVFPQ(quantizer, dimensions, self.nlist, m, 8) 

121 

122 else: 

123 raise ValueError(f"Unknown index type: {self.index_type}") 

124 

125 # Wrap with IDMap to maintain our own IDs 

126 index = faiss.IndexIDMap(index) 

127 

128 return index 

129 

130 async def close(self) -> None: 

131 """Save and close the index.""" 

132 if self.persist_path and self._initialized: 

133 await self.save() 

134 self._initialized = False 

135 

136 async def add_vectors( 

137 self, 

138 vectors: np.ndarray | list[np.ndarray], 

139 ids: list[str] | None = None, 

140 metadata: list[dict[str, Any]] | None = None, 

141 ) -> list[str]: 

142 """Add vectors to the index.""" 

143 if not self._initialized: 

144 await self.initialize() 

145 

146 import numpy as np 

147 

148 # Prepare vectors using common method 

149 vectors = self._prepare_vector(vectors, normalize=(self.metric == DistanceMetric.COSINE)) 

150 

151 # For Faiss, we need to ensure vectors are C-contiguous 

152 if not vectors.flags['C_CONTIGUOUS']: 

153 vectors = np.ascontiguousarray(vectors) 

154 

155 # Generate IDs if not provided 

156 if ids is None: 

157 ids = [str(uuid4()) for _ in range(len(vectors))] 

158 

159 # Generate metadata if not provided 

160 if metadata is None: 

161 metadata = [{} for _ in range(len(vectors))] 

162 

163 # Train index if needed (for IVF types) 

164 if hasattr(self.index, "is_trained") and not self.index.is_trained: 

165 if len(vectors) >= self.nlist: 

166 self.index.train(vectors) 

167 else: 

168 # Not enough vectors to train, use flat index temporarily 

169 pass 

170 

171 # Map IDs to internal indices 

172 internal_ids = [] 

173 for i, ext_id in enumerate(ids): 

174 internal_id = self.next_idx 

175 self.next_idx += 1 

176 self.id_map[ext_id] = internal_id 

177 self.metadata_store[internal_id] = metadata[i] 

178 internal_ids.append(internal_id) 

179 

180 # Add to index with internal IDs 

181 internal_ids_array = np.array(internal_ids, dtype=np.int64) 

182 self.index.add_with_ids(vectors, internal_ids_array) 

183 

184 return ids 

185 

186 async def get_vectors( 

187 self, 

188 ids: list[str], 

189 include_metadata: bool = True, 

190 ) -> list[tuple[np.ndarray | None, dict[str, Any] | None]]: 

191 """Retrieve vectors by ID.""" 

192 if not self._initialized: 

193 await self.initialize() 

194 

195 

196 results: list[tuple[np.ndarray | None, dict[str, Any] | None]] = [] 

197 for ext_id in ids: 

198 if ext_id not in self.id_map: 

199 results.append((None, None)) 

200 continue 

201 

202 internal_id = self.id_map[ext_id] 

203 

204 # Reconstruct vector from index 

205 try: 

206 vector = self.index.reconstruct(int(internal_id)) 

207 metadata = self.metadata_store.get(internal_id) if include_metadata else None 

208 results.append((vector, metadata)) 

209 except Exception: 

210 results.append((None, None)) 

211 

212 return results 

213 

214 async def delete_vectors(self, ids: list[str]) -> int: 

215 """Delete vectors by ID.""" 

216 if not self._initialized: 

217 await self.initialize() 

218 

219 import numpy as np 

220 

221 # Get internal IDs 

222 internal_ids = [] 

223 for ext_id in ids: 

224 if ext_id in self.id_map: 

225 internal_id = self.id_map[ext_id] 

226 internal_ids.append(internal_id) 

227 del self.id_map[ext_id] 

228 if internal_id in self.metadata_store: 

229 del self.metadata_store[internal_id] 

230 

231 if internal_ids: 

232 # Remove from index 

233 internal_ids_array = np.array(internal_ids, dtype=np.int64) 

234 removed = self.index.remove_ids(internal_ids_array) 

235 return removed 

236 

237 return 0 

238 

239 async def search( 

240 self, 

241 query_vector: np.ndarray, 

242 k: int = 10, 

243 filter: dict[str, Any] | None = None, 

244 include_metadata: bool = True, 

245 ) -> list[tuple[str, float, dict[str, Any] | None]]: 

246 """Search for similar vectors.""" 

247 if not self._initialized: 

248 await self.initialize() 

249 

250 # Prepare query vector using common method 

251 query = self._prepare_vector(query_vector, normalize=(self.metric == DistanceMetric.COSINE)) 

252 

253 # Set search parameters for IVF 

254 if hasattr(self.index, "nprobe"): 

255 self.index.nprobe = self.nprobe 

256 

257 # Search 

258 k = min(k, self.index.ntotal) # Don't search for more than we have 

259 if k == 0: 

260 return [] 

261 

262 scores, indices = self.index.search(query, k) 

263 

264 # Convert results 

265 results = [] 

266 reverse_id_map = {v: k for k, v in self.id_map.items()} 

267 

268 for i in range(len(indices[0])): 

269 internal_id = indices[0][i] 

270 if internal_id == -1: # No result 

271 continue 

272 

273 score = float(scores[0][i]) 

274 

275 # Convert score based on metric 

276 if self.metric == DistanceMetric.COSINE: 

277 # Inner product of normalized vectors = cosine similarity 

278 score = score # noqa: PLW0127 - Keep for clarity 

279 elif self.metric in (DistanceMetric.EUCLIDEAN, DistanceMetric.L2): 

280 # Convert distance to similarity score 

281 score = 1.0 / (1.0 + score) 

282 

283 # Get external ID 

284 ext_id = reverse_id_map.get(internal_id, str(internal_id)) 

285 

286 # Apply metadata filter if provided 

287 metadata = self.metadata_store.get(internal_id) if include_metadata else None 

288 if filter and metadata: 

289 # Simple key-value matching 

290 match = all( 

291 metadata.get(key) == value 

292 for key, value in filter.items() 

293 ) 

294 if not match: 

295 continue 

296 

297 results.append((ext_id, score, metadata)) 

298 

299 return results 

300 

301 async def update_metadata( 

302 self, 

303 ids: list[str], 

304 metadata: list[dict[str, Any]], 

305 ) -> int: 

306 """Update metadata for existing vectors.""" 

307 if not self._initialized: 

308 await self.initialize() 

309 

310 updated = 0 

311 for ext_id, meta in zip(ids, metadata, strict=False): 

312 if ext_id in self.id_map: 

313 internal_id = self.id_map[ext_id] 

314 self.metadata_store[internal_id] = meta 

315 updated += 1 

316 

317 return updated 

318 

319 async def count(self, filter: dict[str, Any] | None = None) -> int: 

320 """Count vectors in the store.""" 

321 if not self._initialized: 

322 await self.initialize() 

323 

324 if filter is None: 

325 return self.index.ntotal 

326 

327 # Count with filter 

328 count = 0 

329 for metadata in self.metadata_store.values(): 

330 match = all( 

331 metadata.get(key) == value 

332 for key, value in filter.items() 

333 ) 

334 if match: 

335 count += 1 

336 

337 return count 

338 

339 async def clear(self) -> None: 

340 """Clear all vectors from the store.""" 

341 if not self._initialized: 

342 await self.initialize() 

343 

344 # Reset everything 

345 self.index = self._create_index() 

346 self.id_map.clear() 

347 self.metadata_store.clear() 

348 self.next_idx = 0 

349 

350 async def save(self) -> None: 

351 """Save index and metadata to disk.""" 

352 if not self.persist_path: 

353 return 

354 

355 # Create directory if needed 

356 os.makedirs(os.path.dirname(self.persist_path), exist_ok=True) 

357 

358 # Save index 

359 faiss.write_index(self.index, self.persist_path) 

360 

361 # Save metadata and mappings 

362 metadata_path = self.persist_path + ".meta" 

363 with open(metadata_path, "wb") as f: 

364 pickle.dump({ 

365 "id_map": self.id_map, 

366 "metadata_store": self.metadata_store, 

367 "next_idx": self.next_idx, 

368 "config": { 

369 "dimensions": self.dimensions, 

370 "metric": self.metric.value, 

371 "index_type": self.index_type, 

372 } 

373 }, f) 

374 

375 async def load(self) -> None: 

376 """Load index and metadata from disk.""" 

377 if not self.persist_path or not os.path.exists(self.persist_path): 

378 return 

379 

380 # Load index 

381 self.index = faiss.read_index(self.persist_path) 

382 

383 # Load metadata and mappings 

384 metadata_path = self.persist_path + ".meta" 

385 if os.path.exists(metadata_path): 

386 with open(metadata_path, "rb") as f: 

387 data = pickle.load(f) 

388 self.id_map = data["id_map"] 

389 self.metadata_store = data["metadata_store"] 

390 self.next_idx = data["next_idx"]