Coverage for src / kemi / adapters / storage / sqlite_vec.py: 67%

238 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""SQLite storage adapter with ANN vector index via sqlite-vec. 

2 

3Same database file as the standard SQLite adapter, but search uses an 

4HNSW approximate nearest neighbor index for sub-millisecond vector search 

5instead of brute-force cosine similarity. 

6 

7Install: pip install kemi[vec] 

8 

9Falls back to brute-force if sqlite-vec is not installed. 

10""" 

11 

12from __future__ import annotations 

13 

14import json 

15import logging 

16import sqlite3 

17from typing import TYPE_CHECKING, Any 

18 

19from kemi.adapters.storage.sqlite import SQLiteStorageAdapter 

20from kemi.models import LifecycleState, MemoryObject 

21 

22if TYPE_CHECKING: 

23 from kemi.encryption import EncryptionConfig 

24 

25logger = logging.getLogger(__name__) 

26 

27try: 

28 import sqlite3 

29 

30 import sqlite_vec as _sqlite_vec 

31 

32 # Verify the extension can actually be loaded (not just imported). 

33 # Some environments have the Python package but lack the native .so. 

34 _test_conn = sqlite3.connect(":memory:") 

35 _test_conn.enable_load_extension(True) 

36 _sqlite_vec.load(_test_conn) 

37 _test_conn.enable_load_extension(False) 

38 _test_conn.close() 

39 _SQLITE_VEC_AVAILABLE = True 

40except Exception: # pragma: no cover 

41 _sqlite_vec = None 

42 _SQLITE_VEC_AVAILABLE = False 

43 

44 

45def _embedding_to_json(embedding: list[float]) -> str: 

46 """Serialize embedding to a JSON string for vec0.""" 

47 return json.dumps(embedding) 

48 

49 

50class SQLiteVecStorageAdapter(SQLiteStorageAdapter): 

51 """SQLite storage with HNSW vector index via sqlite-vec. 

52 

53 Uses the same SQLite database file and ``memories`` table as the 

54 standard adapter, but adds a ``memories_vec`` vec0 virtual table 

55 for fast approximate nearest neighbor search. 

56 

57 On ``search()``, the adapter queries the ANN index instead of loading 

58 every row and computing cosine similarity in Python. All other methods 

59 (get, count, export, etc.) delegate to the parent. 

60 

61 If ``sqlite-vec`` is not installed, the adapter silently falls back 

62 to brute-force (same as ``SQLiteStorageAdapter``). 

63 

64 Parameters 

65 ---------- 

66 db_path : str 

67 Path to the SQLite database file. 

68 embedding_dim : int 

69 Dimension of embeddings (e.g. 384 for fastembed, 1536 for OpenAI). 

70 Must match the dimension in use. Default 384. 

71 lazy : bool 

72 If True, defer HNSW index updates until search time. 

73 Inserts are faster (no vec0 index maintenance), but the first 

74 search after a batch of inserts will be slightly slower while 

75 the index catches up. Default False. 

