Coverage for src / dataknobs_data / backends / sqlite_async.py: 18%

237 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 15:45 -0700

1"""Async SQLite backend implementation using aiosqlite.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6from pathlib import Path 

7from typing import Any, TYPE_CHECKING 

8 

9import aiosqlite 

10from dataknobs_config import ConfigurableBase 

11 

12from ..database import AsyncDatabase 

13from ..pooling import ConnectionPoolManager 

14from ..query import Query 

15from ..query_logic import ComplexQuery 

16from ..vector import VectorOperationsMixin 

17from ..vector.bulk_embed_mixin import BulkEmbedMixin 

18from ..vector.python_vector_search import PythonVectorSearchMixin 

19from .sql_base import SQLQueryBuilder, SQLTableManager 

20from .sqlite_mixins import SQLiteVectorSupport 

21from .vector_config_mixin import VectorConfigMixin 

22 

23if TYPE_CHECKING: 

24 from collections.abc import AsyncIterator 

25 from ..records import Record 

26 from ..streaming import StreamConfig, StreamResult 

27 

28 

29logger = logging.getLogger(__name__) 

30 

31# Global pool manager for SQLite connections 

32_pool_manager = ConnectionPoolManager() 

33 

34 

35class AsyncSQLiteDatabase( # type: ignore[misc] 

36 AsyncDatabase, 

37 ConfigurableBase, 

38 VectorConfigMixin, 

39 SQLiteVectorSupport, 

40 PythonVectorSearchMixin, # Provides python_vector_search_async 

41 BulkEmbedMixin, # Must come before VectorOperationsMixin to override bulk_embed_and_store 

42 VectorOperationsMixin 

43): 

44 """Asynchronous SQLite database backend using aiosqlite.""" 

45 

46 def __init__(self, config: dict[str, Any] | None = None): 

47 """Initialize async SQLite database. 

48  

49 Args: 

50 config: Configuration with the following optional keys: 

51 - path: Database file path (default: ":memory:") 

52 - table: Table name (default: "records") 

53 - timeout: Connection timeout in seconds (default: 5.0) 

54 - journal_mode: Journal mode (WAL, DELETE, etc.) (default: WAL for file-based) 

55 - synchronous: Synchronous mode (NORMAL, FULL, OFF) (default: NORMAL) 

56 - pool_size: Number of connections in pool (default: 5) 

