Source code for jeevesagent.memory.sqlite_facts

"""SQLite-backed bi-temporal fact store.

Same shape as :class:`InMemoryFactStore` (supersession on append,
``valid_at`` queries, optional embedder) but durable across process
restarts. Sync sqlite3 calls dispatched through
:func:`anyio.to_thread.run_sync`.

Schema:

* ``facts(id, subject, predicate, object, confidence, valid_from,
  valid_until, recorded_at, sources, embedding)`` — timestamps stored
  as unix-epoch floats; ``sources`` as a JSON-encoded array;
  ``embedding`` as a float32 BLOB or NULL.
* Indexes on ``subject`` and ``(subject, predicate)`` for the common
  filter shapes.
"""

from __future__ import annotations

import json
import math
import sqlite3
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import UTC, datetime
from pathlib import Path
from typing import Any

import anyio

from ..core.protocols import Embedder
from ..core.types import Fact
from ._embedding_util import pack_float32, unpack_float32

_FACTS_DDL = """
CREATE TABLE IF NOT EXISTS facts (
    id          TEXT PRIMARY KEY,
    user_id     TEXT,
    subject     TEXT NOT NULL,
    predicate   TEXT NOT NULL,
    object      TEXT NOT NULL,
    confidence  REAL NOT NULL DEFAULT 1.0,
    valid_from  REAL NOT NULL,
    valid_until REAL,
    recorded_at REAL NOT NULL,
    sources     TEXT NOT NULL DEFAULT '[]',
    embedding   BLOB
)
"""

# Idempotent ALTER for upgrades from a pre-``user_id`` schema.
_FACTS_ADD_USER_ID = "ALTER TABLE facts ADD COLUMN user_id TEXT"

_FACTS_SUBJECT_INDEX = (
    "CREATE INDEX IF NOT EXISTS facts_subject_idx ON facts (subject)"
)
_FACTS_USER_SUBJECT_PRED_INDEX = (
    "CREATE INDEX IF NOT EXISTS facts_user_subject_predicate_idx "
    "ON facts (user_id, subject, predicate)"
)


