Coverage for src/dataknobs_data/vector/stores/chroma.py: 13%
181 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
1"""Chroma vector store implementation."""
3from __future__ import annotations
5from typing import TYPE_CHECKING, Any
6from uuid import uuid4
8from ..types import DistanceMetric
9from .base import VectorStore
11if TYPE_CHECKING:
12 import numpy as np
14try:
15 import chromadb
16 from chromadb.config import Settings
17 CHROMA_AVAILABLE = True
18except ImportError:
19 CHROMA_AVAILABLE = False
22class ChromaVectorStore(VectorStore):
23 """Chroma-based vector store for semantic search.
25 Chroma is a vector database designed for AI applications with features like:
26 - Built-in embedding functions
27 - Metadata filtering
28 - Persistent storage
29 - Multi-tenancy support
30 """
32 def __init__(self, config: dict[str, Any] | None = None):
33 """Initialize Chroma vector store."""
34 if not CHROMA_AVAILABLE:
35 raise ImportError(
36 "ChromaDB is not installed. Install with: pip install chromadb"
37 )
39 super().__init__(config)
40 self.client = None
41 self.collection = None
43 def _parse_backend_config(self) -> None:
44 """Parse Chroma-specific configuration."""
45 # Set default dimensions if not provided
46 if self.dimensions == 0:
47 self.dimensions = 384 # Default for sentence-transformers
49 # Chroma-specific configuration
50 self.collection_name = self.config.get("collection_name", "vectors")
52 # Handle embedding function
53 self.embedding_function = None
54 if "embedding_function" in self.config:
55 ef = self.config["embedding_function"]
56 if isinstance(ef, str):
57 # Map string to Chroma embedding functions
58 if ef == "default":
59 from chromadb.utils import embedding_functions
60 self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
61 elif ef == "openai":
62 from chromadb.utils import embedding_functions
63 api_key = self.config.get("openai_api_key")
64 self.embedding_function = embedding_functions.OpenAIEmbeddingFunction(api_key=api_key)
65 # Add more as needed
66 else:
67 self.embedding_function = ef
69 # Map distance metrics
70 metric_map = {
71 DistanceMetric.COSINE: "cosine",
72 DistanceMetric.EUCLIDEAN: "l2",
73 DistanceMetric.L2: "l2",
74 DistanceMetric.DOT_PRODUCT: "ip",
75 DistanceMetric.INNER_PRODUCT: "ip",
76 }
77 self.chroma_metric = metric_map.get(self.metric, "cosine")
79 async def initialize(self) -> None:
80 """Initialize Chroma client and collection."""
81 if self._initialized:
82 return
84 # Create client
85 if self.persist_path:
86 # Persistent client
87 self.client = chromadb.PersistentClient(
88 path=self.persist_path,
89 settings=Settings(anonymized_telemetry=False)
90 )
91 else:
92 # In-memory client
93 self.client = chromadb.Client(
94 settings=Settings(anonymized_telemetry=False)
95 )
97 # Get or create collection
98 try:
99 self.collection = self.client.get_collection(
100 name=self.collection_name,
101 embedding_function=self.embedding_function
102 )
103 except Exception:
104 # Collection doesn't exist, create it
105 self.collection = self.client.create_collection(
106 name=self.collection_name,
107 metadata={"hnsw:space": self.chroma_metric},
108 embedding_function=self.embedding_function
109 )
111 self._initialized = True
113 async def close(self) -> None:
114 """Close Chroma client."""
115 # Chroma handles persistence automatically
116 self._initialized = False
118 async def add_vectors(
119 self,
120 vectors: np.ndarray | list[np.ndarray],
121 ids: list[str] | None = None,
122 metadata: list[dict[str, Any]] | None = None,
123 ) -> list[str]:
124 """Add vectors to the collection."""
125 if not self._initialized:
126 await self.initialize()
128 import numpy as np
130 # Convert to list format for Chroma
131 if isinstance(vectors, np.ndarray):
132 if vectors.ndim == 1:
133 vectors = [vectors.tolist()]
134 else:
135 vectors = vectors.tolist()
136 elif isinstance(vectors, list) and len(vectors) > 0:
137 if isinstance(vectors[0], np.ndarray):
138 vectors = [v.tolist() for v in vectors]
140 # Generate IDs if not provided
141 if ids is None:
142 ids = [str(uuid4()) for _ in range(len(vectors))]
144 # Ensure metadata is provided
145 if metadata is None:
146 metadata = [{} for _ in range(len(vectors))]
148 # Add to collection
149 self.collection.add(
150 embeddings=vectors,
151 ids=ids,
152 metadatas=metadata
153 )
155 return ids
157 async def get_vectors(
158 self,
159 ids: list[str],
160 include_metadata: bool = True,
161 ) -> list[tuple[np.ndarray | None, dict[str, Any] | None]]:
162 """Retrieve vectors by ID."""
163 if not self._initialized:
164 await self.initialize()
166 import numpy as np
168 # Get from collection
169 result = self.collection.get(
170 ids=ids,
171 include=["embeddings", "metadatas"] if include_metadata else ["embeddings"]
172 )
174 # Convert to expected format
175 vectors = []
176 for id_val in ids:
177 try:
178 idx = result["ids"].index(id_val)
179 embedding = result["embeddings"][idx] if result["embeddings"] else None
180 metadata = result["metadatas"][idx] if include_metadata and result.get("metadatas") else None
182 if embedding is not None:
183 embedding = np.array(embedding, dtype=np.float32)
185 vectors.append((embedding, metadata))
186 except (ValueError, IndexError):
187 vectors.append((None, None))
189 return vectors
191 async def delete_vectors(self, ids: list[str]) -> int:
192 """Delete vectors by ID."""
193 if not self._initialized:
194 await self.initialize()
196 # Check which IDs exist
197 existing = self.collection.get(ids=ids, include=[])
198 existing_ids = existing["ids"]
200 if existing_ids:
201 self.collection.delete(ids=existing_ids)
202 return len(existing_ids)
204 return 0
206 async def search(
207 self,
208 query_vector: np.ndarray,
209 k: int = 10,
210 filter: dict[str, Any] | None = None,
211 include_metadata: bool = True,
212 ) -> list[tuple[str, float, dict[str, Any] | None]]:
213 """Search for similar vectors."""
214 if not self._initialized:
215 await self.initialize()
217 # Convert query vector
218 if hasattr(query_vector, "tolist"):
219 query_vector = query_vector.tolist()
221 # Build where clause for metadata filtering
222 where = None
223 if filter:
224 # Chroma uses a different filter syntax
225 # Convert simple key-value filter to Chroma format
226 where = {}
227 for key, value in filter.items():
228 if isinstance(value, list):
229 where[key] = {"$in": value}
230 else:
231 where[key] = {"$eq": value}
233 # Search
234 results = self.collection.query(
235 query_embeddings=[query_vector],
236 n_results=k,
237 where=where,
238 include=["metadatas", "distances"] if include_metadata else ["distances"]
239 )
241 # Convert results
242 search_results = []
243 if results["ids"] and len(results["ids"]) > 0:
244 ids = results["ids"][0]
245 distances = results["distances"][0] if results.get("distances") else [0] * len(ids)
246 metadatas = results["metadatas"][0] if include_metadata and results.get("metadatas") else [None] * len(ids)
248 for id_val, distance, metadata in zip(ids, distances, metadatas, strict=False):
249 # Convert distance to similarity score
250 if self.metric == DistanceMetric.COSINE:
251 # Chroma returns cosine distance (1 - similarity)
252 score = 1.0 - distance
253 elif self.metric in (DistanceMetric.EUCLIDEAN, DistanceMetric.L2):
254 # Convert distance to similarity
255 score = 1.0 / (1.0 + distance)
256 else:
257 score = float(distance)
259 search_results.append((id_val, score, metadata))
261 return search_results
263 async def update_metadata(
264 self,
265 ids: list[str],
266 metadata: list[dict[str, Any]],
267 ) -> int:
268 """Update metadata for existing vectors."""
269 if not self._initialized:
270 await self.initialize()
272 # Check which IDs exist
273 existing = self.collection.get(ids=ids, include=[])
274 existing_ids = existing["ids"]
276 if existing_ids:
277 # Filter metadata to only update existing vectors
278 filtered_ids = []
279 filtered_metadata = []
280 for id_val, meta in zip(ids, metadata, strict=False):
281 if id_val in existing_ids:
282 filtered_ids.append(id_val)
283 filtered_metadata.append(meta)
285 if filtered_ids:
286 self.collection.update(
287 ids=filtered_ids,
288 metadatas=filtered_metadata
289 )
290 return len(filtered_ids)
292 return 0
294 async def count(self, filter: dict[str, Any] | None = None) -> int:
295 """Count vectors in the collection."""
296 if not self._initialized:
297 await self.initialize()
299 if filter is None:
300 # Get total count
301 return self.collection.count()
303 # Count with filter
304 where = {}
305 for key, value in filter.items():
306 if isinstance(value, list):
307 where[key] = {"$in": value}
308 else:
309 where[key] = {"$eq": value}
311 # Query with limit 1 to get count efficiently
312 results = self.collection.query(
313 query_embeddings=[[0.0] * self.dimensions], # Dummy query
314 n_results=1,
315 where=where,
316 include=[]
317 )
319 # The actual count would need a different approach
320 # For now, return the number of results (limited approach)
321 # In production, you might want to maintain counts separately
322 return len(results["ids"][0]) if results["ids"] else 0
324 async def clear(self) -> None:
325 """Clear all vectors from the collection."""
326 if not self._initialized:
327 await self.initialize()
329 # Delete and recreate collection
330 self.client.delete_collection(name=self.collection_name)
331 self.collection = self.client.create_collection(
332 name=self.collection_name,
333 metadata={"hnsw:space": self.chroma_metric},
334 embedding_function=self.embedding_function
335 )
337 async def add_documents(
338 self,
339 documents: list[str],
340 ids: list[str] | None = None,
341 metadata: list[dict[str, Any]] | None = None,
342 ) -> list[str]:
343 """Add documents to the collection (uses Chroma's embedding)."""
344 if not self._initialized:
345 await self.initialize()
347 # Generate IDs if not provided
348 if ids is None:
349 ids = [str(uuid4()) for _ in range(len(documents))]
351 # Ensure metadata is provided
352 if metadata is None:
353 metadata = [{} for _ in range(len(documents))]
355 # Add documents (Chroma will embed them if embedding_function is set)
356 self.collection.add(
357 documents=documents,
358 ids=ids,
359 metadatas=metadata
360 )
362 return ids
364 async def search_documents(
365 self,
366 query_text: str,
367 k: int = 10,
368 filter: dict[str, Any] | None = None,
369 include_metadata: bool = True,
370 ) -> list[tuple[str, float, str, dict[str, Any] | None]]:
371 """Search using text query (uses Chroma's embedding)."""
372 if not self._initialized:
373 await self.initialize()
375 # Build where clause
376 where = None
377 if filter:
378 where = {}
379 for key, value in filter.items():
380 if isinstance(value, list):
381 where[key] = {"$in": value}
382 else:
383 where[key] = {"$eq": value}
385 # Query with text
386 results = self.collection.query(
387 query_texts=[query_text],
388 n_results=k,
389 where=where,
390 include=["documents", "metadatas", "distances"]
391 )
393 # Convert results
394 search_results = []
395 if results["ids"] and len(results["ids"]) > 0:
396 ids = results["ids"][0]
397 distances = results["distances"][0]
398 documents = results["documents"][0] if results.get("documents") else [None] * len(ids)
399 metadatas = results["metadatas"][0] if include_metadata and results.get("metadatas") else [None] * len(ids)
401 for id_val, distance, doc, metadata in zip(ids, distances, documents, metadatas, strict=False):
402 # Convert distance to similarity score
403 score = 1.0 - distance # Cosine distance to similarity
404 search_results.append((id_val, score, doc, metadata))
406 return search_results