Coverage for src / dataknobs_data / backends / sqlite.py: 16%

307 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 15:45 -0700

1"""SQLite backend implementation with sync and async support.""" 

2 

3from __future__ import annotations 

4 

5import json 

6import logging 

7import sqlite3 

8import time 

9import uuid 

10from pathlib import Path 

11from typing import Any, TYPE_CHECKING 

12 

13import numpy as np 

14from dataknobs_config import ConfigurableBase 

15 

16from ..database import SyncDatabase 

17from ..query import Query 

18from ..query_logic import ComplexQuery 

19from ..records import Record 

20from ..vector.bulk_embed_mixin import BulkEmbedMixin 

21from ..vector.mixins import VectorOperationsMixin 

22from ..vector.python_vector_search import PythonVectorSearchMixin 

23from .sql_base import SQLQueryBuilder, SQLRecordSerializer, SQLTableManager 

24from .sqlite_mixins import SQLiteVectorSupport 

25from .vector_config_mixin import VectorConfigMixin 

26 

27if TYPE_CHECKING: 

28 from collections.abc import Iterator 

29 from ..streaming import StreamConfig, StreamResult 

30 from ..vector.types import DistanceMetric, VectorSearchResult 

31 

32 

33logger = logging.getLogger(__name__) 

34 

35 

36class SyncSQLiteDatabase( # type: ignore[misc] 

37 SyncDatabase, 

38 ConfigurableBase, 

39 VectorConfigMixin, 

40 PythonVectorSearchMixin, # Provides python_vector_search_sync 

41 BulkEmbedMixin, # Must come before VectorOperationsMixin to override bulk_embed_and_store 

42 VectorOperationsMixin, 

43 SQLiteVectorSupport, 

44 SQLRecordSerializer, # Use the standard SQL serializer 

45): 

46 """Synchronous SQLite database backend.""" 

47 

48 def __init__(self, config: dict[str, Any] | None = None): 

49 """Initialize SQLite database. 

50  

51 Args: 

52 config: Configuration with the following optional keys: 

53 - path: Database file path (default: ":memory:") 

54 - table: Table name (default: "records") 

55 - timeout: Connection timeout in seconds (default: 5.0) 

56 - check_same_thread: Allow sharing across threads (default: False) 

57 - journal_mode: Journal mode (WAL, DELETE, etc.) (default: None) 

58 - synchronous: Synchronous mode (NORMAL, FULL, OFF) (default: None) 

59 - vector_enabled: Enable vector support (default: False) 

60 - vector_metric: Distance metric for vector search (default: "cosine") 

61 """ 

62 super().__init__(config) 

63 SQLiteVectorSupport.__init__(self) 

64 

65 # Parse vector configuration using the mixin 

66 self._parse_vector_config(config) 

67 

68 self.db_path = self.config.get("path", ":memory:") 

69 self.table_name = self.config.get("table", "records") 

70 self.timeout = self.config.get("timeout", 5.0) 

71 self.check_same_thread = self.config.get("check_same_thread", False) 

72 self.journal_mode = self.config.get("journal_mode") 

73 self.synchronous = self.config.get("synchronous") 

74 

75 self.query_builder = SQLQueryBuilder(self.table_name, dialect="sqlite", param_style="qmark") 

76 self.table_manager = SQLTableManager(self.table_name, dialect="sqlite") 

77 

78 self.conn: sqlite3.Connection | None = None 

79 self._connected = False 

80 

81 @classmethod 

82 def from_config(cls, config: dict) -> SyncSQLiteDatabase: 

83 """Create from config dictionary.""" 

84 return cls(config) 

85 

86 def connect(self) -> None: 

87 """Connect to the SQLite database.""" 

88 if self._connected: 

89 return 

90 

91 # Create directory if needed for file-based database 

92 if self.db_path != ":memory:": 

93 db_file = Path(self.db_path) 

94 db_file.parent.mkdir(parents=True, exist_ok=True) 

95 

96 # Connect to database 

97 self.conn = sqlite3.connect( 

98 self.db_path, 

99 timeout=self.timeout, 

100 check_same_thread=self.check_same_thread 

101 ) 

102 

