Coverage for src / kemi / adapters / storage / sqlite.py: 92%

391 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1from __future__ import annotations 

2 

3import json 

4import logging 

5import sqlite3 

6import struct 

7import threading 

8from contextlib import contextmanager 

9from datetime import datetime 

10from typing import TYPE_CHECKING, Any 

11 

12from kemi import scoring 

13from kemi.adapters.base import StorageAdapter 

14from kemi.models import LifecycleState, MemoryObject, MemorySource, MemoryType 

15 

16if TYPE_CHECKING: 

17 from kemi.encryption import EncryptionConfig 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22class SQLiteStorageAdapter(StorageAdapter): 

23 """SQLite storage adapter with WAL mode and thread-local connections. 

24 

25 Embedding stored as BLOB (float32 bytes) for compactness. 

26 Schema version tracked in schema_version table. 

27 

28 Thread-safety: Uses thread-local storage for connections, giving each 

29 thread its own connection. This avoids SQLite's thread-safety issues 

30 while allowing concurrent access from multiple threads. 

31 

32 Encryption: When encryption config is provided, uses Fernet field-level 

33 encryption for content and metadata fields. Pass ``encryption=`` to 

34 ``__init__`` or set ``KEMI_ENCRYPTION_KEY`` / ``KEMI_ENCRYPTION_ENABLED`` 

35 environment variables. 

36 """ 

37 

38 CURRENT_VERSION = 8 

39 

40 def __init__( 

41 self, 

42 db_path: str = "kemi.db", 

43 encryption: "EncryptionConfig | None" = None, 

44 ) -> None: 

45 self._db_path = db_path 

46 self._local = threading.local() 

47 self._init_schema() 

48 # Lazy import to avoid circular dependency at module level 

49 from kemi.encryption import EncryptionConfig, FieldEncryptor 

50 

51 if encryption is None: 

52 # Try environment-based config 

53 try: 

54 env_config = EncryptionConfig.from_env() 

55 self._field_encryptor = FieldEncryptor(env_config) if env_config.enabled else None 

56 except Exception: 

57 self._field_encryptor = None 

58 else: 

59 self._field_encryptor = FieldEncryptor(encryption) if encryption.enabled else None 

60 

61 def _get_connection(self) -> sqlite3.Connection: 

62 """Get or create a connection for the current thread. 

63 

64 Each thread gets its own connection to avoid SQLite thread-safety 

65 issues. Connections are created on-demand and cached per-thread. 

66 """ 

67 if not hasattr(self._local, "conn") or self._local.conn is None: 

68 conn = sqlite3.connect(self._db_path) 

69 conn.execute("PRAGMA journal_mode=WAL") 

70 conn.row_factory = sqlite3.Row 

71 self._local.conn = conn 

72 logger.debug("Created new connection for thread %s", threading.current_thread().name) 

73 return self._local.conn # type: ignore[no-any-return] 

74 

75 @contextmanager 

76 def _transaction(self) -> Any: 

77 """Context manager for explicit transaction handling. 

78 

79 Ensures that multiple operations are atomic. If an exception occurs 

80 within the context, the transaction is rolled back. Otherwise, it 

81 commits at context exit. 

82 

83 Usage: 

84 with adapter._transaction() as conn: 

85 conn.execute(...) 

86 """ 

87 conn = self._get_connection() 

88 try: 

89 conn.execute("BEGIN IMMEDIATE") 

90 yield conn 

91 conn.execute("COMMIT") 

92 except Exception: 

93 conn.execute("ROLLBACK") 

94 logger.exception("Transaction failed, rolled back") 

95 raise 

96 

97 def close(self) -> None: 

98 """Close the connection for the current thread. 

99 

100 Note: This only closes the connection for the calling thread. 

101 Other threads will keep their own connections until they also 

102 call close() or are garbage collected. 

103 """ 

104 if hasattr(self._local, "conn") and self._local.conn is not None: 

105 try: 

106 self._local.conn.close() 

107 except Exception: # pragma: no cover 

108 pass 

109 self._local.conn = None 

110 

111 @property 

112 def _shared_conn(self) -> sqlite3.Connection | None: 

113 """Backward-compat property for code/tests that expect _shared_conn.""" 

114 return getattr(self._local, "conn", None) 

115 

116 def __del__(self) -> None: 

117 self.close() 

118 

119 def __enter__(self) -> "SQLiteStorageAdapter": 

120 return self 

121 

122 def __exit__(self, *args: object) -> None: 

123 self.close() 

124 

125 def _init_schema(self) -> None: 

