Coverage for src/dataknobs_data/backends/s3_async.py: 16%

308 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-31 15:06 -0600

1"""Native async S3 backend implementation with aioboto3 and connection pooling.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import json 

7import logging 

8import time 

9import uuid 

10from datetime import datetime 

11from typing import Any, TYPE_CHECKING, cast, Callable, Awaitable 

12 

13from dataknobs_config import ConfigurableBase 

14 

15from ..database import AsyncDatabase 

16from ..pooling import ConnectionPoolManager 

17from ..pooling.s3 import S3PoolConfig, create_aioboto3_session, validate_s3_session 

18from ..query import Operator, Query, SortOrder 

19from ..records import Record 

20from ..streaming import StreamConfig, StreamResult, async_process_batch_with_fallback 

21from ..vector import VectorOperationsMixin 

22from ..vector.bulk_embed_mixin import BulkEmbedMixin 

23from ..vector.python_vector_search import PythonVectorSearchMixin 

24from .sqlite_mixins import SQLiteVectorSupport 

25from .vector_config_mixin import VectorConfigMixin 

26 

27if TYPE_CHECKING: 

28 from collections.abc import AsyncIterator 

29 

30 

31logger = logging.getLogger(__name__) 

32 

33# Global pool manager for S3 sessions 

34_session_manager = ConnectionPoolManager() 

35 

36 

37class AsyncS3Database( # type: ignore[misc] 

38 AsyncDatabase, 

39 ConfigurableBase, 

40 VectorConfigMixin, 

41 SQLiteVectorSupport, 

42 PythonVectorSearchMixin, 

43 BulkEmbedMixin, 

44 VectorOperationsMixin 

45): 

46 """Native async S3 database backend with aioboto3 and session pooling.""" 

47 

48 def __init__(self, config: dict[str, Any] | None = None): 

49 """Initialize async S3 database.""" 

50 super().__init__(config) 

51 

52 if not config or "bucket" not in config: 

53 raise ValueError("S3 backend requires 'bucket' in configuration") 

54 

55 self._pool_config = S3PoolConfig.from_dict(config) 

56 self._session = None 

57 self._connected = False 

58 

59 # Initialize vector support 

60 self._parse_vector_config(config or {}) 

61 self._init_vector_state() # From SQLiteVectorSupport 

62 

63 @classmethod 

64 def from_config(cls, config: dict) -> AsyncS3Database: 

65 """Create from config dictionary.""" 

66 return cls(config) 

67 

68 async def connect(self) -> None: 

69 """Connect to S3 service.""" 

70 if self._connected: 

71 return 

72 

73 # Get or create session for current event loop 

74 from ..pooling import BasePoolConfig 

75 self._session = await _session_manager.get_pool( 

76 self._pool_config, 

77 cast(Callable[[BasePoolConfig], Awaitable[Any]], create_aioboto3_session), 

78 lambda session: validate_s3_session(session, self._pool_config) 

79 ) 

80 

81 self._connected = True 

82 

83 async def close(self) -> None: 

84 """Close the S3 connection.""" 

85 if self._connected: 

86 self._session = None 

87 self._connected = False 

88 

89 def _initialize(self) -> None: 

90 """Initialize is handled in connect.""" 

91 pass 

92 

93 def _check_connection(self) -> None: 

94 """Check if database is connected.""" 

95 if not self._connected or not self._session: 

96 raise RuntimeError("Database not connected. Call connect() first.") 

97 

98 def _get_key(self, id: str) -> str: 

99 """Get the S3 key for a given record ID.""" 

100 if self._pool_config.prefix: 

101 return f"{self._pool_config.prefix}/{id}.json" 

102 return f"{id}.json" 

103 

104 def _record_to_s3_object(self, record: Record) -> dict[str, Any]: 

105 """Convert a Record to an S3 object.""" 

106 # Use Record's built-in serialization which handles VectorFields 

107 record_dict = record.to_dict(include_metadata=True, flatten=False) 

108 

109 # Add timestamps 

110 now = datetime.utcnow().isoformat() 

111 if "metadata" not in record_dict: 

112 record_dict["metadata"] = {} 

113 record_dict["metadata"]["created_at"] = record_dict["metadata"].get("created_at", now) 

114 record_dict["metadata"]["updated_at"] = now 

115 

116 return record_dict 

117 

118 def _s3_object_to_record(self, obj: dict[str, Any]) -> Record: 

119 """Convert an S3 object to a Record.""" 

120 # Use Record's built-in deserialization 

121 return Record.from_dict(obj) 

122 

123 async def create(self, record: Record) -> str: 

124 """Create a new record in S3.""" 

125 self._check_connection() 

126 

127 # Use centralized method to prepare record 

128 record_copy, storage_id = self._prepare_record_for_storage(record) 

129 key = self._get_key(storage_id) 

130 obj = self._record_to_s3_object(record_copy) 

131 

132 # Add ID to metadata 

133 obj["metadata"]["id"] = storage_id 

134 

135 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

136 await s3.put_object( 

137 Bucket=self._pool_config.bucket, 

138 Key=key, 

139 Body=json.dumps(obj), 

140 ContentType="application/json" 

141 ) 

142 

143 return storage_id 

144 

145 async def read(self, id: str) -> Record | None: 

146 """Read a record from S3.""" 

147 self._check_connection() 

148 

149 key = self._get_key(id) 

150 

151 try: 

152 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

153 response = await s3.get_object( 

154 Bucket=self._pool_config.bucket, 

155 Key=key 

156 ) 

157 

158 # Read the object body 

159 body = await response['Body'].read() 

160 obj = json.loads(body) 

161 

162 record = self._s3_object_to_record(obj) 

163 # Use centralized method to prepare record 

164 record = self._prepare_record_from_storage(record, id) 

165 # Ensure ID is in metadata 

166 record.metadata["id"] = id 

167 

168 return record 

169 except Exception: 

170 return None 

171 

172 async def update(self, id: str, record: Record) -> bool: 

173 """Update an existing record in S3.""" 

174 self._check_connection() 

175 

176 # Check if record exists 

177 if not await self.exists(id): 

178 return False 

179 

180 key = self._get_key(id) 

181 obj = self._record_to_s3_object(record) 

182 

183 # Preserve ID in metadata 

184 obj["metadata"]["id"] = id 

185 

186 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

187 await s3.put_object( 

188 Bucket=self._pool_config.bucket, 

189 Key=key, 

190 Body=json.dumps(obj), 

191 ContentType="application/json" 

192 ) 

193 

194 return True 

195 

196 async def delete(self, id: str) -> bool: 

197 """Delete a record from S3.""" 

198 self._check_connection() 

199 

200 key = self._get_key(id) 

201 

202 try: 

203 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

204 await s3.delete_object( 

205 Bucket=self._pool_config.bucket, 

206 Key=key 

207 ) 

208 return True 

209 except Exception: 

210 return False 

211 

212 async def exists(self, id: str) -> bool: 

213 """Check if a record exists in S3.""" 

214 self._check_connection() 

215 

216 key = self._get_key(id) 

217 

218 try: 

219 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

220 await s3.head_object( 

221 Bucket=self._pool_config.bucket, 

222 Key=key 

223 ) 

224 return True 

225 except Exception: 

226 return False 

227 

228 async def upsert(self, id: str, record: Record) -> str: 

229 """Update or insert a record with a specific ID.""" 

230 self._check_connection() 

231 

232 key = self._get_key(id) 

233 obj = self._record_to_s3_object(record) 

234 

235 # Add ID to metadata 

236 obj["metadata"]["id"] = id 

237 

238 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

239 await s3.put_object( 

240 Bucket=self._pool_config.bucket, 

241 Key=key, 

242 Body=json.dumps(obj), 

243 ContentType="application/json" 

244 ) 

245 

246 return id 

247 

248 async def search(self, query: Query) -> list[Record]: 

249 """Search for records matching the query.""" 

250 self._check_connection() 

251 

252 # S3 doesn't support complex queries, so we need to list and filter 

253 records = [] 

254 

255 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

256 # List all objects 

257 paginator = s3.get_paginator('list_objects_v2') 

258 

259 params = { 

260 'Bucket': self._pool_config.bucket, 

261 } 

262 if self._pool_config.prefix: 

263 params['Prefix'] = self._pool_config.prefix 

264 

265 async for page in paginator.paginate(**params): 

266 if 'Contents' not in page: 

267 continue 

268 

269 # Process each object 

270 for obj_summary in page['Contents']: 

271 key = obj_summary['Key'] 

272 

273 # Skip non-JSON files 

274 if not key.endswith('.json'): 

275 continue 

276 

277 # Get the object 

278 response = await s3.get_object( 

279 Bucket=self._pool_config.bucket, 

280 Key=key 

281 ) 

282 

283 body = await response['Body'].read() 

284 obj = json.loads(body) 

285 record = self._s3_object_to_record(obj) 

286 

287 # Extract ID from key 

288 id = key.replace(self._pool_config.prefix + '/', '').replace('.json', '') 

289 record.metadata["id"] = id 

290 

291 # Apply filters 

292 if self._matches_filters(record, query.filters): 

293 records.append(record) 

294 

295 # Apply sorting 

296 if query.sort_specs: 

297 for sort_spec in reversed(query.sort_specs): 

298 reverse = sort_spec.order == SortOrder.DESC 

299 records.sort( 

300 key=lambda r: (r.get_field(sort_spec.field).value if r.get_field(sort_spec.field) else "") or "", 

301 reverse=reverse 

302 ) 

303 

304 # Apply offset and limit 

305 if query.offset_value: 

306 records = records[query.offset_value:] 

307 if query.limit_value: 

308 records = records[:query.limit_value] 

309 

310 # Apply field projection 

311 if query.fields: 

312 records = [r.project(query.fields) for r in records] 

313 

314 return records 

315 

316 def _matches_filters(self, record: Record, filters: list) -> bool: 

317 """Check if a record matches all filters.""" 

318 for filter in filters: 

319 field = record.get_field(filter.field) 

320 if not field: 

321 return False 

322 

323 value = field.value 

324 

325 if filter.operator == Operator.EQ: 

326 if value != filter.value: 

327 return False 

328 elif filter.operator == Operator.NEQ: 

329 if value == filter.value: 

330 return False 

331 elif filter.operator == Operator.GT: 

332 if value <= filter.value: 

333 return False 

334 elif filter.operator == Operator.LT: 

335 if value >= filter.value: 

336 return False 

337 elif filter.operator == Operator.GTE: 

338 if value < filter.value: 

339 return False 

340 elif filter.operator == Operator.LTE: 

341 if value > filter.value: 

342 return False 

343 elif filter.operator == Operator.LIKE: 

344 if str(filter.value) not in str(value): 

345 return False 

346 elif filter.operator == Operator.IN: 

347 if value not in filter.value: 

348 return False 

349 elif filter.operator == Operator.NOT_IN: 

350 if value in filter.value: 

351 return False 

352 

353 return True 

354 

355 async def _count_all(self) -> int: 

356 """Count all records in the database.""" 

357 self._check_connection() 

358 

359 count = 0 

360 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

361 paginator = s3.get_paginator('list_objects_v2') 

362 

363 params = { 

364 'Bucket': self._pool_config.bucket, 

365 } 

366 if self._pool_config.prefix: 

367 params['Prefix'] = self._pool_config.prefix 

368 

369 async for page in paginator.paginate(**params): 

370 if 'Contents' in page: 

371 for obj in page['Contents']: 

372 if obj['Key'].endswith('.json'): 

373 count += 1 

374 

375 return count 

376 

377 async def clear(self) -> int: 

378 """Clear all records from the database.""" 

379 self._check_connection() 

380 

381 count = 0 

382 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

383 # List and delete all objects 

384 paginator = s3.get_paginator('list_objects_v2') 

385 

386 params = { 

387 'Bucket': self._pool_config.bucket, 

388 } 

389 if self._pool_config.prefix: 

390 params['Prefix'] = self._pool_config.prefix 

391 

392 async for page in paginator.paginate(**params): 

393 if 'Contents' not in page: 

394 continue 

395 

396 # Build delete request 

397 objects_to_delete = [] 

398 for obj in page['Contents']: 

399 if obj['Key'].endswith('.json'): 

400 objects_to_delete.append({'Key': obj['Key']}) 

401 count += 1 

402 

403 # Delete in batch 

404 if objects_to_delete: 

405 await s3.delete_objects( 

406 Bucket=self._pool_config.bucket, 

407 Delete={'Objects': objects_to_delete} 

408 ) 

409 

410 return count 

411 

412 async def stream_read( 

413 self, 

414 query: Query | None = None, 

415 config: StreamConfig | None = None 

416 ) -> AsyncIterator[Record]: 

417 """Stream records from S3.""" 

418 self._check_connection() 

419 config = config or StreamConfig() 

420 

421 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

422 paginator = s3.get_paginator('list_objects_v2') 

423 

424 params = { 

425 'Bucket': self._pool_config.bucket, 

426 'MaxKeys': config.batch_size 

427 } 

428 if self._pool_config.prefix: 

429 params['Prefix'] = self._pool_config.prefix 

430 

431 async for page in paginator.paginate(**params): 

432 if 'Contents' not in page: 

433 continue 

434 

435 for obj_summary in page['Contents']: 

436 key = obj_summary['Key'] 

437 

438 if not key.endswith('.json'): 

439 continue 

440 

441 # Get the object 

442 response = await s3.get_object( 

443 Bucket=self._pool_config.bucket, 

444 Key=key 

445 ) 

446 

447 body = await response['Body'].read() 

448 obj = json.loads(body) 

449 record = self._s3_object_to_record(obj) 

450 

451 # Extract ID from key 

452 id = key.replace(self._pool_config.prefix + '/', '').replace('.json', '') 

453 record.metadata["id"] = id 

454 

455 # Apply filters if query provided 

456 if query and query.filters: 

457 if not self._matches_filters(record, query.filters): 

458 continue 

459 

460 # Apply field projection 

461 if query and query.fields: 

462 record = record.project(query.fields) 

463 

464 yield record 

465 

466 async def stream_write( 

467 self, 

468 records: AsyncIterator[Record], 

469 config: StreamConfig | None = None 

470 ) -> StreamResult: 

471 """Stream records into S3.""" 

472 self._check_connection() 

473 config = config or StreamConfig() 

474 result = StreamResult() 

475 start_time = time.time() 

476 quitting = False 

477 

478 batch = [] 

479 async for record in records: 

480 batch.append(record) 

481 

482 if len(batch) >= config.batch_size: 

483 # Write batch with graceful fallback 

484 async def batch_func(b): 

485 await self._write_batch(b) 

486 return [r.id for r in b] 

487 

488 continue_processing = await async_process_batch_with_fallback( 

489 batch, 

490 batch_func, 

491 self.create, 

492 result, 

493 config 

494 ) 

495 

496 if not continue_processing: 

497 quitting = True 

498 break 

499 

500 batch = [] 

501 

502 # Write remaining batch 

503 if batch and not quitting: 

504 async def batch_func(b): 

505 await self._write_batch(b) 

506 return [r.id for r in b] 

507 

508 await async_process_batch_with_fallback( 

509 batch, 

510 batch_func, 

511 self.create, 

512 result, 

513 config 

514 ) 

515 

516 result.duration = time.time() - start_time 

517 return result 

518 

519 async def _write_batch(self, records: list[Record]) -> None: 

520 """Write a batch of records to S3.""" 

521 if not records: 

522 return 

523 

524 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

525 # Write each record (S3 doesn't have native batch write) 

526 # We could potentially use multipart upload for very large batches 

527 tasks = [] 

528 for record in records: 

529 id = str(uuid.uuid4()) 

530 key = self._get_key(id) 

531 obj = self._record_to_s3_object(record) 

532 obj["metadata"]["id"] = id 

533 

534 task = s3.put_object( 

535 Bucket=self._pool_config.bucket, 

536 Key=key, 

537 Body=json.dumps(obj), 

538 ContentType="application/json" 

539 ) 

540 tasks.append(task) 

541 

542 # Execute all uploads concurrently 

543 await asyncio.gather(*tasks) 

544 

545 async def list_all(self) -> list[str]: 

546 """List all record IDs in the database.""" 

547 self._check_connection() 

548 

549 ids = [] 

550 async with self._session.client("s3", endpoint_url=self._pool_config.endpoint_url) as s3: 

551 paginator = s3.get_paginator('list_objects_v2') 

552 

553 params = { 

554 'Bucket': self._pool_config.bucket, 

555 } 

556 if self._pool_config.prefix: 

557 params['Prefix'] = self._pool_config.prefix 

558 

559 async for page in paginator.paginate(**params): 

560 if 'Contents' not in page: 

561 continue 

562 

563 for obj in page['Contents']: 

564 key = obj['Key'] 

565 if key.endswith('.json'): 

566 # Extract ID from key 

567 id = key.replace(self._pool_config.prefix + '/', '').replace('.json', '') 

568 ids.append(id) 

569 

570 return ids 

571 

572 async def vector_search( 

573 self, 

574 query_vector, 

575 vector_field: str = "embedding", 

576 k: int = 10, 

577 filter=None, 

578 metric=None, 

579 **kwargs 

580 ): 

581 """Perform vector similarity search using Python calculations. 

582  

583 WARNING: This implementation downloads all records from S3 to perform 

584 the search locally. This is inefficient for large datasets. Consider 

585 using a vector-enabled backend like PostgreSQL or Elasticsearch for 

586 production use with large datasets. 

587  

588 Future optimization: Override this method to use AWS OpenSearch or 

589 similar vector-enabled service when available. 

590 """ 

591 return await self.python_vector_search_async( 

592 query_vector=query_vector, 

593 vector_field=vector_field, 

594 k=k, 

595 filter=filter, 

596 metric=metric, 

597 **kwargs 

598 )