76 """ 

77 

78 CURRENT_VERSION = 7 

79 

80 def __init__( 

81 self, 

82 db_path: str = "kemi.db", 

83 embedding_dim: int = 384, 

84 lazy: bool = False, 

85 encryption: "EncryptionConfig | None" = None, 

86 ) -> None: 

87 self._embedding_dim = embedding_dim 

88 self._lazy = lazy 

89 self._vec_loaded = False 

90 self._pending_count: int | None = None 

91 super().__init__(db_path, encryption=encryption) 

92 

93 def is_lazy(self) -> bool: 

94 """Returns True if deferred HNSW insertion is enabled.""" 

95 return self._lazy 

96 

97 # ── Connection (load vec0 extension once on the shared conn) ───── 

98 

99 def _get_connection(self) -> sqlite3.Connection: 

100 conn = super()._get_connection() 

101 if _SQLITE_VEC_AVAILABLE and not self._vec_loaded: 

102 try: 

103 conn.enable_load_extension(True) 

104 if _sqlite_vec is not None: 

105 _sqlite_vec.load(conn) 

106 conn.enable_load_extension(False) 

107 self._vec_loaded = True 

108 except (sqlite3.OperationalError, AttributeError): # pragma: no cover 

109 pass 

110 return conn 

111 

112 # ── Schema ────────────────────────────────────────────────── 

113 

114 def _init_schema(self) -> None: 

115 with self._get_connection() as conn: 

116 conn.execute("PRAGMA journal_mode=WAL") 

117 conn.execute( 

118 """CREATE TABLE IF NOT EXISTS schema_version ( 

119 version INTEGER PRIMARY KEY, 

120 applied_at TEXT NOT NULL DEFAULT (datetime('now')) 

121 )""" 

122 ) 

123 conn.execute( 

124 """CREATE TABLE IF NOT EXISTS memories ( 

125 memory_id TEXT PRIMARY KEY, 

126 user_id TEXT NOT NULL, 

127 content TEXT NOT NULL, 

128 embedding BLOB, 

129 embedding_dim INTEGER, 

130 created_at TEXT NOT NULL, 

131 last_accessed_at TEXT NOT NULL, 

132 source TEXT NOT NULL DEFAULT 'user_stated', 

133 importance REAL NOT NULL DEFAULT 0.5, 

134 lifecycle_state TEXT NOT NULL DEFAULT 'active', 

135 metadata TEXT NOT NULL DEFAULT '{}', 

136 tags TEXT NOT NULL DEFAULT '', 

137 vec_rowid INTEGER, confidence REAL NOT NULL DEFAULT 1.0, 

138 memory_type TEXT NOT NULL DEFAULT 'episodic', 

139 session_id TEXT, 

140 namespace TEXT NOT NULL DEFAULT 'default', 

141 version INTEGER NOT NULL DEFAULT 1, 

142 agent_id TEXT, 

143 run_id TEXT, 

144 app_id TEXT, 

145 expires_at TEXT 

146 )""" 

147 ) 

148 conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories(user_id)") 

149 conn.execute( 

150 "CREATE INDEX IF NOT EXISTS idx_memories_lifecycle ON memories(lifecycle_state)" 

151 ) 

152 conn.execute( 

153 "CREATE INDEX IF NOT EXISTS idx_memories_user_lifecycle " 

154 "ON memories(user_id, lifecycle_state)" 

155 ) 

156 conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_tags ON memories(tags)") 

157 conn.execute( 

158 "CREATE INDEX IF NOT EXISTS idx_memories_expires_at ON memories(expires_at)" 

159 ) 

160 conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_namespace ON memories(namespace)") 

161 

162 # Pending vec0 inserts for lazy mode 

163 conn.execute( 

164 """CREATE TABLE IF NOT EXISTS memories_vec_pending ( 

165 memory_id TEXT PRIMARY KEY, 

166 user_id TEXT NOT NULL, 

167 embedding TEXT NOT NULL, 

168 lifecycle_state TEXT NOT NULL 

169 )""" 

170 ) 

171 

172 self._init_vec_table(conn) 

173 self._run_migrations(conn) 

174 

175 def _init_vec_table(self, conn: sqlite3.Connection) -> None: 

176 if not _SQLITE_VEC_AVAILABLE: 

177 return 

178 

179 dim = self._embedding_dim 

180 try: 

181 conn.execute( 

182 f"""CREATE VIRTUAL TABLE IF NOT EXISTS memories_vec 