126 with self._get_connection() as conn: 

127 conn.execute(""" 

128 CREATE TABLE IF NOT EXISTS schema_version ( 

129 version INTEGER PRIMARY KEY, 

130 applied_at TEXT NOT NULL DEFAULT (datetime('now')) 

131 ) 

132 """) 

133 

134 conn.execute(""" 

135 CREATE TABLE IF NOT EXISTS memories ( 

136 memory_id TEXT PRIMARY KEY, 

137 user_id TEXT NOT NULL, 

138 content TEXT NOT NULL, 

139 embedding BLOB, 

140 embedding_dim INTEGER, 

141 created_at TEXT NOT NULL, 

142 last_accessed_at TEXT NOT NULL, 

143 source TEXT NOT NULL DEFAULT 'user_stated', 

144 importance REAL NOT NULL DEFAULT 0.5, 

145 lifecycle_state TEXT NOT NULL DEFAULT 'active', 

146 metadata TEXT NOT NULL DEFAULT '{}', 

147 tags TEXT NOT NULL DEFAULT '', 

148 confidence REAL NOT NULL DEFAULT 1.0, 

149 memory_type TEXT NOT NULL DEFAULT 'episodic', 

150 session_id TEXT, 

151 namespace TEXT NOT NULL DEFAULT 'default', 

152 version INTEGER NOT NULL DEFAULT 1, 

153 agent_id TEXT, 

154 run_id TEXT, 

155 app_id TEXT, 

156 expires_at TEXT 

157 ) 

158 """) 

159 

160 conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories(user_id)") 

161 conn.execute( 

162 "CREATE INDEX IF NOT EXISTS idx_memories_lifecycle ON memories(lifecycle_state)" 

163 ) 

164 conn.execute( 

165 "CREATE INDEX IF NOT EXISTS idx_memories_user_lifecycle " 

166 "ON memories(user_id, lifecycle_state)" 

167 ) 

168 conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_tags ON memories(tags)") 

169 conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_namespace ON memories(namespace)") 

170 # idx_memories_expires_at is created by migration v7 below, AFTER 

171 # the column is added. The CREATE TABLE above includes the column 

172 # for fresh DBs only; legacy DBs need the ALTER TABLE first. 

173 # A redundant try/except here keeps init safe if migration order 

174 # ever changes. 

175 try: 

176 conn.execute( 

177 "CREATE INDEX IF NOT EXISTS idx_memories_expires_at " 

178 "ON memories(expires_at)" 

179 ) 

180 except sqlite3.OperationalError: 

181 logger.debug( 

182 "expires_at column not yet present; index will be created " 

183 "by migration v7" 

184 ) 

185 

186 # Create FTS5 virtual table for full-text search with BM25 ranking 

187 # Using simple FTS5 without content_rowid linkage since we manage sync manually 

188 conn.execute(""" 

189 CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5( 

190 memory_id, 

191 user_id, 

192 content, 

193 namespace, 

194 session_id, 

195 tokenize='porter unicode61' 

196 ) 

197 """) 

198 

199 self._run_migrations(conn) 

200 

201 def _get_schema_version(self, conn: sqlite3.Connection) -> int: 

202 try: 

203 cursor = conn.execute( 

204 "SELECT version FROM schema_version ORDER BY version DESC LIMIT 1" 

205 ) 

206 row = cursor.fetchone() 

207 if row: 

208 return row[0] # type: ignore[no-any-return] 

209 return 0 

210 except sqlite3.OperationalError: # pragma: no cover 

211 return 0 

212 

213 def _run_migrations(self, conn: sqlite3.Connection) -> None: 

214 current = self._get_schema_version(conn) 

215 

216 if current >= self.CURRENT_VERSION: 

217 return 

218 

219 if current < 2: 

220 try: 

221 conn.execute("ALTER TABLE memories ADD COLUMN tags TEXT NOT NULL DEFAULT ''") 

222 except sqlite3.OperationalError: # pragma: no cover 

223 pass 

224 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (2)") 

225 

226 if current < 3: 

227 for col, dtype in [ 

228 ("confidence", "REAL NOT NULL DEFAULT 1.0"), 

229 ("memory_type", "TEXT NOT NULL DEFAULT 'episodic'"), 

230 ("session_id", "TEXT"), 

231 ("namespace", "TEXT NOT NULL DEFAULT 'default'"), 

232 ("version", "INTEGER NOT NULL DEFAULT 1"), 

233 ("agent_id", "TEXT"), 

234 ("run_id", "TEXT"), 

235 ("app_id", "TEXT"), 

236 ]: 

