Source code for jeevesagent.memory.redis_facts

"""Redis-backed bi-temporal fact store.

Each fact lives in a Redis hash at ``{prefix}{id}`` (default prefix
``jeeves:fact:``). Fields:

* ``id`` / ``subject`` / ``predicate`` / ``object`` — strings
* ``confidence`` — string-encoded float
* ``valid_from_ts`` / ``recorded_at_ts`` — string-encoded floats
* ``valid_until_ts`` — string-encoded float (``"0"`` when still valid)
* ``currently_valid`` — ``b"1"`` / ``b"0"`` flag (mirror of
  ``valid_until_ts == 0``)
* ``sources`` — JSON-encoded list of episode ids
* ``embedding`` — float32 BLOB (only present when an embedder is
  configured)

Supersession is a brute-force scan: ``SCAN`` for all fact keys, find
those with matching subject + predicate that are currently valid and
have a different object, and ``HSET`` each to flip
``currently_valid=False`` + stamp ``valid_until_ts``. RediSearch with
HNSW + numeric/tag indexes is a follow-up.
"""

from __future__ import annotations

import json
import math
from collections.abc import AsyncIterator
from datetime import UTC, datetime
from typing import Any

import anyio

from ..core.protocols import Embedder
from ..core.types import Fact
from ._embedding_util import pack_float32, unpack_float32
from .embedder import HashEmbedder

DEFAULT_KEY_PREFIX = "jeeves:fact:"


