Coverage for src/dataknobs_data/backends/s3.py: 15%

311 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-31 15:06 -0600

1"""S3 backend implementation with proper connection management.""" 

2 

3from __future__ import annotations 

4 

5import json 

6import logging 

7import time 

8from concurrent.futures import ThreadPoolExecutor, as_completed 

9from datetime import datetime 

10from typing import Any, TYPE_CHECKING 

11from uuid import uuid4 

12 

13from dataknobs_config import ConfigurableBase 

14 

15from dataknobs_data.database import SyncDatabase 

16from dataknobs_data.query import Query 

17from dataknobs_data.records import Record 

18from dataknobs_data.streaming import StreamConfig, StreamResult, process_batch_with_fallback 

19 

20from ..vector import VectorOperationsMixin 

21from ..vector.bulk_embed_mixin import BulkEmbedMixin 

22from ..vector.python_vector_search import PythonVectorSearchMixin 

23from .sqlite_mixins import SQLiteVectorSupport 

24from .vector_config_mixin import VectorConfigMixin 

25 

26if TYPE_CHECKING: 

27 from collections.abc import Iterator 

28 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33class SyncS3Database( # type: ignore[misc] 

34 SyncDatabase, 

35 ConfigurableBase, 

36 VectorConfigMixin, 

37 SQLiteVectorSupport, 

38 PythonVectorSearchMixin, 

39 BulkEmbedMixin, 

40 VectorOperationsMixin 

41): 

42 """S3-based database backend with proper connection management. 

43  

44 Stores records as JSON objects in S3 with metadata as tags. 

45 """ 

46 

47 def __init__(self, config: dict[str, Any] | None = None): 

48 """Initialize S3 database configuration. 

49  

50 Args: 

51 config: Configuration dictionary 

52 """ 

53 super().__init__(config) 

54 

55 # Connection state 

56 self.s3_client = None 

57 self._connected = False 

58 

59 # Cache for performance 

60 self._index_cache = {} 

61 self._cache_dirty = True 

62 

63 # Store configuration for later connection 

64 self.bucket = self.config.get("bucket") 

65 if not self.bucket: 

66 raise ValueError("S3 bucket name is required in configuration") 

67 

68 # Optional configuration with defaults 

69 self.prefix = self.config.get("prefix", "records/").rstrip("/") + "/" 

70 self.region = self.config.get("region", "us-east-1") 

71 self.endpoint_url = self.config.get("endpoint_url") 

72 self.max_workers = self.config.get("max_workers", 10) 

73 self.multipart_threshold = self.config.get("multipart_threshold", 8 * 1024 * 1024) 

74 self.multipart_chunksize = self.config.get("multipart_chunksize", 8 * 1024 * 1024) 

75 self.max_retries = self.config.get("max_retries", 3) 

76 

77 # AWS credentials (will use environment/IAM role if not provided) 

78 self.aws_access_key_id = self.config.get("access_key_id") 

79 self.aws_secret_access_key = self.config.get("secret_access_key") 

80 self.aws_session_token = self.config.get("session_token") 

81 

82 # Initialize vector support 

83 self._parse_vector_config(config or {}) 

84 self._init_vector_state() 

85 

86 @classmethod 

87 def from_config(cls, config: dict) -> SyncS3Database: 

88 """Create instance from configuration dictionary.""" 

89 return cls(config) 

90 

91 def connect(self) -> None: 

92 """Connect to S3 service.""" 

93 if self._connected: 

94 return # Already connected 

95 

96 import boto3 

97 from botocore.config import Config as BotoConfig 

98 from botocore.exceptions import ClientError 

99 

100 # Configure boto3 client 

101 boto_config = BotoConfig( 

102 region_name=self.region, 

103 max_pool_connections=self.max_workers, 

104 retries={'max_attempts': self.max_retries} 

105 ) 

106 

107 client_kwargs = { 

108 "config": boto_config, 

109 "use_ssl": not bool(self.endpoint_url) # Disable SSL for local testing 

110 } 

