Source code for jeevesagent.memory.chroma

"""Memory backed by Chroma (local persistent or in-memory client).

Chroma's Python API is sync; we dispatch every blocking call to a
worker thread via :func:`anyio.to_thread.run_sync` so the event loop
stays free.

Working blocks are kept in process memory (small, re-derivable);
episodes go to Chroma. The collection is created lazily on first use
and — if a ``persist_directory`` was supplied — survives process
restarts.
"""

from __future__ import annotations

from datetime import UTC, datetime
from typing import Any

import anyio

from ..core.protocols import Embedder
from ..core.types import (
    Episode,
    Fact,
    MemoryBlock,
    MemoryExport,
    MemoryProfile,
    Message,
    Role,
)
from .embedder import HashEmbedder

DEFAULT_COLLECTION = "jeeves_episodes"


[docs] class ChromaMemory: """Memory backed by ``chromadb``. Construct via :meth:`local` for an on-disk persistent client or :meth:`ephemeral` for a process-local in-memory client. """ def __init__( self, client: Any, *, embedder: Embedder | None = None, collection_name: str = DEFAULT_COLLECTION, fact_store: Any | None = None, ) -> None: self._client = client self._embedder: Embedder = embedder if embedder is not None else HashEmbedder() self._collection_name = collection_name self._collection: Any | None = None # Working blocks partition by ``user_id``; key is ``(user_id, name)``. self._blocks: dict[tuple[str | None, str], MemoryBlock] = {} self._lock = anyio.Lock() # ``facts`` is the Agent loop's hook for surfacing semantic # claims into the model's context. Defaults to ``None`` to # avoid creating a second Chroma collection by surprise; pass # an explicit :class:`ChromaFactStore` or use # :meth:`ChromaMemory.ephemeral` / :meth:`ChromaMemory.local` # with ``with_facts=True`` to wire one in. self.facts: Any | None = fact_store # ---- factory ---------------------------------------------------------
[docs] @classmethod def local( cls, persist_directory: str, *, embedder: Embedder | None = None, collection_name: str = DEFAULT_COLLECTION, with_facts: bool = False, facts_collection_name: str = "jeeves_facts", ) -> ChromaMemory: """Persistent on-disk client at ``persist_directory``. ``with_facts=True`` attaches a :class:`ChromaFactStore` rooted at the same client so facts persist alongside episodes in the same on-disk store. """ client = _make_client(persist_directory=persist_directory) instance = cls( client, embedder=embedder, collection_name=collection_name ) if with_facts: from .chroma_facts import ChromaFactStore instance.facts = ChromaFactStore( client, embedder=instance._embedder, collection_name=facts_collection_name, ) return instance
[docs] @classmethod def ephemeral( cls, *, embedder: Embedder | None = None, collection_name: str = DEFAULT_COLLECTION, with_facts: bool = False, facts_collection_name: str = "jeeves_facts", ) -> ChromaMemory: """In-memory client (lost on process exit). Great for tests.""" client = _make_client(persist_directory=None) instance = cls( client, embedder=embedder, collection_name=collection_name ) if with_facts: from .chroma_facts import ChromaFactStore instance.facts = ChromaFactStore( client, embedder=instance._embedder, collection_name=facts_collection_name, ) return instance
# ---- collection lazy-init ------------------------------------------- async def _get_collection(self) -> Any: if self._collection is not None: return self._collection # ``get_or_create_collection`` is sync; dispatch to thread. coll = await anyio.to_thread.run_sync( lambda: self._client.get_or_create_collection( name=self._collection_name ) ) self._collection = coll return coll # ---- working blocks --------------------------------------------------
[docs] async def working( self, *, user_id: str | None = None ) -> list[MemoryBlock]: async with self._lock: scoped = [ b for (uid, _name), b in self._blocks.items() if uid == user_id ] return sorted(scoped, key=lambda b: b.pinned_order)
[docs] async def update_block( self, name: str, content: str, *, user_id: str | None = None ) -> None: key = (user_id, name) async with self._lock: existing = self._blocks.get(key) user_count = sum( 1 for (uid, _) in self._blocks if uid == user_id ) self._blocks[key] = MemoryBlock( name=name, content=content, pinned_order=existing.pinned_order if existing else user_count, )
[docs] async def append_block( self, name: str, content: str, *, user_id: str | None = None ) -> None: key = (user_id, name) async with self._lock: existing = self._blocks.get(key) if existing is None: user_count = sum( 1 for (uid, _) in self._blocks if uid == user_id ) self._blocks[key] = MemoryBlock( name=name, content=content, pinned_order=user_count, ) else: self._blocks[key] = MemoryBlock( name=name, content=existing.content + content, pinned_order=existing.pinned_order, )
# ---- episodes --------------------------------------------------------
[docs] async def remember(self, episode: Episode) -> str: if episode.embedding is None: text = _embedding_text(episode) embedding = await self._embedder.embed(text) episode = episode.model_copy(update={"embedding": embedding}) coll = await self._get_collection() document = _embedding_text(episode) # Store ``user_id`` as a metadata field so Chroma's ``where`` # filter can partition recall queries natively. Chroma rejects # ``None`` metadata values, so we substitute the empty string # for the anonymous bucket and round-trip on read. metadata = { "session_id": episode.session_id, "user_id": episode.user_id or "", "input": episode.input, "output": episode.output, "occurred_at": episode.occurred_at.isoformat(), } embedding = list(episode.embedding) if episode.embedding else [] await anyio.to_thread.run_sync( lambda: coll.upsert( ids=[episode.id], embeddings=[embedding], documents=[document], metadatas=[metadata], ) ) return episode.id
[docs] async def recall( self, query: str, *, kind: str = "episodic", limit: int = 5, time_range: tuple[datetime, datetime] | None = None, user_id: str | None = None, ) -> list[Episode]: coll = await self._get_collection() if not query.strip(): return await self._recall_recent(coll, limit, time_range, user_id) query_embedding = list(await self._embedder.embed(query)) # Hard namespace partition by ``user_id``, pushed into Chroma's # native ``where`` filter so we don't waste a round-trip on # other users' rows. Empty string is the anonymous bucket. where_filter = {"user_id": user_id or ""} result = await anyio.to_thread.run_sync( lambda: coll.query( query_embeddings=[query_embedding], n_results=limit, where=where_filter, ) ) episodes = _decode_query_result(result) if time_range is not None: lo, hi = time_range episodes = [e for e in episodes if lo <= e.occurred_at <= hi] return episodes
async def _recall_recent( self, coll: Any, limit: int, time_range: tuple[datetime, datetime] | None, user_id: str | None, ) -> list[Episode]: where_filter = {"user_id": user_id or ""} result = await anyio.to_thread.run_sync( lambda: coll.get( limit=None, # we'll sort + slice ourselves where=where_filter, include=["metadatas", "documents", "embeddings"], ) ) episodes = _decode_get_result(result) if time_range is not None: lo, hi = time_range episodes = [e for e in episodes if lo <= e.occurred_at <= hi] episodes.sort(key=lambda e: e.occurred_at, reverse=True) return episodes[:limit]
[docs] async def recall_facts( self, query: str, *, limit: int = 5, valid_at: datetime | None = None, user_id: str | None = None, ) -> list[Fact]: if self.facts is None: return [] return list( await self.facts.recall_text( query, limit=limit, valid_at=valid_at, user_id=user_id ) )
[docs] async def session_messages( self, session_id: str, *, user_id: str | None = None, limit: int = 20, ) -> list[Message]: coll = await self._get_collection() # Native ``where`` filter — namespace partition on user_id + # session pin. Empty-string is the anonymous bucket on disk. where_filter = { "$and": [ {"user_id": user_id or ""}, {"session_id": session_id}, ] } result = await anyio.to_thread.run_sync( lambda: coll.get( where=where_filter, include=["metadatas", "documents", "embeddings"], ) ) episodes = _decode_get_result(result) episodes.sort(key=lambda e: e.occurred_at) max_episodes = max(1, limit // 2) episodes = episodes[-max_episodes:] out: list[Message] = [] for ep in episodes: if ep.input: out.append(Message(role=Role.USER, content=ep.input)) if ep.output: out.append(Message(role=Role.ASSISTANT, content=ep.output)) return out
# ---- profile / forget / export (GDPR) -------------------------------
[docs] async def profile( self, *, user_id: str | None = None ) -> MemoryProfile: coll = await self._get_collection() where_filter = {"user_id": user_id or ""} result = await anyio.to_thread.run_sync( lambda: coll.get( where=where_filter, include=["metadatas", "documents", "embeddings"] ) ) episodes = _decode_get_result(result) last_seen: datetime | None = ( max(e.occurred_at for e in episodes) if episodes else None ) seen: set[str] = set() recent_sessions: list[str] = [] for e in sorted(episodes, key=lambda x: x.occurred_at, reverse=True): if e.session_id in seen: continue seen.add(e.session_id) recent_sessions.append(e.session_id) if len(recent_sessions) >= 10: break sample_facts: list[Fact] = [] fact_count = 0 if self.facts is not None: sample_facts = list( await self.facts.query(user_id=user_id, limit=10) ) all_facts = await self.facts.query(user_id=user_id, limit=100_000) fact_count = len(all_facts) return MemoryProfile( user_id=user_id, episode_count=len(episodes), fact_count=fact_count, last_seen=last_seen, recent_sessions=recent_sessions, sample_facts=sample_facts, )
[docs] async def forget( self, *, user_id: str | None = None, session_id: str | None = None, before: datetime | None = None, ) -> int: coll = await self._get_collection() where_filter: dict[str, Any] = {"user_id": user_id or ""} # Chroma's where filter doesn't natively support "<" on # numeric strings; fetch then post-filter for the time-range # case. Session filter we can push down. if session_id is not None: where_filter = { "$and": [ {"user_id": user_id or ""}, {"session_id": session_id}, ] } result = await anyio.to_thread.run_sync( lambda: coll.get(where=where_filter, include=["metadatas"]) ) ids = list(result.get("ids") or []) if before is not None: metas = list(result.get("metadatas") or []) keep_idx = [] for i, meta in enumerate(metas): if meta is None: continue ts = meta.get("occurred_at") if isinstance(ts, str): try: if datetime.fromisoformat(ts) < before: keep_idx.append(i) except ValueError: pass ids = [ids[i] for i in keep_idx] if ids: await anyio.to_thread.run_sync(lambda: coll.delete(ids=ids)) deleted = len(ids) # Facts: rely on the FactStore's ``query`` + per-id deletion. if session_id is None and self.facts is not None: facts = await self.facts.query(user_id=user_id, limit=100_000) if before is not None: facts = [f for f in facts if f.recorded_at < before] # ChromaFactStore stores facts in its own collection; # rely on the public ``query`` + private ``_collection`` # for delete (no public delete method yet). if facts and hasattr(self.facts, "_collection"): fact_coll = self.facts._collection # type: ignore[attr-defined] fact_ids = [f.id for f in facts] await anyio.to_thread.run_sync( lambda: fact_coll.delete(ids=fact_ids) ) deleted += len(fact_ids) return deleted
[docs] async def export( self, *, user_id: str | None = None ) -> MemoryExport: coll = await self._get_collection() result = await anyio.to_thread.run_sync( lambda: coll.get( where={"user_id": user_id or ""}, include=["metadatas", "documents", "embeddings"], ) ) episodes = _decode_get_result(result) facts: list[Fact] = [] if self.facts is not None: facts = list( await self.facts.query(user_id=user_id, limit=100_000) ) return MemoryExport( user_id=user_id, episodes=sorted(episodes, key=lambda e: e.occurred_at), facts=sorted(facts, key=lambda f: f.recorded_at), )
[docs] async def consolidate(self) -> None: return None
# --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_client(*, persist_directory: str | None) -> Any: try: import chromadb except ImportError as exc: # pragma: no cover — depends on user env raise ImportError( "chromadb is not installed. " "Install with: pip install chromadb" ) from exc if persist_directory is None: return chromadb.EphemeralClient() return chromadb.PersistentClient(path=persist_directory) def _embedding_text(episode: Episode) -> str: return "\n".join(p for p in (episode.input, episode.output) if p) def _parse_occurred(meta: dict[str, Any]) -> datetime: raw = meta.get("occurred_at") if isinstance(raw, str): try: return datetime.fromisoformat(raw) except ValueError: pass return datetime.now(UTC) def _safe_list(result: dict[str, Any], key: str) -> list[Any]: """``result[key]`` may be None, a list, or (for embeddings) a numpy array. ``or []`` doesn't work on numpy arrays — they raise ``ValueError: The truth value of an array... is ambiguous`` — so we use an explicit None check.""" val = result.get(key) return list(val) if val is not None else [] def _decode_query_result(result: dict[str, Any]) -> list[Episode]: """Translate a Chroma ``query()`` result into our Episodes.""" ids_lists = _safe_list(result, "ids") metas_lists = _safe_list(result, "metadatas") embeds_lists = _safe_list(result, "embeddings") ids = list(ids_lists[0]) if ids_lists else [] metas = list(metas_lists[0]) if metas_lists else [] embeds = list(embeds_lists[0]) if embeds_lists else [] return _episodes_from_parallel(ids, metas, embeds) def _decode_get_result(result: dict[str, Any]) -> list[Episode]: """Translate a Chroma ``get()`` result (flat lists) into Episodes.""" ids = _safe_list(result, "ids") metas = _safe_list(result, "metadatas") embeds = _safe_list(result, "embeddings") return _episodes_from_parallel(ids, metas, embeds) def _episodes_from_parallel( ids: list[Any], metas: list[Any], embeds: list[Any], ) -> list[Episode]: episodes: list[Episode] = [] for i, eid in enumerate(ids): meta = metas[i] if i < len(metas) and metas[i] is not None else {} emb = list(embeds[i]) if i < len(embeds) else None # Chroma can't store ``None`` so the anonymous bucket is the # empty string on the wire; round-trip back to ``None`` here. user_id_raw = str(meta.get("user_id", "")) episodes.append( Episode( id=str(eid), session_id=str(meta.get("session_id", "")), user_id=user_id_raw or None, occurred_at=_parse_occurred(meta), input=str(meta.get("input", "")), output=str(meta.get("output", "")), embedding=emb, ) ) return episodes