"""Redis-backed :class:`Memory`.
Two flavours, picked at construction time:
* **vector mode** (default when RediSearch is available) — episodes
are stored as Redis hashes; a RediSearch ``FT.CREATE`` index with
``HNSW`` provides cosine-similarity recall.
* **brute-force mode** (when RediSearch isn't available, e.g. plain
Redis) — episodes still go to hashes but recall scans every
episode in process. Fine for small corpora; switch to the vector
mode (RedisStack) for production scale.
Both modes use the ``redis.asyncio`` client. Working blocks live in
process memory; the redundancy of putting them in Redis isn't worth
the extra round-trip for the small payloads we have.
"""
from __future__ import annotations
import math
from datetime import UTC, datetime
from typing import Any
import anyio
from ..core.errors import MemoryStoreError
from ..core.protocols import Embedder
from ..core.types import Episode, Fact, MemoryBlock, Message, Role
from ._embedding_util import pack_float32, unpack_float32
from .embedder import HashEmbedder
DEFAULT_KEY_PREFIX = "jeeves:episode:"
DEFAULT_INDEX_NAME = "jeeves_idx"
[docs]
class RedisMemory:
"""Redis-backed :class:`Memory`. Use :meth:`connect` to construct."""
def __init__(
self,
client: Any,
*,
embedder: Embedder | None = None,
key_prefix: str = DEFAULT_KEY_PREFIX,
index_name: str = DEFAULT_INDEX_NAME,
use_vector_index: bool = True,
fact_store: Any | None = None,
) -> None:
self._client = client
self._embedder: Embedder = embedder if embedder is not None else HashEmbedder()
self._key_prefix = key_prefix
self._index_name = index_name
self._use_vector_index = use_vector_index
self._index_ready = False
self._blocks: dict[str, MemoryBlock] = {}
self._lock = anyio.Lock()
# The Agent loop's fact-recall hook. ``None`` by default —
# construct an explicit :class:`RedisFactStore` (or pass
# ``with_facts=True`` to :meth:`connect`) to attach one.
self.facts: Any | None = fact_store
# ---- factory ---------------------------------------------------------
[docs]
@classmethod
async def connect(
cls,
url: str = "redis://localhost:6379/0",
*,
embedder: Embedder | None = None,
key_prefix: str = DEFAULT_KEY_PREFIX,
index_name: str = DEFAULT_INDEX_NAME,
use_vector_index: bool = True,
with_facts: bool = False,
fact_key_prefix: str = "jeeves:fact:",
) -> RedisMemory:
"""Open an async Redis connection.
``with_facts=True`` attaches a :class:`RedisFactStore` sharing
the same client; facts go to ``{fact_key_prefix}*`` keys so
they don't collide with episode keys.
"""
try:
from redis.asyncio import ( # type: ignore[import-not-found, import-untyped]
from_url,
)
except ImportError as exc: # pragma: no cover
raise ImportError(
"redis is not installed. "
"Install with: pip install redis"
) from exc
client = from_url(url, decode_responses=False)
instance = cls(
client,
embedder=embedder,
key_prefix=key_prefix,
index_name=index_name,
use_vector_index=use_vector_index,
)
if with_facts:
from .redis_facts import RedisFactStore
instance.facts = RedisFactStore(
client,
embedder=instance._embedder,
key_prefix=fact_key_prefix,
)
return instance
[docs]
async def aclose(self) -> None:
if self._client is not None and hasattr(self._client, "aclose"):
await self._client.aclose()
# ---- index management -----------------------------------------------
[docs]
async def ensure_index(self) -> None:
"""Create the RediSearch HNSW index, if not already present.
Skipped silently when ``use_vector_index=False`` or when
RediSearch isn't available on the server.
"""
if self._index_ready or not self._use_vector_index:
self._index_ready = True
return
try:
await self._client.execute_command(
"FT.CREATE",
self._index_name,
"ON",
"HASH",
"PREFIX",
"1",
self._key_prefix,
"SCHEMA",
"session_id",
"TAG",
"occurred_at",
"NUMERIC",
"input",
"TEXT",
"output",
"TEXT",
"embedding",
"VECTOR",
"HNSW",
"6",
"TYPE",
"FLOAT32",
"DIM",
str(self._embedder.dimensions),
"DISTANCE_METRIC",
"COSINE",
)
except Exception as exc: # noqa: BLE001
# ``Index already exists`` is OK; otherwise fall back to
# brute-force recall.
msg = str(exc).lower()
if "already exists" not in msg:
self._use_vector_index = False
self._index_ready = True
# ---- working blocks --------------------------------------------------
[docs]
async def working(self) -> list[MemoryBlock]:
async with self._lock:
return sorted(self._blocks.values(), key=lambda b: b.pinned_order)
[docs]
async def update_block(self, name: str, content: str) -> None:
async with self._lock:
existing = self._blocks.get(name)
self._blocks[name] = MemoryBlock(
name=name,
content=content,
pinned_order=(
existing.pinned_order if existing else len(self._blocks)
),
)
[docs]
async def append_block(self, name: str, content: str) -> None:
async with self._lock:
existing = self._blocks.get(name)
if existing is None:
self._blocks[name] = MemoryBlock(
name=name,
content=content,
pinned_order=len(self._blocks),
)
else:
self._blocks[name] = MemoryBlock(
name=name,
content=existing.content + content,
pinned_order=existing.pinned_order,
)
# ---- episodes --------------------------------------------------------
[docs]
async def remember(self, episode: Episode) -> str:
if episode.embedding is None:
text = "\n".join(p for p in (episode.input, episode.output) if p)
embedding = await self._embedder.embed(text)
episode = episode.model_copy(update={"embedding": embedding})
await self.ensure_index()
embedding_bytes = _pack_float32(episode.embedding or [])
key = self._key_for(episode.id)
mapping = {
"id": episode.id.encode("utf-8"),
"session_id": episode.session_id.encode("utf-8"),
# Persist ``user_id`` so recall queries can filter
# by namespace partition. Encoded as the empty bytestring
# for ``None`` so we can round-trip it (Redis doesn't
# natively distinguish missing fields from empty values).
"user_id": (episode.user_id or "").encode("utf-8"),
"occurred_at": str(episode.occurred_at.timestamp()).encode("utf-8"),
"input": episode.input.encode("utf-8"),
"output": episode.output.encode("utf-8"),
"embedding": embedding_bytes,
}
await self._client.hset(key, mapping=mapping)
return episode.id
[docs]
async def recall(
self,
query: str,
*,
kind: str = "episodic",
limit: int = 5,
time_range: tuple[datetime, datetime] | None = None,
user_id: str | None = None,
) -> list[Episode]:
await self.ensure_index()
if not query.strip():
return await self._recall_recent(limit, time_range, user_id)
query_embedding = await self._embedder.embed(query)
# Over-fetch when filtering so we have enough candidates after
# the namespace partition is applied. The vector index lacks
# native ``user_id`` faceting today; this is a post-filter.
fetch_limit = limit * 8 if user_id is not None else limit
if self._use_vector_index:
episodes = await self._recall_via_index(query_embedding, fetch_limit)
else:
episodes = await self._recall_brute_force(query_embedding, fetch_limit)
# Hard namespace partition by ``user_id``.
episodes = [e for e in episodes if e.user_id == user_id]
if time_range is not None:
lo, hi = time_range
episodes = [e for e in episodes if lo <= e.occurred_at <= hi]
return episodes[:limit]
async def _recall_via_index(
self, query_embedding: list[float], limit: int
) -> list[Episode]:
params = [
"FT.SEARCH",
self._index_name,
f"*=>[KNN {limit} @embedding $vec AS score]",
"PARAMS",
"2",
"vec",
_pack_float32(query_embedding),
"SORTBY",
"score",
"RETURN",
"6",
"session_id",
"user_id",
"occurred_at",
"input",
"output",
"score",
"DIALECT",
"2",
"LIMIT",
"0",
str(limit),
]
try:
result = await self._client.execute_command(*params)
except Exception as exc: # noqa: BLE001
raise MemoryStoreError(f"RediSearch KNN query failed: {exc}") from exc
return _decode_ft_search(result)
async def _recall_brute_force(
self, query_embedding: list[float], limit: int
) -> list[Episode]:
episodes = await self._scan_all_episodes()
scored: list[tuple[float, Episode]] = []
for ep in episodes:
if ep.embedding is None:
continue
scored.append((_cosine(query_embedding, ep.embedding), ep))
scored.sort(key=lambda pair: pair[0], reverse=True)
return [ep for _, ep in scored[:limit]]
async def _recall_recent(
self,
limit: int,
time_range: tuple[datetime, datetime] | None,
user_id: str | None,
) -> list[Episode]:
episodes = await self._scan_all_episodes()
episodes = [e for e in episodes if e.user_id == user_id]
if time_range is not None:
lo, hi = time_range
episodes = [e for e in episodes if lo <= e.occurred_at <= hi]
episodes.sort(key=lambda e: e.occurred_at, reverse=True)
return episodes[:limit]
async def _scan_all_episodes(self) -> list[Episode]:
cursor: int = 0
match = f"{self._key_prefix}*".encode()
episodes: list[Episode] = []
while True:
cursor, keys = await self._client.scan(cursor=cursor, match=match)
for key in keys:
data = await self._client.hgetall(key)
if not data:
continue
ep = _decode_hash(data)
if ep is not None:
episodes.append(ep)
if cursor == 0:
break
return episodes
[docs]
async def recall_facts(
self,
query: str,
*,
limit: int = 5,
valid_at: datetime | None = None,
user_id: str | None = None,
) -> list[Fact]:
if self.facts is None:
return []
return list(
await self.facts.recall_text(
query, limit=limit, valid_at=valid_at, user_id=user_id
)
)
[docs]
async def session_messages(
self,
session_id: str,
*,
user_id: str | None = None,
limit: int = 20,
) -> list[Message]:
# No native ``WHERE session_id`` index in vanilla Redis Hash;
# scan all episode keys, post-filter, and slice. Matches the
# InMemoryMemory + Vector backends' best-effort behaviour for
# the M2 session-continuity path.
episodes = await self._scan_all_episodes()
episodes = [
e
for e in episodes
if e.session_id == session_id and e.user_id == user_id
]
episodes.sort(key=lambda e: e.occurred_at)
max_episodes = max(1, limit // 2)
episodes = episodes[-max_episodes:]
out: list[Message] = []
for ep in episodes:
if ep.input:
out.append(Message(role=Role.USER, content=ep.input))
if ep.output:
out.append(Message(role=Role.ASSISTANT, content=ep.output))
return out
[docs]
async def consolidate(self) -> None:
return None
# ---- key helpers -----------------------------------------------------
def _key_for(self, episode_id: str) -> str:
return f"{self._key_prefix}{episode_id}"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
# Re-exports of the shared util so existing callers
# (``from .redis import _pack_float32``) keep working.
_pack_float32 = pack_float32
_unpack_float32 = unpack_float32
def _cosine(a: list[float], b: list[float]) -> float:
if not a or not b:
return 0.0
dot = 0.0
na = 0.0
nb = 0.0
for x, y in zip(a, b, strict=False):
dot += x * y
na += x * x
nb += y * y
if na <= 0.0 or nb <= 0.0:
return 0.0
return dot / (math.sqrt(na) * math.sqrt(nb))
def _decode_field(value: Any) -> str:
if isinstance(value, bytes):
return value.decode("utf-8", errors="replace")
return str(value)
def _decode_hash(data: dict[Any, Any]) -> Episode | None:
"""Pull an :class:`Episode` from a Redis HGETALL result."""
# Keys may come back as bytes; normalise to str.
norm: dict[str, Any] = {}
for k, v in data.items():
key = k.decode("utf-8") if isinstance(k, bytes) else str(k)
norm[key] = v
eid = _decode_field(norm.get("id", b""))
if not eid:
return None
occurred_raw = _decode_field(norm.get("occurred_at", "0"))
try:
occurred_at = datetime.fromtimestamp(float(occurred_raw), tz=UTC)
except ValueError:
occurred_at = datetime.now(UTC)
embedding_blob = norm.get("embedding")
if isinstance(embedding_blob, bytes | bytearray):
embedding: list[float] | None = _unpack_float32(bytes(embedding_blob))
else:
embedding = None
user_id_raw = _decode_field(norm.get("user_id", ""))
return Episode(
id=eid,
session_id=_decode_field(norm.get("session_id", "")),
user_id=user_id_raw or None,
occurred_at=occurred_at,
input=_decode_field(norm.get("input", "")),
output=_decode_field(norm.get("output", "")),
embedding=embedding,
)
def _decode_ft_search(result: Any) -> list[Episode]:
"""Translate a ``FT.SEARCH`` reply into Episodes.
The reply shape is ``[total, id1, [k1, v1, k2, v2, ...], id2, [...], ...]``.
"""
if not result or not isinstance(result, list):
return []
out: list[Episode] = []
# First element is the total count.
body = result[1:]
for i in range(0, len(body), 2):
if i + 1 >= len(body):
break
kvs = body[i + 1]
if not isinstance(kvs, list):
continue
decoded: dict[str, Any] = {}
for j in range(0, len(kvs), 2):
if j + 1 >= len(kvs):
break
k = kvs[j]
v = kvs[j + 1]
key = k.decode("utf-8") if isinstance(k, bytes) else str(k)
decoded[key] = v
# Use the doc id as our episode id.
doc_id = body[i]
eid = doc_id.decode("utf-8") if isinstance(doc_id, bytes) else str(doc_id)
# Strip the prefix if present.
if ":" in eid:
eid = eid.split(":", 1)[-1]
occurred_raw = _decode_field(decoded.get("occurred_at", "0"))
try:
occurred_at = datetime.fromtimestamp(float(occurred_raw), tz=UTC)
except ValueError:
occurred_at = datetime.now(UTC)
user_id_raw = _decode_field(decoded.get("user_id", ""))
out.append(
Episode(
id=eid,
session_id=_decode_field(decoded.get("session_id", "")),
user_id=user_id_raw or None,
occurred_at=occurred_at,
input=_decode_field(decoded.get("input", "")),
output=_decode_field(decoded.get("output", "")),
)
)
return out
__all__ = [
"RedisMemory",
"DEFAULT_INDEX_NAME",
"DEFAULT_KEY_PREFIX",
]