111 

112 if self.endpoint_url: 

113 client_kwargs["endpoint_url"] = self.endpoint_url 

114 

115 if self.aws_access_key_id and self.aws_secret_access_key: 

116 client_kwargs["aws_access_key_id"] = self.aws_access_key_id 

117 client_kwargs["aws_secret_access_key"] = self.aws_secret_access_key 

118 

119 if self.aws_session_token: 

120 client_kwargs["aws_session_token"] = self.aws_session_token 

121 

122 # Create S3 client 

123 self.s3_client = boto3.client("s3", **client_kwargs) 

124 self.ClientError = ClientError 

125 

126 # Verify bucket exists or create it 

127 self._ensure_bucket_exists() 

128 

129 self._connected = True 

130 logger.info(f"Connected to S3 with bucket={self.bucket}, prefix={self.prefix}") 

131 

132 def close(self) -> None: 

133 """Close the S3 connection.""" 

134 if self.s3_client: 

135 # S3 client doesn't need explicit closing, but clear cache 

136 self._index_cache = {} # type: ignore[unreachable] 

137 self._connected = False 

138 logger.info(f"Closed S3 connection to bucket={self.bucket}") 

139 

140 def _initialize(self) -> None: 

141 """Initialize method - connection setup moved to connect().""" 

142 pass 

143 

144 def _check_connection(self) -> None: 

145 """Check if S3 client is connected.""" 

146 if not self._connected or not self.s3_client: 

147 raise RuntimeError("S3 not connected. Call connect() first.") 

148 

149 def _ensure_bucket_exists(self): 

150 """Ensure the S3 bucket exists, create if necessary.""" 

151 try: 

152 self.s3_client.head_bucket(Bucket=self.bucket) 

153 logger.debug(f"Bucket {self.bucket} exists") 

154 except self.ClientError as e: 

155 error_code = e.response['Error']['Code'] 

156 if error_code == '404': 

157 # Bucket doesn't exist, create it 

158 logger.info(f"Creating bucket {self.bucket}") 

159 if self.region == 'us-east-1': 

160 self.s3_client.create_bucket(Bucket=self.bucket) 

161 else: 

162 self.s3_client.create_bucket( 

163 Bucket=self.bucket, 

164 CreateBucketConfiguration={'LocationConstraint': self.region} 

165 ) 

166 else: 

167 raise 

168 

169 def _get_object_key(self, record_id: str) -> str: 

170 """Generate S3 object key for a record ID.""" 

171 return f"{self.prefix}{record_id}.json" 

172 

173 def _record_to_s3_object(self, record: Record) -> dict[str, Any]: 

174 """Convert a Record to S3 object data.""" 

175 # Use Record's built-in serialization which handles VectorFields 

176 # Use non-flattened format to preserve field metadata 

177 record_dict = record.to_dict(include_metadata=True, flatten=False) 

178 

179 return record_dict 

180 

181 def _s3_object_to_record(self, obj_data: dict[str, Any]) -> Record: 

182 """Convert S3 object data to a Record.""" 

183 # Use Record's built-in deserialization 

184 return Record.from_dict(obj_data) 

185 

186 def create(self, record: Record) -> str: 

187 """Create a new record in S3.""" 

188 self._check_connection() 

189 

190 # Use centralized method to prepare record 

191 record_copy, storage_id = self._prepare_record_for_storage(record) 

192 key = self._get_object_key(storage_id) 

193 

194 # Set metadata 

195 record_copy.metadata = record_copy.metadata or {} 

196 record_copy.metadata["id"] = storage_id 

197 now = datetime.utcnow() 

198 record_copy.metadata["created_at"] = now.isoformat() 

199 record_copy.metadata["updated_at"] = now.isoformat() 

200 

201 # Convert record to JSON 

202 obj_data = self._record_to_s3_object(record_copy) 

203 body = json.dumps(obj_data) 

204 

205 # Store in S3 