237 try: 

238 conn.execute(f"ALTER TABLE memories ADD COLUMN {col} {dtype}") 

239 except sqlite3.OperationalError: 

240 pass 

241 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (3)") 

242 

243 if current < 4: 

244 # Ensure FTS5 table exists for BM25 search 

245 try: 

246 conn.execute(""" 

247 CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5( 

248 memory_id, 

249 user_id, 

250 content, 

251 namespace, 

252 session_id, 

253 tokenize='porter unicode61' 

254 ) 

255 """) 

256 except sqlite3.OperationalError: 

257 pass # Table already exists 

258 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (4)") 

259 

260 if current < 6: 

261 for col, dtype in [ 

262 ("agent_id", "TEXT"), 

263 ("run_id", "TEXT"), 

264 ("app_id", "TEXT"), 

265 ]: 

266 try: 

267 conn.execute(f"ALTER TABLE memories ADD COLUMN {col} {dtype}") 

268 except sqlite3.OperationalError: 

269 pass 

270 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (6)") 

271 

272 if current < 7: 

273 # TTL: add expires_at column and index for fast sweeper queries 

274 try: 

275 conn.execute("ALTER TABLE memories ADD COLUMN expires_at TEXT") 

276 except sqlite3.OperationalError: 

277 pass 

278 try: 

279 conn.execute( 

280 "CREATE INDEX IF NOT EXISTS idx_memories_expires_at " 

281 "ON memories(expires_at)" 

282 ) 

283 except sqlite3.OperationalError: 

284 pass 

285 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (7)") 

286 

287 if current < 8: 

288 # API key authentication table for multi-tenant FastAPI server 

289 conn.execute(""" 

290 CREATE TABLE IF NOT EXISTS api_keys ( 

291 key_id TEXT PRIMARY KEY, 

292 user_id TEXT NOT NULL, 

293 hashed_key TEXT NOT NULL UNIQUE, 

294 name TEXT NOT NULL, 

295 created_at TEXT NOT NULL, 

296 expires_at TEXT, 

297 last_used_at TEXT, 

298 revoked_at TEXT 

299 ) 

300 """) 

301 conn.execute( 

302 "CREATE INDEX IF NOT EXISTS idx_api_keys_user_id ON api_keys(user_id)" 

303 ) 

304 conn.execute( 

305 "CREATE INDEX IF NOT EXISTS idx_api_keys_hashed_key ON api_keys(hashed_key)" 

306 ) 

307 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (8)") 

308 

309 def _row_to_memory(self, row: sqlite3.Row) -> MemoryObject: 

310 embedding = None 

311 if row["embedding"] is not None: 

312 num_floats = len(row["embedding"]) // 4 

313 embedding = list(struct.unpack(f"{num_floats}f", row["embedding"])) 

314 

315 # Use dict-style get for columns that may not exist in older schemas 

316 row_dict = dict(row) 

317 

318 # Decrypt content and metadata if encryption is enabled. 

319 # Both are stored as JSON strings (either plain values or encrypted envelopes). 

320 content_str: str = row["content"] 

321 metadata_str: str = row["metadata"] 

322 

323 try: 

324 content_val = json.loads(content_str) 

325 except (json.JSONDecodeError, TypeError): 

326 content_val = content_str # fallback for plain text without JSON wrapping 

327 

328 try: 

329 metadata_val = json.loads(metadata_str) 

330 except (json.JSONDecodeError, TypeError): 

331 metadata_val = {} 

332 

333 if self._field_encryptor is not None: 

334 if self._field_encryptor._is_encrypted(content_val): 

335 content_val = self._field_encryptor.decrypt_field("content", content_val) 

336 if self._field_encryptor._is_encrypted(metadata_val): 

337 metadata_val = self._field_encryptor.decrypt_field("metadata", metadata_val) 

338 else: 

339 # Without encryption, content_val is a JSON string (the double-quoted string) 

340 # or a raw string. Normalize it. 

341 if isinstance(content_val, str): 

342 pass # already a string 

343 elif content_val is None: 

344 content_val = "" 

345 else: 

346 content_val = str(content_val) 

347 

348 # Decrypt user_id only if user_id encryption is enabled 

349 user_id_val: str = row["user_id"] 

350 if self._field_encryptor is not None and self._field_encryptor._encrypt_user_id: 

351 try: 

352 uid_parsed = json.loads(user_id_val) 

353 if self._field_encryptor._is_encrypted(uid_parsed): 

