Coverage for src / kemi / adapters / storage / sqlite.py: 92%
391 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1from __future__ import annotations
3import json
4import logging
5import sqlite3
6import struct
7import threading
8from contextlib import contextmanager
9from datetime import datetime
10from typing import TYPE_CHECKING, Any
12from kemi import scoring
13from kemi.adapters.base import StorageAdapter
14from kemi.models import LifecycleState, MemoryObject, MemorySource, MemoryType
16if TYPE_CHECKING:
17 from kemi.encryption import EncryptionConfig
19logger = logging.getLogger(__name__)
22class SQLiteStorageAdapter(StorageAdapter):
23 """SQLite storage adapter with WAL mode and thread-local connections.
25 Embedding stored as BLOB (float32 bytes) for compactness.
26 Schema version tracked in schema_version table.
28 Thread-safety: Uses thread-local storage for connections, giving each
29 thread its own connection. This avoids SQLite's thread-safety issues
30 while allowing concurrent access from multiple threads.
32 Encryption: When encryption config is provided, uses Fernet field-level
33 encryption for content and metadata fields. Pass ``encryption=`` to
34 ``__init__`` or set ``KEMI_ENCRYPTION_KEY`` / ``KEMI_ENCRYPTION_ENABLED``
35 environment variables.
36 """
38 CURRENT_VERSION = 8
40 def __init__(
41 self,
42 db_path: str = "kemi.db",
43 encryption: "EncryptionConfig | None" = None,
44 ) -> None:
45 self._db_path = db_path
46 self._local = threading.local()
47 self._init_schema()
48 # Lazy import to avoid circular dependency at module level
49 from kemi.encryption import EncryptionConfig, FieldEncryptor
51 if encryption is None:
52 # Try environment-based config
53 try:
54 env_config = EncryptionConfig.from_env()
55 self._field_encryptor = FieldEncryptor(env_config) if env_config.enabled else None
56 except Exception:
57 self._field_encryptor = None
58 else:
59 self._field_encryptor = FieldEncryptor(encryption) if encryption.enabled else None
61 def _get_connection(self) -> sqlite3.Connection:
62 """Get or create a connection for the current thread.
64 Each thread gets its own connection to avoid SQLite thread-safety
65 issues. Connections are created on-demand and cached per-thread.
66 """
67 if not hasattr(self._local, "conn") or self._local.conn is None:
68 conn = sqlite3.connect(self._db_path)
69 conn.execute("PRAGMA journal_mode=WAL")
70 conn.row_factory = sqlite3.Row
71 self._local.conn = conn
72 logger.debug("Created new connection for thread %s", threading.current_thread().name)
73 return self._local.conn # type: ignore[no-any-return]
75 @contextmanager
76 def _transaction(self) -> Any:
77 """Context manager for explicit transaction handling.
79 Ensures that multiple operations are atomic. If an exception occurs
80 within the context, the transaction is rolled back. Otherwise, it
81 commits at context exit.
83 Usage:
84 with adapter._transaction() as conn:
85 conn.execute(...)
86 """
87 conn = self._get_connection()
88 try:
89 conn.execute("BEGIN IMMEDIATE")
90 yield conn
91 conn.execute("COMMIT")
92 except Exception:
93 conn.execute("ROLLBACK")
94 logger.exception("Transaction failed, rolled back")
95 raise
97 def close(self) -> None:
98 """Close the connection for the current thread.
100 Note: This only closes the connection for the calling thread.
101 Other threads will keep their own connections until they also
102 call close() or are garbage collected.
103 """
104 if hasattr(self._local, "conn") and self._local.conn is not None:
105 try:
106 self._local.conn.close()
107 except Exception: # pragma: no cover
108 pass
109 self._local.conn = None
111 @property
112 def _shared_conn(self) -> sqlite3.Connection | None:
113 """Backward-compat property for code/tests that expect _shared_conn."""
114 return getattr(self._local, "conn", None)
116 def __del__(self) -> None:
117 self.close()
119 def __enter__(self) -> "SQLiteStorageAdapter":
120 return self
122 def __exit__(self, *args: object) -> None:
123 self.close()
125 def _init_schema(self) -> None:
126 with self._get_connection() as conn:
127 conn.execute("""
128 CREATE TABLE IF NOT EXISTS schema_version (
129 version INTEGER PRIMARY KEY,
130 applied_at TEXT NOT NULL DEFAULT (datetime('now'))
131 )
132 """)
134 conn.execute("""
135 CREATE TABLE IF NOT EXISTS memories (
136 memory_id TEXT PRIMARY KEY,
137 user_id TEXT NOT NULL,
138 content TEXT NOT NULL,
139 embedding BLOB,
140 embedding_dim INTEGER,
141 created_at TEXT NOT NULL,
142 last_accessed_at TEXT NOT NULL,
143 source TEXT NOT NULL DEFAULT 'user_stated',
144 importance REAL NOT NULL DEFAULT 0.5,
145 lifecycle_state TEXT NOT NULL DEFAULT 'active',
146 metadata TEXT NOT NULL DEFAULT '{}',
147 tags TEXT NOT NULL DEFAULT '',
148 confidence REAL NOT NULL DEFAULT 1.0,
149 memory_type TEXT NOT NULL DEFAULT 'episodic',
150 session_id TEXT,
151 namespace TEXT NOT NULL DEFAULT 'default',
152 version INTEGER NOT NULL DEFAULT 1,
153 agent_id TEXT,
154 run_id TEXT,
155 app_id TEXT,
156 expires_at TEXT
157 )
158 """)
160 conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories(user_id)")
161 conn.execute(
162 "CREATE INDEX IF NOT EXISTS idx_memories_lifecycle ON memories(lifecycle_state)"
163 )
164 conn.execute(
165 "CREATE INDEX IF NOT EXISTS idx_memories_user_lifecycle "
166 "ON memories(user_id, lifecycle_state)"
167 )
168 conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_tags ON memories(tags)")
169 conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_namespace ON memories(namespace)")
170 # idx_memories_expires_at is created by migration v7 below, AFTER
171 # the column is added. The CREATE TABLE above includes the column
172 # for fresh DBs only; legacy DBs need the ALTER TABLE first.
173 # A redundant try/except here keeps init safe if migration order
174 # ever changes.
175 try:
176 conn.execute(
177 "CREATE INDEX IF NOT EXISTS idx_memories_expires_at "
178 "ON memories(expires_at)"
179 )
180 except sqlite3.OperationalError:
181 logger.debug(
182 "expires_at column not yet present; index will be created "
183 "by migration v7"
184 )
186 # Create FTS5 virtual table for full-text search with BM25 ranking
187 # Using simple FTS5 without content_rowid linkage since we manage sync manually
188 conn.execute("""
189 CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
190 memory_id,
191 user_id,
192 content,
193 namespace,
194 session_id,
195 tokenize='porter unicode61'
196 )
197 """)
199 self._run_migrations(conn)
201 def _get_schema_version(self, conn: sqlite3.Connection) -> int:
202 try:
203 cursor = conn.execute(
204 "SELECT version FROM schema_version ORDER BY version DESC LIMIT 1"
205 )
206 row = cursor.fetchone()
207 if row:
208 return row[0] # type: ignore[no-any-return]
209 return 0
210 except sqlite3.OperationalError: # pragma: no cover
211 return 0
213 def _run_migrations(self, conn: sqlite3.Connection) -> None:
214 current = self._get_schema_version(conn)
216 if current >= self.CURRENT_VERSION:
217 return
219 if current < 2:
220 try:
221 conn.execute("ALTER TABLE memories ADD COLUMN tags TEXT NOT NULL DEFAULT ''")
222 except sqlite3.OperationalError: # pragma: no cover
223 pass
224 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (2)")
226 if current < 3:
227 for col, dtype in [
228 ("confidence", "REAL NOT NULL DEFAULT 1.0"),
229 ("memory_type", "TEXT NOT NULL DEFAULT 'episodic'"),
230 ("session_id", "TEXT"),
231 ("namespace", "TEXT NOT NULL DEFAULT 'default'"),
232 ("version", "INTEGER NOT NULL DEFAULT 1"),
233 ("agent_id", "TEXT"),
234 ("run_id", "TEXT"),
235 ("app_id", "TEXT"),
236 ]:
237 try:
238 conn.execute(f"ALTER TABLE memories ADD COLUMN {col} {dtype}")
239 except sqlite3.OperationalError:
240 pass
241 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (3)")
243 if current < 4:
244 # Ensure FTS5 table exists for BM25 search
245 try:
246 conn.execute("""
247 CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
248 memory_id,
249 user_id,
250 content,
251 namespace,
252 session_id,
253 tokenize='porter unicode61'
254 )
255 """)
256 except sqlite3.OperationalError:
257 pass # Table already exists
258 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (4)")
260 if current < 6:
261 for col, dtype in [
262 ("agent_id", "TEXT"),
263 ("run_id", "TEXT"),
264 ("app_id", "TEXT"),
265 ]:
266 try:
267 conn.execute(f"ALTER TABLE memories ADD COLUMN {col} {dtype}")
268 except sqlite3.OperationalError:
269 pass
270 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (6)")
272 if current < 7:
273 # TTL: add expires_at column and index for fast sweeper queries
274 try:
275 conn.execute("ALTER TABLE memories ADD COLUMN expires_at TEXT")
276 except sqlite3.OperationalError:
277 pass
278 try:
279 conn.execute(
280 "CREATE INDEX IF NOT EXISTS idx_memories_expires_at "
281 "ON memories(expires_at)"
282 )
283 except sqlite3.OperationalError:
284 pass
285 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (7)")
287 if current < 8:
288 # API key authentication table for multi-tenant FastAPI server
289 conn.execute("""
290 CREATE TABLE IF NOT EXISTS api_keys (
291 key_id TEXT PRIMARY KEY,
292 user_id TEXT NOT NULL,
293 hashed_key TEXT NOT NULL UNIQUE,
294 name TEXT NOT NULL,
295 created_at TEXT NOT NULL,
296 expires_at TEXT,
297 last_used_at TEXT,
298 revoked_at TEXT
299 )
300 """)
301 conn.execute(
302 "CREATE INDEX IF NOT EXISTS idx_api_keys_user_id ON api_keys(user_id)"
303 )
304 conn.execute(
305 "CREATE INDEX IF NOT EXISTS idx_api_keys_hashed_key ON api_keys(hashed_key)"
306 )
307 conn.execute("INSERT OR REPLACE INTO schema_version (version) VALUES (8)")
309 def _row_to_memory(self, row: sqlite3.Row) -> MemoryObject:
310 embedding = None
311 if row["embedding"] is not None:
312 num_floats = len(row["embedding"]) // 4
313 embedding = list(struct.unpack(f"{num_floats}f", row["embedding"]))
315 # Use dict-style get for columns that may not exist in older schemas
316 row_dict = dict(row)
318 # Decrypt content and metadata if encryption is enabled.
319 # Both are stored as JSON strings (either plain values or encrypted envelopes).
320 content_str: str = row["content"]
321 metadata_str: str = row["metadata"]
323 try:
324 content_val = json.loads(content_str)
325 except (json.JSONDecodeError, TypeError):
326 content_val = content_str # fallback for plain text without JSON wrapping
328 try:
329 metadata_val = json.loads(metadata_str)
330 except (json.JSONDecodeError, TypeError):
331 metadata_val = {}
333 if self._field_encryptor is not None:
334 if self._field_encryptor._is_encrypted(content_val):
335 content_val = self._field_encryptor.decrypt_field("content", content_val)
336 if self._field_encryptor._is_encrypted(metadata_val):
337 metadata_val = self._field_encryptor.decrypt_field("metadata", metadata_val)
338 else:
339 # Without encryption, content_val is a JSON string (the double-quoted string)
340 # or a raw string. Normalize it.
341 if isinstance(content_val, str):
342 pass # already a string
343 elif content_val is None:
344 content_val = ""
345 else:
346 content_val = str(content_val)
348 # Decrypt user_id only if user_id encryption is enabled
349 user_id_val: str = row["user_id"]
350 if self._field_encryptor is not None and self._field_encryptor._encrypt_user_id:
351 try:
352 uid_parsed = json.loads(user_id_val)
353 if self._field_encryptor._is_encrypted(uid_parsed):
354 user_id_val = self._field_encryptor.decrypt_field("user_id", uid_parsed)
355 except (json.JSONDecodeError, TypeError):
356 pass
358 expires_at = None
359 if row_dict.get("expires_at"):
360 try:
361 expires_at = datetime.fromisoformat(row_dict["expires_at"])
362 except (TypeError, ValueError):
363 expires_at = None
365 return MemoryObject(
366 memory_id=row["memory_id"],
367 user_id=user_id_val,
368 content=str(content_val),
369 embedding=embedding,
370 score=0.0,
371 created_at=datetime.fromisoformat(row["created_at"]),
372 last_accessed_at=datetime.fromisoformat(row["last_accessed_at"]),
373 source=MemorySource(row["source"]),
374 importance=row["importance"],
375 lifecycle_state=LifecycleState(row["lifecycle_state"]),
376 metadata=metadata_val if isinstance(metadata_val, dict) else {},
377 embedding_dim=row["embedding_dim"],
378 tags=[t.replace("\\,", ",") for t in row["tags"].split(",")] if row["tags"] else [],
379 confidence=row["confidence"],
380 memory_type=MemoryType(row["memory_type"]),
381 session_id=row_dict.get("session_id"),
382 namespace=row_dict.get("namespace", "default"),
383 version=row_dict.get("version", 1),
384 agent_id=row_dict.get("agent_id") if row_dict.get("agent_id") is not None else None,
385 run_id=row_dict.get("run_id") if row_dict.get("run_id") is not None else None,
386 app_id=row_dict.get("app_id") if row_dict.get("app_id") is not None else None,
387 expires_at=expires_at,
388 )
390 def _memory_to_row(self, memory: MemoryObject) -> dict[str, Any]:
391 embedding_blob = None
392 if memory.embedding is not None:
393 embedding_blob = struct.pack(f"{len(memory.embedding)}f", *memory.embedding)
395 # Start with content and metadata as their native types so the
396 # field encryptor can process them before JSON serialization.
397 content_val: Any = memory.content
398 metadata_val: Any = memory.metadata
399 user_id_val: Any = memory.user_id
401 if self._field_encryptor is not None:
402 content_val = self._field_encryptor.encrypt_field("content", memory.content)
403 metadata_val = self._field_encryptor.encrypt_field("metadata", memory.metadata)
404 if self._field_encryptor._encrypt_user_id:
405 user_id_val = self._field_encryptor.encrypt_field("user_id", memory.user_id)
406 user_id_val = json.dumps(user_id_val)
408 content_json = json.dumps(content_val)
409 metadata_json = json.dumps(metadata_val)
411 return {
412 "memory_id": memory.memory_id,
413 "user_id": user_id_val,
414 "content": content_json,
415 "embedding": embedding_blob,
416 "embedding_dim": memory.embedding_dim,
417 "created_at": memory.created_at.isoformat(),
418 "last_accessed_at": memory.last_accessed_at.isoformat(),
419 "source": memory.source.value,
420 "importance": memory.importance,
421 "lifecycle_state": memory.lifecycle_state.value,
422 "metadata": metadata_json,
423 "tags": ",".join(t.replace(",", "\\,") for t in memory.tags) if memory.tags else "",
424 "confidence": memory.confidence,
425 "memory_type": memory.memory_type.value,
426 "session_id": memory.session_id,
427 "namespace": memory.namespace,
428 "version": memory.version,
429 "agent_id": memory.agent_id,
430 "run_id": memory.run_id,
431 "app_id": memory.app_id,
432 "expires_at": memory.expires_at.isoformat() if memory.expires_at else None,
433 }
435 def store(self, memory: MemoryObject) -> None:
436 with self._get_connection() as conn:
437 row = self._memory_to_row(memory)
438 conn.execute(
439 """
440 INSERT OR REPLACE INTO memories
441 (memory_id, user_id, content, embedding, embedding_dim, created_at,
442 last_accessed_at, source, importance, lifecycle_state, metadata, tags,
443 confidence, memory_type, session_id, namespace, version,
444 agent_id, run_id, app_id, expires_at)
445 VALUES (:memory_id, :user_id, :content, :embedding, :embedding_dim,
446 :created_at, :last_accessed_at, :source, :importance,
447 :lifecycle_state, :metadata, :tags,
448 :confidence, :memory_type, :session_id, :namespace, :version,
449 :agent_id, :run_id, :app_id, :expires_at)
450 """,
451 row,
452 )
453 # Sync to FTS5 index for BM25 search
454 self._sync_fts_single(conn, memory)
456 def store_many(self, memories: list[MemoryObject]) -> int:
457 """Store multiple memories in a single atomic transaction.
459 All stores are wrapped in a transaction for atomicity. If any store
460 fails, all changes are rolled back.
462 Args:
463 memories: List of MemoryObjects to store.
465 Returns:
466 Number of memories stored.
467 """
468 if not memories:
469 return 0
471 with self._transaction() as conn:
472 for memory in memories:
473 row = self._memory_to_row(memory)
474 conn.execute(
475 """
476 INSERT OR REPLACE INTO memories
477 (memory_id, user_id, content, embedding, embedding_dim, created_at,
478 last_accessed_at, source, importance, lifecycle_state, metadata, tags,
479 confidence, memory_type, session_id, namespace, version,
480 agent_id, run_id, app_id)
481 VALUES (:memory_id, :user_id, :content, :embedding, :embedding_dim,
482 :created_at, :last_accessed_at, :source, :importance,
483 :lifecycle_state, :metadata, :tags,
484 :confidence, :memory_type, :session_id, :namespace, :version,
485 :agent_id, :run_id, :app_id)
486 """,
487 row,
488 )
489 # Sync to FTS5 index
490 self._sync_fts_single(conn, memory)
491 return len(memories)
493 def _sync_fts_single(self, conn: sqlite3.Connection, memory: MemoryObject) -> None:
494 """Sync a single memory to FTS5 index.
496 Deletes any existing FTS row for this memory_id first, then inserts
497 the new one. INSERT OR REPLACE alone is not sufficient because the
498 memories_fts table is contentless and has no UNIQUE constraint on
499 memory_id — it would add a new row on every store(), causing FTS
500 duplicates to accumulate.
501 """
502 try:
503 conn.execute(
504 "DELETE FROM memories_fts WHERE memory_id = ?",
505 (memory.memory_id,),
506 )
507 conn.execute(
508 """
509 INSERT INTO memories_fts (memory_id, user_id, content, namespace, session_id)
510 VALUES (?, ?, ?, ?, ?)
511 """,
512 (
513 memory.memory_id,
514 memory.user_id,
515 memory.content,
516 memory.namespace,
517 memory.session_id,
518 ),
519 )
520 except sqlite3.OperationalError as e:
521 logger.warning("FTS5 sync failed: %s", e)
523 def rebuild_fts_index(self, user_id: str | None = None) -> int:
524 """Rebuild the FTS5 index from memories table.
526 Use this if the FTS index gets out of sync with the memories table.
528 Args:
529 user_id: If provided, only reindex memories for this user. Otherwise
530 the full index is rebuilt.
532 Returns:
533 Number of memories indexed.
534 """
535 with self._get_connection() as conn:
536 if user_id is None:
537 # Clear and rebuild the full FTS index
538 conn.execute("DELETE FROM memories_fts")
539 cursor = conn.execute("""
540 SELECT memory_id, user_id, content, namespace, session_id
541 FROM memories
542 """)
543 else:
544 # Rebuild for a specific user: delete their existing FTS rows,
545 # then re-insert them. The rest of the index is untouched.
546 conn.execute("DELETE FROM memories_fts WHERE user_id = ?", (user_id,))
547 cursor = conn.execute(
548 """
549 SELECT memory_id, user_id, content, namespace, session_id
550 FROM memories
551 WHERE user_id = ?
552 """,
553 (user_id,),
554 )
555 rows = cursor.fetchall()
557 for row in rows:
558 conn.execute(
559 """
560 INSERT INTO memories_fts (memory_id, user_id, content, namespace, session_id)
561 VALUES (?, ?, ?, ?, ?)
562 """,
563 row,
564 )
566 return len(rows)
568 def search(
569 self,
570 user_id: str,
571 query_embedding: list[float],
572 top_k: int = 10,
573 lifecycle_filter: list[LifecycleState] | None = None,
574 namespace: str = "default",
575 session_id: str | None = None,
576 ) -> list[MemoryObject]:
577 if lifecycle_filter is None:
578 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING]
580 states = [s.value for s in lifecycle_filter]
581 params: list[Any] = [user_id, namespace] + states
583 sql = """
584 SELECT * FROM memories
585 WHERE user_id = ? AND namespace = ? AND lifecycle_state IN ({})
586 """.format(",".join("?" * len(states)))
588 if session_id is not None:
589 sql += " AND (session_id = ? OR session_id IS NULL)"
590 params.append(session_id)
592 with self._get_connection() as conn:
593 cursor = conn.execute(sql, params)
594 rows = cursor.fetchall()
596 memories = []
597 for row in rows:
598 memory = self._row_to_memory(row)
599 if memory.embedding is not None:
600 similarity = scoring.cosine_similarity(memory.embedding, query_embedding)
601 memory.score = (similarity + 1.0) / 2.0
602 memories.append(memory)
604 memories.sort(key=lambda m: m.score, reverse=True)
605 return memories[:top_k]
607 def get(self, memory_id: str) -> MemoryObject | None:
608 with self._get_connection() as conn:
609 cursor = conn.execute("SELECT * FROM memories WHERE memory_id = ?", (memory_id,))
610 row = cursor.fetchone()
612 if row:
613 return self._row_to_memory(row)
614 return None
616 def update(self, memory: MemoryObject) -> None:
617 self.store(memory)
619 def delete_by_user(self, user_id: str) -> int:
620 with self._get_connection() as conn:
621 # First delete from FTS5 index
622 try:
623 conn.execute("DELETE FROM memories_fts WHERE user_id = ?", (user_id,))
624 except sqlite3.OperationalError:
625 pass # FTS table might not exist
627 cursor = conn.execute("DELETE FROM memories WHERE user_id = ?", (user_id,))
628 return cursor.rowcount
630 def delete_by_id(self, memory_id: str) -> bool:
631 with self._get_connection() as conn:
632 cursor = conn.execute("DELETE FROM memories WHERE memory_id = ?", (memory_id,))
633 # Also delete from FTS5 index to keep it in sync
634 try:
635 conn.execute("DELETE FROM memories_fts WHERE memory_id = ?", (memory_id,))
636 except sqlite3.OperationalError:
637 pass # FTS table might not exist in old databases
638 return cursor.rowcount > 0
640 def get_all_by_user(
641 self,
642 user_id: str,
643 lifecycle_filter: list[LifecycleState] | None = None,
644 namespace: str = "default",
645 session_id: str | None = None,
646 limit: int | None = None,
647 offset: int | None = None,
648 ) -> list[MemoryObject]:
649 if lifecycle_filter is None:
650 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING]
652 states = [s.value for s in lifecycle_filter]
653 params: list[Any] = [user_id, namespace] + states
655 sql = """
656 SELECT * FROM memories
657 WHERE user_id = ? AND namespace = ? AND lifecycle_state IN ({})
658 """.format(",".join("?" * len(states)))
660 if session_id is not None:
661 sql += " AND (session_id = ? OR session_id IS NULL)"
662 params.append(session_id)
664 if offset is not None:
665 sql += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
666 # Use -1 as "no limit" when limit is None but offset is provided
667 params.extend([limit if limit is not None else -1, offset])
668 elif limit is not None:
669 sql += " ORDER BY created_at DESC LIMIT ?"
670 params.append(limit)
672 with self._get_connection() as conn:
673 cursor = conn.execute(sql, params)
674 rows = cursor.fetchall()
676 return [self._row_to_memory(row) for row in rows]
678 def count(self, user_id: str) -> int:
679 with self._get_connection() as conn:
680 cursor = conn.execute("SELECT COUNT(*) FROM memories WHERE user_id = ?", (user_id,))
681 return cursor.fetchone()[0] # type: ignore[no-any-return]
683 def get_all(
684 self,
685 limit: int | None = None,
686 offset: int | None = None,
687 ) -> list[MemoryObject]:
688 sql = "SELECT * FROM memories"
689 params: list[Any] = []
691 if offset is not None:
692 sql += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
693 # Use -1 as "no limit" when limit is None but offset is provided
694 params.extend([limit if limit is not None else -1, offset])
695 elif limit is not None:
696 sql += " ORDER BY created_at DESC LIMIT ?"
697 params.append(limit)
699 with self._get_connection() as conn:
700 cursor = conn.execute(sql, params)
701 rows = cursor.fetchall()
702 return [self._row_to_memory(row) for row in rows]
704 def search_by_content(
705 self,
706 user_id: str,
707 query: str,
708 top_k: int = 10,
709 lifecycle_filter: list[LifecycleState] | None = None,
710 namespace: str = "default",
711 session_id: str | None = None,
712 ) -> list[MemoryObject]:
713 """Search for memories using FTS5 full-text search with native BM25 ranking.
715 This method uses SQLite's FTS5 for fast full-text search with BM25 scoring.
716 This is much faster than Python-based BM25 scoring for large datasets.
718 Falls back to Python-based scoring if FTS5 query fails.
719 """
720 try:
721 return self._fts5_search(user_id, query, top_k, lifecycle_filter, namespace, session_id)
722 except sqlite3.OperationalError as e:
723 logger.warning("FTS5 search failed, falling back to Python BM25: %s", e)
724 return self._bm25_python_fallback(
725 user_id, query, top_k, lifecycle_filter, namespace, session_id
726 )
728 def _fts5_search(
729 self,
730 user_id: str,
731 query: str,
732 top_k: int,
733 lifecycle_filter: list[LifecycleState] | None,
734 namespace: str,
735 session_id: str | None,
736 ) -> list[MemoryObject]:
737 """Native FTS5 BM25 search - much faster than Python scoring."""
738 if lifecycle_filter is None:
739 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING]
741 states = [s.value for s in lifecycle_filter]
743 # Build lifecycle filter for subquery
744 lifecycle_placeholders = ",".join("?" * len(states))
746 # Use FTS5 MATCH on content column only with BM25 ranking
747 # Join with main memories table to get full memory objects with lifecycle filtering
748 # Build full FTS5 query with content: prefix in Python to avoid parameter binding issues
749 fts_query = f"content:{self._prepare_fts_query(query)}"
751 sql = f"""
752 SELECT m.*, bm25(memories_fts) as fts_score
753 FROM memories m
754 INNER JOIN memories_fts fts ON m.memory_id = fts.memory_id
755 WHERE fts.user_id = ?
756 AND fts.namespace = ?
757 AND m.lifecycle_state IN ({lifecycle_placeholders})
758 AND memories_fts MATCH ?
759 """
761 params: list[Any] = [user_id, namespace] + states
763 if session_id is not None:
764 sql += " AND (m.session_id = ? OR m.session_id IS NULL)"
765 params.append(session_id)
767 sql += " ORDER BY fts_score LIMIT ?"
768 params.append(top_k)
770 # Append FTS query at the end (after lifecycle states and top_k)
771 params.insert(-1, fts_query)
773 with self._get_connection() as conn:
774 cursor = conn.execute(sql, params)
775 rows = cursor.fetchall()
777 memories = []
778 rank = 0
779 for row in rows:
780 memory = self._row_to_memory(row)
781 # Use rank-based scoring: higher rank position = lower score
782 # BM25 returns negative scores where lower = better, so invert
783 rank += 1
784 memory.score = (
785 1.0 / rank
786 ) # Rank-based normalization (1st result = 1.0, 2nd = 0.5, etc.)
787 memories.append(memory)
789 return memories
791 def _prepare_fts_query(self, query: str) -> str:
792 """Prepare query string for FTS5 MATCH on content column.
794 Handles escaping special FTS5 characters to prevent query syntax errors.
795 """
796 if not query or not query.strip():
797 return '""'
799 # Escape FTS5 special characters: " * ( ) : ~
800 # Replace with spaces to preserve word separation
801 escaped = query
802 for char in '"*():~':
803 escaped = escaped.replace(char, " ")
805 # Tokenize and create phrase queries for each term (prefix matching)
806 terms = escaped.strip().split()
807 if not terms:
808 return '""'
810 if len(terms) == 1:
811 # Single term - prefix match
812 term = terms[0].strip()
813 if term:
814 return f'"{term}"*'
815 return '""'
817 # Multiple terms - OR for matching any term
818 phrase_terms = []
819 for term in terms:
820 term = term.strip()
821 if term:
822 phrase_terms.append(f'"{term}"*')
824 if phrase_terms:
825 return " OR ".join(phrase_terms)
826 return '""'
828 def _bm25_python_fallback(
829 self,
830 user_id: str,
831 query: str,
832 top_k: int,
833 lifecycle_filter: list[LifecycleState] | None,
834 namespace: str,
835 session_id: str | None,
836 ) -> list[MemoryObject]:
837 """Python-based BM25 fallback when FTS5 is unavailable."""
838 if lifecycle_filter is None:
839 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING]
841 # Limit fetch to prevent memory issues with large datasets
842 fetch_limit = min(top_k * 20, 200)
844 states = [s.value for s in lifecycle_filter]
845 params: list[Any] = [user_id, namespace] + states
847 sql = """
848 SELECT * FROM memories
849 WHERE user_id = ? AND namespace = ? AND lifecycle_state IN ({})
850 """.format(",".join("?" * len(states)))
852 if session_id is not None:
853 sql += " AND (session_id = ? OR session_id IS NULL)"
854 params.append(session_id)
856 sql += f" ORDER BY created_at DESC LIMIT {fetch_limit}"
858 with self._get_connection() as conn:
859 cursor = conn.execute(sql, params)
860 rows = cursor.fetchall()
862 memories = [self._row_to_memory(row) for row in rows]
864 if not memories:
865 return []
867 corpus = [m.content for m in memories]
868 for memory in memories:
869 memory.score = scoring.bm25_score_corpus(query, memory.content, corpus)
871 memories.sort(key=lambda m: m.score, reverse=True)
872 return memories[:top_k]
874 def get_all_users(self) -> list[str]:
875 with self._get_connection() as conn:
876 cursor = conn.execute("SELECT DISTINCT user_id FROM memories")
877 rows = cursor.fetchall()
878 return [row[0] for row in rows]
880 def upgrade_schema(self, from_version: int, to_version: int) -> None:
881 with self._get_connection() as conn:
882 self._run_migrations(conn)
884 def get_by_tag(
885 self,
886 user_id: str,
887 tag: str,
888 lifecycle_filter: list[LifecycleState] | None = None,
889 namespace: str = "default",
890 ) -> list[MemoryObject]:
891 if lifecycle_filter is None:
892 lifecycle_filter = [LifecycleState.ACTIVE, LifecycleState.DECAYING]
894 states = [s.value for s in lifecycle_filter]
896 with self._get_connection() as conn:
897 placeholders = ",".join("?" * len(states))
898 # Escape LIKE wildcards in the user-supplied tag so a tag like
899 # "a_b" or "a%b" matches literally instead of treating _ / % as
900 # wildcards.
901 escaped_tag = tag.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
902 cursor = conn.execute(
903 f"""
904 SELECT * FROM memories
905 WHERE user_id = ? AND namespace = ? AND lifecycle_state IN ({placeholders})
906 AND (',' || tags || ',') LIKE ('%,' || ? || ',%') ESCAPE '\\'
907 """,
908 [user_id, namespace] + states + [escaped_tag],
909 )
910 rows = cursor.fetchall()
912 return [self._row_to_memory(row) for row in rows]
914 def get_api_key_manager(self) -> Any:
915 """Return an APIKeyManager bound to this adapter's connection.
917 Lazy import to avoid a hard dependency on the api_keys module.
918 Returns a fresh manager instance each call; the manager is cheap to
919 create because it shares the underlying connection pool.
920 """
921 from kemi.api_keys import APIKeyManager
923 return APIKeyManager(connection=self._get_connection())