206 self.s3_client.put_object( 

207 Bucket=self.bucket, 

208 Key=key, 

209 Body=body, 

210 ContentType='application/json' 

211 ) 

212 

213 # Invalidate cache 

214 self._cache_dirty = True 

215 

216 logger.debug(f"Created record {storage_id} at {key}") 

217 return storage_id 

218 

219 def read(self, id: str) -> Record | None: 

220 """Read a record from S3.""" 

221 self._check_connection() 

222 

223 key = self._get_object_key(id) 

224 

225 try: 

226 response = self.s3_client.get_object(Bucket=self.bucket, Key=key) 

227 body = response['Body'].read() 

228 obj_data = json.loads(body) 

229 record = self._s3_object_to_record(obj_data) 

230 # Use centralized method to prepare record 

231 return self._prepare_record_from_storage(record, id) 

232 except self.ClientError as e: 

233 if e.response['Error']['Code'] == 'NoSuchKey': 

234 return None 

235 raise 

236 

237 def update(self, id: str, record: Record) -> bool: 

238 """Update an existing record in S3.""" 

239 self._check_connection() 

240 

241 key = self._get_object_key(id) 

242 

243 # Check if exists and get existing metadata 

244 try: 

245 response = self.s3_client.get_object(Bucket=self.bucket, Key=key) 

246 existing_data = json.loads(response['Body'].read()) 

247 existing_metadata = existing_data.get("metadata", {}) 

248 except self.ClientError as e: 

249 if e.response['Error']['Code'] == 'NoSuchKey': 

250 return False 

251 raise 

252 

253 # Preserve and update metadata 

254 record.metadata = record.metadata or {} 

255 record.metadata["id"] = id 

256 record.metadata["created_at"] = existing_metadata.get("created_at", datetime.utcnow().isoformat()) 

257 record.metadata["updated_at"] = datetime.utcnow().isoformat() 

258 

259 # Update the object 

260 obj_data = self._record_to_s3_object(record) 

261 body = json.dumps(obj_data) 

262 

263 self.s3_client.put_object( 

264 Bucket=self.bucket, 

265 Key=key, 

266 Body=body, 

267 ContentType='application/json' 

268 ) 

269 

270 # Invalidate cache 

271 self._cache_dirty = True 

272 

273 logger.debug(f"Updated record {id} at {key}") 

274 return True 

275 

276 def delete(self, id: str) -> bool: 

277 """Delete a record from S3.""" 

278 self._check_connection() 

279 

280 key = self._get_object_key(id) 

281 

282 # Check if exists 

283 try: 

284 self.s3_client.head_object(Bucket=self.bucket, Key=key) 

285 except self.ClientError as e: 

286 if e.response['Error']['Code'] == '404': 

287 return False 

288 raise 

289 

290 # Delete the object 

291 self.s3_client.delete_object(Bucket=self.bucket, Key=key) 

292 

293 # Invalidate cache 

294 self._cache_dirty = True 

295 

296 logger.debug(f"Deleted record {id} at {key}") 

297 return True 

298 

299 def list_all(self) -> list[str]: 

300 """List all record IDs in the database. 

301  

302 Returns: 

303 List of all record IDs 

304 """ 

305 self._check_connection() 

306 record_ids = [] 

307 

308 # Use paginator for large buckets 

309 paginator = self.s3_client.get_paginator('list_objects_v2') 

310 page_iterator = paginator.paginate( 

311 Bucket=self.bucket, 

312 Prefix=self.prefix 

313 ) 

314 

315 for page in page_iterator: 

316 if 'Contents' in page: 

317 for obj in page['Contents']: 

318 key = obj['Key'] 

319 # Extract record ID from key 

320 if key.startswith(self.prefix) and key.endswith('.json'): 

321 record_id = key[len(self.prefix):-5] # Remove prefix and .json 

322 record_ids.append(record_id) 

323 

324 logger.debug(f"Listed {len(record_ids)} records from S3") 

325 return record_ids 

326 

327 def exists(self, id: str) -> bool: 

328 """Check if a record exists in S3.""" 

