Coverage for src/dataknobs_data/backends/s3.py: 15%

315 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-29 14:14 -0600

1"""S3 backend implementation with proper connection management.""" 

2 

3from __future__ import annotations 

4 

5import json 

6import logging 

7import time 

8from concurrent.futures import ThreadPoolExecutor, as_completed 

9from datetime import datetime 

10from typing import Any, TYPE_CHECKING 

11from uuid import uuid4 

12 

13from dataknobs_config import ConfigurableBase 

14 

15from dataknobs_data.database import SyncDatabase 

16from dataknobs_data.query import Query 

17from dataknobs_data.records import Record 

18from dataknobs_data.streaming import StreamConfig, StreamResult, process_batch_with_fallback 

19 

20from ..vector import VectorOperationsMixin 

21from ..vector.bulk_embed_mixin import BulkEmbedMixin 

22from ..vector.python_vector_search import PythonVectorSearchMixin 

23from .sqlite_mixins import SQLiteVectorSupport 

24from .vector_config_mixin import VectorConfigMixin 

25 

26if TYPE_CHECKING: 

27 from collections.abc import Iterator 

28 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33class SyncS3Database( # type: ignore[misc] 

34 SyncDatabase, 

35 ConfigurableBase, 

36 VectorConfigMixin, 

37 SQLiteVectorSupport, 

38 PythonVectorSearchMixin, 

39 BulkEmbedMixin, 

40 VectorOperationsMixin 

41): 

42 """S3-based database backend with proper connection management. 

43  

44 Stores records as JSON objects in S3 with metadata as tags. 

45 """ 

46 

47 def __init__(self, config: dict[str, Any] | None = None): 

48 """Initialize S3 database configuration. 

49  

50 Args: 

51 config: Configuration dictionary 

52 """ 

53 super().__init__(config) 

54 

55 # Connection state 

56 self.s3_client = None 

57 self._connected = False 

58 

59 # Cache for performance 

60 self._index_cache = {} 

61 self._cache_dirty = True 

62 

63 # Store configuration for later connection 

64 self.bucket = self.config.get("bucket") 

65 if not self.bucket: 

66 raise ValueError("S3 bucket name is required in configuration") 

67 

68 # Optional configuration with defaults 

69 self.prefix = self.config.get("prefix", "records/").rstrip("/") + "/" 

70 self.region = self.config.get("region", "us-east-1") 

71 self.endpoint_url = self.config.get("endpoint_url") 

72 self.max_workers = self.config.get("max_workers", 10) 

73 self.multipart_threshold = self.config.get("multipart_threshold", 8 * 1024 * 1024) 

74 self.multipart_chunksize = self.config.get("multipart_chunksize", 8 * 1024 * 1024) 

75 self.max_retries = self.config.get("max_retries", 3) 

76 

77 # AWS credentials (will use environment/IAM role if not provided) 

78 self.aws_access_key_id = self.config.get("access_key_id") 

79 self.aws_secret_access_key = self.config.get("secret_access_key") 

80 self.aws_session_token = self.config.get("session_token") 

81 

82 # Initialize vector support 

83 self._parse_vector_config(config or {}) 

84 self._init_vector_state() 

85 

86 @classmethod 

87 def from_config(cls, config: dict) -> SyncS3Database: 

88 """Create instance from configuration dictionary.""" 

89 return cls(config) 

90 

91 def connect(self) -> None: 

92 """Connect to S3 service.""" 

93 if self._connected: 

94 return # Already connected 

95 

96 import boto3 

97 from botocore.config import Config as BotoConfig 

98 from botocore.exceptions import ClientError 

99 

100 # Configure boto3 client 

101 boto_config = BotoConfig( 

102 region_name=self.region, 

103 max_pool_connections=self.max_workers, 

104 retries={'max_attempts': self.max_retries} 

105 ) 

106 

107 client_kwargs = { 

108 "config": boto_config, 

109 "use_ssl": not bool(self.endpoint_url) # Disable SSL for local testing 

110 } 

111 

112 if self.endpoint_url: 

