Coverage for src / kemi / adapters / storage / sqlite_vec.py: 67%
238 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""SQLite storage adapter with ANN vector index via sqlite-vec.
3Same database file as the standard SQLite adapter, but search uses an
4HNSW approximate nearest neighbor index for sub-millisecond vector search
5instead of brute-force cosine similarity.
7Install: pip install kemi[vec]
9Falls back to brute-force if sqlite-vec is not installed.
10"""
12from __future__ import annotations
14import json
15import logging
16import sqlite3
17from typing import TYPE_CHECKING, Any
19from kemi.adapters.storage.sqlite import SQLiteStorageAdapter
20from kemi.models import LifecycleState, MemoryObject
22if TYPE_CHECKING:
23 from kemi.encryption import EncryptionConfig
25logger = logging.getLogger(__name__)
27try:
28 import sqlite3
30 import sqlite_vec as _sqlite_vec
32 # Verify the extension can actually be loaded (not just imported).
33 # Some environments have the Python package but lack the native .so.
34 _test_conn = sqlite3.connect(":memory:")
35 _test_conn.enable_load_extension(True)
36 _sqlite_vec.load(_test_conn)
37 _test_conn.enable_load_extension(False)
38 _test_conn.close()
39 _SQLITE_VEC_AVAILABLE = True
40except Exception: # pragma: no cover
41 _sqlite_vec = None
42 _SQLITE_VEC_AVAILABLE = False
45def _embedding_to_json(embedding: list[float]) -> str:
46 """Serialize embedding to a JSON string for vec0."""
47 return json.dumps(embedding)
50class SQLiteVecStorageAdapter(SQLiteStorageAdapter):
51 """SQLite storage with HNSW vector index via sqlite-vec.
53 Uses the same SQLite database file and ``memories`` table as the
54 standard adapter, but adds a ``memories_vec`` vec0 virtual table
55 for fast approximate nearest neighbor search.
57 On ``search()``, the adapter queries the ANN index instead of loading
58 every row and computing cosine similarity in Python. All other methods
59 (get, count, export, etc.) delegate to the parent.
61 If ``sqlite-vec`` is not installed, the adapter silently falls back
62 to brute-force (same as ``SQLiteStorageAdapter``).
64 Parameters
65 ----------
66 db_path : str
67 Path to the SQLite database file.
68 embedding_dim : int
69 Dimension of embeddings (e.g. 384 for fastembed, 1536 for OpenAI).
70 Must match the dimension in use. Default 384.
71 lazy : bool
72 If True, defer HNSW index updates until search time.
73 Inserts are faster (no vec0 index maintenance), but the first
74 search after a batch of inserts will be slightly slower while
75 the index catches up. Default False.
76 """
78 CURRENT_VERSION = 7
80 def __init__(
81 self,
82 db_path: str = "kemi.db",
83 embedding_dim: int = 384,
84 lazy: bool = False,
85 encryption: "EncryptionConfig | None" = None,
86 ) -> None:
87 self._embedding_dim = embedding_dim
88 self._lazy = lazy
89 self._vec_loaded = False
90 self._pending_count: int | None = None
91 super().__init__(db_path, encryption=encryption)
93 def is_lazy(self) -> bool:
94 """Returns True if deferred HNSW insertion is enabled."""
95 return self._lazy
97 # ── Connection (load vec0 extension once on the shared conn) ─────
99 def _get_connection(self) -> sqlite3.Connection:
100 conn = super()._get_connection()
101 if _SQLITE_VEC_AVAILABLE and not self._vec_loaded:
102 try:
103 conn.enable_load_extension(True)
104 if _sqlite_vec is not None:
105 _sqlite_vec.load(conn)
106 conn.enable_load_extension(False)
107 self._vec_loaded = True
108 except (sqlite3.OperationalError, AttributeError): # pragma: no cover
109 pass
110 return conn
112 # ── Schema ──────────────────────────────────────────────────
114 def _init_schema(self) -> None:
115 with self._get_connection() as conn:
116 conn.execute("PRAGMA journal_mode=WAL")
117 conn.execute(
118 """CREATE TABLE IF NOT EXISTS schema_version (
119 version INTEGER PRIMARY KEY,
120 applied_at TEXT NOT NULL DEFAULT (datetime('now'))
121 )"""
122 )
123 conn.execute(
124 """CREATE TABLE IF NOT EXISTS memories (
125 memory_id TEXT PRIMARY KEY,
126 user_id TEXT NOT NULL,
127 content TEXT NOT NULL,
128 embedding BLOB,
129 embedding_dim INTEGER,
130 created_at TEXT NOT NULL,
131 last_accessed_at TEXT NOT NULL,
132 source TEXT NOT NULL DEFAULT 'user_stated',
133 importance REAL NOT NULL DEFAULT 0.5,
134 lifecycle_state TEXT NOT NULL DEFAULT 'active',
135 metadata TEXT NOT NULL DEFAULT '{}',
136 tags TEXT NOT NULL DEFAULT '',
137 vec_rowid INTEGER, confidence REAL NOT NULL DEFAULT 1.0,
138 memory_type TEXT NOT NULL DEFAULT 'episodic',
139 session_id TEXT,
140 namespace TEXT NOT NULL DEFAULT 'default',
141 version INTEGER NOT NULL DEFAULT 1,
142 agent_id TEXT,
143 run_id TEXT,
144 app_id TEXT,
145 expires_at TEXT
146 )"""
147 )
148 conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories(user_id)")
149 conn.execute(
150 "CREATE INDEX IF NOT EXISTS idx_memories_lifecycle ON memories(lifecycle_state)"
151 )
152 conn.execute(
153 "CREATE INDEX IF NOT EXISTS idx_memories_user_lifecycle "
154 "ON memories(user_id, lifecycle_state)"
155 )
156 conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_tags ON memories(tags)")
157 conn.execute(
158 "CREATE INDEX IF NOT EXISTS idx_memories_expires_at ON memories(expires_at)"
159 )
160 conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_namespace ON memories(namespace)")
162 # Pending vec0 inserts for lazy mode
163 conn.execute(
164 """CREATE TABLE IF NOT EXISTS memories_vec_pending (
165 memory_id TEXT PRIMARY KEY,
166 user_id TEXT NOT NULL,
167 embedding TEXT NOT NULL,
168 lifecycle_state TEXT NOT NULL
169 )"""
170 )
172 self._init_vec_table(conn)
173 self._run_migrations(conn)
175 def _init_vec_table(self, conn: sqlite3.Connection) -> None:
176 if not _SQLITE_VEC_AVAILABLE:
177 return
179 dim = self._embedding_dim
180 try:
181 conn.execute(
182 f"""CREATE VIRTUAL TABLE IF NOT EXISTS memories_vec
183 USING vec0(
184 embedding float[{dim}],
185 memory_id text,
186 user_id text,
187 lifecycle_state text
188 )"""
189 )
190 self._vec_loaded = True
191 except sqlite3.OperationalError: # pragma: no cover
192 pass
194 def _run_migrations(self, conn: sqlite3.Connection) -> None:
195 current = self._get_schema_version(conn)
196 if current >= self.CURRENT_VERSION:
197 return
199 if current < 2:
200 try:
201 conn.execute("ALTER TABLE memories ADD COLUMN tags TEXT NOT NULL DEFAULT ''")
202 except sqlite3.OperationalError:
203 pass
204 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (2)")
206 if current < 3:
207 try:
208 conn.execute("ALTER TABLE memories ADD COLUMN vec_rowid INTEGER")
209 except sqlite3.OperationalError:
210 pass
211 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (3)")
213 if current < 4:
214 try:
215 conn.execute(
216 """CREATE TABLE IF NOT EXISTS memories_vec_pending (
217 memory_id TEXT PRIMARY KEY,
218 user_id TEXT NOT NULL,
219 embedding TEXT NOT NULL,
220 lifecycle_state TEXT NOT NULL
221 )"""
222 )
223 except sqlite3.OperationalError:
224 pass
225 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (4)")
227 if current < 5:
228 for col, dtype in [
229 ("confidence", "REAL NOT NULL DEFAULT 1.0"),
230 ("memory_type", "TEXT NOT NULL DEFAULT 'episodic'"),
231 ("session_id", "TEXT"),
232 ("namespace", "TEXT NOT NULL DEFAULT 'default'"),
233 ("version", "INTEGER NOT NULL DEFAULT 1"),
234 ]:
235 try:
236 conn.execute(f"ALTER TABLE memories ADD COLUMN {col} {dtype}")
237 except sqlite3.OperationalError:
238 pass
239 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (5)")
241 if current < 6:
242 for col, dtype in [
243 ("agent_id", "TEXT"),
244 ("run_id", "TEXT"),
245 ("app_id", "TEXT"),
246 ]:
247 try:
248 conn.execute(f"ALTER TABLE memories ADD COLUMN {col} {dtype}")
249 except sqlite3.OperationalError:
250 pass
251 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (6)")
253 if current < 7:
254 # TTL: add expires_at column and index
255 try:
256 conn.execute("ALTER TABLE memories ADD COLUMN expires_at TEXT")
257 except sqlite3.OperationalError:
258 pass
259 try:
260 conn.execute(
261 "CREATE INDEX IF NOT EXISTS idx_memories_expires_at "
262 "ON memories(expires_at)"
263 )
264 except sqlite3.OperationalError:
265 pass
266 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (7)")
268 # ── Store ───────────────────────────────────────────────────
270 def store(self, memory: MemoryObject) -> None:
271 with self._get_connection() as conn:
272 # Preserve existing vec_rowid from DB if not set in metadata.
273 # Without this, INSERT OR REPLACE overwrites vec_rowid to NULL,
274 # causing _flush_pending to create duplicate vec0 entries.
275 if self._lazy and self._vec_loaded and memory.metadata.get("_vec_rowid") is None:
276 existing = conn.execute(
277 "SELECT vec_rowid FROM memories WHERE memory_id = ?",
278 (memory.memory_id,),
279 ).fetchone()
280 if existing and existing["vec_rowid"] is not None:
281 memory.metadata["_vec_rowid"] = existing["vec_rowid"]
283 row = self._memory_to_row(memory)
284 conn.execute(
285 """INSERT OR REPLACE INTO memories
286 (memory_id, user_id, content, embedding, embedding_dim,
287 created_at, last_accessed_at, source, importance,
288 lifecycle_state, metadata, tags, vec_rowid,
289 confidence, memory_type, session_id, namespace, version,
290 agent_id, run_id, app_id, expires_at)
291 VALUES (:memory_id, :user_id, :content, :embedding, :embedding_dim,
292 :created_at, :last_accessed_at, :source, :importance,
293 :lifecycle_state, :metadata, :tags, :vec_rowid,
294 :confidence, :memory_type, :session_id, :namespace, :version,
295 :agent_id, :run_id, :app_id, :expires_at)""",
296 row,
297 )
299 if not self._vec_loaded or memory.embedding is None:
300 return
302 if self._lazy:
303 self._store_pending_on_conn(conn, memory)
304 else:
305 self._store_vec_direct_on_conn(conn, memory)
307 def _store_pending_on_conn(self, conn: sqlite3.Connection, memory: MemoryObject) -> None:
308 """Store embedding in the pending table (no HNSW index update)."""
309 assert memory.embedding is not None
310 embedding_json = _embedding_to_json(memory.embedding)
311 conn.execute(
312 """INSERT OR REPLACE INTO memories_vec_pending
313 (memory_id, user_id, embedding, lifecycle_state)
314 VALUES (?, ?, ?, ?)""",
315 (
316 memory.memory_id,
317 memory.user_id,
318 embedding_json,
319 memory.lifecycle_state.value,
320 ),
321 )
322 # Invalidate cached pending count so next _has_pending() is accurate
323 self._pending_count = None
325 def _store_vec_direct_on_conn(self, conn: sqlite3.Connection, memory: MemoryObject) -> None:
326 """Insert/update the vector in the vec0 HNSW index directly.
328 Since the HNSW index is built incrementally, this is slower
329 than lazy insertion but keeps the index always up-to-date.
330 """
331 assert memory.embedding is not None
332 embedding_json = _embedding_to_json(memory.embedding)
333 existing_row = conn.execute(
334 "SELECT vec_rowid FROM memories WHERE memory_id = ?",
335 (memory.memory_id,),
336 ).fetchone()
337 vec_rowid: int | None = (
338 existing_row[0] if existing_row and existing_row[0] is not None else None
339 )
341 if vec_rowid is not None:
342 conn.execute(
343 "UPDATE memories_vec SET embedding=?, user_id=?, lifecycle_state=? WHERE rowid=?",
344 (embedding_json, memory.user_id, memory.lifecycle_state.value, vec_rowid),
345 )
346 else:
347 cursor = conn.execute(
348 """INSERT INTO memories_vec (embedding, memory_id, user_id, lifecycle_state)
349 VALUES (?, ?, ?, ?)""",
350 (embedding_json, memory.memory_id, memory.user_id, memory.lifecycle_state.value),
351 )
352 new_rowid = cursor.lastrowid
353 if new_rowid is not None:
354 conn.execute(
355 "UPDATE memories SET vec_rowid = ? WHERE memory_id = ?",
356 (new_rowid, memory.memory_id),
357 )
358 memory.metadata["_vec_rowid"] = new_rowid
360 # ── Flush pending → vec0 ───────────────────────────────────
362 def _count_pending(self) -> int:
363 """Return how many memories are waiting in the pending table."""
364 if self._pending_count is not None:
365 return self._pending_count
366 with self._get_connection() as conn:
367 row = conn.execute("SELECT COUNT(*) FROM memories_vec_pending").fetchone()
368 self._pending_count = row[0] if row else 0
369 return self._pending_count
371 def _has_pending(self) -> bool:
372 return self._count_pending() > 0
374 def _flush_pending(self) -> None:
375 """Batch-insert all pending vectors into the vec0 HNSW index.
377 During lazy mode the HNSW index is not updated per insert, so
378 we batch all pending entries here in a single atomic transaction
379 for efficiency.
381 Handles re-flush gracefully: if a memory already has a vec_rowid
382 (e.g. it was flushed before, then updated and re-stored), we
383 UPDATE the existing vec0 row instead of creating a duplicate.
384 """
385 if not self._vec_loaded:
386 return
387 if not self._has_pending():
388 return
390 count = self._pending_count or 0
391 logger.info("Flushing %d pending vectors to ANN index…", count)
393 with self._transaction() as conn:
394 rows = conn.execute(
395 "SELECT p.memory_id, p.user_id, p.embedding, p.lifecycle_state, "
396 " m.vec_rowid "
397 "FROM memories_vec_pending p "
398 "LEFT JOIN memories m ON m.memory_id = p.memory_id"
399 ).fetchall()
401 for row in rows:
402 mid = row["memory_id"]
403 existing_vec_rowid = row["vec_rowid"]
405 if existing_vec_rowid is not None:
406 conn.execute(
407 "UPDATE memories_vec SET embedding=?, user_id=?, lifecycle_state=? "
408 "WHERE rowid=?",
409 (
410 row["embedding"],
411 row["user_id"],
412 row["lifecycle_state"],
413 existing_vec_rowid,
414 ),
415 )
416 else:
417 conn.execute(
418 """INSERT INTO memories_vec
419 (embedding, memory_id, user_id, lifecycle_state)
420 VALUES (?, ?, ?, ?)""",
421 (row["embedding"], mid, row["user_id"], row["lifecycle_state"]),
422 )
423 result = conn.execute("SELECT last_insert_rowid()").fetchone()
424 new_rowid = result[0] if result else None
425 if new_rowid is not None:
426 conn.execute(
427 "UPDATE memories SET vec_rowid = ? WHERE memory_id = ?",
428 (new_rowid, mid),
429 )
431 conn.execute("DELETE FROM memories_vec_pending")
433 self._pending_count = 0
434 logger.info("Flushed %d vectors to ANN index", count)
436 # ── Search ──────────────────────────────────────────────
438 def search(
439 self,
440 user_id: str,
441 query_embedding: list[float],
442 top_k: int = 10,
443 lifecycle_filter: list[LifecycleState] | None = None,
444 namespace: str = "default",
445 session_id: str | None = None,
446 ) -> list[MemoryObject]:
447 if lifecycle_filter is None:
448 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING]
450 states_list = [s.value for s in lifecycle_filter]
452 if self._vec_loaded and query_embedding:
453 if self._lazy and self._has_pending():
454 self._flush_pending()
455 return self._search_vec(user_id, query_embedding, top_k, states_list, namespace)
457 # Fallback: brute-force scan via parent
458 return super().search(
459 user_id, query_embedding, top_k, lifecycle_filter, namespace, session_id
460 )
462 def _search_vec(
463 self,
464 user_id: str,
465 query_embedding: list[float],
466 top_k: int,
467 states_list: list[str],
468 namespace: str = "default",
469 ) -> list[MemoryObject]:
470 """Search using the vec0 ANN index."""
471 embedding_json = _embedding_to_json(query_embedding)
472 placeholders = ",".join("?" * len(states_list))
474 # Over-fetch from vec0 because namespace is filtered post-hoc.
475 # vec0 doesn't index namespace, so we can't push it into the ANN query.
476 # Multiplying by 3 is a heuristic; matches core.py's recall() which also
477 # fetches top_k * 3 from the storage layer.
478 fetch_k = top_k * 3
480 with self._get_connection() as conn:
481 rows = conn.execute(
482 f"""SELECT rowid, distance, memory_id
483 FROM memories_vec
484 WHERE embedding MATCH ?
485 AND user_id = ?
486 AND lifecycle_state IN ({placeholders})
487 ORDER BY distance
488 LIMIT ?""",
489 [embedding_json, user_id] + states_list + [fetch_k],
490 ).fetchall()
492 if not rows:
493 return []
495 memory_ids = [r["memory_id"] for r in rows]
496 distances = {r["memory_id"]: r["distance"] for r in rows}
498 id_placeholders = ",".join("?" * len(memory_ids))
499 memory_rows = conn.execute(
500 f"SELECT * FROM memories WHERE memory_id IN ({id_placeholders})",
501 memory_ids,
502 ).fetchall()
504 mem_map = {r["memory_id"]: r for r in memory_rows}
505 results: list[MemoryObject] = []
506 for mid in memory_ids:
507 if mid not in mem_map:
508 continue
509 mem = self._row_to_memory(mem_map[mid])
510 # vec0 returns cosine distance in [0, 2]; convert to [0, 1] score
511 distance = distances.get(mid, 0.0)
512 mem.score = max(0.0, min(1.0, 1.0 - distance / 2.0))
513 # Filter by namespace (vec0 doesn't support this in the query)
514 if mem.namespace == namespace:
515 results.append(mem)
516 if len(results) >= top_k:
517 break
519 return results
521 # ── Delete ──────────────────────────────────────────────
523 def delete_by_id(self, memory_id: str) -> bool:
524 with self._get_connection() as conn:
525 if self._vec_loaded:
526 row = conn.execute(
527 "SELECT vec_rowid FROM memories WHERE memory_id = ?",
528 (memory_id,),
529 ).fetchone()
530 if row and row["vec_rowid"] is not None:
531 conn.execute(
532 "DELETE FROM memories_vec WHERE rowid = ?",
533 (row["vec_rowid"],),
534 )
536 conn.execute(
537 "DELETE FROM memories_vec_pending WHERE memory_id = ?",
538 (memory_id,),
539 )
540 self._pending_count = None
542 cursor = conn.execute("DELETE FROM memories WHERE memory_id = ?", (memory_id,))
543 return cursor.rowcount > 0
545 def delete_by_user(self, user_id: str) -> int:
546 with self._get_connection() as conn:
547 if self._vec_loaded:
548 rows = conn.execute(
549 "SELECT vec_rowid FROM memories WHERE user_id = ? AND vec_rowid IS NOT NULL",
550 (user_id,),
551 ).fetchall()
552 for r in rows:
553 conn.execute("DELETE FROM memories_vec WHERE rowid = ?", (r[0],))
555 conn.execute(
556 "DELETE FROM memories_vec_pending WHERE user_id = ?",
557 (user_id,),
558 )
559 self._pending_count = None
561 cursor = conn.execute("DELETE FROM memories WHERE user_id = ?", (user_id,))
562 return cursor.rowcount
564 # ── Update ──────────────────────────────────────────────────
566 def update(self, memory: MemoryObject) -> None:
567 self.store(memory)
569 # ── Row helpers ──────────────────────────────────────────────
571 def _memory_to_row(self, memory: MemoryObject) -> dict[str, Any]:
572 row = super()._memory_to_row(memory)
573 vec_rowid = memory.metadata.get("_vec_rowid") if memory.metadata else None
574 row["vec_rowid"] = vec_rowid
575 return row
577 def _row_to_memory(self, row: sqlite3.Row) -> MemoryObject:
578 mem = super()._row_to_memory(row)
579 try:
580 vec_rowid = row["vec_rowid"]
581 if vec_rowid is not None:
582 mem.metadata["_vec_rowid"] = vec_rowid
583 except (IndexError, KeyError): # pragma: no cover
584 pass
585 return mem
587 # ── Utility ──────────────────────────────────────────────
589 @classmethod
590 def is_vec_available(cls) -> bool:
591 """Returns True if sqlite-vec is installed and usable."""
592 return _SQLITE_VEC_AVAILABLE