Coverage for src/dataknobs_data/backends/sqlite.py: 17%

289 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-31 15:06 -0600

1"""SQLite backend implementation with sync and async support.""" 

2 

3from __future__ import annotations 

4 

5import json 

6import logging 

7import sqlite3 

8import time 

9import uuid 

10from pathlib import Path 

11from typing import Any, TYPE_CHECKING 

12 

13import numpy as np 

14from dataknobs_config import ConfigurableBase 

15 

16from ..database import SyncDatabase 

17from ..query import Query 

18from ..query_logic import ComplexQuery 

19from ..records import Record 

20from ..vector.bulk_embed_mixin import BulkEmbedMixin 

21from ..vector.mixins import VectorOperationsMixin 

22from ..vector.python_vector_search import PythonVectorSearchMixin 

23from .sql_base import SQLQueryBuilder, SQLRecordSerializer, SQLTableManager 

24from .sqlite_mixins import SQLiteVectorSupport 

25from .vector_config_mixin import VectorConfigMixin 

26 

27if TYPE_CHECKING: 

28 from collections.abc import Iterator 

29 from ..streaming import StreamConfig, StreamResult 

30 from ..vector.types import DistanceMetric, VectorSearchResult 

31 

32 

33logger = logging.getLogger(__name__) 

34 

35 

36class SyncSQLiteDatabase( # type: ignore[misc] 

37 SyncDatabase, 

38 ConfigurableBase, 

39 VectorConfigMixin, 

40 PythonVectorSearchMixin, # Provides python_vector_search_sync 

41 BulkEmbedMixin, # Must come before VectorOperationsMixin to override bulk_embed_and_store 

42 VectorOperationsMixin, 

43 SQLiteVectorSupport, 

44 SQLRecordSerializer, # Use the standard SQL serializer 

45): 

46 """Synchronous SQLite database backend.""" 

47 

48 def __init__(self, config: dict[str, Any] | None = None): 

49 """Initialize SQLite database. 

50  

51 Args: 

52 config: Configuration with the following optional keys: 

53 - path: Database file path (default: ":memory:") 

54 - table: Table name (default: "records") 

55 - timeout: Connection timeout in seconds (default: 5.0) 

56 - check_same_thread: Allow sharing across threads (default: False) 

57 - journal_mode: Journal mode (WAL, DELETE, etc.) (default: None) 

58 - synchronous: Synchronous mode (NORMAL, FULL, OFF) (default: None) 

59 - vector_enabled: Enable vector support (default: False) 

60 - vector_metric: Distance metric for vector search (default: "cosine") 

61 """ 

62 super().__init__(config) 

63 SQLiteVectorSupport.__init__(self) 

64 

65 # Parse vector configuration using the mixin 

66 self._parse_vector_config(config) 

67 

68 self.db_path = self.config.get("path", ":memory:") 

69 self.table_name = self.config.get("table", "records") 

70 self.timeout = self.config.get("timeout", 5.0) 

71 self.check_same_thread = self.config.get("check_same_thread", False) 

72 self.journal_mode = self.config.get("journal_mode") 

73 self.synchronous = self.config.get("synchronous") 

74 

75 self.query_builder = SQLQueryBuilder(self.table_name, dialect="sqlite", param_style="qmark") 

76 self.table_manager = SQLTableManager(self.table_name, dialect="sqlite") 

77 

78 self.conn: sqlite3.Connection | None = None 

79 self._connected = False 

80 

81 @classmethod 

82 def from_config(cls, config: dict) -> SyncSQLiteDatabase: 

83 """Create from config dictionary.""" 

84 return cls(config) 

85 

86 def connect(self) -> None: 

87 """Connect to the SQLite database.""" 

88 if self._connected: 

89 return 

90 

91 # Create directory if needed for file-based database 

92 if self.db_path != ":memory:": 

93 db_file = Path(self.db_path) 

94 db_file.parent.mkdir(parents=True, exist_ok=True) 

