Source code for jeevesagent.memory.chroma_facts

"""Chroma-backed bi-temporal fact store.

Each fact lives in a Chroma collection as a (id, embedding, document,
metadata) tuple. The metadata carries the bi-temporal fields:

* ``subject`` / ``predicate`` / ``object`` — strings
* ``confidence`` — float
* ``valid_from_ts`` / ``recorded_at_ts`` — unix-epoch floats
* ``valid_until_ts`` — unix-epoch float; ``0.0`` when still valid
* ``currently_valid`` — bool, mirrors ``valid_until_ts == 0`` so we
  can use it directly in Chroma's ``where`` filters
* ``sources`` — JSON-encoded list of episode ids

Supersession is two round-trips: a ``coll.get`` to find the prior
currently-valid facts with matching subject + predicate + different
object, followed by a ``coll.update`` that flips their
``currently_valid`` to false and stamps ``valid_until_ts`` to the new
fact's ``valid_from``.
"""

from __future__ import annotations

import json
from datetime import UTC, datetime
from typing import Any

import anyio

from ..core.protocols import Embedder
from ..core.types import Fact
from .embedder import HashEmbedder

DEFAULT_FACTS_COLLECTION = "jeeves_facts"


[docs] class ChromaFactStore: """Bi-temporal fact store backed by a Chroma collection.""" def __init__( self, client: Any, *, embedder: Embedder | None = None, collection_name: str = DEFAULT_FACTS_COLLECTION, ) -> None: self._client = client self._embedder: Embedder = ( embedder if embedder is not None else HashEmbedder() ) self._collection_name = collection_name self._collection: Any | None = None self._lock = anyio.Lock() # ---- factories -------------------------------------------------------
[docs] @classmethod def local( cls, persist_directory: str, *, embedder: Embedder | None = None, collection_name: str = DEFAULT_FACTS_COLLECTION, ) -> ChromaFactStore: client = _make_client(persist_directory=persist_directory) return cls( client, embedder=embedder, collection_name=collection_name, )
[docs] @classmethod def ephemeral( cls, *, embedder: Embedder | None = None, collection_name: str = DEFAULT_FACTS_COLLECTION, ) -> ChromaFactStore: client = _make_client(persist_directory=None) return cls( client, embedder=embedder, collection_name=collection_name, )
@property def embedder(self) -> Embedder: return self._embedder # ---- collection lazy-init ------------------------------------------- async def _get_collection(self) -> Any: if self._collection is not None: return self._collection coll = await anyio.to_thread.run_sync( lambda: self._client.get_or_create_collection( name=self._collection_name ) ) self._collection = coll return coll # ---- mutation --------------------------------------------------------
[docs] async def append(self, fact: Fact) -> str: triple = _triple_text(fact) embedding = await self._embedder.embed(triple) coll = await self._get_collection() async with self._lock: # Namespace-scoped supersession: only invalidate prior # facts in the same ``user_id`` partition. existing = await anyio.to_thread.run_sync( lambda: coll.get( where={ "$and": [ {"user_id": fact.user_id or ""}, {"subject": fact.subject}, {"predicate": fact.predicate}, {"currently_valid": True}, ] }, include=["metadatas"], ) ) ids_to_close: list[str] = [] metas_to_close: list[dict[str, Any]] = [] for eid, meta in zip( existing.get("ids") or [], existing.get("metadatas") or [], strict=False, ): meta = dict(meta or {}) if meta.get("object") == fact.object: continue # same triple — don't supersede meta["currently_valid"] = False meta["valid_until_ts"] = fact.valid_from.timestamp() ids_to_close.append(eid) metas_to_close.append(meta) if ids_to_close: await anyio.to_thread.run_sync( lambda: coll.update( ids=ids_to_close, metadatas=metas_to_close, ) ) metadata = _fact_to_metadata(fact) await anyio.to_thread.run_sync( lambda: coll.upsert( ids=[fact.id], embeddings=[embedding], documents=[triple], metadatas=[metadata], ) ) return fact.id
# ---- queries ---------------------------------------------------------
[docs] async def query( self, *, subject: str | None = None, predicate: str | None = None, object_: str | None = None, valid_at: datetime | None = None, limit: int = 10, user_id: str | None = None, ) -> list[Fact]: coll = await self._get_collection() where = _build_where(subject, predicate, object_, valid_at, user_id) # Chroma's ``get`` accepts ``limit`` only in newer releases; # fall back to slicing in Python if it raises. def _do_get() -> Any: try: return coll.get( where=where, limit=limit, include=["metadatas"], ) except TypeError: return coll.get(where=where, include=["metadatas"]) result = await anyio.to_thread.run_sync(_do_get) facts = _decode_get(result) # Sort by recorded_at desc; tie-break by valid_from desc. facts.sort( key=lambda f: (f.recorded_at, f.valid_from), reverse=True, ) return facts[:limit]
[docs] async def recall_text( self, query: str, *, limit: int = 5, valid_at: datetime | None = None, user_id: str | None = None, ) -> list[Fact]: coll = await self._get_collection() query_embedding = await self._embedder.embed(query) where = _build_where(None, None, None, valid_at, user_id) result = await anyio.to_thread.run_sync( lambda: coll.query( query_embeddings=[query_embedding], n_results=limit, where=where, include=["metadatas"], ) ) return _decode_query(result)
[docs] async def all_facts(self) -> list[Fact]: coll = await self._get_collection() result = await anyio.to_thread.run_sync( lambda: coll.get(include=["metadatas"]) ) return _decode_get(result)
[docs] async def aclose(self) -> None: return None
# --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_client(*, persist_directory: str | None) -> Any: try: import chromadb except ImportError as exc: # pragma: no cover raise ImportError( "chromadb is not installed. " "Install with: pip install chromadb" ) from exc if persist_directory is None: return chromadb.EphemeralClient() return chromadb.PersistentClient(path=persist_directory) def _triple_text(fact: Fact) -> str: return f"{fact.subject} {fact.predicate} {fact.object}" def _fact_to_metadata(fact: Fact) -> dict[str, Any]: return { # Empty string is the anonymous bucket — Chroma rejects None # metadata values, so we substitute and round-trip on read. "user_id": fact.user_id or "", "subject": fact.subject, "predicate": fact.predicate, "object": fact.object, "confidence": fact.confidence, "valid_from_ts": fact.valid_from.timestamp(), "valid_until_ts": ( fact.valid_until.timestamp() if fact.valid_until is not None else 0.0 ), "currently_valid": fact.valid_until is None, "recorded_at_ts": fact.recorded_at.timestamp(), "sources": json.dumps(list(fact.sources)), } def _metadata_to_fact(eid: str, meta: dict[str, Any]) -> Fact: raw_sources = meta.get("sources", "[]") sources: list[str] = [] if isinstance(raw_sources, str): try: sources = list(json.loads(raw_sources)) except json.JSONDecodeError: sources = [] valid_until: datetime | None = None until_ts = meta.get("valid_until_ts", 0.0) or 0.0 if not meta.get("currently_valid", True) and until_ts > 0: valid_until = datetime.fromtimestamp(float(until_ts), tz=UTC) user_id_raw = str(meta.get("user_id", "")) return Fact( id=eid, user_id=user_id_raw or None, subject=str(meta.get("subject", "")), predicate=str(meta.get("predicate", "")), object=str(meta.get("object", "")), confidence=float(meta.get("confidence", 1.0)), valid_from=datetime.fromtimestamp( float(meta.get("valid_from_ts", 0.0)), tz=UTC ), valid_until=valid_until, recorded_at=datetime.fromtimestamp( float(meta.get("recorded_at_ts", 0.0)), tz=UTC ), sources=sources, ) def _decode_get(result: dict[str, Any]) -> list[Fact]: ids = result.get("ids") or [] metas = result.get("metadatas") or [] facts: list[Fact] = [] for i, eid in enumerate(ids): meta = metas[i] if i < len(metas) and metas[i] is not None else {} facts.append(_metadata_to_fact(str(eid), dict(meta))) return facts def _decode_query(result: dict[str, Any]) -> list[Fact]: """``coll.query`` returns nested lists (one per query). We always pass a single query, so we look at the first row.""" ids_lists = result.get("ids") or [[]] metas_lists = result.get("metadatas") or [[]] ids = ids_lists[0] if ids_lists else [] metas = metas_lists[0] if metas_lists else [] facts: list[Fact] = [] for i, eid in enumerate(ids): meta = metas[i] if i < len(metas) and metas[i] is not None else {} facts.append(_metadata_to_fact(str(eid), dict(meta))) return facts def _build_where( subject: str | None, predicate: str | None, object_: str | None, valid_at: datetime | None, user_id: str | None, ) -> dict[str, Any] | None: """Compose Chroma ``where`` from optional filters. Always pins ``user_id`` (empty string for the anonymous bucket) so recall is namespace-partitioned by default. Multiple filters fold into a single ``$and``.""" clauses: list[dict[str, Any]] = [{"user_id": user_id or ""}] if subject is not None: clauses.append({"subject": subject}) if predicate is not None: clauses.append({"predicate": predicate}) if object_ is not None: clauses.append({"object": object_}) if valid_at is not None: ts = valid_at.timestamp() clauses.append({"valid_from_ts": {"$lte": ts}}) clauses.append( { "$or": [ {"currently_valid": True}, {"valid_until_ts": {"$gt": ts}}, ] } ) if not clauses: return None if len(clauses) == 1: return clauses[0] return {"$and": clauses}