329 self._check_connection() 

330 

331 key = self._get_object_key(id) 

332 

333 try: 

334 self.s3_client.head_object(Bucket=self.bucket, Key=key) 

335 return True 

336 except self.ClientError as e: 

337 if e.response['Error']['Code'] == '404': 

338 return False 

339 raise 

340 

341 def search(self, query: Query) -> list[Record]: 

342 """Search for records matching the query. 

343  

344 Note: S3 doesn't support complex queries, so we need to list and filter. 

345 """ 

346 self._check_connection() 

347 

348 # List all objects with the prefix 

349 records = [] 

350 paginator = self.s3_client.get_paginator('list_objects_v2') 

351 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix) 

352 

353 for page in pages: 

354 if 'Contents' not in page: 

355 continue 

356 

357 # Fetch objects in parallel 

358 with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 

359 futures = [] 

360 for obj in page['Contents']: 

361 if obj['Key'].endswith('.json'): 

362 future = executor.submit(self._fetch_and_filter, obj['Key'], query) 

363 futures.append(future) 

364 

365 for future in as_completed(futures): 

366 record = future.result() 

367 if record: 

368 records.append(record) 

369 

370 # Apply sorting if specified 

371 if query.sort_specs: 

372 for sort_spec in reversed(query.sort_specs): 

373 reverse = sort_spec.order.value == "desc" 

374 records.sort( 

375 key=lambda r: r.get_value(sort_spec.field, ""), 

376 reverse=reverse 

377 ) 

378 

379 # Apply offset and limit 

380 if query.offset_value: 

381 records = records[query.offset_value:] 

382 if query.limit_value: 

383 records = records[:query.limit_value] 

384 

385 # Apply field projection 

386 if query.fields: 

387 records = [r.project(query.fields) for r in records] 

388 

389 return records 

390 

391 def _fetch_and_filter(self, key: str, query: Query) -> Record | None: 

392 """Fetch an object and apply query filters.""" 

393 try: 

394 response = self.s3_client.get_object(Bucket=self.bucket, Key=key) 

395 body = response['Body'].read() 

396 obj_data = json.loads(body) 

397 record = self._s3_object_to_record(obj_data) 

398 

399 # Apply filters 

400 for filter in query.filters: 

401 field_value = record.get_value(filter.field) 

402 if not filter.matches(field_value): 

403 return None 

404 

405 return record 

406 except Exception as e: 

407 logger.warning(f"Error fetching {key}: {e}") 

408 return None 

409 

410 def _count_all(self) -> int: 

411 """Count all records in S3.""" 

412 self._check_connection() 

413 

414 count = 0 

415 paginator = self.s3_client.get_paginator('list_objects_v2') 

416 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix) 

417 

418 for page in pages: 

419 if 'Contents' in page: 

420 count += sum(1 for obj in page['Contents'] if obj['Key'].endswith('.json')) 

421 

422 return count 

423 

424 def clear(self) -> int: 

425 """Clear all records from S3.""" 

426 self._check_connection() 

427 

428 # List and delete all objects 

429 count = 0 

430 paginator = self.s3_client.get_paginator('list_objects_v2') 

431 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix) 

432 

433 for page in pages: 

434 if 'Contents' not in page: 

435 continue 

436 

437 # Delete in batches 

438 objects = [{'Key': obj['Key']} for obj in page['Contents'] if obj['Key'].endswith('.json')] 

439 if objects: 

440 self.s3_client.delete_objects( 

441 Bucket=self.bucket, 

442 Delete={'Objects': objects} 

443 ) 

444 count += len(objects) 

445 

446 # Clear cache 

447 self._index_cache = {} 

448 self._cache_dirty = True 

449 

450 logger.info(f"Cleared {count} records from S3") 

451 return count 

452 

453 def stream_read( 

454 self, 

455 query: Query | None = None, 

456 config: StreamConfig | None = None 

457 ) -> Iterator[Record]: 

458 """Stream records from S3.""" 