95 

96 # Connect to database 

97 self.conn = sqlite3.connect( 

98 self.db_path, 

99 timeout=self.timeout, 

100 check_same_thread=self.check_same_thread 

101 ) 

102 

103 # Enable row factory for dict-like access 

104 self.conn.row_factory = sqlite3.Row 

105 

106 # Configure SQLite for better performance 

107 self._configure_sqlite() 

108 

109 # Create table if it doesn't exist 

110 self._ensure_table() 

111 

112 self._connected = True 

113 logger.info(f"Connected to SQLite database: {self.db_path}") 

114 

115 def close(self) -> None: 

116 """Close the database connection.""" 

117 if self.conn: 

118 self.conn.close() 

119 self.conn = None 

120 self._connected = False 

121 logger.info(f"Disconnected from SQLite database: {self.db_path}") 

122 

123 def _configure_sqlite(self) -> None: 

124 """Configure SQLite settings for performance.""" 

125 if not self.conn: 

126 return 

127 

128 cursor = self.conn.cursor() 

129 

130 # Set journal mode if specified 

131 if self.journal_mode: 

132 cursor.execute(f"PRAGMA journal_mode = {self.journal_mode}") 

133 logger.debug(f"Set journal_mode to {self.journal_mode}") 

134 

135 # Set synchronous mode if specified 

136 if self.synchronous: 

137 cursor.execute(f"PRAGMA synchronous = {self.synchronous}") 

138 logger.debug(f"Set synchronous to {self.synchronous}") 

139 

140 # Enable foreign keys 

141 cursor.execute("PRAGMA foreign_keys = ON") 

142 

143 # Optimize for performance 

144 cursor.execute("PRAGMA temp_store = MEMORY") 

145 cursor.execute("PRAGMA mmap_size = 30000000000") 

146 

147 cursor.close() 

148 

149 def _ensure_table(self) -> None: 

150 """Ensure the table exists.""" 

151 if not self.conn: 

152 raise RuntimeError("Database not connected. Call connect() first.") 

153 

154 cursor = self.conn.cursor() 

155 cursor.executescript(self.table_manager.get_create_table_sql()) 

156 self.conn.commit() 

157 cursor.close() 

158 

159 def _check_connection(self) -> None: 

160 """Check if database is connected.""" 

161 if not self._connected or not self.conn: 

162 raise RuntimeError("Database not connected. Call connect() first.") 

163 

164 def create(self, record: Record) -> str: 

165 """Create a new record.""" 

166 self._check_connection() 

167 

168 # Update vector dimensions tracking if needed 

169 if self._has_vector_fields(record): 

170 self._update_vector_dimensions(record) 

171 

172 # Use centralized method to prepare record 

173 record, storage_id = self._prepare_record_for_storage(record) 

174 

175 # Use the standard SQL serializer 

176 data_json = self.record_to_json(record) 

177 metadata_json = json.dumps(record.metadata) if record.metadata else None 

178 

179 # Build insert query for SQLite's standard table structure 

180 query = f"INSERT INTO {self.table_name} (id, data, metadata) VALUES (?, ?, ?)" 

181 params = [storage_id, data_json, metadata_json] 

182 

183 cursor = self.conn.cursor() 

184 

185 try: 

186 cursor.execute(query, params) 

187 self.conn.commit() 

188 return storage_id 

189 except sqlite3.IntegrityError as e: 

190 self.conn.rollback() 

191 raise ValueError(f"Record with ID {record.id} already exists") from e 

192 finally: 

193 cursor.close() 

194 

195 def read(self, id: str) -> Record | None: 

196 """Read a record by ID.""" 

197 self._check_connection() 

198 

199 query, params = self.query_builder.build_read_query(id) 

200 cursor = self.conn.cursor() 

201 

202 try: 

203 cursor.execute(query, params) 

204 row = cursor.fetchone() 

205 

206 if row: 

207 # Use the standard SQL serializer 