354 user_id_val = self._field_encryptor.decrypt_field("user_id", uid_parsed) 

355 except (json.JSONDecodeError, TypeError): 

356 pass 

357 

358 expires_at = None 

359 if row_dict.get("expires_at"): 

360 try: 

361 expires_at = datetime.fromisoformat(row_dict["expires_at"]) 

362 except (TypeError, ValueError): 

363 expires_at = None 

364 

365 return MemoryObject( 

366 memory_id=row["memory_id"], 

367 user_id=user_id_val, 

368 content=str(content_val), 

369 embedding=embedding, 

370 score=0.0, 

371 created_at=datetime.fromisoformat(row["created_at"]), 

372 last_accessed_at=datetime.fromisoformat(row["last_accessed_at"]), 

373 source=MemorySource(row["source"]), 

374 importance=row["importance"], 

375 lifecycle_state=LifecycleState(row["lifecycle_state"]), 

376 metadata=metadata_val if isinstance(metadata_val, dict) else {}, 

377 embedding_dim=row["embedding_dim"], 

378 tags=[t.replace("\\,", ",") for t in row["tags"].split(",")] if row["tags"] else [], 

379 confidence=row["confidence"], 

380 memory_type=MemoryType(row["memory_type"]), 

381 session_id=row_dict.get("session_id"), 

382 namespace=row_dict.get("namespace", "default"), 

383 version=row_dict.get("version", 1), 

384 agent_id=row_dict.get("agent_id") if row_dict.get("agent_id") is not None else None, 

385 run_id=row_dict.get("run_id") if row_dict.get("run_id") is not None else None, 

386 app_id=row_dict.get("app_id") if row_dict.get("app_id") is not None else None, 

387 expires_at=expires_at, 

388 ) 

389 

390 def _memory_to_row(self, memory: MemoryObject) -> dict[str, Any]: 

391 embedding_blob = None 

392 if memory.embedding is not None: 

393 embedding_blob = struct.pack(f"{len(memory.embedding)}f", *memory.embedding) 

394 

395 # Start with content and metadata as their native types so the 

396 # field encryptor can process them before JSON serialization. 

397 content_val: Any = memory.content 

398 metadata_val: Any = memory.metadata 

399 user_id_val: Any = memory.user_id 

400 

401 if self._field_encryptor is not None: 

402 content_val = self._field_encryptor.encrypt_field("content", memory.content) 

403 metadata_val = self._field_encryptor.encrypt_field("metadata", memory.metadata) 

404 if self._field_encryptor._encrypt_user_id: 

405 user_id_val = self._field_encryptor.encrypt_field("user_id", memory.user_id) 

406 user_id_val = json.dumps(user_id_val) 

407 

408 content_json = json.dumps(content_val) 

409 metadata_json = json.dumps(metadata_val) 

410 

411 return { 

412 "memory_id": memory.memory_id, 

413 "user_id": user_id_val, 

414 "content": content_json, 

415 "embedding": embedding_blob, 

416 "embedding_dim": memory.embedding_dim, 

417 "created_at": memory.created_at.isoformat(), 

418 "last_accessed_at": memory.last_accessed_at.isoformat(), 

419 "source": memory.source.value, 

420 "importance": memory.importance, 

421 "lifecycle_state": memory.lifecycle_state.value, 

422 "metadata": metadata_json, 

423 "tags": ",".join(t.replace(",", "\\,") for t in memory.tags) if memory.tags else "", 

424 "confidence": memory.confidence, 

425 "memory_type": memory.memory_type.value, 

426 "session_id": memory.session_id, 

427 "namespace": memory.namespace, 

428 "version": memory.version, 

429 "agent_id": memory.agent_id, 

430 "run_id": memory.run_id, 

431 "app_id": memory.app_id, 

432 "expires_at": memory.expires_at.isoformat() if memory.expires_at else None, 

433 } 

434 

435 def store(self, memory: MemoryObject) -> None: 

436 with self._get_connection() as conn: 

437 row = self._memory_to_row(memory) 

438 conn.execute( 

439 """ 

440 INSERT OR REPLACE INTO memories 

441 (memory_id, user_id, content, embedding, embedding_dim, created_at, 

442 last_accessed_at, source, importance, lifecycle_state, metadata, tags, 

443 confidence, memory_type, session_id, namespace, version, 

444 agent_id, run_id, app_id, expires_at) 

445 VALUES (:memory_id, :user_id, :content, :embedding, :embedding_dim, 

446 :created_at, :last_accessed_at, :source, :importance, 

447 :lifecycle_state, :metadata, :tags, 

448 :confidence, :memory_type, :session_id, :namespace, :version, 

449 :agent_id, :run_id, :app_id, :expires_at) 

450 """, 

451 row, 

452 ) 

