Coverage for src/dataknobs_data/backends/sqlite_async.py: 18%

229 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-31 15:06 -0600

1"""Async SQLite backend implementation using aiosqlite.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6from pathlib import Path 

7from typing import Any, TYPE_CHECKING 

8 

9import aiosqlite 

10from dataknobs_config import ConfigurableBase 

11 

12from ..database import AsyncDatabase 

13from ..pooling import ConnectionPoolManager 

14from ..query import Query 

15from ..query_logic import ComplexQuery 

16from ..vector import VectorOperationsMixin 

17from ..vector.bulk_embed_mixin import BulkEmbedMixin 

18from ..vector.python_vector_search import PythonVectorSearchMixin 

19from .sql_base import SQLQueryBuilder, SQLTableManager 

20from .sqlite_mixins import SQLiteVectorSupport 

21from .vector_config_mixin import VectorConfigMixin 

22 

23if TYPE_CHECKING: 

24 from collections.abc import AsyncIterator 

25 from ..records import Record 

26 from ..streaming import StreamConfig, StreamResult 

27 

28 

29logger = logging.getLogger(__name__) 

30 

31# Global pool manager for SQLite connections 

32_pool_manager = ConnectionPoolManager() 

33 

34 

35class AsyncSQLiteDatabase( # type: ignore[misc] 

36 AsyncDatabase, 

37 ConfigurableBase, 

38 VectorConfigMixin, 

39 SQLiteVectorSupport, 

40 PythonVectorSearchMixin, # Provides python_vector_search_async 

41 BulkEmbedMixin, # Must come before VectorOperationsMixin to override bulk_embed_and_store 

42 VectorOperationsMixin 

43): 

44 """Asynchronous SQLite database backend using aiosqlite.""" 

45 

46 def __init__(self, config: dict[str, Any] | None = None): 

47 """Initialize async SQLite database. 

48  

49 Args: 

50 config: Configuration with the following optional keys: 

51 - path: Database file path (default: ":memory:") 

52 - table: Table name (default: "records") 

53 - timeout: Connection timeout in seconds (default: 5.0) 

54 - journal_mode: Journal mode (WAL, DELETE, etc.) (default: WAL for file-based) 

55 - synchronous: Synchronous mode (NORMAL, FULL, OFF) (default: NORMAL) 

56 - pool_size: Number of connections in pool (default: 5) 

