Source code for jeevesagent.memory.redis

"""Redis-backed :class:`Memory`.

Two flavours, picked at construction time:

* **vector mode** (default when RediSearch is available) — episodes
  are stored as Redis hashes; a RediSearch ``FT.CREATE`` index with
  ``HNSW`` provides cosine-similarity recall.
* **brute-force mode** (when RediSearch isn't available, e.g. plain
  Redis) — episodes still go to hashes but recall scans every
  episode in process. Fine for small corpora; switch to the vector
  mode (RedisStack) for production scale.

Both modes use the ``redis.asyncio`` client. Working blocks live in
process memory; the redundancy of putting them in Redis isn't worth
the extra round-trip for the small payloads we have.
"""

from __future__ import annotations

import math
from datetime import UTC, datetime
from typing import Any

import anyio

from ..core.errors import MemoryStoreError
from ..core.protocols import Embedder
from ..core.types import Episode, Fact, MemoryBlock, Message, Role
from ._embedding_util import pack_float32, unpack_float32
from .embedder import HashEmbedder

DEFAULT_KEY_PREFIX = "jeeves:episode:"
DEFAULT_INDEX_NAME = "jeeves_idx"


[docs] class RedisMemory: """Redis-backed :class:`Memory`. Use :meth:`connect` to construct.""" def __init__( self, client: Any, *, embedder: Embedder | None = None, key_prefix: str = DEFAULT_KEY_PREFIX, index_name: str = DEFAULT_INDEX_NAME, use_vector_index: bool = True, fact_store: Any | None = None, ) -> None: self._client = client self._embedder: Embedder = embedder if embedder is not None else HashEmbedder() self._key_prefix = key_prefix self._index_name = index_name self._use_vector_index = use_vector_index self._index_ready = False self._blocks: dict[str, MemoryBlock] = {} self._lock = anyio.Lock() # The Agent loop's fact-recall hook. ``None`` by default — # construct an explicit :class:`RedisFactStore` (or pass # ``with_facts=True`` to :meth:`connect`) to attach one. self.facts: Any | None = fact_store # ---- factory ---------------------------------------------------------
[docs] @classmethod async def connect( cls, url: str = "redis://localhost:6379/0", *, embedder: Embedder | None = None, key_prefix: str = DEFAULT_KEY_PREFIX, index_name: str = DEFAULT_INDEX_NAME, use_vector_index: bool = True, with_facts: bool = False, fact_key_prefix: str = "jeeves:fact:", ) -> RedisMemory: """Open an async Redis connection. ``with_facts=True`` attaches a :class:`RedisFactStore` sharing the same client; facts go to ``{fact_key_prefix}*`` keys so they don't collide with episode keys. """ try: from redis.asyncio import ( # type: ignore[import-not-found, import-untyped] from_url, ) except ImportError as exc: # pragma: no cover raise ImportError( "redis is not installed. " "Install with: pip install redis" ) from exc client = from_url(url, decode_responses=False) instance = cls( client, embedder=embedder, key_prefix=key_prefix, index_name=index_name, use_vector_index=use_vector_index, ) if with_facts: from .redis_facts import RedisFactStore instance.facts = RedisFactStore( client, embedder=instance._embedder, key_prefix=fact_key_prefix, ) return instance
[docs] async def aclose(self) -> None: if self._client is not None and hasattr(self._client, "aclose"): await self._client.aclose()
# ---- index management -----------------------------------------------
[docs] async def ensure_index(self) -> None: """Create the RediSearch HNSW index, if not already present. Skipped silently when ``use_vector_index=False`` or when RediSearch isn't available on the server. """ if self._index_ready or not self._use_vector_index: self._index_ready = True return try: await self._client.execute_command( "FT.CREATE", self._index_name, "ON", "HASH", "PREFIX", "1", self._key_prefix, "SCHEMA", "session_id", "TAG", "occurred_at", "NUMERIC", "input", "TEXT", "output", "TEXT", "embedding", "VECTOR", "HNSW", "6", "TYPE", "FLOAT32", "DIM", str(self._embedder.dimensions), "DISTANCE_METRIC", "COSINE", ) except Exception as exc: # noqa: BLE001 # ``Index already exists`` is OK; otherwise fall back to # brute-force recall. msg = str(exc).lower() if "already exists" not in msg: self._use_vector_index = False self._index_ready = True
# ---- working blocks --------------------------------------------------
[docs] async def working(self) -> list[MemoryBlock]: async with self._lock: return sorted(self._blocks.values(), key=lambda b: b.pinned_order)
[docs] async def update_block(self, name: str, content: str) -> None: async with self._lock: existing = self._blocks.get(name) self._blocks[name] = MemoryBlock( name=name, content=content, pinned_order=( existing.pinned_order if existing else len(self._blocks) ), )
[docs] async def append_block(self, name: str, content: str) -> None: async with self._lock: existing = self._blocks.get(name) if existing is None: self._blocks[name] = MemoryBlock( name=name, content=content, pinned_order=len(self._blocks), ) else: self._blocks[name] = MemoryBlock( name=name, content=existing.content + content, pinned_order=existing.pinned_order, )
# ---- episodes --------------------------------------------------------
[docs] async def remember(self, episode: Episode) -> str: if episode.embedding is None: text = "\n".join(p for p in (episode.input, episode.output) if p) embedding = await self._embedder.embed(text) episode = episode.model_copy(update={"embedding": embedding}) await self.ensure_index() embedding_bytes = _pack_float32(episode.embedding or []) key = self._key_for(episode.id) mapping = { "id": episode.id.encode("utf-8"), "session_id": episode.session_id.encode("utf-8"), # Persist ``user_id`` so recall queries can filter # by namespace partition. Encoded as the empty bytestring # for ``None`` so we can round-trip it (Redis doesn't # natively distinguish missing fields from empty values). "user_id": (episode.user_id or "").encode("utf-8"), "occurred_at": str(episode.occurred_at.timestamp()).encode("utf-8"), "input": episode.input.encode("utf-8"), "output": episode.output.encode("utf-8"), "embedding": embedding_bytes, } await self._client.hset(key, mapping=mapping) return episode.id
[docs] async def recall( self, query: str, *, kind: str = "episodic", limit: int = 5, time_range: tuple[datetime, datetime] | None = None, user_id: str | None = None, ) -> list[Episode]: await self.ensure_index() if not query.strip(): return await self._recall_recent(limit, time_range, user_id) query_embedding = await self._embedder.embed(query) # Over-fetch when filtering so we have enough candidates after # the namespace partition is applied. The vector index lacks # native ``user_id`` faceting today; this is a post-filter. fetch_limit = limit * 8 if user_id is not None else limit if self._use_vector_index: episodes = await self._recall_via_index(query_embedding, fetch_limit) else: episodes = await self._recall_brute_force(query_embedding, fetch_limit) # Hard namespace partition by ``user_id``. episodes = [e for e in episodes if e.user_id == user_id] if time_range is not None: lo, hi = time_range episodes = [e for e in episodes if lo <= e.occurred_at <= hi] return episodes[:limit]
async def _recall_via_index( self, query_embedding: list[float], limit: int ) -> list[Episode]: params = [ "FT.SEARCH", self._index_name, f"*=>[KNN {limit} @embedding $vec AS score]", "PARAMS", "2", "vec", _pack_float32(query_embedding), "SORTBY", "score", "RETURN", "6", "session_id", "user_id", "occurred_at", "input", "output", "score", "DIALECT", "2", "LIMIT", "0", str(limit), ] try: result = await self._client.execute_command(*params) except Exception as exc: # noqa: BLE001 raise MemoryStoreError(f"RediSearch KNN query failed: {exc}") from exc return _decode_ft_search(result) async def _recall_brute_force( self, query_embedding: list[float], limit: int ) -> list[Episode]: episodes = await self._scan_all_episodes() scored: list[tuple[float, Episode]] = [] for ep in episodes: if ep.embedding is None: continue scored.append((_cosine(query_embedding, ep.embedding), ep)) scored.sort(key=lambda pair: pair[0], reverse=True) return [ep for _, ep in scored[:limit]] async def _recall_recent( self, limit: int, time_range: tuple[datetime, datetime] | None, user_id: str | None, ) -> list[Episode]: episodes = await self._scan_all_episodes() episodes = [e for e in episodes if e.user_id == user_id] if time_range is not None: lo, hi = time_range episodes = [e for e in episodes if lo <= e.occurred_at <= hi] episodes.sort(key=lambda e: e.occurred_at, reverse=True) return episodes[:limit] async def _scan_all_episodes(self) -> list[Episode]: cursor: int = 0 match = f"{self._key_prefix}*".encode() episodes: list[Episode] = [] while True: cursor, keys = await self._client.scan(cursor=cursor, match=match) for key in keys: data = await self._client.hgetall(key) if not data: continue ep = _decode_hash(data) if ep is not None: episodes.append(ep) if cursor == 0: break return episodes
[docs] async def recall_facts( self, query: str, *, limit: int = 5, valid_at: datetime | None = None, user_id: str | None = None, ) -> list[Fact]: if self.facts is None: return [] return list( await self.facts.recall_text( query, limit=limit, valid_at=valid_at, user_id=user_id ) )
[docs] async def session_messages( self, session_id: str, *, user_id: str | None = None, limit: int = 20, ) -> list[Message]: # No native ``WHERE session_id`` index in vanilla Redis Hash; # scan all episode keys, post-filter, and slice. Matches the # InMemoryMemory + Vector backends' best-effort behaviour for # the M2 session-continuity path. episodes = await self._scan_all_episodes() episodes = [ e for e in episodes if e.session_id == session_id and e.user_id == user_id ] episodes.sort(key=lambda e: e.occurred_at) max_episodes = max(1, limit // 2) episodes = episodes[-max_episodes:] out: list[Message] = [] for ep in episodes: if ep.input: out.append(Message(role=Role.USER, content=ep.input)) if ep.output: out.append(Message(role=Role.ASSISTANT, content=ep.output)) return out
[docs] async def consolidate(self) -> None: return None
# ---- key helpers ----------------------------------------------------- def _key_for(self, episode_id: str) -> str: return f"{self._key_prefix}{episode_id}"
# --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- # Re-exports of the shared util so existing callers # (``from .redis import _pack_float32``) keep working. _pack_float32 = pack_float32 _unpack_float32 = unpack_float32 def _cosine(a: list[float], b: list[float]) -> float: if not a or not b: return 0.0 dot = 0.0 na = 0.0 nb = 0.0 for x, y in zip(a, b, strict=False): dot += x * y na += x * x nb += y * y if na <= 0.0 or nb <= 0.0: return 0.0 return dot / (math.sqrt(na) * math.sqrt(nb)) def _decode_field(value: Any) -> str: if isinstance(value, bytes): return value.decode("utf-8", errors="replace") return str(value) def _decode_hash(data: dict[Any, Any]) -> Episode | None: """Pull an :class:`Episode` from a Redis HGETALL result.""" # Keys may come back as bytes; normalise to str. norm: dict[str, Any] = {} for k, v in data.items(): key = k.decode("utf-8") if isinstance(k, bytes) else str(k) norm[key] = v eid = _decode_field(norm.get("id", b"")) if not eid: return None occurred_raw = _decode_field(norm.get("occurred_at", "0")) try: occurred_at = datetime.fromtimestamp(float(occurred_raw), tz=UTC) except ValueError: occurred_at = datetime.now(UTC) embedding_blob = norm.get("embedding") if isinstance(embedding_blob, bytes | bytearray): embedding: list[float] | None = _unpack_float32(bytes(embedding_blob)) else: embedding = None user_id_raw = _decode_field(norm.get("user_id", "")) return Episode( id=eid, session_id=_decode_field(norm.get("session_id", "")), user_id=user_id_raw or None, occurred_at=occurred_at, input=_decode_field(norm.get("input", "")), output=_decode_field(norm.get("output", "")), embedding=embedding, ) def _decode_ft_search(result: Any) -> list[Episode]: """Translate a ``FT.SEARCH`` reply into Episodes. The reply shape is ``[total, id1, [k1, v1, k2, v2, ...], id2, [...], ...]``. """ if not result or not isinstance(result, list): return [] out: list[Episode] = [] # First element is the total count. body = result[1:] for i in range(0, len(body), 2): if i + 1 >= len(body): break kvs = body[i + 1] if not isinstance(kvs, list): continue decoded: dict[str, Any] = {} for j in range(0, len(kvs), 2): if j + 1 >= len(kvs): break k = kvs[j] v = kvs[j + 1] key = k.decode("utf-8") if isinstance(k, bytes) else str(k) decoded[key] = v # Use the doc id as our episode id. doc_id = body[i] eid = doc_id.decode("utf-8") if isinstance(doc_id, bytes) else str(doc_id) # Strip the prefix if present. if ":" in eid: eid = eid.split(":", 1)[-1] occurred_raw = _decode_field(decoded.get("occurred_at", "0")) try: occurred_at = datetime.fromtimestamp(float(occurred_raw), tz=UTC) except ValueError: occurred_at = datetime.now(UTC) user_id_raw = _decode_field(decoded.get("user_id", "")) out.append( Episode( id=eid, session_id=_decode_field(decoded.get("session_id", "")), user_id=user_id_raw or None, occurred_at=occurred_at, input=_decode_field(decoded.get("input", "")), output=_decode_field(decoded.get("output", "")), ) ) return out __all__ = [ "RedisMemory", "DEFAULT_INDEX_NAME", "DEFAULT_KEY_PREFIX", ]