Coverage for src/dataknobs_data/backends/s3_async.py: 15%

318 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:23 -0700

1"""Native async S3 backend implementation with aioboto3 and connection pooling.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import json 

7import logging 

8import time 

9import uuid 

10from datetime import datetime 

11from typing import Any, TYPE_CHECKING, cast, Callable, Awaitable 

12 

13from dataknobs_config import ConfigurableBase 

14 

15from ..database import AsyncDatabase 

16from ..pooling import ConnectionPoolManager 

17from ..pooling.s3 import S3PoolConfig, create_aioboto3_session, validate_s3_session 

18from ..query import Operator, Query, SortOrder 

19from ..records import Record 

20from ..streaming import StreamConfig, StreamResult, async_process_batch_with_fallback 

21from ..vector import VectorOperationsMixin 

22from ..vector.bulk_embed_mixin import BulkEmbedMixin 

23from ..vector.python_vector_search import PythonVectorSearchMixin 

24from .sqlite_mixins import SQLiteVectorSupport 

25from .vector_config_mixin import VectorConfigMixin 

26 

27if TYPE_CHECKING: 

28 from collections.abc import AsyncIterator 

29 

30 

31logger = logging.getLogger(__name__) 

32 

33# Global pool manager for S3 sessions 

34_session_manager = ConnectionPoolManager() 

35 

36 

37class AsyncS3Database( # type: ignore[misc] 

38 AsyncDatabase, 

39 ConfigurableBase, 

40 VectorConfigMixin, 

41 SQLiteVectorSupport, 

42 PythonVectorSearchMixin, 

43 BulkEmbedMixin, 

44 VectorOperationsMixin 

45): 

46 """Native async S3 database backend with aioboto3 and session pooling.""" 

47 

48 def __init__(self, config: dict[str, Any] | None = None): 

49 """Initialize async S3 database.""" 

50 super().__init__(config) 

51 

52 if not config or "bucket" not in config: 

53 raise ValueError("S3 backend requires 'bucket' in configuration") 

54 

55 self._pool_config = S3PoolConfig.from_dict(config) 

56 self._session = None 

57 self._connected = False 

58 

59 # Initialize vector support 

60 self._parse_vector_config(config or {}) 

61 self._init_vector_state() # From SQLiteVectorSupport 

62 

63 @classmethod 

64 def from_config(cls, config: dict) -> AsyncS3Database: 

65 """Create from config dictionary.""" 

66 return cls(config) 

67 

68 async def connect(self) -> None: 

69 """Connect to S3 service.""" 

70 if self._connected: 

71 return 

72 

73 # Get or create session for current event loop 

74 from ..pooling import BasePoolConfig 

75 self._session = await _session_manager.get_pool( 

76 self._pool_config, 

77 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_aioboto3_session), 

78 lambda session: validate_s3_session(session, self._pool_config) 

79 ) 

80 

81 self._connected = True 

82 

83 async def close(self) -> None: 

84 """Close the S3 connection.""" 

85 if self._connected: 

86 self._session = None 

87 self._connected = False 

88 

89 def _initialize(self) -> None: 

90 """Initialize is handled in connect.""" 

91 pass 

92 

93 def _check_connection(self) -> None: 

94 """Check if database is connected.""" 

95 if not self._connected or not self._session: 

96 raise RuntimeError("Database not connected. Call connect() first.") 

97 

98 def _get_key(self, id: str) -> str: 

99 """Get the S3 key for a given record ID.""" 

100 if self._pool_config.prefix: 

101 return f"{self._pool_config.prefix}/{id}.json" 

102 return f"{id}.json" 

103 

104 def _record_to_s3_object(self, record: Record) -> dict[str, Any]: 

105 """Convert a Record to an S3 object.""" 

106 # Use Record's built-in serialization which handles VectorFields 

107 record_dict = record.to_dict(include_metadata=True, flatten=False) 

108 

109 # Add timestamps 

110 now = datetime.utcnow().isoformat() 

111 if "metadata" not in record_dict: 

112 record_dict["metadata"] = {} 

113 record_dict["metadata"]["created_at"] = record_dict["metadata"].get("created_at", now) 

114 record_dict["metadata"]["updated_at"] = now 

115 

116 return record_dict 

117 

118 def _s3_object_to_record(self, obj: dict[str, Any]) -> Record: 

119 """Convert an S3 object to a Record.""" 

120 # Use Record's built-in deserialization 

121 return Record.from_dict(obj) 

122 

123 async def create(self, record: Record) -> str: 

124 """Create a new record in S3.""" 

125 self._check_connection() 

126 

127 # Use centralized method to prepare record 

128 record_copy, storage_id = self._prepare_record_for_storage(record) 

129 key = self._get_key(storage_id) 

130 obj = self._record_to_s3_object(record_copy) 

131 

132 # Add ID to metadata 

133 obj["metadata"]["id"] = storage_id 

134 

135 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

136 await s3.put_object( 

137 Bucket=self._pool_config.bucket, 

138 Key=key, 

139 Body=json.dumps(obj), 

140 ContentType="application/json" 

141 ) 

142 

143 return storage_id 

144 

145 async def read(self, id: str) -> Record | None: 

146 """Read a record from S3.""" 

147 self._check_connection() 

148 

149 key = self._get_key(id) 

150 

151 try: 

152 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

153 response = await s3.get_object( 

154 Bucket=self._pool_config.bucket, 

155 Key=key 

156 ) 

157 

158 # Read the object body 

159 body = await response['Body'].read() 

160 obj = json.loads(body) 

161 

162 record = self._s3_object_to_record(obj) 

163 # Use centralized method to prepare record 

164 record = self._prepare_record_from_storage(record, id) 

165 # Ensure ID is in metadata 

166 record.metadata["id"] = id 

167 

168 return record 

169 except Exception: 

170 return None 

171 

172 async def update(self, id: str, record: Record) -> bool: 

173 """Update an existing record in S3.""" 

174 self._check_connection() 

175 

176 # Check if record exists 

177 if not await self.exists(id): 

178 return False 

179 

180 key = self._get_key(id) 

181 obj = self._record_to_s3_object(record) 

182 

183 # Preserve ID in metadata 

184 obj["metadata"]["id"] = id 

185 

186 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

187 await s3.put_object( 

188 Bucket=self._pool_config.bucket, 

189 Key=key, 

190 Body=json.dumps(obj), 

191 ContentType="application/json" 

192 ) 

193 

194 return True 

195 

196 async def delete(self, id: str) -> bool: 

197 """Delete a record from S3.""" 

198 self._check_connection() 

199 

200 key = self._get_key(id) 

201 

202 try: 

203 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

204 await s3.delete_object( 

205 Bucket=self._pool_config.bucket, 

206 Key=key 

207 ) 

208 return True 

209 except Exception: 

210 return False 

211 

212 async def exists(self, id: str) -> bool: 

213 """Check if a record exists in S3.""" 

214 self._check_connection() 

215 

216 key = self._get_key(id) 

217 

218 try: 

219 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

220 await s3.head_object( 

221 Bucket=self._pool_config.bucket, 

222 Key=key 

223 ) 

224 return True 

225 except Exception: 

226 return False 

227 

228 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str: 

229 """Update or insert a record. 

