Coverage for src/dataknobs_data/backends/sqlite_async.py: 18%
237 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
1"""Async SQLite backend implementation using aiosqlite."""
3from __future__ import annotations
5import logging
6from pathlib import Path
7from typing import Any, TYPE_CHECKING
9import aiosqlite
10from dataknobs_config import ConfigurableBase
12from ..database import AsyncDatabase
13from ..pooling import ConnectionPoolManager
14from ..query import Query
15from ..query_logic import ComplexQuery
16from ..vector import VectorOperationsMixin
17from ..vector.bulk_embed_mixin import BulkEmbedMixin
18from ..vector.python_vector_search import PythonVectorSearchMixin
19from .sql_base import SQLQueryBuilder, SQLTableManager
20from .sqlite_mixins import SQLiteVectorSupport
21from .vector_config_mixin import VectorConfigMixin
23if TYPE_CHECKING:
24 from collections.abc import AsyncIterator
25 from ..records import Record
26 from ..streaming import StreamConfig, StreamResult
29logger = logging.getLogger(__name__)
31# Global pool manager for SQLite connections
32_pool_manager = ConnectionPoolManager()
35class AsyncSQLiteDatabase( # type: ignore[misc]
36 AsyncDatabase,
37 ConfigurableBase,
38 VectorConfigMixin,
39 SQLiteVectorSupport,
40 PythonVectorSearchMixin, # Provides python_vector_search_async
41 BulkEmbedMixin, # Must come before VectorOperationsMixin to override bulk_embed_and_store
42 VectorOperationsMixin
43):
44 """Asynchronous SQLite database backend using aiosqlite."""
46 def __init__(self, config: dict[str, Any] | None = None):
47 """Initialize async SQLite database.
49 Args:
50 config: Configuration with the following optional keys:
51 - path: Database file path (default: ":memory:")
52 - table: Table name (default: "records")
53 - timeout: Connection timeout in seconds (default: 5.0)
54 - journal_mode: Journal mode (WAL, DELETE, etc.) (default: WAL for file-based)
55 - synchronous: Synchronous mode (NORMAL, FULL, OFF) (default: NORMAL)
56 - pool_size: Number of connections in pool (default: 5)
57 """
58 super().__init__(config)
59 config = config or {}
60 self.db_path = config.get("path", ":memory:")
61 self.table_name = config.get("table", "records")
62 self.timeout = config.get("timeout", 5.0)
63 self.journal_mode = config.get("journal_mode", "WAL" if self.db_path != ":memory:" else None)
64 self.synchronous = config.get("synchronous", "NORMAL")
65 self.pool_size = config.get("pool_size", 5)
67 # Start with standard query builder, will customize after mixins are initialized
68 self.query_builder = SQLQueryBuilder(self.table_name, dialect="sqlite", param_style="qmark")
69 self.table_manager = SQLTableManager(self.table_name, dialect="sqlite")
71 self.db: aiosqlite.Connection | None = None
72 self._connected = False
74 # Initialize vector support
75 self._parse_vector_config(config)
76 self._init_vector_state()
78 @classmethod
79 def from_config(cls, config: dict) -> AsyncSQLiteDatabase:
80 """Create from config dictionary."""
81 return cls(config)
83 async def connect(self) -> None:
84 """Connect to the SQLite database."""
85 if self._connected:
86 return
88 # Create directory if needed for file-based database
89 if self.db_path != ":memory:":
90 db_file = Path(self.db_path)
91 db_file.parent.mkdir(parents=True, exist_ok=True)
93 # Connect to database
94 self.db = await aiosqlite.connect(
95 self.db_path,
96 timeout=self.timeout
97 )
99 # Enable row factory for dict-like access
100 self.db.row_factory = aiosqlite.Row
102 # Configure SQLite for better performance
103 await self._configure_sqlite()
105 # Create table if it doesn't exist
106 await self._ensure_table()
108 self._connected = True
109 logger.info(f"Connected to async SQLite database: {self.db_path}")
111 async def close(self) -> None:
112 """Close the database connection."""
113 if self.db:
114 await self.db.close()
115 self.db = None
116 self._connected = False
117 logger.info(f"Disconnected from async SQLite database: {self.db_path}")
119 async def _configure_sqlite(self) -> None:
120 """Configure SQLite settings for performance."""
121 if not self.db:
122 return
124 # Set journal mode if specified
125 if self.journal_mode:
126 await self.db.execute(f"PRAGMA journal_mode = {self.journal_mode}")
127 logger.debug(f"Set journal_mode to {self.journal_mode}")
129 # Set synchronous mode
130 await self.db.execute(f"PRAGMA synchronous = {self.synchronous}")
131 logger.debug(f"Set synchronous to {self.synchronous}")
133 # Enable foreign keys
134 await self.db.execute("PRAGMA foreign_keys = ON")
136 # Optimize for performance
137 await self.db.execute("PRAGMA temp_store = MEMORY")
138 await self.db.execute("PRAGMA mmap_size = 30000000000")
140 await self.db.commit()
142 async def _ensure_table(self) -> None:
143 """Ensure the table exists."""
144 if not self.db:
145 raise RuntimeError("Database not connected. Call connect() first.")
147 await self.db.executescript(self.table_manager.get_create_table_sql())
148 await self.db.commit()
150 def _check_connection(self) -> None:
151 """Check if database is connected."""
152 if not self._connected or not self.db:
153 raise RuntimeError("Database not connected. Call connect() first.")
155 async def create(self, record: Record) -> str:
156 """Create a new record."""
157 self._check_connection()
159 query, params = self.query_builder.build_create_query(record)
161 try:
162 await self.db.execute(query, params)
163 await self.db.commit()
165 # SQLite doesn't support RETURNING, so we use the ID we generated
166 record_id = params[0] # ID is the first parameter
167 return record_id
168 except aiosqlite.IntegrityError as e:
169 await self.db.rollback()
170 raise ValueError(f"Record with ID {params[0]} already exists") from e
172 async def read(self, id: str) -> Record | None:
173 """Read a record by ID."""
174 self._check_connection()
176 query, params = self.query_builder.build_read_query(id)
178 async with self.db.execute(query, params) as cursor:
179 row = await cursor.fetchone()
181 if row:
182 return SQLQueryBuilder.row_to_record(dict(row))
183 return None
185 async def update(self, id: str, record: Record) -> bool:
186 """Update an existing record.
188 Args:
189 id: The record ID to update
190 record: The record data to update with
192 Returns:
193 True if the record was updated, False if no record with the given ID exists
194 """
195 self._check_connection()
197 query, params = self.query_builder.build_update_query(id, record)
199 cursor = await self.db.execute(query, params)
200 await self.db.commit()
201 rows_affected = cursor.rowcount
203 if rows_affected == 0:
204 logger.warning(f"Update affected 0 rows for id={id}. Record may not exist.")
206 return rows_affected > 0
208 async def delete(self, id: str) -> bool:
209 """Delete a record by ID."""
210 self._check_connection()
212 query, params = self.query_builder.build_delete_query(id)
214 cursor = await self.db.execute(query, params)
215 await self.db.commit()
216 return cursor.rowcount > 0
218 async def exists(self, id: str) -> bool:
219 """Check if a record exists."""
220 self._check_connection()
222 query, params = self.query_builder.build_exists_query(id)
224 async with self.db.execute(query, params) as cursor:
225 result = await cursor.fetchone()
226 return result is not None
228 async def search(self, query: Query | ComplexQuery) -> list[Record]:
229 """Search for records matching a query."""
230 self._check_connection()
232 # Handle ComplexQuery with native SQL support
233 if isinstance(query, ComplexQuery):
234 sql_query, params = self.query_builder.build_complex_search_query(query)
235 else:
236 sql_query, params = self.query_builder.build_search_query(query)
238 async with self.db.execute(sql_query, params) as cursor:
239 rows = await cursor.fetchall()
241 records = []
242 for row in rows:
243 row_dict = dict(row)
244 record = SQLQueryBuilder.row_to_record(row_dict)
246 # Populate storage_id from database ID
247 record.storage_id = str(row_dict['id'])
249 records.append(record)
251 # Apply field projection if specified
252 if query.fields:
253 records = [r.project(query.fields) for r in records]
255 return records
257 async def count(self, query: Query | None = None) -> int:
258 """Count records matching a query."""
259 self._check_connection()
261 sql_query, params = self.query_builder.build_count_query(query)
263 async with self.db.execute(sql_query, params) as cursor:
264 result = await cursor.fetchone()
265 return result[0] if result else 0
267 async def create_batch(self, records: list[Record]) -> list[str]:
268 """Create multiple records efficiently using a single query.
270 Uses multi-value INSERT for better performance.
271 """
272 if not records:
273 return []
275 self._check_connection()
277 # Use the shared batch create query builder
278 query, params, ids = self.query_builder.build_batch_create_query(records)
280 # Execute the batch insert in a transaction
281 await self.db.execute("BEGIN TRANSACTION")
283 try:
284 await self.db.execute(query, params)
285 await self.db.commit()
287 # Return the generated IDs
288 return ids
289 except Exception:
290 await self.db.rollback()
291 raise
293 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
294 """Update multiple records efficiently using a single query.
296 Uses CASE expressions for batch updates, similar to PostgreSQL.
297 """
298 if not updates:
299 return []
301 self._check_connection()
303 # Use the shared batch update query builder
304 query, params = self.query_builder.build_batch_update_query(updates)
306 # Execute the batch update in a transaction
307 await self.db.execute("BEGIN TRANSACTION")
309 try:
310 await self.db.execute(query, params)
311 await self.db.commit()
313 # Check which records were actually updated
314 # SQLite doesn't have RETURNING, so we need to verify each ID
315 update_ids = [record_id for record_id, _ in updates]
316 placeholders = ", ".join(["?" for _ in update_ids])
317 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})"
319 async with self.db.execute(check_query, update_ids) as check_cursor:
320 rows = await check_cursor.fetchall()
321 existing_ids = {row[0] for row in rows}
323 # Return results for each update
324 results = []
325 for record_id, _ in updates:
326 results.append(record_id in existing_ids)
328 return results
329 except Exception:
330 await self.db.rollback()
331 raise
333 async def delete_batch(self, ids: list[str]) -> list[bool]:
334 """Delete multiple records efficiently using a single query.
336 Uses single DELETE with IN clause for better performance.
337 """
338 if not ids:
339 return []
341 self._check_connection()
343 # Check which IDs exist before deletion
344 placeholders = ", ".join(["?" for _ in ids])
345 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})"
347 async with self.db.execute(check_query, ids) as cursor:
348 rows = await cursor.fetchall()
349 existing_ids = {row[0] for row in rows}
351 # Use the shared batch delete query builder
352 query, params = self.query_builder.build_batch_delete_query(ids)
354 # Execute the batch delete in a transaction
355 await self.db.execute("BEGIN TRANSACTION")
357 try:
358 await self.db.execute(query, params)
359 await self.db.commit()
361 # Return results based on which IDs existed
362 results = []
363 for id in ids:
364 results.append(id in existing_ids)
366 return results
367 except Exception:
368 await self.db.rollback()
369 raise
371 def _initialize(self) -> None:
372 """Initialize method - connection setup handled in connect()."""
373 pass
375 async def _count_all(self) -> int:
376 """Count all records in the database."""
377 self._check_connection()
379 async with self.db.execute(f"SELECT COUNT(*) FROM {self.table_name}") as cursor:
380 result = await cursor.fetchone()
381 return result[0] if result else 0
383 async def stream_read(
384 self,
385 query: Query | None = None,
386 config: StreamConfig | None = None
387 ) -> AsyncIterator[Record]:
388 """Stream records from database."""
389 from ..streaming import StreamConfig
391 config = config or StreamConfig()
392 query = query or Query()
394 # Use the existing stream method's logic but yield individual records
395 offset = 0
396 while True:
397 # Fetch a batch
398 query_copy = query.copy()
399 query_copy.offset(offset).limit(config.batch_size)
400 batch = await self.search(query_copy)
402 if not batch:
403 break
405 for record in batch:
406 yield record
408 offset += len(batch)
410 # If we got less than batch_size, we're done
411 if len(batch) < config.batch_size:
412 break
414 async def stream_write(
415 self,
416 records: AsyncIterator[Record],
417 config: StreamConfig | None = None
418 ) -> StreamResult:
419 """Stream records into database."""
420 import time
422 from ..streaming import StreamConfig, StreamResult
424 config = config or StreamConfig()
425 batch = []
426 total_written = 0
427 start_time = time.time()
429 async for record in records:
430 batch.append(record)
432 if len(batch) >= config.batch_size:
433 # Write the batch
434 await self.create_batch(batch)
435 total_written += len(batch)
436 batch = []
438 # Write any remaining records
439 if batch:
440 await self.create_batch(batch)
441 total_written += len(batch)
443 elapsed = time.time() - start_time
445 return StreamResult(
446 total_processed=total_written,
447 successful=total_written,
448 failed=0,
449 duration=elapsed,
450 total_batches=(total_written + config.batch_size - 1) // config.batch_size
451 )
453 async def vector_search(
454 self,
455 query_vector,
456 vector_field: str = "embedding",
457 k: int = 10,
458 filter=None,
459 metric=None,
460 **kwargs
461 ):
462 """Perform async vector similarity search using Python-based calculations.
464 Delegates to PythonVectorSearchMixin for the implementation.
466 Args:
467 query_vector: Query vector
468 vector_field: Name of the vector field to search
469 k: Number of results to return
470 filter: Optional filter conditions
471 metric: Distance metric (uses instance default if not specified)
472 **kwargs: Additional arguments for compatibility
474 Returns:
475 List of VectorSearchResult objects with scores
476 """
477 self._check_connection()
479 # Delegate to the mixin's implementation
480 return await self.python_vector_search_async(
481 query_vector=query_vector,
482 vector_field=vector_field,
483 k=k,
484 filter=filter,
485 metric=metric,
486 **kwargs
487 )