208 record = self.row_to_record(dict(row)) 

209 # Use centralized method to prepare record 

210 return self._prepare_record_from_storage(record, id) 

211 return None 

212 finally: 

213 cursor.close() 

214 

215 def update(self, id: str, record: Record) -> bool: 

216 """Update an existing record.""" 

217 self._check_connection() 

218 

219 # Update vector dimensions tracking if needed 

220 if self._has_vector_fields(record): 

221 self._update_vector_dimensions(record) 

222 

223 # Use the standard SQL serializer 

224 data_json = self.record_to_json(record) 

225 metadata_json = json.dumps(record.metadata) if record.metadata else None 

226 

227 # Build update query 

228 query = f"UPDATE {self.table_name} SET data = ?, metadata = ? WHERE id = ?" 

229 params = [data_json, metadata_json, id] 

230 

231 cursor = self.conn.cursor() 

232 

233 try: 

234 cursor.execute(query, params) 

235 self.conn.commit() 

236 return cursor.rowcount > 0 

237 finally: 

238 cursor.close() 

239 

240 def delete(self, id: str) -> bool: 

241 """Delete a record by ID.""" 

242 self._check_connection() 

243 

244 query, params = self.query_builder.build_delete_query(id) 

245 cursor = self.conn.cursor() 

246 

247 try: 

248 cursor.execute(query, params) 

249 self.conn.commit() 

250 return cursor.rowcount > 0 

251 finally: 

252 cursor.close() 

253 

254 def exists(self, id: str) -> bool: 

255 """Check if a record exists.""" 

256 self._check_connection() 

257 

258 query, params = self.query_builder.build_exists_query(id) 

259 cursor = self.conn.cursor() 

260 

261 try: 

262 cursor.execute(query, params) 

263 result = cursor.fetchone() 

264 return result is not None 

265 finally: 

266 cursor.close() 

267 

268 def search(self, query: Query | ComplexQuery) -> list[Record]: 

269 """Search for records matching a query.""" 

270 self._check_connection() 

271 

272 # Handle ComplexQuery with native SQL support 

273 if isinstance(query, ComplexQuery): 

274 sql_query, params = self.query_builder.build_complex_search_query(query) 

275 else: 

276 sql_query, params = self.query_builder.build_search_query(query) 

277 

278 cursor = self.conn.cursor() 

279 

280 try: 

281 cursor.execute(sql_query, params) 

282 rows = cursor.fetchall() 

283 

284 records = [self.row_to_record(dict(row)) for row in rows] 

285 

286 # Apply field projection if specified 

287 if query.fields: 

288 records = [r.project(query.fields) for r in records] 

289 

290 return records 

291 finally: 

292 cursor.close() 

293 

294 def count(self, query: Query | None = None) -> int: 

295 """Count records matching a query.""" 

296 self._check_connection() 

297 

298 sql_query, params = self.query_builder.build_count_query(query) 

299 cursor = self.conn.cursor() 

300 

301 try: 

302 cursor.execute(sql_query, params) 

303 result = cursor.fetchone() 

304 return result[0] if result else 0 

305 finally: 

306 cursor.close() 

307 

308 def create_batch(self, records: list[Record]) -> list[str]: 

309 """Create multiple records efficiently using a single query. 

310  

311 Uses multi-value INSERT for better performance. 

312 """ 

313 if not records: 

314 return [] 

315 

316 self._check_connection() 

317 

318 # Use the shared batch create query builder 

319 query, params, ids = self.query_builder.build_batch_create_query(records) 

320 

321 cursor = self.conn.cursor() 

322 try: 

323 # Execute the batch insert in a transaction 

324 cursor.execute("BEGIN TRANSACTION") 

325 cursor.execute(query, params) 

326 self.conn.commit() 

327 

328 # Return the generated IDs 

329 return ids 

330 except Exception: 

331 self.conn.rollback() 

332 raise 

333 finally: 

334 cursor.close() 

335 