57 """ 

58 super().__init__(config) 

59 config = config or {} 

60 self.db_path = config.get("path", ":memory:") 

61 self.table_name = config.get("table", "records") 

62 self.timeout = config.get("timeout", 5.0) 

63 self.journal_mode = config.get("journal_mode", "WAL" if self.db_path != ":memory:" else None) 

64 self.synchronous = config.get("synchronous", "NORMAL") 

65 self.pool_size = config.get("pool_size", 5) 

66 

67 # Start with standard query builder, will customize after mixins are initialized 

68 self.query_builder = SQLQueryBuilder(self.table_name, dialect="sqlite", param_style="qmark") 

69 self.table_manager = SQLTableManager(self.table_name, dialect="sqlite") 

70 

71 self.db: aiosqlite.Connection | None = None 

72 self._connected = False 

73 

74 # Initialize vector support 

75 self._parse_vector_config(config) 

76 self._init_vector_state() 

77 

78 @classmethod 

79 def from_config(cls, config: dict) -> AsyncSQLiteDatabase: 

80 """Create from config dictionary.""" 

81 return cls(config) 

82 

83 async def connect(self) -> None: 

84 """Connect to the SQLite database.""" 

85 if self._connected: 

86 return 

87 

88 # Create directory if needed for file-based database 

89 if self.db_path != ":memory:": 

90 db_file = Path(self.db_path) 

91 db_file.parent.mkdir(parents=True, exist_ok=True) 

92 

93 # Connect to database 

94 self.db = await aiosqlite.connect( 

95 self.db_path, 

96 timeout=self.timeout 

97 ) 

98 

99 # Enable row factory for dict-like access 

100 self.db.row_factory = aiosqlite.Row 

101 

102 # Configure SQLite for better performance 

103 await self._configure_sqlite() 

104 

105 # Create table if it doesn't exist 

106 await self._ensure_table() 

107 

108 self._connected = True 

109 logger.info(f"Connected to async SQLite database: {self.db_path}") 

110 

111 async def close(self) -> None: 

112 """Close the database connection.""" 

113 if self.db: 

114 await self.db.close() 

115 self.db = None 

116 self._connected = False 

117 logger.info(f"Disconnected from async SQLite database: {self.db_path}") 

118 

119 async def _configure_sqlite(self) -> None: 

120 """Configure SQLite settings for performance.""" 

121 if not self.db: 

122 return 

123 

124 # Set journal mode if specified 

125 if self.journal_mode: 

126 await self.db.execute(f"PRAGMA journal_mode = {self.journal_mode}") 

127 logger.debug(f"Set journal_mode to {self.journal_mode}") 

128 

129 # Set synchronous mode 

130 await self.db.execute(f"PRAGMA synchronous = {self.synchronous}") 

131 logger.debug(f"Set synchronous to {self.synchronous}") 

132 

133 # Enable foreign keys 

134 await self.db.execute("PRAGMA foreign_keys = ON") 

135 

136 # Optimize for performance 

137 await self.db.execute("PRAGMA temp_store = MEMORY") 

138 await self.db.execute("PRAGMA mmap_size = 30000000000") 

139 

140 await self.db.commit() 

141 

142 async def _ensure_table(self) -> None: 

143 """Ensure the table exists.""" 

144 if not self.db: 

145 raise RuntimeError("Database not connected. Call connect() first.") 

146 

147 await self.db.executescript(self.table_manager.get_create_table_sql()) 

148 await self.db.commit() 

149 

150 def _check_connection(self) -> None: 

151 """Check if database is connected.""" 

152 if not self._connected or not self.db: 

153 raise RuntimeError("Database not connected. Call connect() first.") 

154 

155 async def create(self, record: Record) -> str: 

156 """Create a new record.""" 

157 self._check_connection() 

158 

159 query, params = self.query_builder.build_create_query(record) 

160 

161 try: 

162 await self.db.execute(query, params) 

163 await self.db.commit() 

164 

165 # SQLite doesn't support RETURNING, so we use the ID we generated 

166 record_id = params[0] # ID is the first parameter 

167 return record_id 

168 except aiosqlite.IntegrityError as e: 

169 await self.db.rollback() 

170 raise ValueError(f"Record with ID {params[0]} already exists") from e 

171 

172 async def read(self, id: str) -> Record | None: 

173 """Read a record by ID.""" 

174 self._check_connection() 

175 

176 query, params = self.query_builder.build_read_query(id) 

177 

178 async with self.db.execute(query, params) as cursor: 

179 row = await cursor.fetchone() 

180 

181 if row: 

182 return SQLQueryBuilder.row_to_record(dict(row)) 

183 return None 

184 

185 async def update(self, id: str, record: Record) -> bool: 

186 """Update an existing record. 

187 

188 Args: 

189 id: The record ID to update 

190 record: The record data to update with 

191 

192 Returns: 

193 True if the record was updated, False if no record with the given ID exists 