183 USING vec0( 

184 embedding float[{dim}], 

185 memory_id text, 

186 user_id text, 

187 lifecycle_state text 

188 )""" 

189 ) 

190 self._vec_loaded = True 

191 except sqlite3.OperationalError: # pragma: no cover 

192 pass 

193 

194 def _run_migrations(self, conn: sqlite3.Connection) -> None: 

195 current = self._get_schema_version(conn) 

196 if current >= self.CURRENT_VERSION: 

197 return 

198 

199 if current < 2: 

200 try: 

201 conn.execute("ALTER TABLE memories ADD COLUMN tags TEXT NOT NULL DEFAULT ''") 

202 except sqlite3.OperationalError: 

203 pass 

204 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (2)") 

205 

206 if current < 3: 

207 try: 

208 conn.execute("ALTER TABLE memories ADD COLUMN vec_rowid INTEGER") 

209 except sqlite3.OperationalError: 

210 pass 

211 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (3)") 

212 

213 if current < 4: 

214 try: 

215 conn.execute( 

216 """CREATE TABLE IF NOT EXISTS memories_vec_pending ( 

217 memory_id TEXT PRIMARY KEY, 

218 user_id TEXT NOT NULL, 

219 embedding TEXT NOT NULL, 

220 lifecycle_state TEXT NOT NULL 

221 )""" 

222 ) 

223 except sqlite3.OperationalError: 

224 pass 

225 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (4)") 

226 

227 if current < 5: 

228 for col, dtype in [ 

229 ("confidence", "REAL NOT NULL DEFAULT 1.0"), 

230 ("memory_type", "TEXT NOT NULL DEFAULT 'episodic'"), 

231 ("session_id", "TEXT"), 

232 ("namespace", "TEXT NOT NULL DEFAULT 'default'"), 

233 ("version", "INTEGER NOT NULL DEFAULT 1"), 

234 ]: 

235 try: 

236 conn.execute(f"ALTER TABLE memories ADD COLUMN {col} {dtype}") 

237 except sqlite3.OperationalError: 

238 pass 

239 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (5)") 

240 

241 if current < 6: 

242 for col, dtype in [ 

243 ("agent_id", "TEXT"), 

244 ("run_id", "TEXT"), 

245 ("app_id", "TEXT"), 

246 ]: 

247 try: 

248 conn.execute(f"ALTER TABLE memories ADD COLUMN {col} {dtype}") 

249 except sqlite3.OperationalError: 

250 pass 

251 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (6)") 

252 

253 if current < 7: 

254 # TTL: add expires_at column and index 

255 try: 

256 conn.execute("ALTER TABLE memories ADD COLUMN expires_at TEXT") 

257 except sqlite3.OperationalError: 

258 pass 

259 try: 

260 conn.execute( 

261 "CREATE INDEX IF NOT EXISTS idx_memories_expires_at " 

262 "ON memories(expires_at)" 

263 ) 

264 except sqlite3.OperationalError: 

265 pass 

266 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (7)") 

267 

268 # ── Store ─────────────────────────────────────────────────── 

269 

270 def store(self, memory: MemoryObject) -> None: 

271 with self._get_connection() as conn: 

272 # Preserve existing vec_rowid from DB if not set in metadata. 

273 # Without this, INSERT OR REPLACE overwrites vec_rowid to NULL, 

274 # causing _flush_pending to create duplicate vec0 entries. 

275 if self._lazy and self._vec_loaded and memory.metadata.get("_vec_rowid") is None: 

276 existing = conn.execute( 

277 "SELECT vec_rowid FROM memories WHERE memory_id = ?", 

278 (memory.memory_id,), 

279 ).fetchone() 

280 if existing and existing["vec_rowid"] is not None: 

281 memory.metadata["_vec_rowid"] = existing["vec_rowid"] 

282 

283 row = self._memory_to_row(memory) 

284 conn.execute( 

285 """INSERT OR REPLACE INTO memories 

286 (memory_id, user_id, content, embedding, embedding_dim, 

287 created_at, last_accessed_at, source, importance, 

288 lifecycle_state, metadata, tags, vec_rowid, 

289 confidence, memory_type, session_id, namespace, version, 

290 agent_id, run_id, app_id, expires_at) 