336 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

337 """Update multiple records efficiently using a single query. 

338  

339 Uses CASE expressions for batch updates, similar to PostgreSQL. 

340 """ 

341 if not updates: 

342 return [] 

343 

344 self._check_connection() 

345 

346 # Use the shared batch update query builder 

347 query, params = self.query_builder.build_batch_update_query(updates) 

348 

349 cursor = self.conn.cursor() 

350 try: 

351 # Execute the batch update in a transaction 

352 cursor.execute("BEGIN TRANSACTION") 

353 cursor.execute(query, params) 

354 self.conn.commit() 

355 

356 # Check which records were actually updated 

357 # SQLite doesn't have RETURNING, so we need to verify each ID 

358 update_ids = [record_id for record_id, _ in updates] 

359 placeholders = ", ".join(["?" for _ in update_ids]) 

360 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})" 

361 cursor.execute(check_query, update_ids) 

362 existing_ids = {row[0] for row in cursor.fetchall()} 

363 

364 # Return results for each update 

365 results = [] 

366 for record_id, _ in updates: 

367 results.append(record_id in existing_ids) 

368 

369 return results 

370 except Exception: 

371 self.conn.rollback() 

372 raise 

373 finally: 

374 cursor.close() 

375 

376 def delete_batch(self, ids: list[str]) -> list[bool]: 

377 """Delete multiple records efficiently using a single query. 

378  

379 Uses single DELETE with IN clause for better performance. 

380 """ 

381 if not ids: 

382 return [] 

383 

384 self._check_connection() 

385 

386 # Check which IDs exist before deletion 

387 placeholders = ", ".join(["?" for _ in ids]) 

388 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})" 

389 

390 cursor = self.conn.cursor() 

391 try: 

392 cursor.execute(check_query, ids) 

393 existing_ids = {row[0] for row in cursor.fetchall()} 

394 

395 # Use the shared batch delete query builder 

396 query, params = self.query_builder.build_batch_delete_query(ids) 

397 

398 # Execute the batch delete in a transaction 

399 cursor.execute("BEGIN TRANSACTION") 

400 cursor.execute(query, params) 

401 self.conn.commit() 

402 

403 # Return results based on which IDs existed 

404 results = [] 

405 for id in ids: 

406 results.append(id in existing_ids) 

407 

408 return results 

409 except Exception: 

410 self.conn.rollback() 

411 raise 

412 finally: 

413 cursor.close() 

414 

415 def _initialize(self) -> None: 

416 """Initialize method - connection setup handled in connect().""" 

417 pass 

418 

419 def _count_all(self) -> int: 

420 """Count all records in the database.""" 

421 self._check_connection() 

422 cursor = self.conn.cursor() 

423 try: 

424 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") 

425 result = cursor.fetchone() 

426 return result[0] if result else 0 

427 finally: 

428 cursor.close() 

429 

430 def stream_read( 

431 self, 

432 query: Query | None = None, 

433 config: StreamConfig | None = None 

434 ) -> Iterator[Record]: 

435 """Stream records from database.""" 

436 from ..streaming import StreamConfig 

437 

438 config = config or StreamConfig() 

439 query = query or Query() 

440 

441 # Use the existing stream method's logic but yield individual records 

442 offset = 0 

443 while True: 

444 # Fetch a batch 

445 query_copy = query.copy() 

446 query_copy.offset(offset).limit(config.batch_size) 

447 batch = self.search(query_copy) 

448 

449 if not batch: 

450 break 

451 

452 for record in batch: 

453 yield record 

454 

455 offset += len(batch) 

456 

457 # If we got less than batch_size, we're done 

458 if len(batch) < config.batch_size: 

459 break 

460 

461 def stream_write( 

462 self, 

463 records: Iterator[Record], 

464 config: StreamConfig | None = None 

465 ) -> StreamResult: 

466 """Stream records into database.""" 

467 from ..streaming import StreamConfig, StreamResult 

468 