453 # Sync to FTS5 index for BM25 search 

454 self._sync_fts_single(conn, memory) 

455 

456 def store_many(self, memories: list[MemoryObject]) -> int: 

457 """Store multiple memories in a single atomic transaction. 

458 

459 All stores are wrapped in a transaction for atomicity. If any store 

460 fails, all changes are rolled back. 

461 

462 Args: 

463 memories: List of MemoryObjects to store. 

464 

465 Returns: 

466 Number of memories stored. 

467 """ 

468 if not memories: 

469 return 0 

470 

471 with self._transaction() as conn: 

472 for memory in memories: 

473 row = self._memory_to_row(memory) 

474 conn.execute( 

475 """ 

476 INSERT OR REPLACE INTO memories 

477 (memory_id, user_id, content, embedding, embedding_dim, created_at, 

478 last_accessed_at, source, importance, lifecycle_state, metadata, tags, 

479 confidence, memory_type, session_id, namespace, version, 

480 agent_id, run_id, app_id) 

481 VALUES (:memory_id, :user_id, :content, :embedding, :embedding_dim, 

482 :created_at, :last_accessed_at, :source, :importance, 

483 :lifecycle_state, :metadata, :tags, 

484 :confidence, :memory_type, :session_id, :namespace, :version, 

485 :agent_id, :run_id, :app_id) 

486 """, 

487 row, 

488 ) 

489 # Sync to FTS5 index 

490 self._sync_fts_single(conn, memory) 

491 return len(memories) 

492 

493 def _sync_fts_single(self, conn: sqlite3.Connection, memory: MemoryObject) -> None: 

494 """Sync a single memory to FTS5 index. 

495 

496 Deletes any existing FTS row for this memory_id first, then inserts 

497 the new one. INSERT OR REPLACE alone is not sufficient because the 

498 memories_fts table is contentless and has no UNIQUE constraint on 

499 memory_id — it would add a new row on every store(), causing FTS 

500 duplicates to accumulate. 

501 """ 

502 try: 

503 conn.execute( 

504 "DELETE FROM memories_fts WHERE memory_id = ?", 

505 (memory.memory_id,), 

506 ) 

507 conn.execute( 

508 """ 

509 INSERT INTO memories_fts (memory_id, user_id, content, namespace, session_id) 

510 VALUES (?, ?, ?, ?, ?) 

511 """, 

512 ( 

513 memory.memory_id, 

514 memory.user_id, 

515 memory.content, 

516 memory.namespace, 

517 memory.session_id, 

518 ), 

519 ) 

520 except sqlite3.OperationalError as e: 

521 logger.warning("FTS5 sync failed: %s", e) 

522 

523 def rebuild_fts_index(self, user_id: str | None = None) -> int: 

524 """Rebuild the FTS5 index from memories table. 

525 

526 Use this if the FTS index gets out of sync with the memories table. 

527 

528 Args: 

529 user_id: If provided, only reindex memories for this user. Otherwise 

530 the full index is rebuilt. 

531 

532 Returns: 

533 Number of memories indexed. 

534 """ 

535 with self._get_connection() as conn: 

536 if user_id is None: 

537 # Clear and rebuild the full FTS index 

538 conn.execute("DELETE FROM memories_fts") 

539 cursor = conn.execute(""" 

540 SELECT memory_id, user_id, content, namespace, session_id 

541 FROM memories 

542 """) 

543 else: 

544 # Rebuild for a specific user: delete their existing FTS rows, 

545 # then re-insert them. The rest of the index is untouched. 

546 conn.execute("DELETE FROM memories_fts WHERE user_id = ?", (user_id,)) 

547 cursor = conn.execute( 

548 """ 

549 SELECT memory_id, user_id, content, namespace, session_id 

550 FROM memories 

551 WHERE user_id = ? 

552 """, 

553 (user_id,), 

554 ) 

555 rows = cursor.fetchall() 

556 

557 for row in rows: 

558 conn.execute( 

559 """ 

560 INSERT INTO memories_fts (memory_id, user_id, content, namespace, session_id) 

561 VALUES (?, ?, ?, ?, ?) 

562 """, 

563 row, 

564 ) 

565 

566 return len(rows) 

567 