291 VALUES (:memory_id, :user_id, :content, :embedding, :embedding_dim, 

292 :created_at, :last_accessed_at, :source, :importance, 

293 :lifecycle_state, :metadata, :tags, :vec_rowid, 

294 :confidence, :memory_type, :session_id, :namespace, :version, 

295 :agent_id, :run_id, :app_id, :expires_at)""", 

296 row, 

297 ) 

298 

299 if not self._vec_loaded or memory.embedding is None: 

300 return 

301 

302 if self._lazy: 

303 self._store_pending_on_conn(conn, memory) 

304 else: 

305 self._store_vec_direct_on_conn(conn, memory) 

306 

307 def _store_pending_on_conn(self, conn: sqlite3.Connection, memory: MemoryObject) -> None: 

308 """Store embedding in the pending table (no HNSW index update).""" 

309 assert memory.embedding is not None 

310 embedding_json = _embedding_to_json(memory.embedding) 

311 conn.execute( 

312 """INSERT OR REPLACE INTO memories_vec_pending 

313 (memory_id, user_id, embedding, lifecycle_state) 

314 VALUES (?, ?, ?, ?)""", 

315 ( 

316 memory.memory_id, 

317 memory.user_id, 

318 embedding_json, 

319 memory.lifecycle_state.value, 

320 ), 

321 ) 

322 # Invalidate cached pending count so next _has_pending() is accurate 

323 self._pending_count = None 

324 

325 def _store_vec_direct_on_conn(self, conn: sqlite3.Connection, memory: MemoryObject) -> None: 

326 """Insert/update the vector in the vec0 HNSW index directly. 

327 

328 Since the HNSW index is built incrementally, this is slower 

329 than lazy insertion but keeps the index always up-to-date. 

330 """ 

331 assert memory.embedding is not None 

332 embedding_json = _embedding_to_json(memory.embedding) 

333 existing_row = conn.execute( 

334 "SELECT vec_rowid FROM memories WHERE memory_id = ?", 

335 (memory.memory_id,), 

336 ).fetchone() 

337 vec_rowid: int | None = ( 

338 existing_row[0] if existing_row and existing_row[0] is not None else None 

339 ) 

340 

341 if vec_rowid is not None: 

342 conn.execute( 

343 "UPDATE memories_vec SET embedding=?, user_id=?, lifecycle_state=? WHERE rowid=?", 

344 (embedding_json, memory.user_id, memory.lifecycle_state.value, vec_rowid), 

345 ) 

346 else: 

347 cursor = conn.execute( 

348 """INSERT INTO memories_vec (embedding, memory_id, user_id, lifecycle_state) 

349 VALUES (?, ?, ?, ?)""", 

350 (embedding_json, memory.memory_id, memory.user_id, memory.lifecycle_state.value), 

351 ) 

352 new_rowid = cursor.lastrowid 

353 if new_rowid is not None: 

354 conn.execute( 

355 "UPDATE memories SET vec_rowid = ? WHERE memory_id = ?", 

356 (new_rowid, memory.memory_id), 

357 ) 

358 memory.metadata["_vec_rowid"] = new_rowid 

359 

360 # ── Flush pending → vec0 ─────────────────────────────────── 

361 

362 def _count_pending(self) -> int: 

363 """Return how many memories are waiting in the pending table.""" 

364 if self._pending_count is not None: 

365 return self._pending_count 

366 with self._get_connection() as conn: 

367 row = conn.execute("SELECT COUNT(*) FROM memories_vec_pending").fetchone() 

368 self._pending_count = row[0] if row else 0 

369 return self._pending_count 

370 

371 def _has_pending(self) -> bool: 

372 return self._count_pending() > 0 

373 

374 def _flush_pending(self) -> None: 

375 """Batch-insert all pending vectors into the vec0 HNSW index. 

376 

377 During lazy mode the HNSW index is not updated per insert, so 

378 we batch all pending entries here in a single atomic transaction 

379 for efficiency. 

380 

381 Handles re-flush gracefully: if a memory already has a vec_rowid 

382 (e.g. it was flushed before, then updated and re-stored), we 

383 UPDATE the existing vec0 row instead of creating a duplicate. 