57 """ 

58 super().__init__(config) 

59 config = config or {} 

60 self.db_path = config.get("path", ":memory:") 

61 self.table_name = config.get("table", "records") 

62 self.timeout = config.get("timeout", 5.0) 

63 self.journal_mode = config.get("journal_mode", "WAL" if self.db_path != ":memory:" else None) 

64 self.synchronous = config.get("synchronous", "NORMAL") 

65 self.pool_size = config.get("pool_size", 5) 

66 

67 # Start with standard query builder, will customize after mixins are initialized 

68 self.query_builder = SQLQueryBuilder(self.table_name, dialect="sqlite", param_style="qmark") 

69 self.table_manager = SQLTableManager(self.table_name, dialect="sqlite") 

70 

71 self.db: aiosqlite.Connection | None = None 

72 self._connected = False 

73 

74 # Initialize vector support 

75 self._parse_vector_config(config) 

76 self._init_vector_state() 

77 

78 @classmethod 

79 def from_config(cls, config: dict) -> AsyncSQLiteDatabase: 

80 """Create from config dictionary.""" 

81 return cls(config) 

82 

83 async def connect(self) -> None: 

84 """Connect to the SQLite database.""" 

85 if self._connected: 

86 return 

87 

88 # Create directory if needed for file-based database 

89 if self.db_path != ":memory:": 

90 db_file = Path(self.db_path) 

91 db_file.parent.mkdir(parents=True, exist_ok=True) 

92 

93 # Connect to database 

94 self.db = await aiosqlite.connect( 

95 self.db_path, 

96 timeout=self.timeout 

97 ) 

98 

99 # Enable row factory for dict-like access 

100 self.db.row_factory = aiosqlite.Row 

101 

102 # Configure SQLite for better performance 

103 await self._configure_sqlite() 

104 

105 # Create table if it doesn't exist 

106 await self._ensure_table() 

107 

108 self._connected = True 

109 logger.info(f"Connected to async SQLite database: {self.db_path}") 

110 

111 async def close(self) -> None: 

112 """Close the database connection.""" 

113 if self.db: 

114 await self.db.close() 

115 self.db = None 

116 self._connected = False 

117 logger.info(f"Disconnected from async SQLite database: {self.db_path}") 

118 

119 async def _configure_sqlite(self) -> None: 

120 """Configure SQLite settings for performance.""" 

121 if not self.db: 

122 return 

123 

124 # Set journal mode if specified 

125 if self.journal_mode: 

126 await self.db.execute(f"PRAGMA journal_mode = {self.journal_mode}") 

127 logger.debug(f"Set journal_mode to {self.journal_mode}") 

128 

129 # Set synchronous mode 

130 await self.db.execute(f"PRAGMA synchronous = {self.synchronous}") 

131 logger.debug(f"Set synchronous to {self.synchronous}") 

132 

133 # Enable foreign keys 

134 await self.db.execute("PRAGMA foreign_keys = ON") 

135 

136 # Optimize for performance 

137 await self.db.execute("PRAGMA temp_store = MEMORY") 

138 await self.db.execute("PRAGMA mmap_size = 30000000000") 

139 

140 await self.db.commit() 

141 

142 async def _ensure_table(self) -> None: 

143 """Ensure the table exists.""" 

144 if not self.db: 

145 raise RuntimeError("Database not connected. Call connect() first.") 

146 

147 await self.db.executescript(self.table_manager.get_create_table_sql()) 

148 await self.db.commit() 

149 

150 def _check_connection(self) -> None: 

151 """Check if database is connected.""" 

152 if not self._connected or not self.db: 

153 raise RuntimeError("Database not connected. Call connect() first.") 

154 

155 async def create(self, record: Record) -> str: 

156 """Create a new record.""" 

157 self._check_connection() 

158 

159 query, params = self.query_builder.build_create_query(record) 

160 

161 try: 

162 await self.db.execute(query, params) 

163 await self.db.commit() 

164 

165 # SQLite doesn't support RETURNING, so we use the ID we generated 

166 record_id = params[0] # ID is the first parameter 

167 return record_id 

168 except aiosqlite.IntegrityError as e: 

169 await self.db.rollback() 

170 raise ValueError(f"Record with ID {params[0]} already exists") from e 

171 

172 async def read(self, id: str) -> Record | None: 

173 """Read a record by ID.""" 

174 self._check_connection() 

175 

176 query, params = self.query_builder.build_read_query(id) 

177 

178 async with self.db.execute(query, params) as cursor: 

179 row = await cursor.fetchone() 

180 

181 if row: 

182 return SQLQueryBuilder.row_to_record(dict(row)) 

183 return None 

184 

185 async def update(self, id: str, record: Record) -> bool: 

186 """Update an existing record.""" 

187 self._check_connection() 

188 

189 query, params = self.query_builder.build_update_query(id, record) 

190 

191 cursor = await self.db.execute(query, params) 

192 await self.db.commit() 

193 return cursor.rowcount > 0 

194 

195 async def delete(self, id: str) -> bool: 

196 """Delete a record by ID.""" 

197 self._check_connection() 

198 

199 query, params = self.query_builder.build_delete_query(id) 

200 

201 cursor = await self.db.execute(query, params) 

202 await self.db.commit() 

203 return cursor.rowcount > 0 

204 

205 async def exists(self, id: str) -> bool: 

206 """Check if a record exists.""" 

207 self._check_connection() 

208 

209 query, params = self.query_builder.build_exists_query(id) 

210 

211 async with self.db.execute(query, params) as cursor: 

212 result = await cursor.fetchone() 

213 return result is not None 

214 

215 async def search(self, query: Query | ComplexQuery) -> list[Record]: 

216 """Search for records matching a query.""" 

217 self._check_connection() 

218 

219 # Handle ComplexQuery with native SQL support 

220 if isinstance(query, ComplexQuery): 

221 sql_query, params = self.query_builder.build_complex_search_query(query) 

222 else: 

223 sql_query, params = self.query_builder.build_search_query(query) 

224 

225 async with self.db.execute(sql_query, params) as cursor: 

226 rows = await cursor.fetchall() 

227 

228 records = [SQLQueryBuilder.row_to_record(dict(row)) for row in rows] 

229 

230 # Apply field projection if specified 

231 if query.fields: 

232 records = [r.project(query.fields) for r in records] 

233 

234 return records 

235 

236 async def count(self, query: Query | None = None) -> int: 

237 """Count records matching a query.""" 

238 self._check_connection() 

239 

240 sql_query, params = self.query_builder.build_count_query(query) 

241 

242 async with self.db.execute(sql_query, params) as cursor: 

243 result = await cursor.fetchone() 

244 return result[0] if result else 0 

245 

246 async def create_batch(self, records: list[Record]) -> list[str]: 

247 """Create multiple records efficiently using a single query. 