103 # Enable row factory for dict-like access 

104 self.conn.row_factory = sqlite3.Row 

105 

106 # Configure SQLite for better performance 

107 self._configure_sqlite() 

108 

109 # Create table if it doesn't exist 

110 self._ensure_table() 

111 

112 self._connected = True 

113 logger.info(f"Connected to SQLite database: {self.db_path}") 

114 

115 def close(self) -> None: 

116 """Close the database connection.""" 

117 if self.conn: 

118 self.conn.close() 

119 self.conn = None 

120 self._connected = False 

121 logger.info(f"Disconnected from SQLite database: {self.db_path}") 

122 

123 def _configure_sqlite(self) -> None: 

124 """Configure SQLite settings for performance.""" 

125 if not self.conn: 

126 return 

127 

128 cursor = self.conn.cursor() 

129 

130 # Set journal mode if specified 

131 if self.journal_mode: 

132 cursor.execute(f"PRAGMA journal_mode = {self.journal_mode}") 

133 logger.debug(f"Set journal_mode to {self.journal_mode}") 

134 

135 # Set synchronous mode if specified 

136 if self.synchronous: 

137 cursor.execute(f"PRAGMA synchronous = {self.synchronous}") 

138 logger.debug(f"Set synchronous to {self.synchronous}") 

139 

140 # Enable foreign keys 

141 cursor.execute("PRAGMA foreign_keys = ON") 

142 

143 # Optimize for performance 

144 cursor.execute("PRAGMA temp_store = MEMORY") 

145 cursor.execute("PRAGMA mmap_size = 30000000000") 

146 

147 cursor.close() 

148 

149 def _ensure_table(self) -> None: 

150 """Ensure the table exists.""" 

151 if not self.conn: 

152 raise RuntimeError("Database not connected. Call connect() first.") 

153 

154 cursor = self.conn.cursor() 

155 cursor.executescript(self.table_manager.get_create_table_sql()) 

156 self.conn.commit() 

157 cursor.close() 

158 

159 def _check_connection(self) -> None: 

160 """Check if database is connected.""" 

161 if not self._connected or not self.conn: 

162 raise RuntimeError("Database not connected. Call connect() first.") 

163 

164 def create(self, record: Record) -> str: 

165 """Create a new record.""" 

166 self._check_connection() 

167 

168 # Update vector dimensions tracking if needed 

169 if self._has_vector_fields(record): 

170 self._update_vector_dimensions(record) 

171 

172 # Use centralized method to prepare record 

173 record, storage_id = self._prepare_record_for_storage(record) 

174 

175 # Use the standard SQL serializer 

176 data_json = self.record_to_json(record) 

177 metadata_json = json.dumps(record.metadata) if record.metadata else None 

178 

179 # Build insert query for SQLite's standard table structure 

180 query = f"INSERT INTO {self.table_name} (id, data, metadata) VALUES (?, ?, ?)" 

181 params = [storage_id, data_json, metadata_json] 

182 

183 cursor = self.conn.cursor() 

184 

185 try: 

186 cursor.execute(query, params) 

187 self.conn.commit() 

188 return storage_id 

189 except sqlite3.IntegrityError as e: 

190 self.conn.rollback() 

191 raise ValueError(f"Record with ID {record.id} already exists") from e 

192 finally: 

193 cursor.close() 

194 

195 def read(self, id: str) -> Record | None: 

196 """Read a record by ID.""" 

197 self._check_connection() 

198 

199 query, params = self.query_builder.build_read_query(id) 

200 cursor = self.conn.cursor() 

201 

202 try: 

203 cursor.execute(query, params) 

204 row = cursor.fetchone() 

205 

206 if row: 

207 # Use the standard SQL serializer 

208 record = self.row_to_record(dict(row)) 

209 # Use centralized method to prepare record 

210 return self._prepare_record_from_storage(record, id) 

211 return None 

212 finally: 

213 cursor.close() 

214 

215 def update(self, id: str, record: Record) -> bool: 

216 """Update an existing record. 

217 

218 Args: 

219 id: The record ID to update 

220 record: The record data to update with 

221 

222 Returns: 

223 True if the record was updated, False if no record with the given ID exists 

224 """ 

