Coverage for src/dataknobs_data/vector/stores/chroma.py: 13%

181 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-29 14:14 -0600

1"""Chroma vector store implementation.""" 

2 

3from __future__ import annotations 

4 

5from typing import TYPE_CHECKING, Any 

6from uuid import uuid4 

7 

8from ..types import DistanceMetric 

9from .base import VectorStore 

10 

11if TYPE_CHECKING: 

12 import numpy as np 

13 

14try: 

15 import chromadb 

16 from chromadb.config import Settings 

17 CHROMA_AVAILABLE = True 

18except ImportError: 

19 CHROMA_AVAILABLE = False 

20 

21 

22class ChromaVectorStore(VectorStore): 

23 """Chroma-based vector store for semantic search. 

24  

25 Chroma is a vector database designed for AI applications with features like: 

26 - Built-in embedding functions 

27 - Metadata filtering 

28 - Persistent storage 

29 - Multi-tenancy support 

30 """ 

31 

32 def __init__(self, config: dict[str, Any] | None = None): 

33 """Initialize Chroma vector store.""" 

34 if not CHROMA_AVAILABLE: 

35 raise ImportError( 

36 "ChromaDB is not installed. Install with: pip install chromadb" 

37 ) 

38 

39 super().__init__(config) 

40 self.client = None 

41 self.collection = None 

42 

43 def _parse_backend_config(self) -> None: 

44 """Parse Chroma-specific configuration.""" 

45 # Set default dimensions if not provided 

46 if self.dimensions == 0: 

47 self.dimensions = 384 # Default for sentence-transformers 

48 

49 # Chroma-specific configuration 

50 self.collection_name = self.config.get("collection_name", "vectors") 

51 

52 # Handle embedding function 

53 self.embedding_function = None 

54 if "embedding_function" in self.config: 

55 ef = self.config["embedding_function"] 

56 if isinstance(ef, str): 

57 # Map string to Chroma embedding functions 

58 if ef == "default": 

59 from chromadb.utils import embedding_functions 

60 self.embedding_function = embedding_functions.DefaultEmbeddingFunction() 

61 elif ef == "openai": 

62 from chromadb.utils import embedding_functions 

63 api_key = self.config.get("openai_api_key") 

64 self.embedding_function = embedding_functions.OpenAIEmbeddingFunction(api_key=api_key) 

65 # Add more as needed 

66 else: 

67 self.embedding_function = ef 

68 

69 # Map distance metrics 

70 metric_map = { 

71 DistanceMetric.COSINE: "cosine", 

72 DistanceMetric.EUCLIDEAN: "l2", 

73 DistanceMetric.L2: "l2", 

74 DistanceMetric.DOT_PRODUCT: "ip", 

75 DistanceMetric.INNER_PRODUCT: "ip", 

76 } 

77 self.chroma_metric = metric_map.get(self.metric, "cosine") 

78 

79 async def initialize(self) -> None: 

80 """Initialize Chroma client and collection.""" 

81 if self._initialized: 

82 return 

83 

84 # Create client 

85 if self.persist_path: 

86 # Persistent client 

87 self.client = chromadb.PersistentClient( 

88 path=self.persist_path, 

89 settings=Settings(anonymized_telemetry=False) 

90 ) 

91 else: 

92 # In-memory client 

93 self.client = chromadb.Client( 

94 settings=Settings(anonymized_telemetry=False) 

95 ) 

96 

97 # Get or create collection 

98 try: 

99 self.collection = self.client.get_collection( 

100 name=self.collection_name, 

101 embedding_function=self.embedding_function 

102 ) 

103 except Exception: 

104 # Collection doesn't exist, create it 

105 self.collection = self.client.create_collection( 

106 name=self.collection_name, 

107 metadata={"hnsw:space": self.chroma_metric}, 

108 embedding_function=self.embedding_function 

109 ) 

110 

111 self._initialized = True 

112 

113 async def close(self) -> None: 

114 """Close Chroma client.""" 

115 # Chroma handles persistence automatically 

116 self._initialized = False 

117 

118 async def add_vectors( 

119 self, 

120 vectors: np.ndarray | list[np.ndarray], 

121 ids: list[str] | None = None, 

122 metadata: list[dict[str, Any]] | None = None, 

123 ) -> list[str]: 

124 """Add vectors to the collection.""" 

125 if not self._initialized: 

126 await self.initialize() 

127 

128 import numpy as np 

129 

130 # Convert to list format for Chroma 

131 if isinstance(vectors, np.ndarray): 

132 if vectors.ndim == 1: 