248  

249 Uses multi-value INSERT for better performance. 

250 """ 

251 if not records: 

252 return [] 

253 

254 self._check_connection() 

255 

256 # Use the shared batch create query builder 

257 query, params, ids = self.query_builder.build_batch_create_query(records) 

258 

259 # Execute the batch insert in a transaction 

260 await self.db.execute("BEGIN TRANSACTION") 

261 

262 try: 

263 await self.db.execute(query, params) 

264 await self.db.commit() 

265 

266 # Return the generated IDs 

267 return ids 

268 except Exception: 

269 await self.db.rollback() 

270 raise 

271 

272 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

273 """Update multiple records efficiently using a single query. 

274  

275 Uses CASE expressions for batch updates, similar to PostgreSQL. 

276 """ 

277 if not updates: 

278 return [] 

279 

280 self._check_connection() 

281 

282 # Use the shared batch update query builder 

283 query, params = self.query_builder.build_batch_update_query(updates) 

284 

285 # Execute the batch update in a transaction 

286 await self.db.execute("BEGIN TRANSACTION") 

287 

288 try: 

289 await self.db.execute(query, params) 

290 await self.db.commit() 

291 

292 # Check which records were actually updated 

293 # SQLite doesn't have RETURNING, so we need to verify each ID 

294 update_ids = [record_id for record_id, _ in updates] 

295 placeholders = ", ".join(["?" for _ in update_ids]) 

296 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})" 

297 

298 async with self.db.execute(check_query, update_ids) as check_cursor: 

299 rows = await check_cursor.fetchall() 

300 existing_ids = {row[0] for row in rows} 

301 

302 # Return results for each update 

303 results = [] 

304 for record_id, _ in updates: 

305 results.append(record_id in existing_ids) 

306 

307 return results 

308 except Exception: 

309 await self.db.rollback() 

310 raise 

311 

312 async def delete_batch(self, ids: list[str]) -> list[bool]: 

313 """Delete multiple records efficiently using a single query. 

314  

315 Uses single DELETE with IN clause for better performance. 