225 self._check_connection() 

226 

227 # Update vector dimensions tracking if needed 

228 if self._has_vector_fields(record): 

229 self._update_vector_dimensions(record) 

230 

231 # Use the standard SQL serializer 

232 data_json = self.record_to_json(record) 

233 metadata_json = json.dumps(record.metadata) if record.metadata else None 

234 

235 # Build update query 

236 query = f"UPDATE {self.table_name} SET data = ?, metadata = ? WHERE id = ?" 

237 params = [data_json, metadata_json, id] 

238 

239 cursor = self.conn.cursor() 

240 

241 try: 

242 cursor.execute(query, params) 

243 self.conn.commit() 

244 rows_affected = cursor.rowcount 

245 

246 if rows_affected == 0: 

247 logger.warning(f"Update affected 0 rows for id={id}. Record may not exist.") 

248 

249 return rows_affected > 0 

250 finally: 

251 cursor.close() 

252 

253 def delete(self, id: str) -> bool: 

254 """Delete a record by ID.""" 

255 self._check_connection() 

256 

257 query, params = self.query_builder.build_delete_query(id) 

258 cursor = self.conn.cursor() 

259 

260 try: 

261 cursor.execute(query, params) 

262 self.conn.commit() 

263 return cursor.rowcount > 0 

264 finally: 

265 cursor.close() 

266 

267 def exists(self, id: str) -> bool: 

268 """Check if a record exists.""" 

269 self._check_connection() 

270 

271 query, params = self.query_builder.build_exists_query(id) 

272 cursor = self.conn.cursor() 

273 

274 try: 

275 cursor.execute(query, params) 

276 result = cursor.fetchone() 

277 return result is not None 

278 finally: 

279 cursor.close() 

280 

281 def clear(self) -> int: 

282 """Clear all records from the database.""" 

283 self._check_connection() 

284 

285 cursor = self.conn.cursor() 

286 try: 

287 # Get count before clearing 

288 cursor.execute(f"SELECT COUNT(*) FROM {self.table_manager.table_name}") 

289 count = cursor.fetchone()[0] 

290 

291 # Clear the table 

292 cursor.execute(f"DELETE FROM {self.table_manager.table_name}") 

293 self.conn.commit() 

294 

295 return count 

296 finally: 

297 cursor.close() 

298 

299 def search(self, query: Query | ComplexQuery) -> list[Record]: 

300 """Search for records matching a query.""" 

301 self._check_connection() 

302 

303 # Handle ComplexQuery with native SQL support 

304 if isinstance(query, ComplexQuery): 

305 sql_query, params = self.query_builder.build_complex_search_query(query) 

306 else: 

307 sql_query, params = self.query_builder.build_search_query(query) 

308 

309 cursor = self.conn.cursor() 

310 

311 try: 

312 cursor.execute(sql_query, params) 

313 rows = cursor.fetchall() 

314 

315 records = [] 

316 for row in rows: 

317 row_dict = dict(row) 

318 record = self.row_to_record(row_dict) 

319 

320 # Populate storage_id from database ID 

321 record.storage_id = str(row_dict['id']) 

322 

323 records.append(record) 

324 

325 # Apply field projection if specified 

326 if query.fields: 

327 records = [r.project(query.fields) for r in records] 

328 

329 return records 

330 finally: 

331 cursor.close() 

332 

333 def count(self, query: Query | None = None) -> int: 

334 """Count records matching a query.""" 

335 self._check_connection() 

336 

337 sql_query, params = self.query_builder.build_count_query(query) 

338 cursor = self.conn.cursor() 

339 

340 try: 

341 cursor.execute(sql_query, params) 

342 result = cursor.fetchone() 

343 return result[0] if result else 0 

344 finally: 

345 cursor.close() 

346 

347 def create_batch(self, records: list[Record]) -> list[str]: 

348 """Create multiple records efficiently using a single query. 

349  

350 Uses multi-value INSERT for better performance. 

351 """ 

352 if not records: 

353 return [] 

354 

355 self._check_connection() 

356 

357 # Use the shared batch create query builder 

358 query, params, ids = self.query_builder.build_batch_create_query(records) 