384 """ 

385 if not self._vec_loaded: 

386 return 

387 if not self._has_pending(): 

388 return 

389 

390 count = self._pending_count or 0 

391 logger.info("Flushing %d pending vectors to ANN index…", count) 

392 

393 with self._transaction() as conn: 

394 rows = conn.execute( 

395 "SELECT p.memory_id, p.user_id, p.embedding, p.lifecycle_state, " 

396 " m.vec_rowid " 

397 "FROM memories_vec_pending p " 

398 "LEFT JOIN memories m ON m.memory_id = p.memory_id" 

399 ).fetchall() 

400 

401 for row in rows: 

402 mid = row["memory_id"] 

403 existing_vec_rowid = row["vec_rowid"] 

404 

405 if existing_vec_rowid is not None: 

406 conn.execute( 

407 "UPDATE memories_vec SET embedding=?, user_id=?, lifecycle_state=? " 

408 "WHERE rowid=?", 

409 ( 

410 row["embedding"], 

411 row["user_id"], 

412 row["lifecycle_state"], 

413 existing_vec_rowid, 

414 ), 

415 ) 

416 else: 

417 conn.execute( 

418 """INSERT INTO memories_vec 

419 (embedding, memory_id, user_id, lifecycle_state) 

420 VALUES (?, ?, ?, ?)""", 

421 (row["embedding"], mid, row["user_id"], row["lifecycle_state"]), 

422 ) 

423 result = conn.execute("SELECT last_insert_rowid()").fetchone() 

424 new_rowid = result[0] if result else None 

425 if new_rowid is not None: 

426 conn.execute( 

427 "UPDATE memories SET vec_rowid = ? WHERE memory_id = ?", 

428 (new_rowid, mid), 

429 ) 

430 

431 conn.execute("DELETE FROM memories_vec_pending") 

432 

433 self._pending_count = 0 

434 logger.info("Flushed %d vectors to ANN index", count) 

435 

436 # ── Search ────────────────────────────────────────────── 

437 

438 def search( 

439 self, 

440 user_id: str, 

441 query_embedding: list[float], 

442 top_k: int = 10, 

443 lifecycle_filter: list[LifecycleState] | None = None, 

444 namespace: str = "default", 

445 session_id: str | None = None, 

446 ) -> list[MemoryObject]: 

447 if lifecycle_filter is None: 

448 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING] 

449 

450 states_list = [s.value for s in lifecycle_filter] 

451 

452 if self._vec_loaded and query_embedding: 

453 if self._lazy and self._has_pending(): 

454 self._flush_pending() 

455 return self._search_vec(user_id, query_embedding, top_k, states_list, namespace) 

456 

457 # Fallback: brute-force scan via parent 

458 return super().search( 

459 user_id, query_embedding, top_k, lifecycle_filter, namespace, session_id 

460 ) 

461 

462 def _search_vec( 

463 self, 

464 user_id: str, 

465 query_embedding: list[float], 

466 top_k: int, 

467 states_list: list[str], 

468 namespace: str = "default", 

469 ) -> list[MemoryObject]: 

470 """Search using the vec0 ANN index.""" 

471 embedding_json = _embedding_to_json(query_embedding) 

472 placeholders = ",".join("?" * len(states_list)) 

473 

474 # Over-fetch from vec0 because namespace is filtered post-hoc. 

475 # vec0 doesn't index namespace, so we can't push it into the ANN query. 

476 # Multiplying by 3 is a heuristic; matches core.py's recall() which also 

477 # fetches top_k * 3 from the storage layer. 

478 fetch_k = top_k * 3 

479 

480 with self._get_connection() as conn: 

481 rows = conn.execute( 

482 f"""SELECT rowid, distance, memory_id 

483 FROM memories_vec 

484 WHERE embedding MATCH ? 

485 AND user_id = ? 

486 AND lifecycle_state IN ({placeholders}) 

487 ORDER BY distance 

