"""Chroma-backed vector store.
Wraps ``chromadb`` for persistent on-disk or hosted Chroma. Lazy
import — install via ``pip install 'jeevesagent[vectorstore-chroma]'``.
Embeddings come from our framework's :class:`Embedder` protocol so
swapping embedders works the same across every vector store. We
pass ``None`` as Chroma's ``embedding_function`` and supply
embeddings ourselves at ``add`` time.
Filter operators are translated from our Mongo-style language to
Chroma's native ``where`` syntax (which already speaks Mongo-ish
``$eq`` / ``$in`` / ``$gt`` etc., so the translation is mostly
direct).
"""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
import anyio
from ..core.ids import new_id
from ..core.protocols import Embedder
from ..loader.base import Chunk
from ._filter import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError
from ._mmr import mmr_select
from .base import SearchResult, _chunks_from_texts
def _translate_filter(
filter: Mapping[str, Any] | None,
) -> dict[str, Any] | None:
"""Translate our Mongo-style filter to Chroma's ``where`` shape.
Chroma already speaks Mongo-ish operators natively, but with two
quirks: scalar shorthand isn't accepted (must be ``{"$eq": v}``),
and the top-level ``$and`` is implicit when multiple keys are
present. We normalize both.
"""
if not filter:
return None
return _xlate_node(filter)
def _xlate_node(node: Mapping[str, Any]) -> dict[str, Any]:
out: dict[str, Any] = {}
for key, value in node.items():
if key in LOGICAL_OPERATORS:
if key == "$not":
# Chroma 0.5+ doesn't have a plain $not; emulate via
# negated comparisons isn't always feasible. Raise
# and ask the caller to invert manually.
raise FilterError(
"$not isn't supported by Chroma. "
"Invert the underlying comparison instead."
)
assert isinstance(value, list)
out[key] = [_xlate_node(sub) for sub in value]
elif key.startswith("$"):
raise FilterError(f"Unknown top-level operator: {key}")
else:
out[key] = _xlate_field(value)
# Chroma expects an explicit $and when there are multiple field
# constraints at the top level.
if len(out) > 1 and not any(k in LOGICAL_OPERATORS for k in out):
return {"$and": [{k: v} for k, v in out.items()]}
return out
def _xlate_field(condition: Any) -> dict[str, Any]:
"""Normalize a field constraint to operator form."""
if isinstance(condition, Mapping) and condition and all(
k.startswith("$") for k in condition
):
for op in condition:
if op not in COMPARISON_OPERATORS:
raise FilterError(f"Unknown field operator: {op}")
return dict(condition)
if isinstance(condition, list | tuple):
return {"$in": list(condition)}
return {"$eq": condition}
[docs]
class ChromaVectorStore:
"""Vector store backed by ``chromadb``."""
name = "chroma"
def __init__(
self,
embedder: Embedder,
*,
collection_name: str = "jeeves_vectors",
persist_directory: str | None = None,
client: Any = None,
) -> None:
if embedder is None:
raise ValueError("embedder is required")
self._embedder = embedder
self._collection_name = collection_name
try:
import chromadb # type: ignore[import-not-found, import-untyped]
except ImportError as exc: # pragma: no cover
raise ImportError(
"chromadb is not installed. "
"Install with: pip install 'jeevesagent[vectorstore-chroma]'."
) from exc
if client is not None:
self._client = client
elif persist_directory is not None:
self._client = chromadb.PersistentClient(
path=persist_directory
)
else:
self._client = chromadb.Client()
self._collection = self._client.get_or_create_collection(
name=collection_name
)
# ---------------------------------------------------------------
# Factory classmethods — explicit kwargs so IDEs autocomplete
# ---------------------------------------------------------------
[docs]
@classmethod
async def from_chunks(
cls,
chunks: list[Chunk],
*,
embedder: Embedder,
ids: list[str] | None = None,
collection_name: str = "jeeves_vectors",
persist_directory: str | None = None,
client: Any = None,
) -> ChromaVectorStore:
"""One-shot: construct a ChromaVectorStore + add ``chunks``."""
store = cls(
embedder=embedder,
collection_name=collection_name,
persist_directory=persist_directory,
client=client,
)
await store.add(chunks, ids=ids)
return store
[docs]
@classmethod
async def from_texts(
cls,
texts: list[str],
*,
embedder: Embedder,
metadatas: list[dict[str, Any]] | None = None,
ids: list[str] | None = None,
collection_name: str = "jeeves_vectors",
persist_directory: str | None = None,
client: Any = None,
) -> ChromaVectorStore:
"""One-shot: construct a ChromaVectorStore from raw text
strings (each becomes a :class:`Chunk` with the matching
metadata dict, or empty if ``metadatas`` is None)."""
return await cls.from_chunks(
_chunks_from_texts(texts, metadatas),
embedder=embedder,
ids=ids,
collection_name=collection_name,
persist_directory=persist_directory,
client=client,
)
@property
def embedder(self) -> Embedder:
return self._embedder
[docs]
async def add(
self,
chunks: list[Chunk],
ids: list[str] | None = None,
) -> list[str]:
if not chunks:
return []
if ids is not None and len(ids) != len(chunks):
raise ValueError(
f"ids length ({len(ids)}) must match chunks "
f"length ({len(chunks)})"
)
try:
vectors = await self._embedder.embed_batch(
[c.content for c in chunks]
)
except (AttributeError, NotImplementedError):
vectors = [
await self._embedder.embed(c.content) for c in chunks
]
assigned = (
list(ids)
if ids is not None
else [new_id("vec") for _ in chunks]
)
contents = [c.content for c in chunks]
# Chroma rejects empty-dict metadatas; supply a sentinel.
metadatas = [
{**c.metadata} if c.metadata else {"_empty": True}
for c in chunks
]
await anyio.to_thread.run_sync(
lambda: self._collection.add(
ids=assigned,
embeddings=vectors,
documents=contents,
metadatas=metadatas,
)
)
return assigned
[docs]
async def delete(self, ids: list[str]) -> None:
if not ids:
return
await anyio.to_thread.run_sync(
lambda: self._collection.delete(ids=list(ids))
)
[docs]
async def get_by_ids(self, ids: list[str]) -> list[Chunk]:
if not ids:
return []
result = await anyio.to_thread.run_sync(
lambda: self._collection.get(ids=list(ids))
)
got_ids = result.get("ids") or []
docs = result.get("documents") or []
metas = result.get("metadatas") or [{}] * len(got_ids)
# Preserve caller order; skip unknowns.
index = {cid: i for i, cid in enumerate(got_ids)}
out: list[Chunk] = []
for cid in ids:
if cid not in index:
continue
i = index[cid]
meta = dict(metas[i] or {})
meta.pop("_empty", None)
out.append(Chunk(content=docs[i] or "", metadata=meta))
return out
[docs]
async def search(
self,
query: str,
*,
k: int = 4,
filter: Mapping[str, Any] | None = None,
diversity: float | None = None,
) -> list[SearchResult]:
q_vec = await self._embedder.embed(query)
return await self.search_by_vector(
q_vec, k=k, filter=filter, diversity=diversity
)
[docs]
async def search_by_vector(
self,
vector: list[float],
*,
k: int = 4,
filter: Mapping[str, Any] | None = None,
diversity: float | None = None,
) -> list[SearchResult]:
where = _translate_filter(filter)
# When diversity is requested, fetch a wider candidate pool
# and rerank in-process via MMR.
n_fetch = max(k * 4, 20) if diversity else k
result = await anyio.to_thread.run_sync(
lambda: self._collection.query(
query_embeddings=[vector],
n_results=n_fetch,
where=where,
include=["documents", "metadatas", "distances", "embeddings"],
)
)
ids_batch = result.get("ids", [[]])
docs_batch = result.get("documents", [[]])
metas_batch = result.get("metadatas", [[]])
dists_batch = result.get("distances", [[]])
embs_batch = result.get("embeddings", [[]])
if not ids_batch or not ids_batch[0]:
return []
candidates: list[SearchResult] = []
cand_vecs: list[list[float]] = []
for cid, doc, meta, dist, emb in zip(
ids_batch[0],
docs_batch[0],
metas_batch[0] or [{}] * len(ids_batch[0]),
dists_batch[0] or [0.0] * len(ids_batch[0]),
(embs_batch[0] if embs_batch else [None] * len(ids_batch[0])),
strict=False,
):
score = max(0.0, 1.0 - float(dist))
chunk_meta = dict(meta or {})
chunk_meta.pop("_empty", None)
candidates.append(
SearchResult(
chunk=Chunk(
content=doc or "",
metadata=chunk_meta,
),
score=score,
id=cid,
)
)
cand_vecs.append(list(emb) if emb is not None else [])
if diversity is None or diversity <= 0:
return candidates[:k]
chosen = mmr_select(
vector, cand_vecs, k, diversity=diversity
)
return [candidates[i] for i in chosen]
[docs]
async def count(self) -> int:
n = await anyio.to_thread.run_sync(self._collection.count)
return int(n)