[docs] class RedisFactStore: """Bi-temporal fact store over plain Redis hashes.""" def __init__( self, client: Any, *, embedder: Embedder | None = None, key_prefix: str = DEFAULT_KEY_PREFIX, ) -> None: self._client = client self._embedder: Embedder = ( embedder if embedder is not None else HashEmbedder() ) self._key_prefix = key_prefix self._lock = anyio.Lock()
[docs] @classmethod async def connect( cls, url: str = "redis://localhost:6379/0", *, embedder: Embedder | None = None, key_prefix: str = DEFAULT_KEY_PREFIX, ) -> RedisFactStore: try: from redis.asyncio import ( # type: ignore[import-not-found, import-untyped] from_url, ) except ImportError as exc: # pragma: no cover raise ImportError( "redis is not installed. " "Install with: pip install redis" ) from exc client = from_url(url, decode_responses=False) return cls(client, embedder=embedder, key_prefix=key_prefix)
@property def embedder(self) -> Embedder: return self._embedder
[docs] async def aclose(self) -> None: if self._client is not None and hasattr(self._client, "aclose"): await self._client.aclose()
# ---- mutation --------------------------------------------------------
[docs] async def append(self, fact: Fact) -> str: triple = _triple_text(fact) embedding = await self._embedder.embed(triple) async with self._lock: await self._supersede(fact) await self._write_fact(fact, embedding) return fact.id
async def _supersede(self, fact: Fact) -> None: ts = str(fact.valid_from.timestamp()).encode("utf-8") async for key, data in self._scan_facts(): # Namespace-scoped supersession: alice's facts never # invalidate bob's. Anonymous bucket (None / empty) is # its own namespace. other_user = _decode_field(data.get(b"user_id", b"")) or None if other_user != fact.user_id: continue if _decode_field(data.get(b"subject", b"")) != fact.subject: continue if _decode_field(data.get(b"predicate", b"")) != fact.predicate: continue if data.get(b"currently_valid", b"0") != b"1": continue if _decode_field(data.get(b"object", b"")) == fact.object: continue await self._client.hset( key, mapping={ b"currently_valid": b"0", b"valid_until_ts": ts, }, ) async def _write_fact( self, fact: Fact, embedding: list[float] ) -> None: key = self._key_for(fact.id) valid_until_ts = ( str(fact.valid_until.timestamp()).encode("utf-8") if fact.valid_until is not None else b"0" ) mapping = { b"id": fact.id.encode("utf-8"), # Persist ``user_id`` so recall queries can filter by # namespace partition. Empty bytes for the anonymous # bucket; round-trip back to ``None`` on read. b"user_id": (fact.user_id or "").encode("utf-8"), b"subject": fact.subject.encode("utf-8"), b"predicate": fact.predicate.encode("utf-8"), b"object": fact.object.encode("utf-8"), b"confidence": str(fact.confidence).encode("utf-8"), b"valid_from_ts": str(fact.valid_from.timestamp()).encode("utf-8"), b"valid_until_ts": valid_until_ts, b"currently_valid": b"1" if fact.valid_until is None else b"0", b"recorded_at_ts": str(fact.recorded_at.timestamp()).encode("utf-8"), b"sources": json.dumps(list(fact.sources)).encode("utf-8"), b"embedding": pack_float32(embedding), } await self._client.hset(key, mapping=mapping) # ---- queries ---------------------------------------------------------
[docs] async def query( self, *, subject: str | None = None, predicate: str | None = None, object_: str | None = None, valid_at: datetime | None = None, limit: int = 10, user_id: str | None = None, ) -> list[Fact]: valid_at_ts = valid_at.timestamp() if valid_at is not None else None results: list[Fact] = [] async for _key, data in self._scan_facts(): fact = _hash_to_fact(data) if fact is None: continue # Hard namespace partition by ``user_id``. if fact.user_id != user_id: continue if subject is not None and fact.subject != subject: continue if predicate is not None and fact.predicate != predicate: continue if object_ is not None and fact.object != object_: continue if valid_at_ts is not None and not _is_valid_at(fact, valid_at_ts): continue results.append(fact) results.sort(key=lambda f: f.recorded_at, reverse=True) return results[:limit]
[docs] async def recall_text( self, query: str, *, limit: int = 5, valid_at: datetime | None = None, user_id: str | None = None, ) -> list[Fact]: query_embedding = await self._embedder.embed(query) valid_at_ts = valid_at.timestamp() if valid_at is not None else None scored: list[tuple[float, Fact]] = [] async for _key, data in self._scan_facts(): fact = _hash_to_fact(data) if fact is None: continue # Hard namespace partition by ``user_id``. if fact.user_id != user_id: continue if valid_at_ts is not None and not _is_valid_at(fact, valid_at_ts): continue blob = data.get(b"embedding") if not isinstance(blob, bytes | bytearray) or not blob: continue stored = unpack_float32(bytes(blob)) scored.append((_cosine(query_embedding, stored), fact)) scored.sort(key=lambda pair: pair[0], reverse=True) return [f for _, f in scored[:limit]]
[docs] async def all_facts(self) -> list[Fact]: out: list[Fact] = [] async for _key, data in self._scan_facts(): fact = _hash_to_fact(data) if fact is not None: out.append(fact) out.sort(key=lambda f: f.recorded_at, reverse=True) return out
# ---- scanning helpers ------------------------------------------------ def _key_for(self, fact_id: str) -> bytes: return f"{self._key_prefix}{fact_id}".encode() async def _scan_facts( self, ) -> AsyncIterator[tuple[bytes, dict[bytes, Any]]]: cursor: int = 0 match = f"{self._key_prefix}*".encode() while True: cursor, keys = await self._client.scan( cursor=cursor, match=match ) for key in keys: data = await self._client.hgetall(key) if data: yield key, _normalize_keys(data) if cursor == 0: break
# --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _triple_text(fact: Fact) -> str: return f"{fact.subject} {fact.predicate} {fact.object}" def _normalize_keys(data: dict[Any, Any]) -> dict[bytes, Any]: """Coerce all keys to ``bytes`` so we can index uniformly.""" out: dict[bytes, Any] = {} for k, v in data.items(): if isinstance(k, bytes): out[k] = v else: out[str(k).encode("utf-8")] = v return out def _decode_field(value: Any) -> str: if isinstance(value, bytes): return value.decode("utf-8", errors="replace") return str(value) def _hash_to_fact(data: dict[bytes, Any]) -> Fact | None: eid = _decode_field(data.get(b"id", b"")) if not eid: return None sources_raw = _decode_field(data.get(b"sources", "[]")) try: sources = list(json.loads(sources_raw)) except json.JSONDecodeError: sources = [] try: valid_from_ts = float(_decode_field(data.get(b"valid_from_ts", "0"))) recorded_at_ts = float( _decode_field(data.get(b"recorded_at_ts", "0")) ) valid_until_ts = float( _decode_field(data.get(b"valid_until_ts", "0")) ) except ValueError: return None valid_until = ( datetime.fromtimestamp(valid_until_ts, tz=UTC) if valid_until_ts > 0 and data.get(b"currently_valid") != b"1" else None ) try: confidence = float(_decode_field(data.get(b"confidence", "1.0"))) except ValueError: confidence = 1.0 user_id_raw = _decode_field(data.get(b"user_id", b"")) return Fact( id=eid, user_id=user_id_raw or None, subject=_decode_field(data.get(b"subject", b"")), predicate=_decode_field(data.get(b"predicate", b"")), object=_decode_field(data.get(b"object", b"")), confidence=confidence, valid_from=datetime.fromtimestamp(valid_from_ts, tz=UTC), valid_until=valid_until, recorded_at=datetime.fromtimestamp(recorded_at_ts, tz=UTC), sources=sources, ) def _is_valid_at(fact: Fact, when_ts: float) -> bool: if when_ts < fact.valid_from.timestamp(): return False if fact.valid_until is None: return True return when_ts < fact.valid_until.timestamp() def _cosine(a: list[float], b: list[float]) -> float: if not a or not b: return 0.0 dot = 0.0 na = 0.0 nb = 0.0 for x, y in zip(a, b, strict=False): dot += x * y na += x * x nb += y * y if na <= 0.0 or nb <= 0.0: return 0.0 return dot / (math.sqrt(na) * math.sqrt(nb))