Coverage for src/dataknobs_data/backends/sqlite_async.py: 18%
229 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
1"""Async SQLite backend implementation using aiosqlite."""
3from __future__ import annotations
5import logging
6from pathlib import Path
7from typing import Any, TYPE_CHECKING
9import aiosqlite
10from dataknobs_config import ConfigurableBase
12from ..database import AsyncDatabase
13from ..pooling import ConnectionPoolManager
14from ..query import Query
15from ..query_logic import ComplexQuery
16from ..vector import VectorOperationsMixin
17from ..vector.bulk_embed_mixin import BulkEmbedMixin
18from ..vector.python_vector_search import PythonVectorSearchMixin
19from .sql_base import SQLQueryBuilder, SQLTableManager
20from .sqlite_mixins import SQLiteVectorSupport
21from .vector_config_mixin import VectorConfigMixin
23if TYPE_CHECKING:
24 from collections.abc import AsyncIterator
25 from ..records import Record
26 from ..streaming import StreamConfig, StreamResult
29logger = logging.getLogger(__name__)
31# Global pool manager for SQLite connections
32_pool_manager = ConnectionPoolManager()
35class AsyncSQLiteDatabase( # type: ignore[misc]
36 AsyncDatabase,
37 ConfigurableBase,
38 VectorConfigMixin,
39 SQLiteVectorSupport,
40 PythonVectorSearchMixin, # Provides python_vector_search_async
41 BulkEmbedMixin, # Must come before VectorOperationsMixin to override bulk_embed_and_store
42 VectorOperationsMixin
43):
44 """Asynchronous SQLite database backend using aiosqlite."""
46 def __init__(self, config: dict[str, Any] | None = None):
47 """Initialize async SQLite database.
49 Args:
50 config: Configuration with the following optional keys:
51 - path: Database file path (default: ":memory:")
52 - table: Table name (default: "records")
53 - timeout: Connection timeout in seconds (default: 5.0)
54 - journal_mode: Journal mode (WAL, DELETE, etc.) (default: WAL for file-based)
55 - synchronous: Synchronous mode (NORMAL, FULL, OFF) (default: NORMAL)
56 - pool_size: Number of connections in pool (default: 5)
57 """
58 super().__init__(config)
59 config = config or {}
60 self.db_path = config.get("path", ":memory:")
61 self.table_name = config.get("table", "records")
62 self.timeout = config.get("timeout", 5.0)
63 self.journal_mode = config.get("journal_mode", "WAL" if self.db_path != ":memory:" else None)
64 self.synchronous = config.get("synchronous", "NORMAL")
65 self.pool_size = config.get("pool_size", 5)
67 # Start with standard query builder, will customize after mixins are initialized
68 self.query_builder = SQLQueryBuilder(self.table_name, dialect="sqlite", param_style="qmark")
69 self.table_manager = SQLTableManager(self.table_name, dialect="sqlite")
71 self.db: aiosqlite.Connection | None = None
72 self._connected = False
74 # Initialize vector support
75 self._parse_vector_config(config)
76 self._init_vector_state()
78 @classmethod
79 def from_config(cls, config: dict) -> AsyncSQLiteDatabase:
80 """Create from config dictionary."""
81 return cls(config)
83 async def connect(self) -> None:
84 """Connect to the SQLite database."""
85 if self._connected:
86 return
88 # Create directory if needed for file-based database
89 if self.db_path != ":memory:":
90 db_file = Path(self.db_path)
91 db_file.parent.mkdir(parents=True, exist_ok=True)
93 # Connect to database
94 self.db = await aiosqlite.connect(
95 self.db_path,
96 timeout=self.timeout
97 )
99 # Enable row factory for dict-like access
100 self.db.row_factory = aiosqlite.Row
102 # Configure SQLite for better performance
103 await self._configure_sqlite()
105 # Create table if it doesn't exist
106 await self._ensure_table()
108 self._connected = True
109 logger.info(f"Connected to async SQLite database: {self.db_path}")
111 async def close(self) -> None:
112 """Close the database connection."""
113 if self.db:
114 await self.db.close()
115 self.db = None
116 self._connected = False
117 logger.info(f"Disconnected from async SQLite database: {self.db_path}")
119 async def _configure_sqlite(self) -> None:
120 """Configure SQLite settings for performance."""
121 if not self.db:
122 return
124 # Set journal mode if specified
125 if self.journal_mode:
126 await self.db.execute(f"PRAGMA journal_mode = {self.journal_mode}")
127 logger.debug(f"Set journal_mode to {self.journal_mode}")
129 # Set synchronous mode
130 await self.db.execute(f"PRAGMA synchronous = {self.synchronous}")
131 logger.debug(f"Set synchronous to {self.synchronous}")
133 # Enable foreign keys
134 await self.db.execute("PRAGMA foreign_keys = ON")
136 # Optimize for performance
137 await self.db.execute("PRAGMA temp_store = MEMORY")
138 await self.db.execute("PRAGMA mmap_size = 30000000000")
140 await self.db.commit()
142 async def _ensure_table(self) -> None:
143 """Ensure the table exists."""
144 if not self.db:
145 raise RuntimeError("Database not connected. Call connect() first.")
147 await self.db.executescript(self.table_manager.get_create_table_sql())
148 await self.db.commit()
150 def _check_connection(self) -> None:
151 """Check if database is connected."""
152 if not self._connected or not self.db:
153 raise RuntimeError("Database not connected. Call connect() first.")
155 async def create(self, record: Record) -> str:
156 """Create a new record."""
157 self._check_connection()
159 query, params = self.query_builder.build_create_query(record)
161 try:
162 await self.db.execute(query, params)
163 await self.db.commit()
165 # SQLite doesn't support RETURNING, so we use the ID we generated
166 record_id = params[0] # ID is the first parameter
167 return record_id
168 except aiosqlite.IntegrityError as e:
169 await self.db.rollback()
170 raise ValueError(f"Record with ID {params[0]} already exists") from e
172 async def read(self, id: str) -> Record | None:
173 """Read a record by ID."""
174 self._check_connection()
176 query, params = self.query_builder.build_read_query(id)
178 async with self.db.execute(query, params) as cursor:
179 row = await cursor.fetchone()
181 if row:
182 return SQLQueryBuilder.row_to_record(dict(row))
183 return None
185 async def update(self, id: str, record: Record) -> bool:
186 """Update an existing record."""
187 self._check_connection()
189 query, params = self.query_builder.build_update_query(id, record)
191 cursor = await self.db.execute(query, params)
192 await self.db.commit()
193 return cursor.rowcount > 0
195 async def delete(self, id: str) -> bool:
196 """Delete a record by ID."""
197 self._check_connection()
199 query, params = self.query_builder.build_delete_query(id)
201 cursor = await self.db.execute(query, params)
202 await self.db.commit()
203 return cursor.rowcount > 0
205 async def exists(self, id: str) -> bool:
206 """Check if a record exists."""
207 self._check_connection()
209 query, params = self.query_builder.build_exists_query(id)
211 async with self.db.execute(query, params) as cursor:
212 result = await cursor.fetchone()
213 return result is not None
215 async def search(self, query: Query | ComplexQuery) -> list[Record]:
216 """Search for records matching a query."""
217 self._check_connection()
219 # Handle ComplexQuery with native SQL support
220 if isinstance(query, ComplexQuery):
221 sql_query, params = self.query_builder.build_complex_search_query(query)
222 else:
223 sql_query, params = self.query_builder.build_search_query(query)
225 async with self.db.execute(sql_query, params) as cursor:
226 rows = await cursor.fetchall()
228 records = [SQLQueryBuilder.row_to_record(dict(row)) for row in rows]
230 # Apply field projection if specified
231 if query.fields:
232 records = [r.project(query.fields) for r in records]
234 return records
236 async def count(self, query: Query | None = None) -> int:
237 """Count records matching a query."""
238 self._check_connection()
240 sql_query, params = self.query_builder.build_count_query(query)
242 async with self.db.execute(sql_query, params) as cursor:
243 result = await cursor.fetchone()
244 return result[0] if result else 0
246 async def create_batch(self, records: list[Record]) -> list[str]:
247 """Create multiple records efficiently using a single query.
249 Uses multi-value INSERT for better performance.
250 """
251 if not records:
252 return []
254 self._check_connection()
256 # Use the shared batch create query builder
257 query, params, ids = self.query_builder.build_batch_create_query(records)
259 # Execute the batch insert in a transaction
260 await self.db.execute("BEGIN TRANSACTION")
262 try:
263 await self.db.execute(query, params)
264 await self.db.commit()
266 # Return the generated IDs
267 return ids
268 except Exception:
269 await self.db.rollback()
270 raise
272 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
273 """Update multiple records efficiently using a single query.
275 Uses CASE expressions for batch updates, similar to PostgreSQL.
276 """
277 if not updates:
278 return []
280 self._check_connection()
282 # Use the shared batch update query builder
283 query, params = self.query_builder.build_batch_update_query(updates)
285 # Execute the batch update in a transaction
286 await self.db.execute("BEGIN TRANSACTION")
288 try:
289 await self.db.execute(query, params)
290 await self.db.commit()
292 # Check which records were actually updated
293 # SQLite doesn't have RETURNING, so we need to verify each ID
294 update_ids = [record_id for record_id, _ in updates]
295 placeholders = ", ".join(["?" for _ in update_ids])
296 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})"
298 async with self.db.execute(check_query, update_ids) as check_cursor:
299 rows = await check_cursor.fetchall()
300 existing_ids = {row[0] for row in rows}
302 # Return results for each update
303 results = []
304 for record_id, _ in updates:
305 results.append(record_id in existing_ids)
307 return results
308 except Exception:
309 await self.db.rollback()
310 raise
312 async def delete_batch(self, ids: list[str]) -> list[bool]:
313 """Delete multiple records efficiently using a single query.
315 Uses single DELETE with IN clause for better performance.
316 """
317 if not ids:
318 return []
320 self._check_connection()
322 # Check which IDs exist before deletion
323 placeholders = ", ".join(["?" for _ in ids])
324 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})"
326 async with self.db.execute(check_query, ids) as cursor:
327 rows = await cursor.fetchall()
328 existing_ids = {row[0] for row in rows}
330 # Use the shared batch delete query builder
331 query, params = self.query_builder.build_batch_delete_query(ids)
333 # Execute the batch delete in a transaction
334 await self.db.execute("BEGIN TRANSACTION")
336 try:
337 await self.db.execute(query, params)
338 await self.db.commit()
340 # Return results based on which IDs existed
341 results = []
342 for id in ids:
343 results.append(id in existing_ids)
345 return results
346 except Exception:
347 await self.db.rollback()
348 raise
350 def _initialize(self) -> None:
351 """Initialize method - connection setup handled in connect()."""
352 pass
354 async def _count_all(self) -> int:
355 """Count all records in the database."""
356 self._check_connection()
358 async with self.db.execute(f"SELECT COUNT(*) FROM {self.table_name}") as cursor:
359 result = await cursor.fetchone()
360 return result[0] if result else 0
362 async def stream_read(
363 self,
364 query: Query | None = None,
365 config: StreamConfig | None = None
366 ) -> AsyncIterator[Record]:
367 """Stream records from database."""
368 from ..streaming import StreamConfig
370 config = config or StreamConfig()
371 query = query or Query()
373 # Use the existing stream method's logic but yield individual records
374 offset = 0
375 while True:
376 # Fetch a batch
377 query_copy = query.copy()
378 query_copy.offset(offset).limit(config.batch_size)
379 batch = await self.search(query_copy)
381 if not batch:
382 break
384 for record in batch:
385 yield record
387 offset += len(batch)
389 # If we got less than batch_size, we're done
390 if len(batch) < config.batch_size:
391 break
393 async def stream_write(
394 self,
395 records: AsyncIterator[Record],
396 config: StreamConfig | None = None
397 ) -> StreamResult:
398 """Stream records into database."""
399 import time
401 from ..streaming import StreamConfig, StreamResult
403 config = config or StreamConfig()
404 batch = []
405 total_written = 0
406 start_time = time.time()
408 async for record in records:
409 batch.append(record)
411 if len(batch) >= config.batch_size:
412 # Write the batch
413 await self.create_batch(batch)
414 total_written += len(batch)
415 batch = []
417 # Write any remaining records
418 if batch:
419 await self.create_batch(batch)
420 total_written += len(batch)
422 elapsed = time.time() - start_time
424 return StreamResult(
425 total_processed=total_written,
426 successful=total_written,
427 failed=0,
428 duration=elapsed,
429 total_batches=(total_written + config.batch_size - 1) // config.batch_size
430 )
432 async def vector_search(
433 self,
434 query_vector,
435 vector_field: str = "embedding",
436 k: int = 10,
437 filter=None,
438 metric=None,
439 **kwargs
440 ):
441 """Perform async vector similarity search using Python-based calculations.
443 Delegates to PythonVectorSearchMixin for the implementation.
445 Args:
446 query_vector: Query vector
447 vector_field: Name of the vector field to search
448 k: Number of results to return
449 filter: Optional filter conditions
450 metric: Distance metric (uses instance default if not specified)
451 **kwargs: Additional arguments for compatibility
453 Returns:
454 List of VectorSearchResult objects with scores
455 """
456 self._check_connection()
458 # Delegate to the mixin's implementation
459 return await self.python_vector_search_async(
460 query_vector=query_vector,
461 vector_field=vector_field,
462 k=k,
463 filter=filter,
464 metric=metric,
465 **kwargs
466 )