"""Memory backed by Chroma (local persistent or in-memory client).
Chroma's Python API is sync; we dispatch every blocking call to a
worker thread via :func:`anyio.to_thread.run_sync` so the event loop
stays free.
Working blocks are kept in process memory (small, re-derivable);
episodes go to Chroma. The collection is created lazily on first use
and — if a ``persist_directory`` was supplied — survives process
restarts.
"""
from __future__ import annotations
from datetime import UTC, datetime
from typing import Any
import anyio
from ..core.protocols import Embedder
from ..core.types import Episode, Fact, MemoryBlock, Message, Role
from .embedder import HashEmbedder
DEFAULT_COLLECTION = "jeeves_episodes"
[docs]
class ChromaMemory:
"""Memory backed by ``chromadb``.
Construct via :meth:`local` for an on-disk persistent client or
:meth:`ephemeral` for a process-local in-memory client.
"""
def __init__(
self,
client: Any,
*,
embedder: Embedder | None = None,
collection_name: str = DEFAULT_COLLECTION,
fact_store: Any | None = None,
) -> None:
self._client = client
self._embedder: Embedder = embedder if embedder is not None else HashEmbedder()
self._collection_name = collection_name
self._collection: Any | None = None
self._blocks: dict[str, MemoryBlock] = {}
self._lock = anyio.Lock()
# ``facts`` is the Agent loop's hook for surfacing semantic
# claims into the model's context. Defaults to ``None`` to
# avoid creating a second Chroma collection by surprise; pass
# an explicit :class:`ChromaFactStore` or use
# :meth:`ChromaMemory.ephemeral` / :meth:`ChromaMemory.local`
# with ``with_facts=True`` to wire one in.
self.facts: Any | None = fact_store
# ---- factory ---------------------------------------------------------
[docs]
@classmethod
def local(
cls,
persist_directory: str,
*,
embedder: Embedder | None = None,
collection_name: str = DEFAULT_COLLECTION,
with_facts: bool = False,
facts_collection_name: str = "jeeves_facts",
) -> ChromaMemory:
"""Persistent on-disk client at ``persist_directory``.
``with_facts=True`` attaches a :class:`ChromaFactStore` rooted
at the same client so facts persist alongside episodes in the
same on-disk store.
"""
client = _make_client(persist_directory=persist_directory)
instance = cls(
client, embedder=embedder, collection_name=collection_name
)
if with_facts:
from .chroma_facts import ChromaFactStore
instance.facts = ChromaFactStore(
client,
embedder=instance._embedder,
collection_name=facts_collection_name,
)
return instance
[docs]
@classmethod
def ephemeral(
cls,
*,
embedder: Embedder | None = None,
collection_name: str = DEFAULT_COLLECTION,
with_facts: bool = False,
facts_collection_name: str = "jeeves_facts",
) -> ChromaMemory:
"""In-memory client (lost on process exit). Great for tests."""
client = _make_client(persist_directory=None)
instance = cls(
client, embedder=embedder, collection_name=collection_name
)
if with_facts:
from .chroma_facts import ChromaFactStore
instance.facts = ChromaFactStore(
client,
embedder=instance._embedder,
collection_name=facts_collection_name,
)
return instance
# ---- collection lazy-init -------------------------------------------
async def _get_collection(self) -> Any:
if self._collection is not None:
return self._collection
# ``get_or_create_collection`` is sync; dispatch to thread.
coll = await anyio.to_thread.run_sync(
lambda: self._client.get_or_create_collection(
name=self._collection_name
)
)
self._collection = coll
return coll
# ---- working blocks --------------------------------------------------
[docs]
async def working(self) -> list[MemoryBlock]:
async with self._lock:
return sorted(self._blocks.values(), key=lambda b: b.pinned_order)
[docs]
async def update_block(self, name: str, content: str) -> None:
async with self._lock:
existing = self._blocks.get(name)
self._blocks[name] = MemoryBlock(
name=name,
content=content,
pinned_order=(
existing.pinned_order if existing else len(self._blocks)
),
)
[docs]
async def append_block(self, name: str, content: str) -> None:
async with self._lock:
existing = self._blocks.get(name)
if existing is None:
self._blocks[name] = MemoryBlock(
name=name,
content=content,
pinned_order=len(self._blocks),
)
else:
self._blocks[name] = MemoryBlock(
name=name,
content=existing.content + content,
pinned_order=existing.pinned_order,
)
# ---- episodes --------------------------------------------------------
[docs]
async def remember(self, episode: Episode) -> str:
if episode.embedding is None:
text = _embedding_text(episode)
embedding = await self._embedder.embed(text)
episode = episode.model_copy(update={"embedding": embedding})
coll = await self._get_collection()
document = _embedding_text(episode)
# Store ``user_id`` as a metadata field so Chroma's ``where``
# filter can partition recall queries natively. Chroma rejects
# ``None`` metadata values, so we substitute the empty string
# for the anonymous bucket and round-trip on read.
metadata = {
"session_id": episode.session_id,
"user_id": episode.user_id or "",
"input": episode.input,
"output": episode.output,
"occurred_at": episode.occurred_at.isoformat(),
}
embedding = list(episode.embedding) if episode.embedding else []
await anyio.to_thread.run_sync(
lambda: coll.upsert(
ids=[episode.id],
embeddings=[embedding],
documents=[document],
metadatas=[metadata],
)
)
return episode.id
[docs]
async def recall(
self,
query: str,
*,
kind: str = "episodic",
limit: int = 5,
time_range: tuple[datetime, datetime] | None = None,
user_id: str | None = None,
) -> list[Episode]:
coll = await self._get_collection()
if not query.strip():
return await self._recall_recent(coll, limit, time_range, user_id)
query_embedding = list(await self._embedder.embed(query))
# Hard namespace partition by ``user_id``, pushed into Chroma's
# native ``where`` filter so we don't waste a round-trip on
# other users' rows. Empty string is the anonymous bucket.
where_filter = {"user_id": user_id or ""}
result = await anyio.to_thread.run_sync(
lambda: coll.query(
query_embeddings=[query_embedding],
n_results=limit,
where=where_filter,
)
)
episodes = _decode_query_result(result)
if time_range is not None:
lo, hi = time_range
episodes = [e for e in episodes if lo <= e.occurred_at <= hi]
return episodes
async def _recall_recent(
self,
coll: Any,
limit: int,
time_range: tuple[datetime, datetime] | None,
user_id: str | None,
) -> list[Episode]:
where_filter = {"user_id": user_id or ""}
result = await anyio.to_thread.run_sync(
lambda: coll.get(
limit=None, # we'll sort + slice ourselves
where=where_filter,
include=["metadatas", "documents", "embeddings"],
)
)
episodes = _decode_get_result(result)
if time_range is not None:
lo, hi = time_range
episodes = [e for e in episodes if lo <= e.occurred_at <= hi]
episodes.sort(key=lambda e: e.occurred_at, reverse=True)
return episodes[:limit]
[docs]
async def recall_facts(
self,
query: str,
*,
limit: int = 5,
valid_at: datetime | None = None,
user_id: str | None = None,
) -> list[Fact]:
if self.facts is None:
return []
return list(
await self.facts.recall_text(
query, limit=limit, valid_at=valid_at, user_id=user_id
)
)
[docs]
async def session_messages(
self,
session_id: str,
*,
user_id: str | None = None,
limit: int = 20,
) -> list[Message]:
coll = await self._get_collection()
# Native ``where`` filter — namespace partition on user_id +
# session pin. Empty-string is the anonymous bucket on disk.
where_filter = {
"$and": [
{"user_id": user_id or ""},
{"session_id": session_id},
]
}
result = await anyio.to_thread.run_sync(
lambda: coll.get(
where=where_filter,
include=["metadatas", "documents", "embeddings"],
)
)
episodes = _decode_get_result(result)
episodes.sort(key=lambda e: e.occurred_at)
max_episodes = max(1, limit // 2)
episodes = episodes[-max_episodes:]
out: list[Message] = []
for ep in episodes:
if ep.input:
out.append(Message(role=Role.USER, content=ep.input))
if ep.output:
out.append(Message(role=Role.ASSISTANT, content=ep.output))
return out
[docs]
async def consolidate(self) -> None:
return None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_client(*, persist_directory: str | None) -> Any:
try:
import chromadb
except ImportError as exc: # pragma: no cover — depends on user env
raise ImportError(
"chromadb is not installed. "
"Install with: pip install chromadb"
) from exc
if persist_directory is None:
return chromadb.EphemeralClient()
return chromadb.PersistentClient(path=persist_directory)
def _embedding_text(episode: Episode) -> str:
return "\n".join(p for p in (episode.input, episode.output) if p)
def _parse_occurred(meta: dict[str, Any]) -> datetime:
raw = meta.get("occurred_at")
if isinstance(raw, str):
try:
return datetime.fromisoformat(raw)
except ValueError:
pass
return datetime.now(UTC)
def _safe_list(result: dict[str, Any], key: str) -> list[Any]:
"""``result[key]`` may be None, a list, or (for embeddings) a numpy
array. ``or []`` doesn't work on numpy arrays — they raise
``ValueError: The truth value of an array... is ambiguous`` — so we
use an explicit None check."""
val = result.get(key)
return list(val) if val is not None else []
def _decode_query_result(result: dict[str, Any]) -> list[Episode]:
"""Translate a Chroma ``query()`` result into our Episodes."""
ids_lists = _safe_list(result, "ids")
metas_lists = _safe_list(result, "metadatas")
embeds_lists = _safe_list(result, "embeddings")
ids = list(ids_lists[0]) if ids_lists else []
metas = list(metas_lists[0]) if metas_lists else []
embeds = list(embeds_lists[0]) if embeds_lists else []
return _episodes_from_parallel(ids, metas, embeds)
def _decode_get_result(result: dict[str, Any]) -> list[Episode]:
"""Translate a Chroma ``get()`` result (flat lists) into Episodes."""
ids = _safe_list(result, "ids")
metas = _safe_list(result, "metadatas")
embeds = _safe_list(result, "embeddings")
return _episodes_from_parallel(ids, metas, embeds)
def _episodes_from_parallel(
ids: list[Any],
metas: list[Any],
embeds: list[Any],
) -> list[Episode]:
episodes: list[Episode] = []
for i, eid in enumerate(ids):
meta = metas[i] if i < len(metas) and metas[i] is not None else {}
emb = list(embeds[i]) if i < len(embeds) else None
# Chroma can't store ``None`` so the anonymous bucket is the
# empty string on the wire; round-trip back to ``None`` here.
user_id_raw = str(meta.get("user_id", ""))
episodes.append(
Episode(
id=str(eid),
session_id=str(meta.get("session_id", "")),
user_id=user_id_raw or None,
occurred_at=_parse_occurred(meta),
input=str(meta.get("input", "")),
output=str(meta.get("output", "")),
embedding=emb,
)
)
return episodes