113 client_kwargs["endpoint_url"] = self.endpoint_url 

114 

115 if self.aws_access_key_id and self.aws_secret_access_key: 

116 client_kwargs["aws_access_key_id"] = self.aws_access_key_id 

117 client_kwargs["aws_secret_access_key"] = self.aws_secret_access_key 

118 

119 if self.aws_session_token: 

120 client_kwargs["aws_session_token"] = self.aws_session_token 

121 

122 # Create S3 client 

123 self.s3_client = boto3.client("s3", **client_kwargs) 

124 self.ClientError = ClientError 

125 

126 # Verify bucket exists or create it 

127 self._ensure_bucket_exists() 

128 

129 self._connected = True 

130 logger.info(f"Connected to S3 with bucket={self.bucket}, prefix={self.prefix}") 

131 

132 def close(self) -> None: 

133 """Close the S3 connection.""" 

134 if self.s3_client: 

135 # S3 client doesn't need explicit closing, but clear cache 

136 self._index_cache = {} # type: ignore[unreachable] 

137 self._connected = False 

138 logger.info(f"Closed S3 connection to bucket={self.bucket}") 

139 

140 def _initialize(self) -> None: 

141 """Initialize method - connection setup moved to connect().""" 

142 pass 

143 

144 def _check_connection(self) -> None: 

145 """Check if S3 client is connected.""" 

146 if not self._connected or not self.s3_client: 

147 raise RuntimeError("S3 not connected. Call connect() first.") 

148 

149 def _ensure_bucket_exists(self): 

150 """Ensure the S3 bucket exists, create if necessary.""" 

151 try: 

152 self.s3_client.head_bucket(Bucket=self.bucket) 

153 logger.debug(f"Bucket {self.bucket} exists") 

154 except self.ClientError as e: 

155 error_code = e.response['Error']['Code'] 

156 if error_code == '404': 

157 # Bucket doesn't exist, create it 

158 logger.info(f"Creating bucket {self.bucket}") 

159 if self.region == 'us-east-1': 

160 self.s3_client.create_bucket(Bucket=self.bucket) 

161 else: 

162 self.s3_client.create_bucket( 

163 Bucket=self.bucket, 

164 CreateBucketConfiguration={'LocationConstraint': self.region} 

165 ) 

166 else: 

167 raise 

168 

169 def _get_object_key(self, record_id: str) -> str: 

170 """Generate S3 object key for a record ID.""" 

171 return f"{self.prefix}{record_id}.json" 

172 

173 def _record_to_s3_object(self, record: Record) -> dict[str, Any]: 

174 """Convert a Record to S3 object data.""" 

175 # Use Record's built-in serialization which handles VectorFields 

176 # Use non-flattened format to preserve field metadata 

177 record_dict = record.to_dict(include_metadata=True, flatten=False) 

178 

179 return record_dict 

180 

181 def _s3_object_to_record(self, obj_data: dict[str, Any]) -> Record: 

182 """Convert S3 object data to a Record.""" 

183 # Use Record's built-in deserialization 

184 return Record.from_dict(obj_data) 

185 

186 def create(self, record: Record) -> str: 

187 """Create a new record in S3.""" 

188 self._check_connection() 

189 

190 # Use centralized method to prepare record 

191 record_copy, storage_id = self._prepare_record_for_storage(record) 

192 key = self._get_object_key(storage_id) 

193 

194 # Set metadata 

195 record_copy.metadata = record_copy.metadata or {} 

196 record_copy.metadata["id"] = storage_id 

197 now = datetime.utcnow() 

198 record_copy.metadata["created_at"] = now.isoformat() 

199 record_copy.metadata["updated_at"] = now.isoformat() 

200 

201 # Convert record to JSON 

202 obj_data = self._record_to_s3_object(record_copy) 

203 body = json.dumps(obj_data) 

204 

205 # Store in S3 

206 self.s3_client.put_object( 

207 Bucket=self.bucket, 

208 Key=key, 

209 Body=body, 

210 ContentType='application/json' 

211 ) 

212 

213 # Invalidate cache 