133 vectors = [vectors.tolist()] 

134 else: 

135 vectors = vectors.tolist() 

136 elif isinstance(vectors, list) and len(vectors) > 0: 

137 if isinstance(vectors[0], np.ndarray): 

138 vectors = [v.tolist() for v in vectors] 

139 

140 # Generate IDs if not provided 

141 if ids is None: 

142 ids = [str(uuid4()) for _ in range(len(vectors))] 

143 

144 # Ensure metadata is provided 

145 if metadata is None: 

146 metadata = [{} for _ in range(len(vectors))] 

147 

148 # Add to collection 

149 self.collection.add( 

150 embeddings=vectors, 

151 ids=ids, 

152 metadatas=metadata 

153 ) 

154 

155 return ids 

156 

157 async def get_vectors( 

158 self, 

159 ids: list[str], 

160 include_metadata: bool = True, 

161 ) -> list[tuple[np.ndarray | None, dict[str, Any] | None]]: 

162 """Retrieve vectors by ID.""" 

163 if not self._initialized: 

164 await self.initialize() 

165 

166 import numpy as np 

167 

168 # Get from collection 

169 result = self.collection.get( 

170 ids=ids, 

171 include=["embeddings", "metadatas"] if include_metadata else ["embeddings"] 

172 ) 

173 

174 # Convert to expected format 

175 vectors = [] 

176 for id_val in ids: 

177 try: 

178 idx = result["ids"].index(id_val) 

179 embedding = result["embeddings"][idx] if result["embeddings"] else None 

180 metadata = result["metadatas"][idx] if include_metadata and result.get("metadatas") else None 

181 

182 if embedding is not None: 

183 embedding = np.array(embedding, dtype=np.float32) 

184 

185 vectors.append((embedding, metadata)) 

186 except (ValueError, IndexError): 

187 vectors.append((None, None)) 

188 

189 return vectors 

190 

191 async def delete_vectors(self, ids: list[str]) -> int: 

192 """Delete vectors by ID.""" 

193 if not self._initialized: 

194 await self.initialize() 

195 

196 # Check which IDs exist 

197 existing = self.collection.get(ids=ids, include=[]) 

198 existing_ids = existing["ids"] 

199 

200 if existing_ids: 

201 self.collection.delete(ids=existing_ids) 

202 return len(existing_ids) 

203 

204 return 0 

205 

206 async def search( 

207 self, 

208 query_vector: np.ndarray, 

209 k: int = 10, 

210 filter: dict[str, Any] | None = None, 

211 include_metadata: bool = True, 

212 ) -> list[tuple[str, float, dict[str, Any] | None]]: 

213 """Search for similar vectors.""" 

214 if not self._initialized: 

215 await self.initialize() 

216 

217 # Convert query vector 

218 if hasattr(query_vector, "tolist"): 

219 query_vector = query_vector.tolist() 

220 

221 # Build where clause for metadata filtering 

222 where = None 

223 if filter: 

224 # Chroma uses a different filter syntax 

225 # Convert simple key-value filter to Chroma format 

226 where = {} 

227 for key, value in filter.items(): 

228 if isinstance(value, list): 

229 where[key] = {"$in": value} 

230 else: 

231 where[key] = {"$eq": value} 

232 

233 # Search 

234 results = self.collection.query( 

235 query_embeddings=[query_vector], 

236 n_results=k, 

237 where=where, 

238 include=["metadatas", "distances"] if include_metadata else ["distances"] 

239 ) 

240 

241 # Convert results 

242 search_results = [] 

243 if results["ids"] and len(results["ids"]) > 0: 

244 ids = results["ids"][0] 

245 distances = results["distances"][0] if results.get("distances") else [0] * len(ids) 

246 metadatas = results["metadatas"][0] if include_metadata and results.get("metadatas") else [None] * len(ids) 

247 

248 for id_val, distance, metadata in zip(ids, distances, metadatas, strict=False): 

249 # Convert distance to similarity score 

250 if self.metric == DistanceMetric.COSINE: 

251 # Chroma returns cosine distance (1 - similarity) 

252 score = 1.0 - distance 

253 elif self.metric in (DistanceMetric.EUCLIDEAN, DistanceMetric.L2): 

254 # Convert distance to similarity 

255 score = 1.0 / (1.0 + distance) 

256 else: 

257 score = float(distance) 

258 

259 search_results.append((id_val, score, metadata)) 

260 

261 return search_results 

262 

263 async def update_metadata( 

264 self, 

265 ids: list[str], 

266 metadata: list[dict[str, Any]], 

267 ) -> int: 