194 """ 

195 self._check_connection() 

196 

197 query, params = self.query_builder.build_update_query(id, record) 

198 

199 cursor = await self.db.execute(query, params) 

200 await self.db.commit() 

201 rows_affected = cursor.rowcount 

202 

203 if rows_affected == 0: 

204 logger.warning(f"Update affected 0 rows for id={id}. Record may not exist.") 

205 

206 return rows_affected > 0 

207 

208 async def delete(self, id: str) -> bool: 

209 """Delete a record by ID.""" 

210 self._check_connection() 

211 

212 query, params = self.query_builder.build_delete_query(id) 

213 

214 cursor = await self.db.execute(query, params) 

215 await self.db.commit() 

216 return cursor.rowcount > 0 

217 

218 async def exists(self, id: str) -> bool: 

219 """Check if a record exists.""" 

220 self._check_connection() 

221 

222 query, params = self.query_builder.build_exists_query(id) 

223 

224 async with self.db.execute(query, params) as cursor: 

225 result = await cursor.fetchone() 

226 return result is not None 

227 

228 async def search(self, query: Query | ComplexQuery) -> list[Record]: 

229 """Search for records matching a query.""" 

230 self._check_connection() 

231 

232 # Handle ComplexQuery with native SQL support 

233 if isinstance(query, ComplexQuery): 

234 sql_query, params = self.query_builder.build_complex_search_query(query) 

235 else: 

236 sql_query, params = self.query_builder.build_search_query(query) 

237 

238 async with self.db.execute(sql_query, params) as cursor: 

239 rows = await cursor.fetchall() 

240 

241 records = [] 

242 for row in rows: 

243 row_dict = dict(row) 

244 record = SQLQueryBuilder.row_to_record(row_dict) 

245 

246 # Populate storage_id from database ID 

247 record.storage_id = str(row_dict['id']) 

248 

249 records.append(record) 

250 

251 # Apply field projection if specified 

252 if query.fields: 

253 records = [r.project(query.fields) for r in records] 

254 

255 return records 

256 

257 async def count(self, query: Query | None = None) -> int: 

258 """Count records matching a query.""" 

259 self._check_connection() 

260 

261 sql_query, params = self.query_builder.build_count_query(query) 

262 

263 async with self.db.execute(sql_query, params) as cursor: 

264 result = await cursor.fetchone() 

265 return result[0] if result else 0 

266 

267 async def create_batch(self, records: list[Record]) -> list[str]: 

268 """Create multiple records efficiently using a single query. 

269  

270 Uses multi-value INSERT for better performance. 

271 """ 

272 if not records: 

273 return [] 

274 

275 self._check_connection() 

276 

277 # Use the shared batch create query builder 

278 query, params, ids = self.query_builder.build_batch_create_query(records) 

279 

280 # Execute the batch insert in a transaction 

281 await self.db.execute("BEGIN TRANSACTION") 

282 

283 try: 

284 await self.db.execute(query, params) 

285 await self.db.commit() 

286 

287 # Return the generated IDs 

288 return ids 

289 except Exception: 

290 await self.db.rollback() 

291 raise 

292 

293 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

294 """Update multiple records efficiently using a single query. 

295  

296 Uses CASE expressions for batch updates, similar to PostgreSQL. 

297 """ 

298 if not updates: 

299 return [] 

300 

301 self._check_connection() 

302 

303 # Use the shared batch update query builder 

304 query, params = self.query_builder.build_batch_update_query(updates) 

305 

306 # Execute the batch update in a transaction 

307 await self.db.execute("BEGIN TRANSACTION") 

308 

309 try: 

310 await self.db.execute(query, params) 

311 await self.db.commit() 

312 

313 # Check which records were actually updated 

314 # SQLite doesn't have RETURNING, so we need to verify each ID 

315 update_ids = [record_id for record_id, _ in updates] 

316 placeholders = ", ".join(["?" for _ in update_ids]) 

317 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})" 

318 

319 async with self.db.execute(check_query, update_ids) as check_cursor: 

320 rows = await check_cursor.fetchall() 

321 existing_ids = {row[0] for row in rows} 

322 

323 # Return results for each update 

324 results = [] 

325 for record_id, _ in updates: 

326 results.append(record_id in existing_ids) 

327 

328 return results 

329 except Exception: 

330 await self.db.rollback() 

331 raise 

332 

333 async def delete_batch(self, ids: list[str]) -> list[bool]: 

334 """Delete multiple records efficiently using a single query. 

335  

336 Uses single DELETE with IN clause for better performance. 