214 self._cache_dirty = True 

215 

216 logger.debug(f"Created record {storage_id} at {key}") 

217 return storage_id 

218 

219 def read(self, id: str) -> Record | None: 

220 """Read a record from S3.""" 

221 self._check_connection() 

222 

223 key = self._get_object_key(id) 

224 

225 try: 

226 response = self.s3_client.get_object(Bucket=self.bucket, Key=key) 

227 body = response['Body'].read() 

228 obj_data = json.loads(body) 

229 record = self._s3_object_to_record(obj_data) 

230 # Use centralized method to prepare record 

231 return self._prepare_record_from_storage(record, id) 

232 except self.ClientError as e: 

233 if e.response['Error']['Code'] == 'NoSuchKey': 

234 return None 

235 raise 

236 

237 def update(self, id: str, record: Record) -> bool: 

238 """Update an existing record in S3.""" 

239 self._check_connection() 

240 

241 key = self._get_object_key(id) 

242 

243 # Check if exists and get existing metadata 

244 try: 

245 response = self.s3_client.get_object(Bucket=self.bucket, Key=key) 

246 existing_data = json.loads(response['Body'].read()) 

247 existing_metadata = existing_data.get("metadata", {}) 

248 except self.ClientError as e: 

249 if e.response['Error']['Code'] == 'NoSuchKey': 

250 return False 

251 raise 

252 

253 # Preserve and update metadata 

254 record.metadata = record.metadata or {} 

255 record.metadata["id"] = id 

256 record.metadata["created_at"] = existing_metadata.get("created_at", datetime.utcnow().isoformat()) 

257 record.metadata["updated_at"] = datetime.utcnow().isoformat() 

258 

259 # Update the object 

260 obj_data = self._record_to_s3_object(record) 

261 body = json.dumps(obj_data) 

262 

263 self.s3_client.put_object( 

264 Bucket=self.bucket, 

265 Key=key, 

266 Body=body, 

267 ContentType='application/json' 

268 ) 

269 

270 # Invalidate cache 

271 self._cache_dirty = True 

272 

273 logger.debug(f"Updated record {id} at {key}") 

274 return True 

275 

276 def delete(self, id: str) -> bool: 

277 """Delete a record from S3.""" 

278 self._check_connection() 

279 

280 key = self._get_object_key(id) 

281 

282 # Check if exists 

283 try: 

284 self.s3_client.head_object(Bucket=self.bucket, Key=key) 

285 except self.ClientError as e: 

286 if e.response['Error']['Code'] == '404': 

287 return False 

288 raise 

289 

290 # Delete the object 

291 self.s3_client.delete_object(Bucket=self.bucket, Key=key) 

292 

293 # Invalidate cache 

294 self._cache_dirty = True 

295 

296 logger.debug(f"Deleted record {id} at {key}") 

297 return True 

298 

299 def list_all(self) -> list[str]: 

300 """List all record IDs in the database. 

301  

302 Returns: 

303 List of all record IDs 

304 """ 

305 self._check_connection() 

306 record_ids = [] 

307 

308 # Use paginator for large buckets 

309 paginator = self.s3_client.get_paginator('list_objects_v2') 

310 page_iterator = paginator.paginate( 

311 Bucket=self.bucket, 

312 Prefix=self.prefix 

313 ) 

314 

315 for page in page_iterator: 

316 if 'Contents' in page: 

317 for obj in page['Contents']: 

318 key = obj['Key'] 

319 # Extract record ID from key 

320 if key.startswith(self.prefix) and key.endswith('.json'): 

321 record_id = key[len(self.prefix):-5] # Remove prefix and .json 

322 record_ids.append(record_id) 

323 

324 logger.debug(f"Listed {len(record_ids)} records from S3") 

325 return record_ids 

326 

327 def exists(self, id: str) -> bool: 

328 """Check if a record exists in S3.""" 

329 self._check_connection() 

330 

331 key = self._get_object_key(id) 

332 

333 try: 

334 self.s3_client.head_object(Bucket=self.bucket, Key=key) 

