Coverage for src/dataknobs_data/backends/s3_async.py: 16%
308 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
1"""Native async S3 backend implementation with aioboto3 and connection pooling."""
3from __future__ import annotations
5import asyncio
6import json
7import logging
8import time
9import uuid
10from datetime import datetime
11from typing import Any, TYPE_CHECKING, cast, Callable, Awaitable
13from dataknobs_config import ConfigurableBase
15from ..database import AsyncDatabase
16from ..pooling import ConnectionPoolManager
17from ..pooling.s3 import S3PoolConfig, create_aioboto3_session, validate_s3_session
18from ..query import Operator, Query, SortOrder
19from ..records import Record
20from ..streaming import StreamConfig, StreamResult, async_process_batch_with_fallback
21from ..vector import VectorOperationsMixin
22from ..vector.bulk_embed_mixin import BulkEmbedMixin
23from ..vector.python_vector_search import PythonVectorSearchMixin
24from .sqlite_mixins import SQLiteVectorSupport
25from .vector_config_mixin import VectorConfigMixin
27if TYPE_CHECKING:
28 from collections.abc import AsyncIterator
31logger = logging.getLogger(__name__)
33# Global pool manager for S3 sessions
34_session_manager = ConnectionPoolManager()
37class AsyncS3Database( # type: ignore[misc]
38 AsyncDatabase,
39 ConfigurableBase,
40 VectorConfigMixin,
41 SQLiteVectorSupport,
42 PythonVectorSearchMixin,
43 BulkEmbedMixin,
44 VectorOperationsMixin
45):
46 """Native async S3 database backend with aioboto3 and session pooling."""
48 def __init__(self, config: dict[str, Any] | None = None):
49 """Initialize async S3 database."""
50 super().__init__(config)
52 if not config or "bucket" not in config:
53 raise ValueError("S3 backend requires 'bucket' in configuration")
55 self._pool_config = S3PoolConfig.from_dict(config)
56 self._session = None
57 self._connected = False
59 # Initialize vector support
60 self._parse_vector_config(config or {})
61 self._init_vector_state() # From SQLiteVectorSupport
63 @classmethod
64 def from_config(cls, config: dict) -> AsyncS3Database:
65 """Create from config dictionary."""
66 return cls(config)
68 async def connect(self) -> None:
69 """Connect to S3 service."""
70 if self._connected:
71 return
73 # Get or create session for current event loop
74 from ..pooling import BasePoolConfig
75 self._session = await _session_manager.get_pool(
76 self._pool_config,
77 cast(Callable[[BasePoolConfig], Awaitable[Any]], create_aioboto3_session),
78 lambda session: validate_s3_session(session, self._pool_config)
79 )
81 self._connected = True
83 async def close(self) -> None:
84 """Close the S3 connection."""
85 if self._connected:
86 self._session = None
87 self._connected = False
89 def _initialize(self) -> None:
90 """Initialize is handled in connect."""
91 pass
93 def _check_connection(self) -> None:
94 """Check if database is connected."""
95 if not self._connected or not self._session:
96 raise RuntimeError("Database not connected. Call connect() first.")
98 def _get_key(self, id: str) -> str:
99 """Get the S3 key for a given record ID."""
100 if self._pool_config.prefix:
101 return f"{self._pool_config.prefix}/{id}.json"
102 return f"{id}.json"
104 def _record_to_s3_object(self, record: Record) -> dict[str, Any]:
105 """Convert a Record to an S3 object."""
106 # Use Record's built-in serialization which handles VectorFields
107 record_dict = record.to_dict(include_metadata=True, flatten=False)
109 # Add timestamps
110 now = datetime.utcnow().isoformat()
111 if "metadata" not in record_dict:
112 record_dict["metadata"] = {}
113 record_dict["metadata"]["created_at"] = record_dict["metadata"].get("created_at", now)
114 record_dict["metadata"]["updated_at"] = now
116 return record_dict
118 def _s3_object_to_record(self, obj: dict[str, Any]) -> Record:
119 """Convert an S3 object to a Record."""
120 # Use Record's built-in deserialization
121 return Record.from_dict(obj)
123 async def create(self, record: Record) -> str:
124 """Create a new record in S3."""
125 self._check_connection()
127 # Use centralized method to prepare record
128 record_copy, storage_id = self._prepare_record_for_storage(record)
129 key = self._get_key(storage_id)
130 obj = self._record_to_s3_object(record_copy)
132 # Add ID to metadata
133 obj["metadata"]["id"] = storage_id
135 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
136 await s3.put_object(
137 Bucket=self._pool_config.bucket,
138 Key=key,
139 Body=json.dumps(obj),
140 ContentType="application/json"
141 )
143 return storage_id
145 async def read(self, id: str) -> Record | None:
146 """Read a record from S3."""
147 self._check_connection()
149 key = self._get_key(id)
151 try:
152 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
153 response = await s3.get_object(
154 Bucket=self._pool_config.bucket,
155 Key=key
156 )
158 # Read the object body
159 body = await response['Body'].read()
160 obj = json.loads(body)
162 record = self._s3_object_to_record(obj)
163 # Use centralized method to prepare record
164 record = self._prepare_record_from_storage(record, id)
165 # Ensure ID is in metadata
166 record.metadata["id"] = id
168 return record
169 except Exception:
170 return None
172 async def update(self, id: str, record: Record) -> bool:
173 """Update an existing record in S3."""
174 self._check_connection()
176 # Check if record exists
177 if not await self.exists(id):
178 return False
180 key = self._get_key(id)
181 obj = self._record_to_s3_object(record)
183 # Preserve ID in metadata
184 obj["metadata"]["id"] = id
186 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
187 await s3.put_object(
188 Bucket=self._pool_config.bucket,
189 Key=key,
190 Body=json.dumps(obj),
191 ContentType="application/json"
192 )
194 return True
196 async def delete(self, id: str) -> bool:
197 """Delete a record from S3."""
198 self._check_connection()
200 key = self._get_key(id)
202 try:
203 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
204 await s3.delete_object(
205 Bucket=self._pool_config.bucket,
206 Key=key
207 )
208 return True
209 except Exception:
210 return False
212 async def exists(self, id: str) -> bool:
213 """Check if a record exists in S3."""
214 self._check_connection()
216 key = self._get_key(id)
218 try:
219 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
220 await s3.head_object(
221 Bucket=self._pool_config.bucket,
222 Key=key
223 )
224 return True
225 except Exception:
226 return False
228 async def upsert(self, id: str, record: Record) -> str:
229 """Update or insert a record with a specific ID."""
230 self._check_connection()
232 key = self._get_key(id)
233 obj = self._record_to_s3_object(record)
235 # Add ID to metadata
236 obj["metadata"]["id"] = id
238 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
239 await s3.put_object(
240 Bucket=self._pool_config.bucket,
241 Key=key,
242 Body=json.dumps(obj),
243 ContentType="application/json"
244 )
246 return id
248 async def search(self, query: Query) -> list[Record]:
249 """Search for records matching the query."""
250 self._check_connection()
252 # S3 doesn't support complex queries, so we need to list and filter
253 records = []
255 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
256 # List all objects
257 paginator = s3.get_paginator('list_objects_v2')
259 params = {
260 'Bucket': self._pool_config.bucket,
261 }
262 if self._pool_config.prefix:
263 params['Prefix'] = self._pool_config.prefix
265 async for page in paginator.paginate(**params):
266 if 'Contents' not in page:
267 continue
269 # Process each object
270 for obj_summary in page['Contents']:
271 key = obj_summary['Key']
273 # Skip non-JSON files
274 if not key.endswith('.json'):
275 continue
277 # Get the object
278 response = await s3.get_object(
279 Bucket=self._pool_config.bucket,
280 Key=key
281 )
283 body = await response['Body'].read()
284 obj = json.loads(body)
285 record = self._s3_object_to_record(obj)
287 # Extract ID from key
288 id = key.replace(self._pool_config.prefix + '/', '').replace('.json', '')
289 record.metadata["id"] = id
291 # Apply filters
292 if self._matches_filters(record, query.filters):
293 records.append(record)
295 # Apply sorting
296 if query.sort_specs:
297 for sort_spec in reversed(query.sort_specs):
298 reverse = sort_spec.order == SortOrder.DESC
299 records.sort(
300 key=lambda r: (r.get_field(sort_spec.field).value if r.get_field(sort_spec.field) else "") or "",
301 reverse=reverse
302 )
304 # Apply offset and limit
305 if query.offset_value:
306 records = records[query.offset_value:]
307 if query.limit_value:
308 records = records[:query.limit_value]
310 # Apply field projection
311 if query.fields:
312 records = [r.project(query.fields) for r in records]
314 return records
316 def _matches_filters(self, record: Record, filters: list) -> bool:
317 """Check if a record matches all filters."""
318 for filter in filters:
319 field = record.get_field(filter.field)
320 if not field:
321 return False
323 value = field.value
325 if filter.operator == Operator.EQ:
326 if value != filter.value:
327 return False
328 elif filter.operator == Operator.NEQ:
329 if value == filter.value:
330 return False
331 elif filter.operator == Operator.GT:
332 if value <= filter.value:
333 return False
334 elif filter.operator == Operator.LT:
335 if value >= filter.value:
336 return False
337 elif filter.operator == Operator.GTE:
338 if value < filter.value:
339 return False
340 elif filter.operator == Operator.LTE:
341 if value > filter.value:
342 return False
343 elif filter.operator == Operator.LIKE:
344 if str(filter.value) not in str(value):
345 return False
346 elif filter.operator == Operator.IN:
347 if value not in filter.value:
348 return False
349 elif filter.operator == Operator.NOT_IN:
350 if value in filter.value:
351 return False
353 return True
355 async def _count_all(self) -> int:
356 """Count all records in the database."""
357 self._check_connection()
359 count = 0
360 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
361 paginator = s3.get_paginator('list_objects_v2')
363 params = {
364 'Bucket': self._pool_config.bucket,
365 }
366 if self._pool_config.prefix:
367 params['Prefix'] = self._pool_config.prefix
369 async for page in paginator.paginate(**params):
370 if 'Contents' in page:
371 for obj in page['Contents']:
372 if obj['Key'].endswith('.json'):
373 count += 1
375 return count
377 async def clear(self) -> int:
378 """Clear all records from the database."""
379 self._check_connection()
381 count = 0
382 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
383 # List and delete all objects
384 paginator = s3.get_paginator('list_objects_v2')
386 params = {
387 'Bucket': self._pool_config.bucket,
388 }
389 if self._pool_config.prefix:
390 params['Prefix'] = self._pool_config.prefix
392 async for page in paginator.paginate(**params):
393 if 'Contents' not in page:
394 continue
396 # Build delete request
397 objects_to_delete = []
398 for obj in page['Contents']:
399 if obj['Key'].endswith('.json'):
400 objects_to_delete.append({'Key': obj['Key']})
401 count += 1
403 # Delete in batch
404 if objects_to_delete:
405 await s3.delete_objects(
406 Bucket=self._pool_config.bucket,
407 Delete={'Objects': objects_to_delete}
408 )
410 return count
412 async def stream_read(
413 self,
414 query: Query | None = None,
415 config: StreamConfig | None = None
416 ) -> AsyncIterator[Record]:
417 """Stream records from S3."""
418 self._check_connection()
419 config = config or StreamConfig()
421 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
422 paginator = s3.get_paginator('list_objects_v2')
424 params = {
425 'Bucket': self._pool_config.bucket,
426 'MaxKeys': config.batch_size
427 }
428 if self._pool_config.prefix:
429 params['Prefix'] = self._pool_config.prefix
431 async for page in paginator.paginate(**params):
432 if 'Contents' not in page:
433 continue
435 for obj_summary in page['Contents']:
436 key = obj_summary['Key']
438 if not key.endswith('.json'):
439 continue
441 # Get the object
442 response = await s3.get_object(
443 Bucket=self._pool_config.bucket,
444 Key=key
445 )
447 body = await response['Body'].read()
448 obj = json.loads(body)
449 record = self._s3_object_to_record(obj)
451 # Extract ID from key
452 id = key.replace(self._pool_config.prefix + '/', '').replace('.json', '')
453 record.metadata["id"] = id
455 # Apply filters if query provided
456 if query and query.filters:
457 if not self._matches_filters(record, query.filters):
458 continue
460 # Apply field projection
461 if query and query.fields:
462 record = record.project(query.fields)
464 yield record
466 async def stream_write(
467 self,
468 records: AsyncIterator[Record],
469 config: StreamConfig | None = None
470 ) -> StreamResult:
471 """Stream records into S3."""
472 self._check_connection()
473 config = config or StreamConfig()
474 result = StreamResult()
475 start_time = time.time()
476 quitting = False
478 batch = []
479 async for record in records:
480 batch.append(record)
482 if len(batch) >= config.batch_size:
483 # Write batch with graceful fallback
484 async def batch_func(b):
485 await self._write_batch(b)
486 return [r.id for r in b]
488 continue_processing = await async_process_batch_with_fallback(
489 batch,
490 batch_func,
491 self.create,
492 result,
493 config
494 )
496 if not continue_processing:
497 quitting = True
498 break
500 batch = []
502 # Write remaining batch
503 if batch and not quitting:
504 async def batch_func(b):
505 await self._write_batch(b)
506 return [r.id for r in b]
508 await async_process_batch_with_fallback(
509 batch,
510 batch_func,
511 self.create,
512 result,
513 config
514 )
516 result.duration = time.time() - start_time
517 return result
519 async def _write_batch(self, records: list[Record]) -> None:
520 """Write a batch of records to S3."""
521 if not records:
522 return
524 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
525 # Write each record (S3 doesn't have native batch write)
526 # We could potentially use multipart upload for very large batches
527 tasks = []
528 for record in records:
529 id = str(uuid.uuid4())
530 key = self._get_key(id)
531 obj = self._record_to_s3_object(record)
532 obj["metadata"]["id"] = id
534 task = s3.put_object(
535 Bucket=self._pool_config.bucket,
536 Key=key,
537 Body=json.dumps(obj),
538 ContentType="application/json"
539 )
540 tasks.append(task)
542 # Execute all uploads concurrently
543 await asyncio.gather(*tasks)
545 async def list_all(self) -> list[str]:
546 """List all record IDs in the database."""
547 self._check_connection()
549 ids = []
550 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
551 paginator = s3.get_paginator('list_objects_v2')
553 params = {
554 'Bucket': self._pool_config.bucket,
555 }
556 if self._pool_config.prefix:
557 params['Prefix'] = self._pool_config.prefix
559 async for page in paginator.paginate(**params):
560 if 'Contents' not in page:
561 continue
563 for obj in page['Contents']:
564 key = obj['Key']
565 if key.endswith('.json'):
566 # Extract ID from key
567 id = key.replace(self._pool_config.prefix + '/', '').replace('.json', '')
568 ids.append(id)
570 return ids
572 async def vector_search(
573 self,
574 query_vector,
575 vector_field: str = "embedding",
576 k: int = 10,
577 filter=None,
578 metric=None,
579 **kwargs
580 ):
581 """Perform vector similarity search using Python calculations.
583 WARNING: This implementation downloads all records from S3 to perform
584 the search locally. This is inefficient for large datasets. Consider
585 using a vector-enabled backend like PostgreSQL or Elasticsearch for
586 production use with large datasets.
588 Future optimization: Override this method to use AWS OpenSearch or
589 similar vector-enabled service when available.
590 """
591 return await self.python_vector_search_async(
592 query_vector=query_vector,
593 vector_field=vector_field,
594 k=k,
595 filter=filter,
596 metric=metric,
597 **kwargs
598 )