268 """Update metadata for existing vectors.""" 

269 if not self._initialized: 

270 await self.initialize() 

271 

272 # Check which IDs exist 

273 existing = self.collection.get(ids=ids, include=[]) 

274 existing_ids = existing["ids"] 

275 

276 if existing_ids: 

277 # Filter metadata to only update existing vectors 

278 filtered_ids = [] 

279 filtered_metadata = [] 

280 for id_val, meta in zip(ids, metadata, strict=False): 

281 if id_val in existing_ids: 

282 filtered_ids.append(id_val) 

283 filtered_metadata.append(meta) 

284 

285 if filtered_ids: 

286 self.collection.update( 

287 ids=filtered_ids, 

288 metadatas=filtered_metadata 

289 ) 

290 return len(filtered_ids) 

291 

292 return 0 

293 

294 async def count(self, filter: dict[str, Any] | None = None) -> int: 

295 """Count vectors in the collection.""" 

296 if not self._initialized: 

297 await self.initialize() 

298 

299 if filter is None: 

300 # Get total count 

301 return self.collection.count() 

302 

303 # Count with filter 

304 where = {} 

305 for key, value in filter.items(): 

306 if isinstance(value, list): 

307 where[key] = {"$in": value} 

308 else: 

309 where[key] = {"$eq": value} 

310 

311 # Query with limit 1 to get count efficiently 

312 results = self.collection.query( 

313 query_embeddings=[[0.0] * self.dimensions], # Dummy query 

314 n_results=1, 

315 where=where, 

316 include=[] 

317 ) 

318 

319 # The actual count would need a different approach 

320 # For now, return the number of results (limited approach) 

321 # In production, you might want to maintain counts separately 

322 return len(results["ids"][0]) if results["ids"] else 0 

323 

324 async def clear(self) -> None: 

325 """Clear all vectors from the collection.""" 

326 if not self._initialized: 

327 await self.initialize() 

328 

329 # Delete and recreate collection 

330 self.client.delete_collection(name=self.collection_name) 

331 self.collection = self.client.create_collection( 

332 name=self.collection_name, 

333 metadata={"hnsw:space": self.chroma_metric}, 

334 embedding_function=self.embedding_function 

335 ) 

336 

337 async def add_documents( 

338 self, 

339 documents: list[str], 

340 ids: list[str] | None = None, 

341 metadata: list[dict[str, Any]] | None = None, 

342 ) -> list[str]: 

343 """Add documents to the collection (uses Chroma's embedding).""" 

344 if not self._initialized: 

345 await self.initialize() 

346 

347 # Generate IDs if not provided 

348 if ids is None: 

349 ids = [str(uuid4()) for _ in range(len(documents))] 

350 

351 # Ensure metadata is provided 

352 if metadata is None: 

353 metadata = [{} for _ in range(len(documents))] 

354 

355 # Add documents (Chroma will embed them if embedding_function is set) 

356 self.collection.add( 

357 documents=documents, 

358 ids=ids, 

359 metadatas=metadata 

360 ) 

361 

362 return ids 

363 

364 async def search_documents( 

365 self, 

366 query_text: str, 

367 k: int = 10, 

368 filter: dict[str, Any] | None = None, 

369 include_metadata: bool = True, 

370 ) -> list[tuple[str, float, str, dict[str, Any] | None]]: 

371 """Search using text query (uses Chroma's embedding).""" 

372 if not self._initialized: 

373 await self.initialize() 

374 

375 # Build where clause 

376 where = None 

377 if filter: 

378 where = {} 

379 for key, value in filter.items(): 

380 if isinstance(value, list): 

381 where[key] = {"$in": value} 

382 else: 

383 where[key] = {"$eq": value} 

384 

385 # Query with text 

386 results = self.collection.query( 

387 query_texts=[query_text], 

388 n_results=k, 

389 where=where, 

390 include=["documents", "metadatas", "distances"] 

391 ) 

392 

393 # Convert results 

394 search_results = [] 

395 if results["ids"] and len(results["ids"]) > 0: 

396 ids = results["ids"][0] 

397 distances = results["distances"][0] 

398 documents = results["documents"][0] if results.get("documents") else [None] * len(ids) 

399 metadatas = results["metadatas"][0] if include_metadata and results.get("metadatas") else [None] * len(ids) 

400 

401 for id_val, distance, doc, metadata in zip(ids, distances, documents, metadatas, strict=False): 

402 # Convert distance to similarity score 

403 score = 1.0 - distance # Cosine distance to similarity 

404 search_results.append((id_val, score, doc, metadata)) 

405 

406 return search_results