230  

231 Can be called as: 

232 - upsert(id, record) - explicit ID and record 

233 - upsert(record) - extract ID from record using Record's built-in logic 

234 """ 

235 self._check_connection() 

236 

237 # Determine ID and record based on arguments 

238 if isinstance(id_or_record, str): 

239 id = id_or_record 

240 if record is None: 

241 raise ValueError("Record required when ID is provided") 

242 else: 

243 record = id_or_record 

244 id = record.id 

245 if id is None: 

246 import uuid # type: ignore[unreachable] 

247 id = str(uuid.uuid4()) 

248 record.storage_id = id 

249 

250 key = self._get_key(id) 

251 obj = self._record_to_s3_object(record) 

252 

253 # Add ID to metadata 

254 obj["metadata"]["id"] = id 

255 

256 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

257 await s3.put_object( 

258 Bucket=self._pool_config.bucket, 

259 Key=key, 

260 Body=json.dumps(obj), 

261 ContentType="application/json" 

262 ) 

263 

264 return id 

265 

266 async def search(self, query: Query) -> list[Record]: 

267 """Search for records matching the query.""" 

268 self._check_connection() 

269 

270 # S3 doesn't support complex queries, so we need to list and filter 

271 records = [] 

272 

273 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

274 # List all objects 

275 paginator = s3.get_paginator('list_objects_v2') 

276 

277 params = { 

278 'Bucket': self._pool_config.bucket, 

279 } 

280 if self._pool_config.prefix: 

281 params['Prefix'] = self._pool_config.prefix 

282 

283 async for page in paginator.paginate(**params): 

284 if 'Contents' not in page: 

285 continue 

286 

287 # Process each object 

288 for obj_summary in page['Contents']: 

289 key = obj_summary['Key'] 

290 

291 # Skip non-JSON files 

292 if not key.endswith('.json'): 

293 continue 

294 

295 # Get the object 

296 response = await s3.get_object( 

297 Bucket=self._pool_config.bucket, 

298 Key=key 

299 ) 

300 

301 body = await response['Body'].read() 

302 obj = json.loads(body) 

303 record = self._s3_object_to_record(obj) 

304 

305 # Extract ID from key 

306 id = key.replace(self._pool_config.prefix + '/', '').replace('.json', '') 

307 record.metadata["id"] = id 

308 

309 # Apply filters 

310 if self._matches_filters(record, query.filters): 

311 records.append(record) 

312 

313 # Apply sorting 

314 if query.sort_specs: 

315 for sort_spec in reversed(query.sort_specs): 

316 reverse = sort_spec.order == SortOrder.DESC 

317 records.sort( 

318 key=lambda r: (r.get_field(sort_spec.field).value if r.get_field(sort_spec.field) else "") or "", 

319 reverse=reverse 

320 ) 

321 

322 # Apply offset and limit 

323 if query.offset_value: 

324 records = records[query.offset_value:] 

325 if query.limit_value: 

326 records = records[:query.limit_value] 

327 

328 # Apply field projection 

329 if query.fields: 

330 records = [r.project(query.fields) for r in records] 

331 

332 return records 

333 

334 def _matches_filters(self, record: Record, filters: list) -> bool: 

335 """Check if a record matches all filters.""" 

336 for filter in filters: 

337 field = record.get_field(filter.field) 

338 if not field: 

339 return False 

340 

341 value = field.value 

342 

343 if filter.operator == Operator.EQ: 

344 if value != filter.value: 

345 return False 

346 elif filter.operator == Operator.NEQ: 

347 if value == filter.value: 

348 return False 

349 elif filter.operator == Operator.GT: 

350 if value <= filter.value: 

351 return False 

352 elif filter.operator == Operator.LT: 

353 if value >= filter.value: 

354 return False 

355 elif filter.operator == Operator.GTE: 

356 if value < filter.value: 

357 return False 

358 elif filter.operator == Operator.LTE: 

359 if value > filter.value: 

360 return False 

361 elif filter.operator == Operator.LIKE: 

362 if str(filter.value) not in str(value): 

363 return False 

364 elif filter.operator == Operator.IN: 

365 if value not in filter.value: 

366 return False 

367 elif filter.operator == Operator.NOT_IN: 

368 if value in filter.value: 

369 return False 

370 

371 return True 

372 

373 async def _count_all(self) -> int: 

374 """Count all records in the database.""" 

375 self._check_connection() 

376 

377 count = 0 

378 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

379 paginator = s3.get_paginator('list_objects_v2') 

380 

381 params = { 

382 'Bucket': self._pool_config.bucket, 

383 } 

384 if self._pool_config.prefix: 

385 params['Prefix'] = self._pool_config.prefix 

386 

387 async for page in paginator.paginate(**params): 

388 if 'Contents' in page: 

389 for obj in page['Contents']: 

390 if obj['Key'].endswith('.json'): 

391 count += 1 

392 

393 return count 

394 

395 async def clear(self) -> int: 

396 """Clear all records from the database.""" 

397 self._check_connection() 

398 

399 count = 0 

400 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

401 # List and delete all objects 

402 paginator = s3.get_paginator('list_objects_v2') 

403 

404 params = { 

405 'Bucket': self._pool_config.bucket, 

406 } 

407 if self._pool_config.prefix: 

408 params['Prefix'] = self._pool_config.prefix 

409 

410 async for page in paginator.paginate(**params): 

411 if 'Contents' not in page: 

412 continue 

413 

414 # Build delete request 

415 objects_to_delete = [] 

416 for obj in page['Contents']: 

417 if obj['Key'].endswith('.json'): 

418 objects_to_delete.append({'Key': obj['Key']}) 

419 count += 1 

420 

421 # Delete in batch 

422 if objects_to_delete: 

423 await s3.delete_objects( 

424 Bucket=self._pool_config.bucket, 

425 Delete={'Objects': objects_to_delete} 

426 ) 

427 

428 return count 

429 

430 async def stream_read( 

431 self, 

432 query: Query | None = None, 

433 config: StreamConfig | None = None 

434 ) -> AsyncIterator[Record]: 

435 """Stream records from S3.""" 

436 self._check_connection() 

437 config = config or StreamConfig() 

438 

439 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

440 paginator = s3.get_paginator('list_objects_v2') 

441 

442 params = { 

443 'Bucket': self._pool_config.bucket, 

444 'MaxKeys': config.batch_size 

445 } 

446 if self._pool_config.prefix: 

447 params['Prefix'] = self._pool_config.prefix 

448 

449 async for page in paginator.paginate(**params): 

450 if 'Contents' not in page: 

451 continue 

452 

453 for obj_summary in page['Contents']: 

454 key = obj_summary['Key'] 

455 

456 if not key.endswith('.json'): 

457 continue 

458 

459 # Get the object 

460 response = await s3.get_object( 

461 Bucket=self._pool_config.bucket, 

462 Key=key 

463 ) 

464 

465 body = await response['Body'].read() 

466 obj = json.loads(body) 

467 record = self._s3_object_to_record(obj) 

468 

469 # Extract ID from key 

470 id = key.replace(self._pool_config.prefix + '/', '').replace('.json', '') 

471 record.metadata["id"] = id 

472 

473 # Apply filters if query provided 

474 if query and query.filters: 

475 if not self._matches_filters(record, query.filters): 

476 continue 

477 

478 # Apply field projection 

479 if query and query.fields: 

480 record = record.project(query.fields) 

481 

482 yield record 

483 

484 async def stream_write( 

485 self, 

486 records: AsyncIterator[Record], 

487 config: StreamConfig | None = None 

488 ) -> StreamResult: 

489 """Stream records into S3.""" 

490 self._check_connection() 

491 config = config or StreamConfig() 

492 result = StreamResult() 

493 start_time = time.time() 

494 quitting = False 

495 

496 batch = [] 

497 async for record in records: 

498 batch.append(record) 

499 

500 if len(batch) >= config.batch_size: 

501 # Write batch with graceful fallback 

502 async def batch_func(b): 

503 await self._write_batch(b) 

504 return [r.id for r in b] 

505 

506 continue_processing = await async_process_batch_with_fallback( 

507 batch, 

508 batch_func, 

509 self.create, 

510 result, 

511 config 

512 ) 

513 

514 if not continue_processing: 

515 quitting = True 

516 break 

517 

518 batch = [] 

519 

520 # Write remaining batch 

521 if batch and not quitting: 

522 async def batch_func(b): 

523 await self._write_batch(b) 

524 return [r.id for r in b] 

525 

526 await async_process_batch_with_fallback( 

527 batch, 

528 batch_func, 

529 self.create, 

530 result, 

531 config 

532 ) 

533 

534 result.duration = time.time() - start_time 

535 return result 

536 

537 async def _write_batch(self, records: list[Record]) -> None: 

538 """Write a batch of records to S3.""" 

539 if not records: 

540 return 

541 

542 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

543 # Write each record (S3 doesn't have native batch write) 

544 # We could potentially use multipart upload for very large batches 

545 tasks = [] 

546 for record in records: 

547 id = str(uuid.uuid4()) 

548 key = self._get_key(id) 

549 obj = self._record_to_s3_object(record) 

550 obj["metadata"]["id"] = id 

551 

552 task = s3.put_object( 

553 Bucket=self._pool_config.bucket, 

554 Key=key, 

555 Body=json.dumps(obj), 

556 ContentType="application/json" 

557 ) 

558 tasks.append(task) 

559 

560 # Execute all uploads concurrently 

561 await asyncio.gather(*tasks) 

562 

563 async def list_all(self) -> list[str]: 

564 """List all record IDs in the database.""" 

565 self._check_connection() 

566 

567 ids = [] 

568 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

569 paginator = s3.get_paginator('list_objects_v2') 

570 

571 params = { 

572 'Bucket': self._pool_config.bucket, 

573 } 

574 if self._pool_config.prefix: 

575 params['Prefix'] = self._pool_config.prefix 

576 

577 async for page in paginator.paginate(**params): 

578 if 'Contents' not in page: 

579 continue 

580 

581 for obj in page['Contents']: 

582 key = obj['Key'] 

583 if key.endswith('.json'): 

584 # Extract ID from key 

585 id = key.replace(self._pool_config.prefix + '/', '').replace('.json', '') 

586 ids.append(id) 

587 

588 return ids 

589 

590 async def vector_search( 

591 self, 

592 query_vector, 

593 vector_field: str = "embedding", 

594 k: int = 10, 

595 filter=None, 

596 metric=None, 

597 **kwargs 

598 ): 

599 """Perform vector similarity search using Python calculations. 

600  

601 WARNING: This implementation downloads all records from S3 to perform 

602 the search locally. This is inefficient for large datasets. Consider 

603 using a vector-enabled backend like PostgreSQL or Elasticsearch for 

604 production use with large datasets. 

605  

606 Future optimization: Override this method to use AWS OpenSearch or 

607 similar vector-enabled service when available. 

608 """ 

609 return await self.python_vector_search_async( 

610 query_vector=query_vector, 

611 vector_field=vector_field, 

612 k=k, 

613 filter=filter, 

614 metric=metric, 

615 **kwargs 

616 )