335 return True 

336 except self.ClientError as e: 

337 if e.response['Error']['Code'] == '404': 

338 return False 

339 raise 

340 

341 def search(self, query: Query) -> list[Record]: 

342 """Search for records matching the query. 

343  

344 Note: S3 doesn't support complex queries, so we need to list and filter. 

345 """ 

346 self._check_connection() 

347 

348 # List all objects with the prefix 

349 records = [] 

350 paginator = self.s3_client.get_paginator('list_objects_v2') 

351 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix) 

352 

353 for page in pages: 

354 if 'Contents' not in page: 

355 continue 

356 

357 # Fetch objects in parallel 

358 with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 

359 futures = [] 

360 for obj in page['Contents']: 

361 if obj['Key'].endswith('.json'): 

362 future = executor.submit(self._fetch_and_filter, obj['Key'], query) 

363 futures.append(future) 

364 

365 for future in as_completed(futures): 

366 record = future.result() 

367 if record: 

368 records.append(record) 

369 

370 # Apply sorting if specified 

371 if query.sort_specs: 

372 for sort_spec in reversed(query.sort_specs): 

373 reverse = sort_spec.order.value == "desc" 

374 # Special handling for 'id' field 

375 if sort_spec.field == 'id': 

376 records.sort( 

377 key=lambda r: r.id or "", 

378 reverse=reverse 

379 ) 

380 else: 

381 records.sort( 

382 key=lambda r: r.get_value(sort_spec.field, ""), 

383 reverse=reverse 

384 ) 

385 

386 # Apply offset and limit 

387 if query.offset_value: 

388 records = records[query.offset_value:] 

389 if query.limit_value: 

390 records = records[:query.limit_value] 

391 

392 # Apply field projection 

393 if query.fields: 

394 records = [r.project(query.fields) for r in records] 

395 

396 return records 

397 

398 def _fetch_and_filter(self, key: str, query: Query) -> Record | None: 

399 """Fetch an object and apply query filters.""" 

400 try: 

401 response = self.s3_client.get_object(Bucket=self.bucket, Key=key) 

402 body = response['Body'].read() 

403 obj_data = json.loads(body) 

404 record = self._s3_object_to_record(obj_data) 

405 

406 # Apply filters 

407 for filter in query.filters: 

408 # Special handling for 'id' field 

409 if filter.field == 'id': 

410 field_value = record.id 

411 else: 

412 field_value = record.get_value(filter.field) 

413 if not filter.matches(field_value): 

414 return None 

415 

416 return record 

417 except Exception as e: 

418 logger.warning(f"Error fetching {key}: {e}") 

419 return None 

420 

421 def _count_all(self) -> int: 

422 """Count all records in S3.""" 

423 self._check_connection() 

424 

425 count = 0 

426 paginator = self.s3_client.get_paginator('list_objects_v2') 

427 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix) 

428 

429 for page in pages: 

430 if 'Contents' in page: 

431 count += sum(1 for obj in page['Contents'] if obj['Key'].endswith('.json')) 

432 

433 return count 

434 

435 def clear(self) -> int: 

436 """Clear all records from S3.""" 

437 self._check_connection() 

438 

439 # List and delete all objects 

440 count = 0 

441 paginator = self.s3_client.get_paginator('list_objects_v2') 

442 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix) 

443 

444 for page in pages: 

445 if 'Contents' not in page: 

446 continue 

447 

448 # Delete in batches 

449 objects = [{'Key': obj['Key']} for obj in page['Contents'] if obj['Key'].endswith('.json')] 

450 if objects: 

451 self.s3_client.delete_objects( 

452 Bucket=self.bucket, 

453 Delete={'Objects': objects} 

454 ) 

455 count += len(objects) 

456 

457 # Clear cache 

458 self._index_cache = {} 

459 self._cache_dirty = True 

460 

461 logger.info(f"Cleared {count} records from S3") 

462 return count 

463 

464 def stream_read( 

465 self, 

466 query: Query | None = None, 

467 config: StreamConfig | None = None 

468 ) -> Iterator[Record]: 

