Coverage for src/dataknobs_data/backends/s3_async.py: 15%
318 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
1"""Native async S3 backend implementation with aioboto3 and connection pooling."""
3from __future__ import annotations
5import asyncio
6import json
7import logging
8import time
9import uuid
10from datetime import datetime
11from typing import Any, TYPE_CHECKING, cast, Callable, Awaitable
13from dataknobs_config import ConfigurableBase
15from ..database import AsyncDatabase
16from ..pooling import ConnectionPoolManager
17from ..pooling.s3 import S3PoolConfig, create_aioboto3_session, validate_s3_session
18from ..query import Operator, Query, SortOrder
19from ..records import Record
20from ..streaming import StreamConfig, StreamResult, async_process_batch_with_fallback
21from ..vector import VectorOperationsMixin
22from ..vector.bulk_embed_mixin import BulkEmbedMixin
23from ..vector.python_vector_search import PythonVectorSearchMixin
24from .sqlite_mixins import SQLiteVectorSupport
25from .vector_config_mixin import VectorConfigMixin
27if TYPE_CHECKING:
28 from collections.abc import AsyncIterator
31logger = logging.getLogger(__name__)
33# Global pool manager for S3 sessions
34_session_manager = ConnectionPoolManager()
37class AsyncS3Database( # type: ignore[misc]
38 AsyncDatabase,
39 ConfigurableBase,
40 VectorConfigMixin,
41 SQLiteVectorSupport,
42 PythonVectorSearchMixin,
43 BulkEmbedMixin,
44 VectorOperationsMixin
45):
46 """Native async S3 database backend with aioboto3 and session pooling."""
48 def __init__(self, config: dict[str, Any] | None = None):
49 """Initialize async S3 database."""
50 super().__init__(config)
52 if not config or "bucket" not in config:
53 raise ValueError("S3 backend requires 'bucket' in configuration")
55 self._pool_config = S3PoolConfig.from_dict(config)
56 self._session = None
57 self._connected = False
59 # Initialize vector support
60 self._parse_vector_config(config or {})
61 self._init_vector_state() # From SQLiteVectorSupport
63 @classmethod
64 def from_config(cls, config: dict) -> AsyncS3Database:
65 """Create from config dictionary."""
66 return cls(config)
68 async def connect(self) -> None:
69 """Connect to S3 service."""
70 if self._connected:
71 return
73 # Get or create session for current event loop
74 from ..pooling import BasePoolConfig
75 self._session = await _session_manager.get_pool(
76 self._pool_config,
77 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_aioboto3_session),
78 lambda session: validate_s3_session(session, self._pool_config)
79 )
81 self._connected = True
83 async def close(self) -> None:
84 """Close the S3 connection."""
85 if self._connected:
86 self._session = None
87 self._connected = False
89 def _initialize(self) -> None:
90 """Initialize is handled in connect."""
91 pass
93 def _check_connection(self) -> None:
94 """Check if database is connected."""
95 if not self._connected or not self._session:
96 raise RuntimeError("Database not connected. Call connect() first.")
98 def _get_key(self, id: str) -> str:
99 """Get the S3 key for a given record ID."""
100 if self._pool_config.prefix:
101 return f"{self._pool_config.prefix}/{id}.json"
102 return f"{id}.json"
104 def _record_to_s3_object(self, record: Record) -> dict[str, Any]:
105 """Convert a Record to an S3 object."""
106 # Use Record's built-in serialization which handles VectorFields
107 record_dict = record.to_dict(include_metadata=True, flatten=False)
109 # Add timestamps
110 now = datetime.utcnow().isoformat()
111 if "metadata" not in record_dict:
112 record_dict["metadata"] = {}
113 record_dict["metadata"]["created_at"] = record_dict["metadata"].get("created_at", now)
114 record_dict["metadata"]["updated_at"] = now
116 return record_dict
118 def _s3_object_to_record(self, obj: dict[str, Any]) -> Record:
119 """Convert an S3 object to a Record."""
120 # Use Record's built-in deserialization
121 return Record.from_dict(obj)
123 async def create(self, record: Record) -> str:
124 """Create a new record in S3."""
125 self._check_connection()
127 # Use centralized method to prepare record
128 record_copy, storage_id = self._prepare_record_for_storage(record)
129 key = self._get_key(storage_id)
130 obj = self._record_to_s3_object(record_copy)
132 # Add ID to metadata
133 obj["metadata"]["id"] = storage_id
135 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
136 await s3.put_object(
137 Bucket=self._pool_config.bucket,
138 Key=key,
139 Body=json.dumps(obj),
140 ContentType="application/json"
141 )
143 return storage_id
145 async def read(self, id: str) -> Record | None:
146 """Read a record from S3."""
147 self._check_connection()
149 key = self._get_key(id)
151 try:
152 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
153 response = await s3.get_object(
154 Bucket=self._pool_config.bucket,
155 Key=key
156 )
158 # Read the object body
159 body = await response['Body'].read()
160 obj = json.loads(body)
162 record = self._s3_object_to_record(obj)
163 # Use centralized method to prepare record
164 record = self._prepare_record_from_storage(record, id)
165 # Ensure ID is in metadata
166 record.metadata["id"] = id
168 return record
169 except Exception:
170 return None
172 async def update(self, id: str, record: Record) -> bool:
173 """Update an existing record in S3."""
174 self._check_connection()
176 # Check if record exists
177 if not await self.exists(id):
178 return False
180 key = self._get_key(id)
181 obj = self._record_to_s3_object(record)
183 # Preserve ID in metadata
184 obj["metadata"]["id"] = id
186 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
187 await s3.put_object(
188 Bucket=self._pool_config.bucket,
189 Key=key,
190 Body=json.dumps(obj),
191 ContentType="application/json"
192 )
194 return True
196 async def delete(self, id: str) -> bool:
197 """Delete a record from S3."""
198 self._check_connection()
200 key = self._get_key(id)
202 try:
203 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
204 await s3.delete_object(
205 Bucket=self._pool_config.bucket,
206 Key=key
207 )
208 return True
209 except Exception:
210 return False
212 async def exists(self, id: str) -> bool:
213 """Check if a record exists in S3."""
214 self._check_connection()
216 key = self._get_key(id)
218 try:
219 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
220 await s3.head_object(
221 Bucket=self._pool_config.bucket,
222 Key=key
223 )
224 return True
225 except Exception:
226 return False
228 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
229 """Update or insert a record.
231 Can be called as:
232 - upsert(id, record) - explicit ID and record
233 - upsert(record) - extract ID from record using Record's built-in logic
234 """
235 self._check_connection()
237 # Determine ID and record based on arguments
238 if isinstance(id_or_record, str):
239 id = id_or_record
240 if record is None:
241 raise ValueError("Record required when ID is provided")
242 else:
243 record = id_or_record
244 id = record.id
245 if id is None:
246 import uuid # type: ignore[unreachable]
247 id = str(uuid.uuid4())
248 record.storage_id = id
250 key = self._get_key(id)
251 obj = self._record_to_s3_object(record)
253 # Add ID to metadata
254 obj["metadata"]["id"] = id
256 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
257 await s3.put_object(
258 Bucket=self._pool_config.bucket,
259 Key=key,
260 Body=json.dumps(obj),
261 ContentType="application/json"
262 )
264 return id
266 async def search(self, query: Query) -> list[Record]:
267 """Search for records matching the query."""
268 self._check_connection()
270 # S3 doesn't support complex queries, so we need to list and filter
271 records = []
273 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
274 # List all objects
275 paginator = s3.get_paginator('list_objects_v2')
277 params = {
278 'Bucket': self._pool_config.bucket,
279 }
280 if self._pool_config.prefix:
281 params['Prefix'] = self._pool_config.prefix
283 async for page in paginator.paginate(**params):
284 if 'Contents' not in page:
285 continue
287 # Process each object
288 for obj_summary in page['Contents']:
289 key = obj_summary['Key']
291 # Skip non-JSON files
292 if not key.endswith('.json'):
293 continue
295 # Get the object
296 response = await s3.get_object(
297 Bucket=self._pool_config.bucket,
298 Key=key
299 )
301 body = await response['Body'].read()
302 obj = json.loads(body)
303 record = self._s3_object_to_record(obj)
305 # Extract ID from key
306 id = key.replace(self._pool_config.prefix + '/', '').replace('.json', '')
307 record.metadata["id"] = id
309 # Apply filters
310 if self._matches_filters(record, query.filters):
311 records.append(record)
313 # Apply sorting
314 if query.sort_specs:
315 for sort_spec in reversed(query.sort_specs):
316 reverse = sort_spec.order == SortOrder.DESC
317 records.sort(
318 key=lambda r: (r.get_field(sort_spec.field).value if r.get_field(sort_spec.field) else "") or "",
319 reverse=reverse
320 )
322 # Apply offset and limit
323 if query.offset_value:
324 records = records[query.offset_value:]
325 if query.limit_value:
326 records = records[:query.limit_value]
328 # Apply field projection
329 if query.fields:
330 records = [r.project(query.fields) for r in records]
332 return records
334 def _matches_filters(self, record: Record, filters: list) -> bool:
335 """Check if a record matches all filters."""
336 for filter in filters:
337 field = record.get_field(filter.field)
338 if not field:
339 return False
341 value = field.value
343 if filter.operator == Operator.EQ:
344 if value != filter.value:
345 return False
346 elif filter.operator == Operator.NEQ:
347 if value == filter.value:
348 return False
349 elif filter.operator == Operator.GT:
350 if value <= filter.value:
351 return False
352 elif filter.operator == Operator.LT:
353 if value >= filter.value:
354 return False
355 elif filter.operator == Operator.GTE:
356 if value < filter.value:
357 return False
358 elif filter.operator == Operator.LTE:
359 if value > filter.value:
360 return False
361 elif filter.operator == Operator.LIKE:
362 if str(filter.value) not in str(value):
363 return False
364 elif filter.operator == Operator.IN:
365 if value not in filter.value:
366 return False
367 elif filter.operator == Operator.NOT_IN:
368 if value in filter.value:
369 return False
371 return True
373 async def _count_all(self) -> int:
374 """Count all records in the database."""
375 self._check_connection()
377 count = 0
378 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
379 paginator = s3.get_paginator('list_objects_v2')
381 params = {
382 'Bucket': self._pool_config.bucket,
383 }
384 if self._pool_config.prefix:
385 params['Prefix'] = self._pool_config.prefix
387 async for page in paginator.paginate(**params):
388 if 'Contents' in page:
389 for obj in page['Contents']:
390 if obj['Key'].endswith('.json'):
391 count += 1
393 return count
395 async def clear(self) -> int:
396 """Clear all records from the database."""
397 self._check_connection()
399 count = 0
400 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
401 # List and delete all objects
402 paginator = s3.get_paginator('list_objects_v2')
404 params = {
405 'Bucket': self._pool_config.bucket,
406 }
407 if self._pool_config.prefix:
408 params['Prefix'] = self._pool_config.prefix
410 async for page in paginator.paginate(**params):
411 if 'Contents' not in page:
412 continue
414 # Build delete request
415 objects_to_delete = []
416 for obj in page['Contents']:
417 if obj['Key'].endswith('.json'):
418 objects_to_delete.append({'Key': obj['Key']})
419 count += 1
421 # Delete in batch
422 if objects_to_delete:
423 await s3.delete_objects(
424 Bucket=self._pool_config.bucket,
425 Delete={'Objects': objects_to_delete}
426 )
428 return count
430 async def stream_read(
431 self,
432 query: Query | None = None,
433 config: StreamConfig | None = None
434 ) -> AsyncIterator[Record]:
435 """Stream records from S3."""
436 self._check_connection()
437 config = config or StreamConfig()
439 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
440 paginator = s3.get_paginator('list_objects_v2')
442 params = {
443 'Bucket': self._pool_config.bucket,
444 'MaxKeys': config.batch_size
445 }
446 if self._pool_config.prefix:
447 params['Prefix'] = self._pool_config.prefix
449 async for page in paginator.paginate(**params):
450 if 'Contents' not in page:
451 continue
453 for obj_summary in page['Contents']:
454 key = obj_summary['Key']
456 if not key.endswith('.json'):
457 continue
459 # Get the object
460 response = await s3.get_object(
461 Bucket=self._pool_config.bucket,
462 Key=key
463 )
465 body = await response['Body'].read()
466 obj = json.loads(body)
467 record = self._s3_object_to_record(obj)
469 # Extract ID from key
470 id = key.replace(self._pool_config.prefix + '/', '').replace('.json', '')
471 record.metadata["id"] = id
473 # Apply filters if query provided
474 if query and query.filters:
475 if not self._matches_filters(record, query.filters):
476 continue
478 # Apply field projection
479 if query and query.fields:
480 record = record.project(query.fields)
482 yield record
484 async def stream_write(
485 self,
486 records: AsyncIterator[Record],
487 config: StreamConfig | None = None
488 ) -> StreamResult:
489 """Stream records into S3."""
490 self._check_connection()
491 config = config or StreamConfig()
492 result = StreamResult()
493 start_time = time.time()
494 quitting = False
496 batch = []
497 async for record in records:
498 batch.append(record)
500 if len(batch) >= config.batch_size:
501 # Write batch with graceful fallback
502 async def batch_func(b):
503 await self._write_batch(b)
504 return [r.id for r in b]
506 continue_processing = await async_process_batch_with_fallback(
507 batch,
508 batch_func,
509 self.create,
510 result,
511 config
512 )
514 if not continue_processing:
515 quitting = True
516 break
518 batch = []
520 # Write remaining batch
521 if batch and not quitting:
522 async def batch_func(b):
523 await self._write_batch(b)
524 return [r.id for r in b]
526 await async_process_batch_with_fallback(
527 batch,
528 batch_func,
529 self.create,
530 result,
531 config
532 )
534 result.duration = time.time() - start_time
535 return result
537 async def _write_batch(self, records: list[Record]) -> None:
538 """Write a batch of records to S3."""
539 if not records:
540 return
542 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
543 # Write each record (S3 doesn't have native batch write)
544 # We could potentially use multipart upload for very large batches
545 tasks = []
546 for record in records:
547 id = str(uuid.uuid4())
548 key = self._get_key(id)
549 obj = self._record_to_s3_object(record)
550 obj["metadata"]["id"] = id
552 task = s3.put_object(
553 Bucket=self._pool_config.bucket,
554 Key=key,
555 Body=json.dumps(obj),
556 ContentType="application/json"
557 )
558 tasks.append(task)
560 # Execute all uploads concurrently
561 await asyncio.gather(*tasks)
563 async def list_all(self) -> list[str]:
564 """List all record IDs in the database."""
565 self._check_connection()
567 ids = []
568 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3:
569 paginator = s3.get_paginator('list_objects_v2')
571 params = {
572 'Bucket': self._pool_config.bucket,
573 }
574 if self._pool_config.prefix:
575 params['Prefix'] = self._pool_config.prefix
577 async for page in paginator.paginate(**params):
578 if 'Contents' not in page:
579 continue
581 for obj in page['Contents']:
582 key = obj['Key']
583 if key.endswith('.json'):
584 # Extract ID from key
585 id = key.replace(self._pool_config.prefix + '/', '').replace('.json', '')
586 ids.append(id)
588 return ids
590 async def vector_search(
591 self,
592 query_vector,
593 vector_field: str = "embedding",
594 k: int = 10,
595 filter=None,
596 metric=None,
597 **kwargs
598 ):
599 """Perform vector similarity search using Python calculations.
601 WARNING: This implementation downloads all records from S3 to perform
602 the search locally. This is inefficient for large datasets. Consider
603 using a vector-enabled backend like PostgreSQL or Elasticsearch for
604 production use with large datasets.
606 Future optimization: Override this method to use AWS OpenSearch or
607 similar vector-enabled service when available.
608 """
609 return await self.python_vector_search_async(
610 query_vector=query_vector,
611 vector_field=vector_field,
612 k=k,
613 filter=filter,
614 metric=metric,
615 **kwargs
616 )