568 def search( 

569 self, 

570 user_id: str, 

571 query_embedding: list[float], 

572 top_k: int = 10, 

573 lifecycle_filter: list[LifecycleState] | None = None, 

574 namespace: str = "default", 

575 session_id: str | None = None, 

576 ) -> list[MemoryObject]: 

577 if lifecycle_filter is None: 

578 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING] 

579 

580 states = [s.value for s in lifecycle_filter] 

581 params: list[Any] = [user_id, namespace] + states 

582 

583 sql = """ 

584 SELECT * FROM memories 

585 WHERE user_id = ? AND namespace = ? AND lifecycle_state IN ({}) 

586 """.format(",".join("?" * len(states))) 

587 

588 if session_id is not None: 

589 sql += " AND (session_id = ? OR session_id IS NULL)" 

590 params.append(session_id) 

591 

592 with self._get_connection() as conn: 

593 cursor = conn.execute(sql, params) 

594 rows = cursor.fetchall() 

595 

596 memories = [] 

597 for row in rows: 

598 memory = self._row_to_memory(row) 

599 if memory.embedding is not None: 

600 similarity = scoring.cosine_similarity(memory.embedding, query_embedding) 

601 memory.score = (similarity + 1.0) / 2.0 

602 memories.append(memory) 

603 

604 memories.sort(key=lambda m: m.score, reverse=True) 

605 return memories[:top_k] 

606 

607 def get(self, memory_id: str) -> MemoryObject | None: 

608 with self._get_connection() as conn: 

609 cursor = conn.execute("SELECT * FROM memories WHERE memory_id = ?", (memory_id,)) 

610 row = cursor.fetchone() 

611 

612 if row: 

613 return self._row_to_memory(row) 

614 return None 

615 

616 def update(self, memory: MemoryObject) -> None: 

617 self.store(memory) 

618 

619 def delete_by_user(self, user_id: str) -> int: 

620 with self._get_connection() as conn: 

621 # First delete from FTS5 index 

622 try: 

623 conn.execute("DELETE FROM memories_fts WHERE user_id = ?", (user_id,)) 

624 except sqlite3.OperationalError: 

625 pass # FTS table might not exist 

626 

627 cursor = conn.execute("DELETE FROM memories WHERE user_id = ?", (user_id,)) 

628 return cursor.rowcount 

629 

630 def delete_by_id(self, memory_id: str) -> bool: 

631 with self._get_connection() as conn: 

632 cursor = conn.execute("DELETE FROM memories WHERE memory_id = ?", (memory_id,)) 

633 # Also delete from FTS5 index to keep it in sync 

634 try: 

635 conn.execute("DELETE FROM memories_fts WHERE memory_id = ?", (memory_id,)) 

636 except sqlite3.OperationalError: 

637 pass # FTS table might not exist in old databases 

638 return cursor.rowcount > 0 

639 

640 def get_all_by_user( 

641 self, 

642 user_id: str, 

643 lifecycle_filter: list[LifecycleState] | None = None, 

644 namespace: str = "default", 

645 session_id: str | None = None, 

646 limit: int | None = None, 

647 offset: int | None = None, 

648 ) -> list[MemoryObject]: 

649 if lifecycle_filter is None: 

650 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING] 

651 

652 states = [s.value for s in lifecycle_filter] 

653 params: list[Any] = [user_id, namespace] + states 

654 

655 sql = """ 

656 SELECT * FROM memories 

657 WHERE user_id = ? AND namespace = ? AND lifecycle_state IN ({}) 

658 """.format(",".join("?" * len(states))) 

659 

660 if session_id is not None: 

661 sql += " AND (session_id = ? OR session_id IS NULL)" 

662 params.append(session_id) 

663 

664 if offset is not None: 

665 sql += " ORDER BY created_at DESC LIMIT ? OFFSET ?" 

666 # Use -1 as "no limit" when limit is None but offset is provided 

667 params.extend([limit if limit is not None else -1, offset]) 

668 elif limit is not None: 

669 sql += " ORDER BY created_at DESC LIMIT ?" 

670 params.append(limit) 

671 

672 with self._get_connection() as conn: 

673 cursor = conn.execute(sql, params) 

674 rows = cursor.fetchall() 

675 

676 return [self._row_to_memory(row) for row in rows] 

677 

678 def count(self, user_id: str) -> int: 

679 with self._get_connection() as conn: 

680 cursor = conn.execute("SELECT COUNT(*) FROM memories WHERE user_id = ?", (user_id,)) 

681 return cursor.fetchone()[0] # type: ignore[no-any-return] 

682 