469 """Stream records from S3.""" 

470 self._check_connection() 

471 config = config or StreamConfig() 

472 

473 # List objects and stream them 

474 paginator = self.s3_client.get_paginator('list_objects_v2') 

475 pages = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix) 

476 

477 batch = [] 

478 for page in pages: 

479 if 'Contents' not in page: 

480 continue 

481 

482 for obj in page['Contents']: 

483 if not obj['Key'].endswith('.json'): 

484 continue 

485 

486 record = self._fetch_and_filter(obj['Key'], query or Query()) 

487 if record: 

488 batch.append(record) 

489 

490 if len(batch) >= config.batch_size: 

491 for r in batch: 

492 yield r 

493 batch = [] 

494 

495 # Yield remaining records 

496 for r in batch: 

497 yield r 

498 

499 def stream_write( 

500 self, 

501 records: Iterator[Record], 

502 config: StreamConfig | None = None 

503 ) -> StreamResult: 

504 """Stream records into S3.""" 

505 self._check_connection() 

506 config = config or StreamConfig() 

507 result = StreamResult() 

508 start_time = time.time() 

509 quitting = False 

510 

511 batch = [] 

512 for record in records: 

513 batch.append(record) 

514 

515 if len(batch) >= config.batch_size: 

516 # Write batch with graceful fallback 

517 # Use lambda wrapper for _write_batch 

518 continue_processing = process_batch_with_fallback( 

519 batch, 

520 lambda b: self._write_batch(b), 

521 self.create, 

522 result, 

523 config 

524 ) 

525 

526 if not continue_processing: 

527 quitting = True 

528 break 

529 

530 batch = [] 

531 

532 # Write remaining batch 

533 if batch and not quitting: 

534 process_batch_with_fallback( 

535 batch, 

536 lambda b: self._write_batch(b), 

537 self.create, 

538 result, 

539 config 

540 ) 

541 

542 result.duration = time.time() - start_time 

543 return result 

544 

545 def _write_batch(self, records: list[Record]) -> list[str]: 

546 """Write a batch of records to S3. 

547  

548 Returns: 

549 List of created record IDs 

550 """ 

551 ids = [] 

552 with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 

553 futures = [] 

554 for record in records: 

555 record_id = str(uuid4()) 

556 ids.append(record_id) 

557 future = executor.submit(self._write_single, record_id, record) 

558 futures.append(future) 

559 

560 # Wait for all writes to complete 

561 for future in as_completed(futures): 

562 future.result() # This will raise if there was an error 

563 

564 return ids 

565 

566 def _write_single(self, record_id: str, record: Record) -> None: 

567 """Write a single record to S3.""" 

568 # Set metadata 

569 record.metadata = record.metadata or {} 

570 record.metadata["id"] = record_id 

571 now = datetime.utcnow() 

572 record.metadata["created_at"] = now.isoformat() 

573 record.metadata["updated_at"] = now.isoformat() 

574 

575 key = self._get_object_key(record_id) 

576 obj_data = self._record_to_s3_object(record) 

577 body = json.dumps(obj_data) 

578 

579 self.s3_client.put_object( 

580 Bucket=self.bucket, 

581 Key=key, 

582 Body=body, 

583 ContentType='application/json' 

584 ) 

585 

586 def vector_search( 

587 self, 

588 query_vector, 

589 vector_field: str = "embedding", 

590 k: int = 10, 

591 filter=None, 

592 metric=None, 

593 **kwargs 

594 ): 

595 """Perform vector similarity search using Python calculations. 

596  

597 WARNING: This implementation downloads all records from S3 to perform 

598 the search locally. This is inefficient for large datasets. Consider 

599 using a vector-enabled backend like PostgreSQL or Elasticsearch for 

600 production use with large datasets. 

601 """ 

602 return self.python_vector_search_sync( 

603 query_vector=query_vector, 

604 vector_field=vector_field, 

605 k=k, 

606 filter=filter, 

607 metric=metric, 

608 **kwargs 

609 ) 

610 

611 

612# Import the native async implementation