488 LIMIT ?""", 

489 [embedding_json, user_id] + states_list + [fetch_k], 

490 ).fetchall() 

491 

492 if not rows: 

493 return [] 

494 

495 memory_ids = [r["memory_id"] for r in rows] 

496 distances = {r["memory_id"]: r["distance"] for r in rows} 

497 

498 id_placeholders = ",".join("?" * len(memory_ids)) 

499 memory_rows = conn.execute( 

500 f"SELECT * FROM memories WHERE memory_id IN ({id_placeholders})", 

501 memory_ids, 

502 ).fetchall() 

503 

504 mem_map = {r["memory_id"]: r for r in memory_rows} 

505 results: list[MemoryObject] = [] 

506 for mid in memory_ids: 

507 if mid not in mem_map: 

508 continue 

509 mem = self._row_to_memory(mem_map[mid]) 

510 # vec0 returns cosine distance in [0, 2]; convert to [0, 1] score 

511 distance = distances.get(mid, 0.0) 

512 mem.score = max(0.0, min(1.0, 1.0 - distance / 2.0)) 

513 # Filter by namespace (vec0 doesn't support this in the query) 

514 if mem.namespace == namespace: 

515 results.append(mem) 

516 if len(results) >= top_k: 

517 break 

518 

519 return results 

520 

521 # ── Delete ────────────────────────────────────────────── 

522 

523 def delete_by_id(self, memory_id: str) -> bool: 

524 with self._get_connection() as conn: 

525 if self._vec_loaded: 

526 row = conn.execute( 

527 "SELECT vec_rowid FROM memories WHERE memory_id = ?", 

528 (memory_id,), 

529 ).fetchone() 

530 if row and row["vec_rowid"] is not None: 

531 conn.execute( 

532 "DELETE FROM memories_vec WHERE rowid = ?", 

533 (row["vec_rowid"],), 

534 ) 

535 

536 conn.execute( 

537 "DELETE FROM memories_vec_pending WHERE memory_id = ?", 

538 (memory_id,), 

539 ) 

540 self._pending_count = None 

541 

542 cursor = conn.execute("DELETE FROM memories WHERE memory_id = ?", (memory_id,)) 

543 return cursor.rowcount > 0 

544 

545 def delete_by_user(self, user_id: str) -> int: 

546 with self._get_connection() as conn: 

547 if self._vec_loaded: 

548 rows = conn.execute( 

549 "SELECT vec_rowid FROM memories WHERE user_id = ? AND vec_rowid IS NOT NULL", 

550 (user_id,), 

551 ).fetchall() 

552 for r in rows: 

553 conn.execute("DELETE FROM memories_vec WHERE rowid = ?", (r[0],)) 

554 

555 conn.execute( 

556 "DELETE FROM memories_vec_pending WHERE user_id = ?", 

557 (user_id,), 

558 ) 

559 self._pending_count = None 

560 

561 cursor = conn.execute("DELETE FROM memories WHERE user_id = ?", (user_id,)) 

562 return cursor.rowcount 

563 

564 # ── Update ────────────────────────────────────────────────── 

565 

566 def update(self, memory: MemoryObject) -> None: 

567 self.store(memory) 

568 

569 # ── Row helpers ────────────────────────────────────────────── 

570 

571 def _memory_to_row(self, memory: MemoryObject) -> dict[str, Any]: 

572 row = super()._memory_to_row(memory) 

573 vec_rowid = memory.metadata.get("_vec_rowid") if memory.metadata else None 

574 row["vec_rowid"] = vec_rowid 

575 return row 

576 

577 def _row_to_memory(self, row: sqlite3.Row) -> MemoryObject: 

578 mem = super()._row_to_memory(row) 

579 try: 

580 vec_rowid = row["vec_rowid"] 

581 if vec_rowid is not None: 

582 mem.metadata["_vec_rowid"] = vec_rowid 

583 except (IndexError, KeyError): # pragma: no cover 

584 pass 

585 return mem 

586 

587 # ── Utility ────────────────────────────────────────────── 

588 

589 @classmethod 

590 def is_vec_available(cls) -> bool: 

591 """Returns True if sqlite-vec is installed and usable.""" 

592 return _SQLITE_VEC_AVAILABLE