Coverage for src/dataknobs_data/backends/s3.py: 15%
311 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
1"""S3 backend implementation with proper connection management."""
3from __future__ import annotations
5import json
6import logging
7import time
8from concurrent.futures import ThreadPoolExecutor, as_completed
9from datetime import datetime
10from typing import Any, TYPE_CHECKING
11from uuid import uuid4
13from dataknobs_config import ConfigurableBase
15from dataknobs_data.database import SyncDatabase
16from dataknobs_data.query import Query
17from dataknobs_data.records import Record
18from dataknobs_data.streaming import StreamConfig, StreamResult, process_batch_with_fallback
20from ..vector import VectorOperationsMixin
21from ..vector.bulk_embed_mixin import BulkEmbedMixin
22from ..vector.python_vector_search import PythonVectorSearchMixin
23from .sqlite_mixins import SQLiteVectorSupport
24from .vector_config_mixin import VectorConfigMixin
26if TYPE_CHECKING:
27 from collections.abc import Iterator
30logger = logging.getLogger(__name__)
33class SyncS3Database( # type: ignore[misc]
34 SyncDatabase,
35 ConfigurableBase,
36 VectorConfigMixin,
37 SQLiteVectorSupport,
38 PythonVectorSearchMixin,
39 BulkEmbedMixin,
40 VectorOperationsMixin
41):
42 """S3-based database backend with proper connection management.
44 Stores records as JSON objects in S3 with metadata as tags.
45 """
47 def __init__(self, config: dict[str, Any] | None = None):
48 """Initialize S3 database configuration.
50 Args:
51 config: Configuration dictionary
52 """
53 super().__init__(config)
55 # Connection state
56 self.s3_client = None
57 self._connected = False
59 # Cache for performance
60 self._index_cache = {}
61 self._cache_dirty = True
63 # Store configuration for later connection
64 self.bucket = self.config.get("bucket")
65 if not self.bucket:
66 raise ValueError("S3 bucket name is required in configuration")
68 # Optional configuration with defaults
69 self.prefix = self.config.get("prefix", "records/").rstrip("/") + "/"
70 self.region = self.config.get("region", "us-east-1")
71 self.endpoint_url = self.config.get("endpoint_url")
72 self.max_workers = self.config.get("max_workers", 10)
73 self.multipart_threshold = self.config.get("multipart_threshold", 8 * 1024 * 1024)
74 self.multipart_chunksize = self.config.get("multipart_chunksize", 8 * 1024 * 1024)
75 self.max_retries = self.config.get("max_retries", 3)
77 # AWS credentials (will use environment/IAM role if not provided)
78 self.aws_access_key_id = self.config.get("access_key_id")
79 self.aws_secret_access_key = self.config.get("secret_access_key")
80 self.aws_session_token = self.config.get("session_token")
82 # Initialize vector support
83 self._parse_vector_config(config or {})
84 self._init_vector_state()
86 @classmethod
87 def from_config(cls, config: dict) -> SyncS3Database:
88 """Create instance from configuration dictionary."""
89 return cls(config)
91 def connect(self) -> None:
92 """Connect to S3 service."""
93 if self._connected:
94 return # Already connected
96 import boto3
97 from botocore.config import Config as BotoConfig
98 from botocore.exceptions import ClientError
100 # Configure boto3 client
101 boto_config = BotoConfig(
102 region_name=self.region,
103 max_pool_connections=self.max_workers,
104 retries={'max_attempts': self.max_retries}
105 )
107 client_kwargs = {
108 "config": boto_config,
109 "use_ssl": not bool(self.endpoint_url) # Disable SSL for local testing
110 }
112 if self.endpoint_url:
113 client_kwargs["endpoint_url"] = self.endpoint_url
115 if self.aws_access_key_id and self.aws_secret_access_key:
116 client_kwargs["aws_access_key_id"] = self.aws_access_key_id
117 client_kwargs["aws_secret_access_key"] = self.aws_secret_access_key
119 if self.aws_session_token:
120 client_kwargs["aws_session_token"] = self.aws_session_token
122 # Create S3 client
123 self.s3_client = boto3.client("s3", **client_kwargs)
124 self.ClientError = ClientError
126 # Verify bucket exists or create it
127 self._ensure_bucket_exists()
129 self._connected = True
130 logger.info(f"Connected to S3 with bucket={self.bucket}, prefix={self.prefix}")
132 def close(self) -> None:
133 """Close the S3 connection."""
134 if self.s3_client:
135 # S3 client doesn't need explicit closing, but clear cache
136 self._index_cache = {} # type: ignore[unreachable]
137 self._connected = False
138 logger.info(f"Closed S3 connection to bucket={self.bucket}")
140 def _initialize(self) -> None:
141 """Initialize method - connection setup moved to connect()."""
142 pass
144 def _check_connection(self) -> None:
145 """Check if S3 client is connected."""
146 if not self._connected or not self.s3_client:
147 raise RuntimeError("S3 not connected. Call connect() first.")
149 def _ensure_bucket_exists(self):
150 """Ensure the S3 bucket exists, create if necessary."""
151 try:
152 self.s3_client.head_bucket(Bucket=self.bucket)
153 logger.debug(f"Bucket {self.bucket} exists")
154 except self.ClientError as e:
155 error_code = e.response['Error']['Code']
156 if error_code == '404':
157 # Bucket doesn't exist, create it
158 logger.info(f"Creating bucket {self.bucket}")
159 if self.region == 'us-east-1':
160 self.s3_client.create_bucket(Bucket=self.bucket)
161 else:
162 self.s3_client.create_bucket(
163 Bucket=self.bucket,
164 CreateBucketConfiguration={'LocationConstraint': self.region}
165 )
166 else:
167 raise
169 def _get_object_key(self, record_id: str) -> str:
170 """Generate S3 object key for a record ID."""
171 return f"{self.prefix}{record_id}.json"
173 def _record_to_s3_object(self, record: Record) -> dict[str, Any]:
174 """Convert a Record to S3 object data."""
175 # Use Record's built-in serialization which handles VectorFields
176 # Use non-flattened format to preserve field metadata
177 record_dict = record.to_dict(include_metadata=True, flatten=False)
179 return record_dict
181 def _s3_object_to_record(self, obj_data: dict[str, Any]) -> Record:
182 """Convert S3 object data to a Record."""
183 # Use Record's built-in deserialization
184 return Record.from_dict(obj_data)
186 def create(self, record: Record) -> str:
187 """Create a new record in S3."""
188 self._check_connection()
190 # Use centralized method to prepare record
191 record_copy, storage_id = self._prepare_record_for_storage(record)
192 key = self._get_object_key(storage_id)
194 # Set metadata
195 record_copy.metadata = record_copy.metadata or {}
196 record_copy.metadata["id"] = storage_id
197 now = datetime.utcnow()
198 record_copy.metadata["created_at"] = now.isoformat()
199 record_copy.metadata["updated_at"] = now.isoformat()
201 # Convert record to JSON
202 obj_data = self._record_to_s3_object(record_copy)
203 body = json.dumps(obj_data)
205 # Store in S3
206 self.s3_client.put_object(
207 Bucket=self.bucket,
208 Key=key,
209 Body=body,
210 ContentType='application/json'
211 )
213 # Invalidate cache
214 self._cache_dirty = True
216 logger.debug(f"Created record {storage_id} at {key}")
217 return storage_id
219 def read(self, id: str) -> Record | None:
220 """Read a record from S3."""
221 self._check_connection()
223 key = self._get_object_key(id)
225 try:
226 response = self.s3_client.get_object(Bucket=self.bucket, Key=key)
227 body = response['Body'].read()
228 obj_data = json.loads(body)
229 record = self._s3_object_to_record(obj_data)
230 # Use centralized method to prepare record
231 return self._prepare_record_from_storage(record, id)
232 except self.ClientError as e:
233 if e.response['Error']['Code'] == 'NoSuchKey':
234 return None
235 raise
237 def update(self, id: str, record: Record) -> bool:
238 """Update an existing record in S3."""
239 self._check_connection()
241 key = self._get_object_key(id)
243 # Check if exists and get existing metadata
244 try:
245 response = self.s3_client.get_object(Bucket=self.bucket, Key=key)
246 existing_data = json.loads(response['Body'].read())
247 existing_metadata = existing_data.get("metadata", {})
248 except self.ClientError as e:
249 if e.response['Error']['Code'] == 'NoSuchKey':
250 return False
251 raise
253 # Preserve and update metadata
254 record.metadata = record.metadata or {}
255 record.metadata["id"] = id
256 record.metadata["created_at"] = existing_metadata.get("created_at", datetime.utcnow().isoformat())
257 record.metadata["updated_at"] = datetime.utcnow().isoformat()
259 # Update the object
260 obj_data = self._record_to_s3_object(record)
261 body = json.dumps(obj_data)
263 self.s3_client.put_object(
264 Bucket=self.bucket,
265 Key=key,
266 Body=body,
267 ContentType='application/json'
268 )
270 # Invalidate cache
271 self._cache_dirty = True
273 logger.debug(f"Updated record {id} at {key}")
274 return True
276 def delete(self, id: str) -> bool:
277 """Delete a record from S3."""
278 self._check_connection()
280 key = self._get_object_key(id)
282 # Check if exists
283 try:
284 self.s3_client.head_object(Bucket=self.bucket, Key=key)
285 except self.ClientError as e:
286 if e.response['Error']['Code'] == '404':
287 return False
288 raise
290 # Delete the object
291 self.s3_client.delete_object(Bucket=self.bucket, Key=key)
293 # Invalidate cache
294 self._cache_dirty = True
296 logger.debug(f"Deleted record {id} at {key}")
297 return True
299 def list_all(self) -> list[str]:
300 """List all record IDs in the database.
302 Returns:
303 List of all record IDs
304 """
305 self._check_connection()
306 record_ids = []
308 # Use paginator for large buckets
309 paginator = self.s3_client.get_paginator('list_objects_v2')
310 page_iterator = paginator.paginate(
311 Bucket=self.bucket,
312 Prefix=self.prefix
313 )
315 for page in page_iterator:
316 if 'Contents' in page:
317 for obj in page['Contents']:
318 key = obj['Key']
319 # Extract record ID from key
320 if key.startswith(self.prefix) and key.endswith('.json'):
321 record_id = key[len(self.prefix):-5] # Remove prefix and .json
322 record_ids.append(record_id)
324 logger.debug(f"Listed {len(record_ids)} records from S3")
325 return record_ids
327 def exists(self, id: str) -> bool:
328 """Check if a record exists in S3."""
329 self._check_connection()
331 key = self._get_object_key(id)
333 try:
334 self.s3_client.head_object(Bucket=self.bucket, Key=key)
335 return True
336 except self.ClientError as e:
337 if e.response['Error']['Code'] == '404':
338 return False
339 raise
341 def search(self, query: Query) -> list[Record]:
342 """Search for records matching the query.
344 Note: S3 doesn't support complex queries, so we need to list and filter.
345 """
346 self._check_connection()
348 # List all objects with the prefix
349 records = []
350 paginator = self.s3_client.get_paginator('list_objects_v2')
351 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix)
353 for page in pages:
354 if 'Contents' not in page:
355 continue
357 # Fetch objects in parallel
358 with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
359 futures = []
360 for obj in page['Contents']:
361 if obj['Key'].endswith('.json'):
362 future = executor.submit(self._fetch_and_filter, obj['Key'], query)
363 futures.append(future)
365 for future in as_completed(futures):
366 record = future.result()
367 if record:
368 records.append(record)
370 # Apply sorting if specified
371 if query.sort_specs:
372 for sort_spec in reversed(query.sort_specs):
373 reverse = sort_spec.order.value == "desc"
374 records.sort(
375 key=lambda r: r.get_value(sort_spec.field, ""),
376 reverse=reverse
377 )
379 # Apply offset and limit
380 if query.offset_value:
381 records = records[query.offset_value:]
382 if query.limit_value:
383 records = records[:query.limit_value]
385 # Apply field projection
386 if query.fields:
387 records = [r.project(query.fields) for r in records]
389 return records
391 def _fetch_and_filter(self, key: str, query: Query) -> Record | None:
392 """Fetch an object and apply query filters."""
393 try:
394 response = self.s3_client.get_object(Bucket=self.bucket, Key=key)
395 body = response['Body'].read()
396 obj_data = json.loads(body)
397 record = self._s3_object_to_record(obj_data)
399 # Apply filters
400 for filter in query.filters:
401 field_value = record.get_value(filter.field)
402 if not filter.matches(field_value):
403 return None
405 return record
406 except Exception as e:
407 logger.warning(f"Error fetching {key}: {e}")
408 return None
410 def _count_all(self) -> int:
411 """Count all records in S3."""
412 self._check_connection()
414 count = 0
415 paginator = self.s3_client.get_paginator('list_objects_v2')
416 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix)
418 for page in pages:
419 if 'Contents' in page:
420 count += sum(1 for obj in page['Contents'] if obj['Key'].endswith('.json'))
422 return count
424 def clear(self) -> int:
425 """Clear all records from S3."""
426 self._check_connection()
428 # List and delete all objects
429 count = 0
430 paginator = self.s3_client.get_paginator('list_objects_v2')
431 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix)
433 for page in pages:
434 if 'Contents' not in page:
435 continue
437 # Delete in batches
438 objects = [{'Key': obj['Key']} for obj in page['Contents'] if obj['Key'].endswith('.json')]
439 if objects:
440 self.s3_client.delete_objects(
441 Bucket=self.bucket,
442 Delete={'Objects': objects}
443 )
444 count += len(objects)
446 # Clear cache
447 self._index_cache = {}
448 self._cache_dirty = True
450 logger.info(f"Cleared {count} records from S3")
451 return count
453 def stream_read(
454 self,
455 query: Query | None = None,
456 config: StreamConfig | None = None
457 ) -> Iterator[Record]:
458 """Stream records from S3."""
459 self._check_connection()
460 config = config or StreamConfig()
462 # List objects and stream them
463 paginator = self.s3_client.get_paginator('list_objects_v2')
464 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix)
466 batch = []
467 for page in pages:
468 if 'Contents' not in page:
469 continue
471 for obj in page['Contents']:
472 if not obj['Key'].endswith('.json'):
473 continue
475 record = self._fetch_and_filter(obj['Key'], query or Query())
476 if record:
477 batch.append(record)
479 if len(batch) >= config.batch_size:
480 for r in batch:
481 yield r
482 batch = []
484 # Yield remaining records
485 for r in batch:
486 yield r
488 def stream_write(
489 self,
490 records: Iterator[Record],
491 config: StreamConfig | None = None
492 ) -> StreamResult:
493 """Stream records into S3."""
494 self._check_connection()
495 config = config or StreamConfig()
496 result = StreamResult()
497 start_time = time.time()
498 quitting = False
500 batch = []
501 for record in records:
502 batch.append(record)
504 if len(batch) >= config.batch_size:
505 # Write batch with graceful fallback
506 # Use lambda wrapper for _write_batch
507 continue_processing = process_batch_with_fallback(
508 batch,
509 lambda b: self._write_batch(b),
510 self.create,
511 result,
512 config
513 )
515 if not continue_processing:
516 quitting = True
517 break
519 batch = []
521 # Write remaining batch
522 if batch and not quitting:
523 process_batch_with_fallback(
524 batch,
525 lambda b: self._write_batch(b),
526 self.create,
527 result,
528 config
529 )
531 result.duration = time.time() - start_time
532 return result
534 def _write_batch(self, records: list[Record]) -> list[str]:
535 """Write a batch of records to S3.
537 Returns:
538 List of created record IDs
539 """
540 ids = []
541 with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
542 futures = []
543 for record in records:
544 record_id = str(uuid4())
545 ids.append(record_id)
546 future = executor.submit(self._write_single, record_id, record)
547 futures.append(future)
549 # Wait for all writes to complete
550 for future in as_completed(futures):
551 future.result() # This will raise if there was an error
553 return ids
555 def _write_single(self, record_id: str, record: Record) -> None:
556 """Write a single record to S3."""
557 # Set metadata
558 record.metadata = record.metadata or {}
559 record.metadata["id"] = record_id
560 now = datetime.utcnow()
561 record.metadata["created_at"] = now.isoformat()
562 record.metadata["updated_at"] = now.isoformat()
564 key = self._get_object_key(record_id)
565 obj_data = self._record_to_s3_object(record)
566 body = json.dumps(obj_data)
568 self.s3_client.put_object(
569 Bucket=self.bucket,
570 Key=key,
571 Body=body,
572 ContentType='application/json'
573 )
575 def vector_search(
576 self,
577 query_vector,
578 vector_field: str = "embedding",
579 k: int = 10,
580 filter=None,
581 metric=None,
582 **kwargs
583 ):
584 """Perform vector similarity search using Python calculations.
586 WARNING: This implementation downloads all records from S3 to perform
587 the search locally. This is inefficient for large datasets. Consider
588 using a vector-enabled backend like PostgreSQL or Elasticsearch for
589 production use with large datasets.
590 """
591 return self.python_vector_search_sync(
592 query_vector=query_vector,
593 vector_field=vector_field,
594 k=k,
595 filter=filter,
596 metric=metric,
597 **kwargs
598 )
601# Import the native async implementation