316 """ 

317 if not ids: 

318 return [] 

319 

320 self._check_connection() 

321 

322 # Check which IDs exist before deletion 

323 placeholders = ", ".join(["?" for _ in ids]) 

324 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})" 

325 

326 async with self.db.execute(check_query, ids) as cursor: 

327 rows = await cursor.fetchall() 

328 existing_ids = {row[0] for row in rows} 

329 

330 # Use the shared batch delete query builder 

331 query, params = self.query_builder.build_batch_delete_query(ids) 

332 

333 # Execute the batch delete in a transaction 

334 await self.db.execute("BEGIN TRANSACTION") 

335 

336 try: 

337 await self.db.execute(query, params) 

338 await self.db.commit() 

339 

340 # Return results based on which IDs existed 

341 results = [] 

342 for id in ids: 

343 results.append(id in existing_ids) 

344 

345 return results 

346 except Exception: 

347 await self.db.rollback() 

348 raise 

349 

350 def _initialize(self) -> None: 

351 """Initialize method - connection setup handled in connect().""" 

352 pass 

353 

354 async def _count_all(self) -> int: 

355 """Count all records in the database.""" 

356 self._check_connection() 

357 

358 async with self.db.execute(f"SELECT COUNT(*) FROM {self.table_name}") as cursor: 

359 result = await cursor.fetchone() 

360 return result[0] if result else 0 

361 

362 async def stream_read( 

363 self, 

364 query: Query | None = None, 

365 config: StreamConfig | None = None 

366 ) -> AsyncIterator[Record]: 

367 """Stream records from database.""" 

368 from ..streaming import StreamConfig 

369 

370 config = config or StreamConfig() 

371 query = query or Query() 

372 

373 # Use the existing stream method's logic but yield individual records 

374 offset = 0 

375 while True: 

376 # Fetch a batch 

377 query_copy = query.copy() 

378 query_copy.offset(offset).limit(config.batch_size) 

379 batch = await self.search(query_copy) 

380 

381 if not batch: 

382 break 

383 

384 for record in batch: 

385 yield record 

386 

387 offset += len(batch) 

388 

389 # If we got less than batch_size, we're done 

390 if len(batch) < config.batch_size: 

391 break 

392 

393 async def stream_write( 

394 self, 

395 records: AsyncIterator[Record], 

396 config: StreamConfig | None = None 

397 ) -> StreamResult: 

398 """Stream records into database.""" 

399 import time 

400 

401 from ..streaming import StreamConfig, StreamResult 

402 

403 config = config or StreamConfig() 

404 batch = [] 

405 total_written = 0 

406 start_time = time.time() 

407 

408 async for record in records: 

409 batch.append(record) 

410 

411 if len(batch) >= config.batch_size: 

412 # Write the batch 

413 await self.create_batch(batch) 

414 total_written += len(batch) 

415 batch = [] 

416 

417 # Write any remaining records 

418 if batch: 

419 await self.create_batch(batch) 

420 total_written += len(batch) 

421 

422 elapsed = time.time() - start_time 

423 

424 return StreamResult( 

425 total_processed=total_written, 

426 successful=total_written, 

427 failed=0, 

428 duration=elapsed, 

429 total_batches=(total_written + config.batch_size - 1) // config.batch_size 

430 ) 

431 

432 async def vector_search( 

433 self, 

434 query_vector, 

435 vector_field: str = "embedding", 

436 k: int = 10, 

437 filter=None, 

438 metric=None, 

439 **kwargs 

440 ): 

441 """Perform async vector similarity search using Python-based calculations. 

442  

443 Delegates to PythonVectorSearchMixin for the implementation. 

444  

445 Args: 

446 query_vector: Query vector 

447 vector_field: Name of the vector field to search 

448 k: Number of results to return  

449 filter: Optional filter conditions 

450 metric: Distance metric (uses instance default if not specified) 

451 **kwargs: Additional arguments for compatibility 

452  

453 Returns: 

454 List of VectorSearchResult objects with scores 

455 """ 

456 self._check_connection() 

457 

458 # Delegate to the mixin's implementation 

459 return await self.python_vector_search_async( 

460 query_vector=query_vector, 

461 vector_field=vector_field, 

462 k=k, 

463 filter=filter, 

464 metric=metric, 

465 **kwargs 

466 )