Coverage for src/dataknobs_data/backends/s3.py: 15%
315 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
1"""S3 backend implementation with proper connection management."""
3from __future__ import annotations
5import json
6import logging
7import time
8from concurrent.futures import ThreadPoolExecutor, as_completed
9from datetime import datetime
10from typing import Any, TYPE_CHECKING
11from uuid import uuid4
13from dataknobs_config import ConfigurableBase
15from dataknobs_data.database import SyncDatabase
16from dataknobs_data.query import Query
17from dataknobs_data.records import Record
18from dataknobs_data.streaming import StreamConfig, StreamResult, process_batch_with_fallback
20from ..vector import VectorOperationsMixin
21from ..vector.bulk_embed_mixin import BulkEmbedMixin
22from ..vector.python_vector_search import PythonVectorSearchMixin
23from .sqlite_mixins import SQLiteVectorSupport
24from .vector_config_mixin import VectorConfigMixin
26if TYPE_CHECKING:
27 from collections.abc import Iterator
30logger = logging.getLogger(__name__)
33class SyncS3Database( # type: ignore[misc]
34 SyncDatabase,
35 ConfigurableBase,
36 VectorConfigMixin,
37 SQLiteVectorSupport,
38 PythonVectorSearchMixin,
39 BulkEmbedMixin,
40 VectorOperationsMixin
41):
42 """S3-based database backend with proper connection management.
44 Stores records as JSON objects in S3 with metadata as tags.
45 """
47 def __init__(self, config: dict[str, Any] | None = None):
48 """Initialize S3 database configuration.
50 Args:
51 config: Configuration dictionary
52 """
53 super().__init__(config)
55 # Connection state
56 self.s3_client = None
57 self._connected = False
59 # Cache for performance
60 self._index_cache = {}
61 self._cache_dirty = True
63 # Store configuration for later connection
64 self.bucket = self.config.get("bucket")
65 if not self.bucket:
66 raise ValueError("S3 bucket name is required in configuration")
68 # Optional configuration with defaults
69 self.prefix = self.config.get("prefix", "records/").rstrip("/") + "/"
70 self.region = self.config.get("region", "us-east-1")
71 self.endpoint_url = self.config.get("endpoint_url")
72 self.max_workers = self.config.get("max_workers", 10)
73 self.multipart_threshold = self.config.get("multipart_threshold", 8 * 1024 * 1024)
74 self.multipart_chunksize = self.config.get("multipart_chunksize", 8 * 1024 * 1024)
75 self.max_retries = self.config.get("max_retries", 3)
77 # AWS credentials (will use environment/IAM role if not provided)
78 self.aws_access_key_id = self.config.get("access_key_id")
79 self.aws_secret_access_key = self.config.get("secret_access_key")
80 self.aws_session_token = self.config.get("session_token")
82 # Initialize vector support
83 self._parse_vector_config(config or {})
84 self._init_vector_state()
86 @classmethod
87 def from_config(cls, config: dict) -> SyncS3Database:
88 """Create instance from configuration dictionary."""
89 return cls(config)
91 def connect(self) -> None:
92 """Connect to S3 service."""
93 if self._connected:
94 return # Already connected
96 import boto3
97 from botocore.config import Config as BotoConfig
98 from botocore.exceptions import ClientError
100 # Configure boto3 client
101 boto_config = BotoConfig(
102 region_name=self.region,
103 max_pool_connections=self.max_workers,
104 retries={'max_attempts': self.max_retries}
105 )
107 client_kwargs = {
108 "config": boto_config,
109 "use_ssl": not bool(self.endpoint_url) # Disable SSL for local testing
110 }
112 if self.endpoint_url:
113 client_kwargs["endpoint_url"] = self.endpoint_url
115 if self.aws_access_key_id and self.aws_secret_access_key:
116 client_kwargs["aws_access_key_id"] = self.aws_access_key_id
117 client_kwargs["aws_secret_access_key"] = self.aws_secret_access_key
119 if self.aws_session_token:
120 client_kwargs["aws_session_token"] = self.aws_session_token
122 # Create S3 client
123 self.s3_client = boto3.client("s3", **client_kwargs)
124 self.ClientError = ClientError
126 # Verify bucket exists or create it
127 self._ensure_bucket_exists()
129 self._connected = True
130 logger.info(f"Connected to S3 with bucket={self.bucket}, prefix={self.prefix}")
132 def close(self) -> None:
133 """Close the S3 connection."""
134 if self.s3_client:
135 # S3 client doesn't need explicit closing, but clear cache
136 self._index_cache = {} # type: ignore[unreachable]
137 self._connected = False
138 logger.info(f"Closed S3 connection to bucket={self.bucket}")
140 def _initialize(self) -> None:
141 """Initialize method - connection setup moved to connect()."""
142 pass
144 def _check_connection(self) -> None:
145 """Check if S3 client is connected."""
146 if not self._connected or not self.s3_client:
147 raise RuntimeError("S3 not connected. Call connect() first.")
149 def _ensure_bucket_exists(self):
150 """Ensure the S3 bucket exists, create if necessary."""
151 try:
152 self.s3_client.head_bucket(Bucket=self.bucket)
153 logger.debug(f"Bucket {self.bucket} exists")
154 except self.ClientError as e:
155 error_code = e.response['Error']['Code']
156 if error_code == '404':
157 # Bucket doesn't exist, create it
158 logger.info(f"Creating bucket {self.bucket}")
159 if self.region == 'us-east-1':
160 self.s3_client.create_bucket(Bucket=self.bucket)
161 else:
162 self.s3_client.create_bucket(
163 Bucket=self.bucket,
164 CreateBucketConfiguration={'LocationConstraint': self.region}
165 )
166 else:
167 raise
169 def _get_object_key(self, record_id: str) -> str:
170 """Generate S3 object key for a record ID."""
171 return f"{self.prefix}{record_id}.json"
173 def _record_to_s3_object(self, record: Record) -> dict[str, Any]:
174 """Convert a Record to S3 object data."""
175 # Use Record's built-in serialization which handles VectorFields
176 # Use non-flattened format to preserve field metadata
177 record_dict = record.to_dict(include_metadata=True, flatten=False)
179 return record_dict
181 def _s3_object_to_record(self, obj_data: dict[str, Any]) -> Record:
182 """Convert S3 object data to a Record."""
183 # Use Record's built-in deserialization
184 return Record.from_dict(obj_data)
186 def create(self, record: Record) -> str:
187 """Create a new record in S3."""
188 self._check_connection()
190 # Use centralized method to prepare record
191 record_copy, storage_id = self._prepare_record_for_storage(record)
192 key = self._get_object_key(storage_id)
194 # Set metadata
195 record_copy.metadata = record_copy.metadata or {}
196 record_copy.metadata["id"] = storage_id
197 now = datetime.utcnow()
198 record_copy.metadata["created_at"] = now.isoformat()
199 record_copy.metadata["updated_at"] = now.isoformat()
201 # Convert record to JSON
202 obj_data = self._record_to_s3_object(record_copy)
203 body = json.dumps(obj_data)
205 # Store in S3
206 self.s3_client.put_object(
207 Bucket=self.bucket,
208 Key=key,
209 Body=body,
210 ContentType='application/json'
211 )
213 # Invalidate cache
214 self._cache_dirty = True
216 logger.debug(f"Created record {storage_id} at {key}")
217 return storage_id
219 def read(self, id: str) -> Record | None:
220 """Read a record from S3."""
221 self._check_connection()
223 key = self._get_object_key(id)
225 try:
226 response = self.s3_client.get_object(Bucket=self.bucket, Key=key)
227 body = response['Body'].read()
228 obj_data = json.loads(body)
229 record = self._s3_object_to_record(obj_data)
230 # Use centralized method to prepare record
231 return self._prepare_record_from_storage(record, id)
232 except self.ClientError as e:
233 if e.response['Error']['Code'] == 'NoSuchKey':
234 return None
235 raise
237 def update(self, id: str, record: Record) -> bool:
238 """Update an existing record in S3."""
239 self._check_connection()
241 key = self._get_object_key(id)
243 # Check if exists and get existing metadata
244 try:
245 response = self.s3_client.get_object(Bucket=self.bucket, Key=key)
246 existing_data = json.loads(response['Body'].read())
247 existing_metadata = existing_data.get("metadata", {})
248 except self.ClientError as e:
249 if e.response['Error']['Code'] == 'NoSuchKey':
250 return False
251 raise
253 # Preserve and update metadata
254 record.metadata = record.metadata or {}
255 record.metadata["id"] = id
256 record.metadata["created_at"] = existing_metadata.get("created_at", datetime.utcnow().isoformat())
257 record.metadata["updated_at"] = datetime.utcnow().isoformat()
259 # Update the object
260 obj_data = self._record_to_s3_object(record)
261 body = json.dumps(obj_data)
263 self.s3_client.put_object(
264 Bucket=self.bucket,
265 Key=key,
266 Body=body,
267 ContentType='application/json'
268 )
270 # Invalidate cache
271 self._cache_dirty = True
273 logger.debug(f"Updated record {id} at {key}")
274 return True
276 def delete(self, id: str) -> bool:
277 """Delete a record from S3."""
278 self._check_connection()
280 key = self._get_object_key(id)
282 # Check if exists
283 try:
284 self.s3_client.head_object(Bucket=self.bucket, Key=key)
285 except self.ClientError as e:
286 if e.response['Error']['Code'] == '404':
287 return False
288 raise
290 # Delete the object
291 self.s3_client.delete_object(Bucket=self.bucket, Key=key)
293 # Invalidate cache
294 self._cache_dirty = True
296 logger.debug(f"Deleted record {id} at {key}")
297 return True
299 def list_all(self) -> list[str]:
300 """List all record IDs in the database.
302 Returns:
303 List of all record IDs
304 """
305 self._check_connection()
306 record_ids = []
308 # Use paginator for large buckets
309 paginator = self.s3_client.get_paginator('list_objects_v2')
310 page_iterator = paginator.paginate(
311 Bucket=self.bucket,
312 Prefix=self.prefix
313 )
315 for page in page_iterator:
316 if 'Contents' in page:
317 for obj in page['Contents']:
318 key = obj['Key']
319 # Extract record ID from key
320 if key.startswith(self.prefix) and key.endswith('.json'):
321 record_id = key[len(self.prefix):-5] # Remove prefix and .json
322 record_ids.append(record_id)
324 logger.debug(f"Listed {len(record_ids)} records from S3")
325 return record_ids
327 def exists(self, id: str) -> bool:
328 """Check if a record exists in S3."""
329 self._check_connection()
331 key = self._get_object_key(id)
333 try:
334 self.s3_client.head_object(Bucket=self.bucket, Key=key)
335 return True
336 except self.ClientError as e:
337 if e.response['Error']['Code'] == '404':
338 return False
339 raise
341 def search(self, query: Query) -> list[Record]:
342 """Search for records matching the query.
344 Note: S3 doesn't support complex queries, so we need to list and filter.
345 """
346 self._check_connection()
348 # List all objects with the prefix
349 records = []
350 paginator = self.s3_client.get_paginator('list_objects_v2')
351 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix)
353 for page in pages:
354 if 'Contents' not in page:
355 continue
357 # Fetch objects in parallel
358 with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
359 futures = []
360 for obj in page['Contents']:
361 if obj['Key'].endswith('.json'):
362 future = executor.submit(self._fetch_and_filter, obj['Key'], query)
363 futures.append(future)
365 for future in as_completed(futures):
366 record = future.result()
367 if record:
368 records.append(record)
370 # Apply sorting if specified
371 if query.sort_specs:
372 for sort_spec in reversed(query.sort_specs):
373 reverse = sort_spec.order.value == "desc"
374 # Special handling for 'id' field
375 if sort_spec.field == 'id':
376 records.sort(
377 key=lambda r: r.id or "",
378 reverse=reverse
379 )
380 else:
381 records.sort(
382 key=lambda r: r.get_value(sort_spec.field, ""),
383 reverse=reverse
384 )
386 # Apply offset and limit
387 if query.offset_value:
388 records = records[query.offset_value:]
389 if query.limit_value:
390 records = records[:query.limit_value]
392 # Apply field projection
393 if query.fields:
394 records = [r.project(query.fields) for r in records]
396 return records
398 def _fetch_and_filter(self, key: str, query: Query) -> Record | None:
399 """Fetch an object and apply query filters."""
400 try:
401 response = self.s3_client.get_object(Bucket=self.bucket, Key=key)
402 body = response['Body'].read()
403 obj_data = json.loads(body)
404 record = self._s3_object_to_record(obj_data)
406 # Apply filters
407 for filter in query.filters:
408 # Special handling for 'id' field
409 if filter.field == 'id':
410 field_value = record.id
411 else:
412 field_value = record.get_value(filter.field)
413 if not filter.matches(field_value):
414 return None
416 return record
417 except Exception as e:
418 logger.warning(f"Error fetching {key}: {e}")
419 return None
421 def _count_all(self) -> int:
422 """Count all records in S3."""
423 self._check_connection()
425 count = 0
426 paginator = self.s3_client.get_paginator('list_objects_v2')
427 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix)
429 for page in pages:
430 if 'Contents' in page:
431 count += sum(1 for obj in page['Contents'] if obj['Key'].endswith('.json'))
433 return count
435 def clear(self) -> int:
436 """Clear all records from S3."""
437 self._check_connection()
439 # List and delete all objects
440 count = 0
441 paginator = self.s3_client.get_paginator('list_objects_v2')
442 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix)
444 for page in pages:
445 if 'Contents' not in page:
446 continue
448 # Delete in batches
449 objects = [{'Key': obj['Key']} for obj in page['Contents'] if obj['Key'].endswith('.json')]
450 if objects:
451 self.s3_client.delete_objects(
452 Bucket=self.bucket,
453 Delete={'Objects': objects}
454 )
455 count += len(objects)
457 # Clear cache
458 self._index_cache = {}
459 self._cache_dirty = True
461 logger.info(f"Cleared {count} records from S3")
462 return count
464 def stream_read(
465 self,
466 query: Query | None = None,
467 config: StreamConfig | None = None
468 ) -> Iterator[Record]:
469 """Stream records from S3."""
470 self._check_connection()
471 config = config or StreamConfig()
473 # List objects and stream them
474 paginator = self.s3_client.get_paginator('list_objects_v2')
475 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix)
477 batch = []
478 for page in pages:
479 if 'Contents' not in page:
480 continue
482 for obj in page['Contents']:
483 if not obj['Key'].endswith('.json'):
484 continue
486 record = self._fetch_and_filter(obj['Key'], query or Query())
487 if record:
488 batch.append(record)
490 if len(batch) >= config.batch_size:
491 for r in batch:
492 yield r
493 batch = []
495 # Yield remaining records
496 for r in batch:
497 yield r
499 def stream_write(
500 self,
501 records: Iterator[Record],
502 config: StreamConfig | None = None
503 ) -> StreamResult:
504 """Stream records into S3."""
505 self._check_connection()
506 config = config or StreamConfig()
507 result = StreamResult()
508 start_time = time.time()
509 quitting = False
511 batch = []
512 for record in records:
513 batch.append(record)
515 if len(batch) >= config.batch_size:
516 # Write batch with graceful fallback
517 # Use lambda wrapper for _write_batch
518 continue_processing = process_batch_with_fallback(
519 batch,
520 lambda b: self._write_batch(b),
521 self.create,
522 result,
523 config
524 )
526 if not continue_processing:
527 quitting = True
528 break
530 batch = []
532 # Write remaining batch
533 if batch and not quitting:
534 process_batch_with_fallback(
535 batch,
536 lambda b: self._write_batch(b),
537 self.create,
538 result,
539 config
540 )
542 result.duration = time.time() - start_time
543 return result
545 def _write_batch(self, records: list[Record]) -> list[str]:
546 """Write a batch of records to S3.
548 Returns:
549 List of created record IDs
550 """
551 ids = []
552 with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
553 futures = []
554 for record in records:
555 record_id = str(uuid4())
556 ids.append(record_id)
557 future = executor.submit(self._write_single, record_id, record)
558 futures.append(future)
560 # Wait for all writes to complete
561 for future in as_completed(futures):
562 future.result() # This will raise if there was an error
564 return ids
566 def _write_single(self, record_id: str, record: Record) -> None:
567 """Write a single record to S3."""
568 # Set metadata
569 record.metadata = record.metadata or {}
570 record.metadata["id"] = record_id
571 now = datetime.utcnow()
572 record.metadata["created_at"] = now.isoformat()
573 record.metadata["updated_at"] = now.isoformat()
575 key = self._get_object_key(record_id)
576 obj_data = self._record_to_s3_object(record)
577 body = json.dumps(obj_data)
579 self.s3_client.put_object(
580 Bucket=self.bucket,
581 Key=key,
582 Body=body,
583 ContentType='application/json'
584 )
586 def vector_search(
587 self,
588 query_vector,
589 vector_field: str = "embedding",
590 k: int = 10,
591 filter=None,
592 metric=None,
593 **kwargs
594 ):
595 """Perform vector similarity search using Python calculations.
597 WARNING: This implementation downloads all records from S3 to perform
598 the search locally. This is inefficient for large datasets. Consider
599 using a vector-enabled backend like PostgreSQL or Elasticsearch for
600 production use with large datasets.
601 """
602 return self.python_vector_search_sync(
603 query_vector=query_vector,
604 vector_field=vector_field,
605 k=k,
606 filter=filter,
607 metric=metric,
608 **kwargs
609 )
612# Import the native async implementation