[docs] class SqliteFactStore: """Durable bi-temporal fact store rooted at a sqlite file.""" def __init__( self, path: str | Path, *, embedder: Embedder | None = None, ) -> None: self._path = Path(path) self._path.parent.mkdir(parents=True, exist_ok=True) self._embedder = embedder self._init_schema() @property def path(self) -> Path: return self._path @property def embedder(self) -> Embedder | None: return self._embedder # ---- connection management ------------------------------------------- @contextmanager def _connect(self) -> Iterator[sqlite3.Connection]: # New connection per call; SQLite connections aren't safe to # share across the worker threads we hop into. conn = sqlite3.connect(self._path) try: yield conn finally: conn.close() def _init_schema(self) -> None: with self._connect() as conn: conn.execute(_FACTS_DDL) # Best-effort upgrade: add the column if it doesn't exist. # ``ALTER TABLE ADD COLUMN`` raises if the column is already # present; suppress that case but let real errors propagate. try: conn.execute(_FACTS_ADD_USER_ID) except sqlite3.OperationalError as exc: if "duplicate column name" not in str(exc).lower(): raise conn.execute(_FACTS_SUBJECT_INDEX) conn.execute(_FACTS_USER_SUBJECT_PRED_INDEX) conn.commit() # ---- mutation --------------------------------------------------------
[docs] async def append(self, fact: Fact) -> str: """Append a fact, invalidating any superseded predecessors. Same supersession rule as :class:`InMemoryFactStore`: if there's an existing currently-valid fact with matching subject + predicate but different object, set its ``valid_until`` to the new fact's ``valid_from``. """ embedding_blob: bytes | None = None if self._embedder is not None: triple = f"{fact.subject} {fact.predicate} {fact.object}" embedding = await self._embedder.embed(triple) embedding_blob = pack_float32(embedding) await anyio.to_thread.run_sync( self._append_sync, fact, embedding_blob ) return fact.id
def _append_sync( self, fact: Fact, embedding_blob: bytes | None, ) -> None: with self._connect() as conn: # Close off any still-valid superseded predecessors. # Namespace-scoped: alice's facts never invalidate bob's. # SQLite uses ``IS`` to compare against NULL (since # ``=`` returns NULL when either side is NULL). conn.execute( "UPDATE facts SET valid_until = ? " "WHERE user_id IS ? " "AND subject = ? AND predicate = ? AND object != ? " "AND valid_until IS NULL", ( _to_epoch(fact.valid_from), fact.user_id, fact.subject, fact.predicate, fact.object, ), ) conn.execute( "INSERT OR REPLACE INTO facts " "(id, user_id, subject, predicate, object, confidence, " "valid_from, valid_until, recorded_at, sources, embedding) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", ( fact.id, fact.user_id, fact.subject, fact.predicate, fact.object, fact.confidence, _to_epoch(fact.valid_from), _to_epoch(fact.valid_until) if fact.valid_until is not None else None, _to_epoch(fact.recorded_at), json.dumps(list(fact.sources)), embedding_blob, ), ) conn.commit() # ---- queries ---------------------------------------------------------
[docs] async def query( self, *, subject: str | None = None, predicate: str | None = None, object_: str | None = None, valid_at: datetime | None = None, limit: int = 10, user_id: str | None = None, ) -> list[Fact]: rows = await anyio.to_thread.run_sync( self._query_sync, subject, predicate, object_, valid_at, limit, user_id, ) return [_row_to_fact(r) for r in rows]
def _query_sync( self, subject: str | None, predicate: str | None, object_: str | None, valid_at: datetime | None, limit: int, user_id: str | None, ) -> list[tuple[Any, ...]]: # Hard namespace partition by ``user_id`` (always in WHERE). sql_parts = ["SELECT * FROM facts WHERE user_id IS ?"] params: list[Any] = [user_id] if subject is not None: sql_parts.append("AND subject = ?") params.append(subject) if predicate is not None: sql_parts.append("AND predicate = ?") params.append(predicate) if object_ is not None: sql_parts.append("AND object = ?") params.append(object_) if valid_at is not None: ts = _to_epoch(valid_at) sql_parts.append( "AND valid_from <= ? " "AND (valid_until IS NULL OR ? < valid_until)" ) params.extend([ts, ts]) sql_parts.append("ORDER BY recorded_at DESC LIMIT ?") params.append(limit) with self._connect() as conn: cursor = conn.execute(" ".join(sql_parts), params) return cursor.fetchall()
[docs] async def recall_text( self, query: str, *, limit: int = 5, valid_at: datetime | None = None, user_id: str | None = None, ) -> list[Fact]: if self._embedder is not None: return await self._recall_embedding(query, limit, valid_at, user_id) return await self._recall_tokens(query, limit, valid_at, user_id)
async def _recall_embedding( self, query: str, limit: int, valid_at: datetime | None, user_id: str | None, ) -> list[Fact]: assert self._embedder is not None query_embedding = await self._embedder.embed(query) rows = await anyio.to_thread.run_sync( self._scan_for_recall, valid_at, user_id ) scored: list[tuple[float, tuple[Any, ...]]] = [] for row in rows: blob = row[10] # embedding column (index shifted by user_id) if not blob: continue stored = unpack_float32(bytes(blob)) scored.append((_cosine(query_embedding, stored), row)) scored.sort(key=lambda pair: pair[0], reverse=True) return [_row_to_fact(r) for _, r in scored[:limit]] async def _recall_tokens( self, query: str, limit: int, valid_at: datetime | None, user_id: str | None, ) -> list[Fact]: rows = await anyio.to_thread.run_sync( self._scan_for_recall, valid_at, user_id ) query_tokens = _tokenize(query) if not query_tokens: # Recency fallback when query has no useful tokens. return [_row_to_fact(r) for r in rows[:limit]] scored: list[tuple[int, int, tuple[Any, ...]]] = [] for row in rows: # Columns shifted by user_id at index 1: subject/predicate/ # object are now indices 2/3/4. haystack = f"{row[2]} {row[3]} {row[4]}" haystack_tokens = _tokenize(haystack) overlap = sum(1 for t in query_tokens if t in haystack_tokens) if overlap > 0: scored.append((-overlap, len(haystack), row)) scored.sort() return [_row_to_fact(r) for _, _, r in scored[:limit]] def _scan_for_recall( self, valid_at: datetime | None, user_id: str | None ) -> list[tuple[Any, ...]]: with self._connect() as conn: if valid_at is None: cursor = conn.execute( "SELECT * FROM facts WHERE user_id IS ? " "ORDER BY recorded_at DESC", (user_id,), ) else: ts = _to_epoch(valid_at) cursor = conn.execute( "SELECT * FROM facts " "WHERE user_id IS ? " "AND valid_from <= ? " "AND (valid_until IS NULL OR ? < valid_until) " "ORDER BY recorded_at DESC", (user_id, ts, ts), ) return cursor.fetchall()
[docs] async def all_facts(self) -> list[Fact]: rows = await anyio.to_thread.run_sync(self._all_sync) return [_row_to_fact(r) for r in rows]
def _all_sync(self) -> list[tuple[Any, ...]]: with self._connect() as conn: return conn.execute( "SELECT * FROM facts ORDER BY recorded_at DESC" ).fetchall()
[docs] async def aclose(self) -> None: return None
# --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _to_epoch(dt: datetime) -> float: if dt.tzinfo is None: dt = dt.replace(tzinfo=UTC) return dt.timestamp() def _from_epoch(ts: float | None) -> datetime | None: if ts is None: return None return datetime.fromtimestamp(ts, tz=UTC) def _row_to_fact(row: tuple[Any, ...]) -> Fact: # Column layout (after the user_id migration): # 0:id 1:user_id 2:subject 3:predicate 4:object 5:confidence # 6:valid_from 7:valid_until 8:recorded_at 9:sources 10:embedding sources: list[str] = [] if row[9]: try: sources = list(json.loads(row[9])) except json.JSONDecodeError: sources = [] valid_from = _from_epoch(row[6]) assert valid_from is not None recorded_at = _from_epoch(row[8]) assert recorded_at is not None return Fact( id=row[0], user_id=row[1], subject=row[2], predicate=row[3], object=row[4], confidence=row[5], valid_from=valid_from, valid_until=_from_epoch(row[7]), recorded_at=recorded_at, sources=sources, ) def _tokenize(text: str) -> set[str]: """Same tokenisation as :mod:`memory.facts`. Duplicated here rather than imported to avoid a circular import between ``facts`` and ``sqlite_facts``. """ out: set[str] = set() buf: list[str] = [] for ch in text.lower(): if ch.isalnum(): buf.append(ch) else: if buf: token = "".join(buf) if len(token) >= 2 and token not in _STOP_WORDS: out.add(token) buf = [] if buf: token = "".join(buf) if len(token) >= 2 and token not in _STOP_WORDS: out.add(token) return out def _cosine(a: list[float], b: list[float]) -> float: if not a or not b: return 0.0 dot = 0.0 na = 0.0 nb = 0.0 for x, y in zip(a, b, strict=False): dot += x * y na += x * x nb += y * y if na <= 0.0 or nb <= 0.0: return 0.0 return dot / (math.sqrt(na) * math.sqrt(nb)) _STOP_WORDS: frozenset[str] = frozenset( { "the", "and", "for", "with", "from", "into", "this", "that", "what", "tell", "you", "are", "is", "be", "of", "to", "in", "on", "an", "or", "me", "my", "us", "our", "by", "as", "at", "it", "its", "have", "has", "had", "do", "does", "did", "will", "would", "could", "should", "can", } )