683 def get_all( 

684 self, 

685 limit: int | None = None, 

686 offset: int | None = None, 

687 ) -> list[MemoryObject]: 

688 sql = "SELECT * FROM memories" 

689 params: list[Any] = [] 

690 

691 if offset is not None: 

692 sql += " ORDER BY created_at DESC LIMIT ? OFFSET ?" 

693 # Use -1 as "no limit" when limit is None but offset is provided 

694 params.extend([limit if limit is not None else -1, offset]) 

695 elif limit is not None: 

696 sql += " ORDER BY created_at DESC LIMIT ?" 

697 params.append(limit) 

698 

699 with self._get_connection() as conn: 

700 cursor = conn.execute(sql, params) 

701 rows = cursor.fetchall() 

702 return [self._row_to_memory(row) for row in rows] 

703 

704 def search_by_content( 

705 self, 

706 user_id: str, 

707 query: str, 

708 top_k: int = 10, 

709 lifecycle_filter: list[LifecycleState] | None = None, 

710 namespace: str = "default", 

711 session_id: str | None = None, 

712 ) -> list[MemoryObject]: 

713 """Search for memories using FTS5 full-text search with native BM25 ranking. 

714 

715 This method uses SQLite's FTS5 for fast full-text search with BM25 scoring. 

716 This is much faster than Python-based BM25 scoring for large datasets. 

717 

718 Falls back to Python-based scoring if FTS5 query fails. 

719 """ 

720 try: 

721 return self._fts5_search(user_id, query, top_k, lifecycle_filter, namespace, session_id) 

722 except sqlite3.OperationalError as e: 

723 logger.warning("FTS5 search failed, falling back to Python BM25: %s", e) 

724 return self._bm25_python_fallback( 

725 user_id, query, top_k, lifecycle_filter, namespace, session_id 

726 ) 

727 

728 def _fts5_search( 

729 self, 

730 user_id: str, 

731 query: str, 

732 top_k: int, 

733 lifecycle_filter: list[LifecycleState] | None, 

734 namespace: str, 

735 session_id: str | None, 

736 ) -> list[MemoryObject]: 

737 """Native FTS5 BM25 search - much faster than Python scoring.""" 

738 if lifecycle_filter is None: 

739 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING] 

740 

741 states = [s.value for s in lifecycle_filter] 

742 

743 # Build lifecycle filter for subquery 

744 lifecycle_placeholders = ",".join("?" * len(states)) 

745 

746 # Use FTS5 MATCH on content column only with BM25 ranking 

747 # Join with main memories table to get full memory objects with lifecycle filtering 

748 # Build full FTS5 query with content: prefix in Python to avoid parameter binding issues 

749 fts_query = f"content:{self._prepare_fts_query(query)}" 

750 

751 sql = f""" 

752 SELECT m.*, bm25(memories_fts) as fts_score 

753 FROM memories m 

754 INNER JOIN memories_fts fts ON m.memory_id = fts.memory_id 

755 WHERE fts.user_id = ? 

756 AND fts.namespace = ? 

757 AND m.lifecycle_state IN ({lifecycle_placeholders}) 

758 AND memories_fts MATCH ? 

759 """ 

760 

761 params: list[Any] = [user_id, namespace] + states 

762 

763 if session_id is not None: 

764 sql += " AND (m.session_id = ? OR m.session_id IS NULL)" 

765 params.append(session_id) 

766 

767 sql += " ORDER BY fts_score LIMIT ?" 

768 params.append(top_k) 

769 

770 # Append FTS query at the end (after lifecycle states and top_k) 

771 params.insert(-1, fts_query) 

772 

773 with self._get_connection() as conn: 

774 cursor = conn.execute(sql, params) 

775 rows = cursor.fetchall() 

776 

777 memories = [] 

778 rank = 0 

779 for row in rows: 

780 memory = self._row_to_memory(row) 

781 # Use rank-based scoring: higher rank position = lower score 

782 # BM25 returns negative scores where lower = better, so invert 

783 rank += 1 

784 memory.score = ( 

785 1.0 / rank 

786 ) # Rank-based normalization (1st result = 1.0, 2nd = 0.5, etc.) 

787 memories.append(memory) 

788 

789 return memories 

790 

791 def _prepare_fts_query(self, query: str) -> str: 

792 """Prepare query string for FTS5 MATCH on content column. 

793 

794 Handles escaping special FTS5 characters to prevent query syntax errors. 

795 """ 

796 if not query or not query.strip(): 

797 return '""' 

798 

799 # Escape FTS5 special characters: " * ( ) : ~ 