459 self._check_connection() 

460 config = config or StreamConfig() 

461 

462 # List objects and stream them 

463 paginator = self.s3_client.get_paginator('list_objects_v2') 

464 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix) 

465 

466 batch = [] 

467 for page in pages: 

468 if 'Contents' not in page: 

469 continue 

470 

471 for obj in page['Contents']: 

472 if not obj['Key'].endswith('.json'): 

473 continue 

474 

475 record = self._fetch_and_filter(obj['Key'], query or Query()) 

476 if record: 

477 batch.append(record) 

478 

479 if len(batch) >= config.batch_size: 

480 for r in batch: 

481 yield r 

482 batch = [] 

483 

484 # Yield remaining records 

485 for r in batch: 

486 yield r 

487 

488 def stream_write( 

489 self, 

490 records: Iterator[Record], 

491 config: StreamConfig | None = None 

492 ) -> StreamResult: 

493 """Stream records into S3.""" 

494 self._check_connection() 

495 config = config or StreamConfig() 

496 result = StreamResult() 

497 start_time = time.time() 

498 quitting = False 

499 

500 batch = [] 

501 for record in records: 

502 batch.append(record) 

503 

504 if len(batch) >= config.batch_size: 

505 # Write batch with graceful fallback 

506 # Use lambda wrapper for _write_batch 

507 continue_processing = process_batch_with_fallback( 

508 batch, 

509 lambda b: self._write_batch(b), 

510 self.create, 

511 result, 

512 config 

513 ) 

514 

515 if not continue_processing: 

516 quitting = True 

517 break 

518 

519 batch = [] 

520 

521 # Write remaining batch 

522 if batch and not quitting: 

523 process_batch_with_fallback( 

524 batch, 

525 lambda b: self._write_batch(b), 

526 self.create, 

527 result, 

528 config 

529 ) 

530 

531 result.duration = time.time() - start_time 

532 return result 

533 

534 def _write_batch(self, records: list[Record]) -> list[str]: 

535 """Write a batch of records to S3. 

536  

537 Returns: 

538 List of created record IDs 

539 """ 

540 ids = [] 

541 with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 

542 futures = [] 

543 for record in records: 

544 record_id = str(uuid4()) 

545 ids.append(record_id) 

546 future = executor.submit(self._write_single, record_id, record) 

547 futures.append(future) 

548 

549 # Wait for all writes to complete 

550 for future in as_completed(futures): 

551 future.result() # This will raise if there was an error 

552 

553 return ids 

554 

555 def _write_single(self, record_id: str, record: Record) -> None: 

556 """Write a single record to S3.""" 

557 # Set metadata 

558 record.metadata = record.metadata or {} 

559 record.metadata["id"] = record_id 

560 now = datetime.utcnow() 

561 record.metadata["created_at"] = now.isoformat() 

562 record.metadata["updated_at"] = now.isoformat() 

563 

564 key = self._get_object_key(record_id) 

565 obj_data = self._record_to_s3_object(record) 

566 body = json.dumps(obj_data) 

567 

568 self.s3_client.put_object( 

569 Bucket=self.bucket, 

570 Key=key, 

571 Body=body, 

572 ContentType='application/json' 

573 ) 

574 

575 def vector_search( 

576 self, 

577 query_vector, 

578 vector_field: str = "embedding", 

579 k: int = 10, 

580 filter=None, 

581 metric=None, 

582 **kwargs 

583 ): 

584 """Perform vector similarity search using Python calculations. 

585  

586 WARNING: This implementation downloads all records from S3 to perform 

587 the search locally. This is inefficient for large datasets. Consider 

588 using a vector-enabled backend like PostgreSQL or Elasticsearch for 

589 production use with large datasets. 

590 """ 

591 return self.python_vector_search_sync( 

592 query_vector=query_vector, 

593 vector_field=vector_field, 

594 k=k, 

595 filter=filter, 

596 metric=metric, 

597 **kwargs 

598 ) 

599 

600 

601# Import the native async implementation