337 """ 

338 if not ids: 

339 return [] 

340 

341 self._check_connection() 

342 

343 # Check which IDs exist before deletion 

344 placeholders = ", ".join(["?" for _ in ids]) 

345 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})" 

346 

347 async with self.db.execute(check_query, ids) as cursor: 

348 rows = await cursor.fetchall() 

349 existing_ids = {row[0] for row in rows} 

350 

351 # Use the shared batch delete query builder 

352 query, params = self.query_builder.build_batch_delete_query(ids) 

353 

354 # Execute the batch delete in a transaction 

355 await self.db.execute("BEGIN TRANSACTION") 

356 

357 try: 

358 await self.db.execute(query, params) 

359 await self.db.commit() 

360 

361 # Return results based on which IDs existed 

362 results = [] 

363 for id in ids: 

364 results.append(id in existing_ids) 

365 

366 return results 

367 except Exception: 

368 await self.db.rollback() 

369 raise 

370 

371 def _initialize(self) -> None: 

372 """Initialize method - connection setup handled in connect().""" 

373 pass 

374 

375 async def _count_all(self) -> int: 

376 """Count all records in the database.""" 

377 self._check_connection() 

378 

379 async with self.db.execute(f"SELECT COUNT(*) FROM {self.table_name}") as cursor: 

380 result = await cursor.fetchone() 

381 return result[0] if result else 0 

382 

383 async def stream_read( 

384 self, 

385 query: Query | None = None, 

386 config: StreamConfig | None = None 

387 ) -> AsyncIterator[Record]: 

388 """Stream records from database.""" 

389 from ..streaming import StreamConfig 

390 

391 config = config or StreamConfig() 

392 query = query or Query() 

393 

394 # Use the existing stream method's logic but yield individual records 

395 offset = 0 

396 while True: 

397 # Fetch a batch 

398 query_copy = query.copy() 

399 query_copy.offset(offset).limit(config.batch_size) 

400 batch = await self.search(query_copy) 

401 

402 if not batch: 

403 break 

404 

405 for record in batch: 

406 yield record 

407 

408 offset += len(batch) 

409 

410 # If we got less than batch_size, we're done 

411 if len(batch) < config.batch_size: 

412 break 

413 

414 async def stream_write( 

415 self, 

416 records: AsyncIterator[Record], 

417 config: StreamConfig | None = None 

418 ) -> StreamResult: 

419 """Stream records into database.""" 

420 import time 

421 

422 from ..streaming import StreamConfig, StreamResult 

423 

424 config = config or StreamConfig() 

425 batch = [] 

426 total_written = 0 

427 start_time = time.time() 

428 

429 async for record in records: 

430 batch.append(record) 

431 

432 if len(batch) >= config.batch_size: 

433 # Write the batch 

434 await self.create_batch(batch) 

435 total_written += len(batch) 

436 batch = [] 

437 

438 # Write any remaining records 

439 if batch: 

440 await self.create_batch(batch) 

441 total_written += len(batch) 

442 

443 elapsed = time.time() - start_time 

444 

445 return StreamResult( 

446 total_processed=total_written, 

447 successful=total_written, 

448 failed=0, 

449 duration=elapsed, 

450 total_batches=(total_written + config.batch_size - 1) // config.batch_size 

451 ) 

452 

453 async def vector_search( 

454 self, 

455 query_vector, 

456 vector_field: str = "embedding", 

457 k: int = 10, 

458 filter=None, 

459 metric=None, 

460 **kwargs 

461 ): 

462 """Perform async vector similarity search using Python-based calculations. 

463  

464 Delegates to PythonVectorSearchMixin for the implementation. 

465  

466 Args: 

467 query_vector: Query vector 

468 vector_field: Name of the vector field to search 

469 k: Number of results to return  

470 filter: Optional filter conditions 

471 metric: Distance metric (uses instance default if not specified) 

472 **kwargs: Additional arguments for compatibility 

473  

474 Returns: 

475 List of VectorSearchResult objects with scores 

476 """ 

477 self._check_connection() 

478 

479 # Delegate to the mixin's implementation 

480 return await self.python_vector_search_async( 

481 query_vector=query_vector, 

482 vector_field=vector_field, 

483 k=k, 

484 filter=filter, 

485 metric=metric, 

486 **kwargs 

487 )