800 # Replace with spaces to preserve word separation 

801 escaped = query 

802 for char in '"*():~': 

803 escaped = escaped.replace(char, " ") 

804 

805 # Tokenize and create phrase queries for each term (prefix matching) 

806 terms = escaped.strip().split() 

807 if not terms: 

808 return '""' 

809 

810 if len(terms) == 1: 

811 # Single term - prefix match 

812 term = terms[0].strip() 

813 if term: 

814 return f'"{term}"*' 

815 return '""' 

816 

817 # Multiple terms - OR for matching any term 

818 phrase_terms = [] 

819 for term in terms: 

820 term = term.strip() 

821 if term: 

822 phrase_terms.append(f'"{term}"*') 

823 

824 if phrase_terms: 

825 return " OR ".join(phrase_terms) 

826 return '""' 

827 

828 def _bm25_python_fallback( 

829 self, 

830 user_id: str, 

831 query: str, 

832 top_k: int, 

833 lifecycle_filter: list[LifecycleState] | None, 

834 namespace: str, 

835 session_id: str | None, 

836 ) -> list[MemoryObject]: 

837 """Python-based BM25 fallback when FTS5 is unavailable.""" 

838 if lifecycle_filter is None: 

839 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING] 

840 

841 # Limit fetch to prevent memory issues with large datasets 

842 fetch_limit = min(top_k * 20, 200) 

843 

844 states = [s.value for s in lifecycle_filter] 

845 params: list[Any] = [user_id, namespace] + states 

846 

847 sql = """ 

848 SELECT * FROM memories 

849 WHERE user_id = ? AND namespace = ? AND lifecycle_state IN ({}) 

850 """.format(",".join("?" * len(states))) 

851 

852 if session_id is not None: 

853 sql += " AND (session_id = ? OR session_id IS NULL)" 

854 params.append(session_id) 

855 

856 sql += f" ORDER BY created_at DESC LIMIT {fetch_limit}" 

857 

858 with self._get_connection() as conn: 

859 cursor = conn.execute(sql, params) 

860 rows = cursor.fetchall() 

861 

862 memories = [self._row_to_memory(row) for row in rows] 

863 

864 if not memories: 

865 return [] 

866 

867 corpus = [m.content for m in memories] 

868 for memory in memories: 

869 memory.score = scoring.bm25_score_corpus(query, memory.content, corpus) 

870 

871 memories.sort(key=lambda m: m.score, reverse=True) 

872 return memories[:top_k] 

873 

874 def get_all_users(self) -> list[str]: 

875 with self._get_connection() as conn: 

876 cursor = conn.execute("SELECT DISTINCT user_id FROM memories") 

877 rows = cursor.fetchall() 

878 return [row[0] for row in rows] 

879 

880 def upgrade_schema(self, from_version: int, to_version: int) -> None: 

881 with self._get_connection() as conn: 

882 self._run_migrations(conn) 

883 

884 def get_by_tag( 

885 self, 

886 user_id: str, 

887 tag: str, 

888 lifecycle_filter: list[LifecycleState] | None = None, 

889 namespace: str = "default", 

890 ) -> list[MemoryObject]: 

891 if lifecycle_filter is None: 

892 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING] 

893 

894 states = [s.value for s in lifecycle_filter] 

895 

896 with self._get_connection() as conn: 

897 placeholders = ",".join("?" * len(states)) 

898 # Escape LIKE wildcards in the user-supplied tag so a tag like 

899 # "a_b" or "a%b" matches literally instead of treating _ / % as 

900 # wildcards. 

901 escaped_tag = tag.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") 

902 cursor = conn.execute( 

903 f""" 

904 SELECT * FROM memories 

905 WHERE user_id = ? AND namespace = ? AND lifecycle_state IN ({placeholders}) 

906 AND (',' || tags || ',') LIKE ('%,' || ? || ',%') ESCAPE '\\' 

907 """, 

908 [user_id, namespace] + states + [escaped_tag], 

909 ) 

910 rows = cursor.fetchall() 

911 

912 return [self._row_to_memory(row) for row in rows] 

913 

914 def get_api_key_manager(self) -> Any: 

915 """Return an APIKeyManager bound to this adapter's connection. 

916 

917 Lazy import to avoid a hard dependency on the api_keys module. 

918 Returns a fresh manager instance each call; the manager is cheap to 

919 create because it shares the underlying connection pool. 

920 """ 

921 from kemi.api_keys import APIKeyManager 

922 

923 return APIKeyManager(connection=self._get_connection())