Coverage for src/dataknobs_data/backends/sqlite.py: 17%
289 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
1"""SQLite backend implementation with sync and async support."""
3from __future__ import annotations
5import json
6import logging
7import sqlite3
8import time
9import uuid
10from pathlib import Path
11from typing import Any, TYPE_CHECKING
13import numpy as np
14from dataknobs_config import ConfigurableBase
16from ..database import SyncDatabase
17from ..query import Query
18from ..query_logic import ComplexQuery
19from ..records import Record
20from ..vector.bulk_embed_mixin import BulkEmbedMixin
21from ..vector.mixins import VectorOperationsMixin
22from ..vector.python_vector_search import PythonVectorSearchMixin
23from .sql_base import SQLQueryBuilder, SQLRecordSerializer, SQLTableManager
24from .sqlite_mixins import SQLiteVectorSupport
25from .vector_config_mixin import VectorConfigMixin
27if TYPE_CHECKING:
28 from collections.abc import Iterator
29 from ..streaming import StreamConfig, StreamResult
30 from ..vector.types import DistanceMetric, VectorSearchResult
33logger = logging.getLogger(__name__)
36class SyncSQLiteDatabase( # type: ignore[misc]
37 SyncDatabase,
38 ConfigurableBase,
39 VectorConfigMixin,
40 PythonVectorSearchMixin, # Provides python_vector_search_sync
41 BulkEmbedMixin, # Must come before VectorOperationsMixin to override bulk_embed_and_store
42 VectorOperationsMixin,
43 SQLiteVectorSupport,
44 SQLRecordSerializer, # Use the standard SQL serializer
45):
46 """Synchronous SQLite database backend."""
48 def __init__(self, config: dict[str, Any] | None = None):
49 """Initialize SQLite database.
51 Args:
52 config: Configuration with the following optional keys:
53 - path: Database file path (default: ":memory:")
54 - table: Table name (default: "records")
55 - timeout: Connection timeout in seconds (default: 5.0)
56 - check_same_thread: Allow sharing across threads (default: False)
57 - journal_mode: Journal mode (WAL, DELETE, etc.) (default: None)
58 - synchronous: Synchronous mode (NORMAL, FULL, OFF) (default: None)
59 - vector_enabled: Enable vector support (default: False)
60 - vector_metric: Distance metric for vector search (default: "cosine")
61 """
62 super().__init__(config)
63 SQLiteVectorSupport.__init__(self)
65 # Parse vector configuration using the mixin
66 self._parse_vector_config(config)
68 self.db_path = self.config.get("path", ":memory:")
69 self.table_name = self.config.get("table", "records")
70 self.timeout = self.config.get("timeout", 5.0)
71 self.check_same_thread = self.config.get("check_same_thread", False)
72 self.journal_mode = self.config.get("journal_mode")
73 self.synchronous = self.config.get("synchronous")
75 self.query_builder = SQLQueryBuilder(self.table_name, dialect="sqlite", param_style="qmark")
76 self.table_manager = SQLTableManager(self.table_name, dialect="sqlite")
78 self.conn: sqlite3.Connection | None = None
79 self._connected = False
81 @classmethod
82 def from_config(cls, config: dict) -> SyncSQLiteDatabase:
83 """Create from config dictionary."""
84 return cls(config)
86 def connect(self) -> None:
87 """Connect to the SQLite database."""
88 if self._connected:
89 return
91 # Create directory if needed for file-based database
92 if self.db_path != ":memory:":
93 db_file = Path(self.db_path)
94 db_file.parent.mkdir(parents=True, exist_ok=True)
96 # Connect to database
97 self.conn = sqlite3.connect(
98 self.db_path,
99 timeout=self.timeout,
100 check_same_thread=self.check_same_thread
101 )
103 # Enable row factory for dict-like access
104 self.conn.row_factory = sqlite3.Row
106 # Configure SQLite for better performance
107 self._configure_sqlite()
109 # Create table if it doesn't exist
110 self._ensure_table()
112 self._connected = True
113 logger.info(f"Connected to SQLite database: {self.db_path}")
115 def close(self) -> None:
116 """Close the database connection."""
117 if self.conn:
118 self.conn.close()
119 self.conn = None
120 self._connected = False
121 logger.info(f"Disconnected from SQLite database: {self.db_path}")
123 def _configure_sqlite(self) -> None:
124 """Configure SQLite settings for performance."""
125 if not self.conn:
126 return
128 cursor = self.conn.cursor()
130 # Set journal mode if specified
131 if self.journal_mode:
132 cursor.execute(f"PRAGMA journal_mode = {self.journal_mode}")
133 logger.debug(f"Set journal_mode to {self.journal_mode}")
135 # Set synchronous mode if specified
136 if self.synchronous:
137 cursor.execute(f"PRAGMA synchronous = {self.synchronous}")
138 logger.debug(f"Set synchronous to {self.synchronous}")
140 # Enable foreign keys
141 cursor.execute("PRAGMA foreign_keys = ON")
143 # Optimize for performance
144 cursor.execute("PRAGMA temp_store = MEMORY")
145 cursor.execute("PRAGMA mmap_size = 30000000000")
147 cursor.close()
149 def _ensure_table(self) -> None:
150 """Ensure the table exists."""
151 if not self.conn:
152 raise RuntimeError("Database not connected. Call connect() first.")
154 cursor = self.conn.cursor()
155 cursor.executescript(self.table_manager.get_create_table_sql())
156 self.conn.commit()
157 cursor.close()
159 def _check_connection(self) -> None:
160 """Check if database is connected."""
161 if not self._connected or not self.conn:
162 raise RuntimeError("Database not connected. Call connect() first.")
164 def create(self, record: Record) -> str:
165 """Create a new record."""
166 self._check_connection()
168 # Update vector dimensions tracking if needed
169 if self._has_vector_fields(record):
170 self._update_vector_dimensions(record)
172 # Use centralized method to prepare record
173 record, storage_id = self._prepare_record_for_storage(record)
175 # Use the standard SQL serializer
176 data_json = self.record_to_json(record)
177 metadata_json = json.dumps(record.metadata) if record.metadata else None
179 # Build insert query for SQLite's standard table structure
180 query = f"INSERT INTO {self.table_name} (id, data, metadata) VALUES (?, ?, ?)"
181 params = [storage_id, data_json, metadata_json]
183 cursor = self.conn.cursor()
185 try:
186 cursor.execute(query, params)
187 self.conn.commit()
188 return storage_id
189 except sqlite3.IntegrityError as e:
190 self.conn.rollback()
191 raise ValueError(f"Record with ID {record.id} already exists") from e
192 finally:
193 cursor.close()
195 def read(self, id: str) -> Record | None:
196 """Read a record by ID."""
197 self._check_connection()
199 query, params = self.query_builder.build_read_query(id)
200 cursor = self.conn.cursor()
202 try:
203 cursor.execute(query, params)
204 row = cursor.fetchone()
206 if row:
207 # Use the standard SQL serializer
208 record = self.row_to_record(dict(row))
209 # Use centralized method to prepare record
210 return self._prepare_record_from_storage(record, id)
211 return None
212 finally:
213 cursor.close()
215 def update(self, id: str, record: Record) -> bool:
216 """Update an existing record."""
217 self._check_connection()
219 # Update vector dimensions tracking if needed
220 if self._has_vector_fields(record):
221 self._update_vector_dimensions(record)
223 # Use the standard SQL serializer
224 data_json = self.record_to_json(record)
225 metadata_json = json.dumps(record.metadata) if record.metadata else None
227 # Build update query
228 query = f"UPDATE {self.table_name} SET data = ?, metadata = ? WHERE id = ?"
229 params = [data_json, metadata_json, id]
231 cursor = self.conn.cursor()
233 try:
234 cursor.execute(query, params)
235 self.conn.commit()
236 return cursor.rowcount > 0
237 finally:
238 cursor.close()
240 def delete(self, id: str) -> bool:
241 """Delete a record by ID."""
242 self._check_connection()
244 query, params = self.query_builder.build_delete_query(id)
245 cursor = self.conn.cursor()
247 try:
248 cursor.execute(query, params)
249 self.conn.commit()
250 return cursor.rowcount > 0
251 finally:
252 cursor.close()
254 def exists(self, id: str) -> bool:
255 """Check if a record exists."""
256 self._check_connection()
258 query, params = self.query_builder.build_exists_query(id)
259 cursor = self.conn.cursor()
261 try:
262 cursor.execute(query, params)
263 result = cursor.fetchone()
264 return result is not None
265 finally:
266 cursor.close()
268 def search(self, query: Query | ComplexQuery) -> list[Record]:
269 """Search for records matching a query."""
270 self._check_connection()
272 # Handle ComplexQuery with native SQL support
273 if isinstance(query, ComplexQuery):
274 sql_query, params = self.query_builder.build_complex_search_query(query)
275 else:
276 sql_query, params = self.query_builder.build_search_query(query)
278 cursor = self.conn.cursor()
280 try:
281 cursor.execute(sql_query, params)
282 rows = cursor.fetchall()
284 records = [self.row_to_record(dict(row)) for row in rows]
286 # Apply field projection if specified
287 if query.fields:
288 records = [r.project(query.fields) for r in records]
290 return records
291 finally:
292 cursor.close()
294 def count(self, query: Query | None = None) -> int:
295 """Count records matching a query."""
296 self._check_connection()
298 sql_query, params = self.query_builder.build_count_query(query)
299 cursor = self.conn.cursor()
301 try:
302 cursor.execute(sql_query, params)
303 result = cursor.fetchone()
304 return result[0] if result else 0
305 finally:
306 cursor.close()
308 def create_batch(self, records: list[Record]) -> list[str]:
309 """Create multiple records efficiently using a single query.
311 Uses multi-value INSERT for better performance.
312 """
313 if not records:
314 return []
316 self._check_connection()
318 # Use the shared batch create query builder
319 query, params, ids = self.query_builder.build_batch_create_query(records)
321 cursor = self.conn.cursor()
322 try:
323 # Execute the batch insert in a transaction
324 cursor.execute("BEGIN TRANSACTION")
325 cursor.execute(query, params)
326 self.conn.commit()
328 # Return the generated IDs
329 return ids
330 except Exception:
331 self.conn.rollback()
332 raise
333 finally:
334 cursor.close()
336 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
337 """Update multiple records efficiently using a single query.
339 Uses CASE expressions for batch updates, similar to PostgreSQL.
340 """
341 if not updates:
342 return []
344 self._check_connection()
346 # Use the shared batch update query builder
347 query, params = self.query_builder.build_batch_update_query(updates)
349 cursor = self.conn.cursor()
350 try:
351 # Execute the batch update in a transaction
352 cursor.execute("BEGIN TRANSACTION")
353 cursor.execute(query, params)
354 self.conn.commit()
356 # Check which records were actually updated
357 # SQLite doesn't have RETURNING, so we need to verify each ID
358 update_ids = [record_id for record_id, _ in updates]
359 placeholders = ", ".join(["?" for _ in update_ids])
360 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})"
361 cursor.execute(check_query, update_ids)
362 existing_ids = {row[0] for row in cursor.fetchall()}
364 # Return results for each update
365 results = []
366 for record_id, _ in updates:
367 results.append(record_id in existing_ids)
369 return results
370 except Exception:
371 self.conn.rollback()
372 raise
373 finally:
374 cursor.close()
376 def delete_batch(self, ids: list[str]) -> list[bool]:
377 """Delete multiple records efficiently using a single query.
379 Uses single DELETE with IN clause for better performance.
380 """
381 if not ids:
382 return []
384 self._check_connection()
386 # Check which IDs exist before deletion
387 placeholders = ", ".join(["?" for _ in ids])
388 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})"
390 cursor = self.conn.cursor()
391 try:
392 cursor.execute(check_query, ids)
393 existing_ids = {row[0] for row in cursor.fetchall()}
395 # Use the shared batch delete query builder
396 query, params = self.query_builder.build_batch_delete_query(ids)
398 # Execute the batch delete in a transaction
399 cursor.execute("BEGIN TRANSACTION")
400 cursor.execute(query, params)
401 self.conn.commit()
403 # Return results based on which IDs existed
404 results = []
405 for id in ids:
406 results.append(id in existing_ids)
408 return results
409 except Exception:
410 self.conn.rollback()
411 raise
412 finally:
413 cursor.close()
415 def _initialize(self) -> None:
416 """Initialize method - connection setup handled in connect()."""
417 pass
419 def _count_all(self) -> int:
420 """Count all records in the database."""
421 self._check_connection()
422 cursor = self.conn.cursor()
423 try:
424 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
425 result = cursor.fetchone()
426 return result[0] if result else 0
427 finally:
428 cursor.close()
430 def stream_read(
431 self,
432 query: Query | None = None,
433 config: StreamConfig | None = None
434 ) -> Iterator[Record]:
435 """Stream records from database."""
436 from ..streaming import StreamConfig
438 config = config or StreamConfig()
439 query = query or Query()
441 # Use the existing stream method's logic but yield individual records
442 offset = 0
443 while True:
444 # Fetch a batch
445 query_copy = query.copy()
446 query_copy.offset(offset).limit(config.batch_size)
447 batch = self.search(query_copy)
449 if not batch:
450 break
452 for record in batch:
453 yield record
455 offset += len(batch)
457 # If we got less than batch_size, we're done
458 if len(batch) < config.batch_size:
459 break
461 def stream_write(
462 self,
463 records: Iterator[Record],
464 config: StreamConfig | None = None
465 ) -> StreamResult:
466 """Stream records into database."""
467 from ..streaming import StreamConfig, StreamResult
469 config = config or StreamConfig()
470 batch = []
471 total_written = 0
472 start_time = time.time()
474 for record in records:
475 batch.append(record)
477 if len(batch) >= config.batch_size:
478 # Write the batch
479 self.create_batch(batch)
480 total_written += len(batch)
481 batch = []
483 # Write any remaining records
484 if batch:
485 self.create_batch(batch)
486 total_written += len(batch)
488 elapsed = time.time() - start_time
490 return StreamResult(
491 total_processed=total_written,
492 successful=total_written,
493 failed=0,
494 duration=elapsed,
495 total_batches=(total_written + config.batch_size - 1) // config.batch_size
496 )
498 # Vector support methods
499 def has_vector_support(self) -> bool:
500 """Check if this backend has vector support.
502 Returns:
503 False - SQLite has no native vector support, uses Python-based similarity
504 """
505 return False # No native vector support
507 def enable_vector_support(self) -> bool:
508 """Enable vector support for this backend.
510 Returns:
511 True - Vector support is always available (Python-based)
512 """
513 # SQLite doesn't need any special setup for vector support
514 # We handle vectors as JSON strings
515 self.vector_enabled = True
516 return True
518 def vector_search(
519 self,
520 query_vector: np.ndarray,
521 field_name: str = "embedding",
522 k: int = 10,
523 filter: Query | None = None,
524 metric: DistanceMetric | None = None,
525 **kwargs
526 ) -> list[VectorSearchResult]:
527 """Perform vector similarity search using Python-based calculations.
529 Delegates to PythonVectorSearchMixin for the implementation.
531 Args:
532 query_vector: Query vector
533 field_name: Name of the vector field to search
534 k: Number of results to return
535 filter: Optional filter conditions
536 metric: Distance metric (uses instance default if not specified)
537 **kwargs: Additional arguments for compatibility
539 Returns:
540 List of search results with scores
541 """
542 self._check_connection()
544 # Delegate to the mixin's implementation
545 return self.python_vector_search_sync(
546 query_vector=query_vector,
547 vector_field=field_name,
548 k=k,
549 filter=filter,
550 metric=metric,
551 **kwargs
552 )
554 def add_vectors(
555 self,
556 vectors: list[np.ndarray],
557 ids: list[str] | None = None,
558 metadata: list[dict[str, Any]] | None = None,
559 field_name: str = "embedding",
560 ) -> list[str]:
561 """Add vectors to the database.
563 Args:
564 vectors: List of vectors to add
565 ids: Optional list of IDs
566 metadata: Optional list of metadata dicts
567 field_name: Name of the vector field
569 Returns:
570 List of created record IDs
571 """
572 from collections import OrderedDict
574 from ..fields import VectorField
576 # Generate IDs if not provided
577 if ids is None:
578 ids = [str(uuid.uuid4()) for _ in vectors]
580 # Create records with vector fields
581 records = []
582 for i, vector in enumerate(vectors):
583 # Create vector field
584 vector_field = VectorField(
585 name=field_name,
586 value=vector,
587 dimensions=len(vector) if isinstance(vector, (list, np.ndarray)) else None
588 )
590 # Create record
591 record_metadata = metadata[i] if metadata and i < len(metadata) else {}
592 record = Record(
593 data=OrderedDict({field_name: vector_field}),
594 metadata=record_metadata,
595 storage_id=ids[i]
596 )
597 records.append(record)
599 # Use batch create for efficiency
600 return self.create_batch(records)