359 

360 cursor = self.conn.cursor() 

361 try: 

362 # Execute the batch insert in a transaction 

363 cursor.execute("BEGIN TRANSACTION") 

364 cursor.execute(query, params) 

365 self.conn.commit() 

366 

367 # Return the generated IDs 

368 return ids 

369 except Exception: 

370 self.conn.rollback() 

371 raise 

372 finally: 

373 cursor.close() 

374 

375 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

376 """Update multiple records efficiently using a single query. 

377  

378 Uses CASE expressions for batch updates, similar to PostgreSQL. 

379 """ 

380 if not updates: 

381 return [] 

382 

383 self._check_connection() 

384 

385 # Use the shared batch update query builder 

386 query, params = self.query_builder.build_batch_update_query(updates) 

387 

388 cursor = self.conn.cursor() 

389 try: 

390 # Execute the batch update in a transaction 

391 cursor.execute("BEGIN TRANSACTION") 

392 cursor.execute(query, params) 

393 self.conn.commit() 

394 

395 # Check which records were actually updated 

396 # SQLite doesn't have RETURNING, so we need to verify each ID 

397 update_ids = [record_id for record_id, _ in updates] 

398 placeholders = ", ".join(["?" for _ in update_ids]) 

399 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})" 

400 cursor.execute(check_query, update_ids) 

401 existing_ids = {row[0] for row in cursor.fetchall()} 

402 

403 # Return results for each update 

404 results = [] 

405 for record_id, _ in updates: 

406 results.append(record_id in existing_ids) 

407 

408 return results 

409 except Exception: 

410 self.conn.rollback() 

411 raise 

412 finally: 

413 cursor.close() 

414 

415 def delete_batch(self, ids: list[str]) -> list[bool]: 

416 """Delete multiple records efficiently using a single query. 

417  

418 Uses single DELETE with IN clause for better performance. 

419 """ 

420 if not ids: 

421 return [] 

422 

423 self._check_connection() 

424 

425 # Check which IDs exist before deletion 

426 placeholders = ", ".join(["?" for _ in ids]) 

427 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})" 

428 

429 cursor = self.conn.cursor() 

430 try: 

431 cursor.execute(check_query, ids) 

432 existing_ids = {row[0] for row in cursor.fetchall()} 

433 

434 # Use the shared batch delete query builder 

435 query, params = self.query_builder.build_batch_delete_query(ids) 

436 

437 # Execute the batch delete in a transaction 

438 cursor.execute("BEGIN TRANSACTION") 

439 cursor.execute(query, params) 

440 self.conn.commit() 

441 

442 # Return results based on which IDs existed 

443 results = [] 

444 for id in ids: 

445 results.append(id in existing_ids) 

446 

447 return results 

448 except Exception: 

449 self.conn.rollback() 

450 raise 

451 finally: 

452 cursor.close() 

453 

454 def _initialize(self) -> None: 

455 """Initialize method - connection setup handled in connect().""" 

456 pass 

457 

458 def _count_all(self) -> int: 

459 """Count all records in the database.""" 

460 self._check_connection() 

461 cursor = self.conn.cursor() 

462 try: 

463 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") 

464 result = cursor.fetchone() 

465 return result[0] if result else 0 

466 finally: 

467 cursor.close() 

468 

469 def stream_read( 

470 self, 

471 query: Query | None = None, 

472 config: StreamConfig | None = None 

473 ) -> Iterator[Record]: 

474 """Stream records from database.""" 

475 from ..streaming import StreamConfig 

476 

477 config = config or StreamConfig() 

478 query = query or Query() 

479 

480 # Use the existing stream method's logic but yield individual records 

481 offset = 0 

482 while True: 

483 # Fetch a batch 

484 query_copy = query.copy() 

485 query_copy.offset(offset).limit(config.batch_size) 

486 batch = self.search(query_copy) 

487 

488 if not batch: 

489 break 

490 

491 for record in batch: 

492 yield record 

493 

494 offset += len(batch) 

495 

496 # If we got less than batch_size, we're done 

497 if len(batch) < config.batch_size: 

498 break 

499 