469 config = config or StreamConfig() 

470 batch = [] 

471 total_written = 0 

472 start_time = time.time() 

473 

474 for record in records: 

475 batch.append(record) 

476 

477 if len(batch) >= config.batch_size: 

478 # Write the batch 

479 self.create_batch(batch) 

480 total_written += len(batch) 

481 batch = [] 

482 

483 # Write any remaining records 

484 if batch: 

485 self.create_batch(batch) 

486 total_written += len(batch) 

487 

488 elapsed = time.time() - start_time 

489 

490 return StreamResult( 

491 total_processed=total_written, 

492 successful=total_written, 

493 failed=0, 

494 duration=elapsed, 

495 total_batches=(total_written + config.batch_size - 1) // config.batch_size 

496 ) 

497 

498 # Vector support methods 

499 def has_vector_support(self) -> bool: 

500 """Check if this backend has vector support. 

501  

502 Returns: 

503 False - SQLite has no native vector support, uses Python-based similarity 

504 """ 

505 return False # No native vector support 

506 

507 def enable_vector_support(self) -> bool: 

508 """Enable vector support for this backend. 

509  

510 Returns: 

511 True - Vector support is always available (Python-based) 

512 """ 

513 # SQLite doesn't need any special setup for vector support 

514 # We handle vectors as JSON strings 

515 self.vector_enabled = True 

516 return True 

517 

518 def vector_search( 

519 self, 

520 query_vector: np.ndarray, 

521 field_name: str = "embedding", 

522 k: int = 10, 

523 filter: Query | None = None, 

524 metric: DistanceMetric | None = None, 

525 **kwargs 

526 ) -> list[VectorSearchResult]: 

527 """Perform vector similarity search using Python-based calculations. 

528  

529 Delegates to PythonVectorSearchMixin for the implementation. 

530  

531 Args: 

532 query_vector: Query vector 

533 field_name: Name of the vector field to search 

534 k: Number of results to return 

535 filter: Optional filter conditions 

536 metric: Distance metric (uses instance default if not specified) 

537 **kwargs: Additional arguments for compatibility 

538  

539 Returns: 

540 List of search results with scores 

541 """ 

542 self._check_connection() 

543 

544 # Delegate to the mixin's implementation 

545 return self.python_vector_search_sync( 

546 query_vector=query_vector, 

547 vector_field=field_name, 

548 k=k, 

549 filter=filter, 

550 metric=metric, 

551 **kwargs 

552 ) 

553 

554 def add_vectors( 

555 self, 

556 vectors: list[np.ndarray], 

557 ids: list[str] | None = None, 

558 metadata: list[dict[str, Any]] | None = None, 

559 field_name: str = "embedding", 

560 ) -> list[str]: 

561 """Add vectors to the database. 

562  

563 Args: 

564 vectors: List of vectors to add 

565 ids: Optional list of IDs 

566 metadata: Optional list of metadata dicts 

567 field_name: Name of the vector field 

568  

569 Returns: 

570 List of created record IDs 

571 """ 

572 from collections import OrderedDict 

573 

574 from ..fields import VectorField 

575 

576 # Generate IDs if not provided 

577 if ids is None: 

578 ids = [str(uuid.uuid4()) for _ in vectors] 

579 

580 # Create records with vector fields 

581 records = [] 

582 for i, vector in enumerate(vectors): 

583 # Create vector field 

584 vector_field = VectorField( 

585 name=field_name, 

586 value=vector, 

587 dimensions=len(vector) if isinstance(vector, (list, np.ndarray)) else None 

588 ) 

589 

590 # Create record 

591 record_metadata = metadata[i] if metadata and i < len(metadata) else {} 

592 record = Record( 

593 data=OrderedDict({field_name: vector_field}), 

594 metadata=record_metadata, 

595 storage_id=ids[i] 

596 ) 

597 records.append(record) 

598 

599 # Use batch create for efficiency 

600 return self.create_batch(records)