500 def stream_write( 

501 self, 

502 records: Iterator[Record], 

503 config: StreamConfig | None = None 

504 ) -> StreamResult: 

505 """Stream records into database.""" 

506 from ..streaming import StreamConfig, StreamResult 

507 

508 config = config or StreamConfig() 

509 batch = [] 

510 total_written = 0 

511 start_time = time.time() 

512 

513 for record in records: 

514 batch.append(record) 

515 

516 if len(batch) >= config.batch_size: 

517 # Write the batch 

518 self.create_batch(batch) 

519 total_written += len(batch) 

520 batch = [] 

521 

522 # Write any remaining records 

523 if batch: 

524 self.create_batch(batch) 

525 total_written += len(batch) 

526 

527 elapsed = time.time() - start_time 

528 

529 return StreamResult( 

530 total_processed=total_written, 

531 successful=total_written, 

532 failed=0, 

533 duration=elapsed, 

534 total_batches=(total_written + config.batch_size - 1) // config.batch_size 

535 ) 

536 

537 # Vector support methods 

538 def has_vector_support(self) -> bool: 

539 """Check if this backend has vector support. 

540  

541 Returns: 

542 False - SQLite has no native vector support, uses Python-based similarity 

543 """ 

544 return False # No native vector support 

545 

546 def enable_vector_support(self) -> bool: 

547 """Enable vector support for this backend. 

548  

549 Returns: 

550 True - Vector support is always available (Python-based) 

551 """ 

552 # SQLite doesn't need any special setup for vector support 

553 # We handle vectors as JSON strings 

554 self.vector_enabled = True 

555 return True 

556 

557 def vector_search( 

558 self, 

559 query_vector: np.ndarray, 

560 field_name: str = "embedding", 

561 k: int = 10, 

562 filter: Query | None = None, 

563 metric: DistanceMetric | None = None, 

564 **kwargs 

565 ) -> list[VectorSearchResult]: 

566 """Perform vector similarity search using Python-based calculations. 

567  

568 Delegates to PythonVectorSearchMixin for the implementation. 

569  

570 Args: 

571 query_vector: Query vector 

572 field_name: Name of the vector field to search 

573 k: Number of results to return 

574 filter: Optional filter conditions 

575 metric: Distance metric (uses instance default if not specified) 

576 **kwargs: Additional arguments for compatibility 

577  

578 Returns: 

579 List of search results with scores 

580 """ 

581 self._check_connection() 

582 

583 # Delegate to the mixin's implementation 

584 return self.python_vector_search_sync( 

585 query_vector=query_vector, 

586 vector_field=field_name, 

587 k=k, 

588 filter=filter, 

589 metric=metric, 

590 **kwargs 

591 ) 

592 

593 def add_vectors( 

594 self, 

595 vectors: list[np.ndarray], 

596 ids: list[str] | None = None, 

597 metadata: list[dict[str, Any]] | None = None, 

598 field_name: str = "embedding", 

599 ) -> list[str]: 

600 """Add vectors to the database. 

601  

602 Args: 

603 vectors: List of vectors to add 

604 ids: Optional list of IDs 

605 metadata: Optional list of metadata dicts 

606 field_name: Name of the vector field 

607  

608 Returns: 

609 List of created record IDs 

610 """ 

611 from collections import OrderedDict 

612 

613 from ..fields import VectorField 

614 

615 # Generate IDs if not provided 

616 if ids is None: 

617 ids = [str(uuid.uuid4()) for _ in vectors] 

618 

619 # Create records with vector fields 

620 records = [] 

621 for i, vector in enumerate(vectors): 

622 # Create vector field 

623 vector_field = VectorField( 

624 name=field_name, 

625 value=vector, 

626 dimensions=len(vector) if isinstance(vector, (list, np.ndarray)) else None 

627 ) 

628 

629 # Create record 

630 record_metadata = metadata[i] if metadata and i < len(metadata) else {} 

631 record = Record( 

632 data=OrderedDict({field_name: vector_field}), 

633 metadata=record_metadata, 

634 storage_id=ids[i] 

635 ) 

636 records.append(record) 

637 

638 # Use batch create for efficiency 

639 return self.create_batch(records)