Coverage for src/dataknobs_data/backends/elasticsearch_async.py: 14%

316 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-31 15:06 -0600

1"""Native async Elasticsearch backend implementation with connection pooling.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6import time 

7from typing import TYPE_CHECKING, Any, cast 

8 

9from dataknobs_config import ConfigurableBase 

10 

11from ..database import AsyncDatabase 

12from ..pooling import ConnectionPoolManager 

13from ..pooling.elasticsearch import ( 

14 ElasticsearchPoolConfig, 

15 close_elasticsearch_client, 

16 create_async_elasticsearch_client, 

17 validate_elasticsearch_client, 

18) 

19from ..query import Operator, Query, SortOrder 

20from ..streaming import StreamConfig, StreamResult, async_process_batch_with_fallback 

21from ..vector.mixins import VectorOperationsMixin 

22from ..vector.types import DistanceMetric, VectorSearchResult 

23from .elasticsearch_mixins import ( 

24 ElasticsearchBaseConfig, 

25 ElasticsearchErrorHandler, 

26 ElasticsearchIndexManager, 

27 ElasticsearchQueryBuilder, 

28 ElasticsearchRecordSerializer, 

29 ElasticsearchVectorSupport, 

30) 

31 

32if TYPE_CHECKING: 

33 import numpy as np 

34 from collections.abc import AsyncIterator, Callable, Awaitable 

35 from ..records import Record 

36 

37logger = logging.getLogger(__name__) 

38 

39# Global pool manager for Elasticsearch clients 

40_client_manager = ConnectionPoolManager() 

41 

42 

43class AsyncElasticsearchDatabase( 

44 AsyncDatabase, 

45 ConfigurableBase, 

46 VectorOperationsMixin, 

47 ElasticsearchBaseConfig, 

48 ElasticsearchIndexManager, 

49 ElasticsearchVectorSupport, 

50 ElasticsearchErrorHandler, 

51 ElasticsearchRecordSerializer, 

52 ElasticsearchQueryBuilder, 

53): 

54 """Native async Elasticsearch database backend with connection pooling.""" 

55 

56 def __init__(self, config: dict[str, Any] | None = None): 

57 """Initialize async Elasticsearch database.""" 

58 super().__init__(config) 

59 

60 # Initialize vector support 

61 self.vector_fields = {} # field_name -> dimensions 

62 self.vector_enabled = False 

63 

64 config = config or {} 

65 self._pool_config = ElasticsearchPoolConfig.from_dict(config) 

66 self.index_name = self._pool_config.index 

67 self.refresh = config.get("refresh", True) 

68 self._client = None 

69 self._connected = False 

70 

71 @classmethod 

72 def from_config(cls, config: dict) -> AsyncElasticsearchDatabase: 

73 """Create from config dictionary.""" 

74 return cls(config) 

75 

76 async def connect(self) -> None: 

77 """Connect to the Elasticsearch database.""" 

78 if self._connected: 

79 return 

80 

81 # Get or create client for current event loop 

82 from ..pooling import BasePoolConfig 

83 self._client = await _client_manager.get_pool( 

84 self._pool_config, 

85 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_async_elasticsearch_client), 

86 validate_elasticsearch_client, 

87 close_elasticsearch_client 

88 ) 

89 

90 # Ensure index exists 

91 await self._ensure_index() 

92 self._connected = True 

93 

94 async def close(self) -> None: 

95 """Close the database connection.""" 

96 if self._connected: 

97 # Note: The client is managed by the pool manager, so we don't close it here 

98 # Just mark as disconnected 

99 self._client = None 

100 self._connected = False 

101 

102 def _initialize(self) -> None: 

103 """Initialize is handled in connect.""" 

104 pass 

105 

106 async def _ensure_index(self) -> None: 

107 """Ensure the index exists with proper mappings.""" 

108 if not self._client: 

109 raise RuntimeError("Database not connected. Call connect() first.") 

110 

111 # Check if index exists 

112 if not await self._client.indices.exists(index=self.index_name): # type: ignore[unreachable] 

113 # Get mappings with vector field support 

114 mappings = self.get_index_mappings(self.vector_fields) 

115 

116 # Get settings optimized for KNN if we have vector fields 

117 settings = self.get_knn_index_settings() if self.vector_fields else { 

118 "number_of_shards": 1, 

119 "number_of_replicas": 0, 

120 } 

121 

122 await self._client.indices.create( 

123 index=self.index_name, 

124 mappings=mappings, 

125 settings=settings 

126 ) 

127 

128 if self.vector_fields: 

129 self.vector_enabled = True 

130 logger.info(f"Created index '{self.index_name}' with vector support") 

131 

132 def _check_connection(self) -> None: 

133 """Check if database is connected.""" 

134 if not self._connected or not self._client: 

135 raise RuntimeError("Database not connected. Call connect() first.") 

136 

137 def _record_to_doc(self, record: Record) -> dict[str, Any]: 

138 """Convert a Record to an Elasticsearch document.""" 

139 # Update vector tracking if needed 

140 if self._has_vector_fields(record): 

141 self._update_vector_tracking(record) 

142 

143 # Add vector field metadata to record metadata 

144 if "vector_fields" not in record.metadata: 

145 record.metadata["vector_fields"] = {} 

146 

147 for field_name in self.vector_fields: 

148 if field_name in record.fields: 

149 field = record.fields[field_name] 

150 if hasattr(field, "source_field"): 

151 record.metadata["vector_fields"][field_name] = { 

152 "type": "vector", 

153 "dimensions": self.vector_fields[field_name], 

154 "source_field": field.source_field, 

155 "model": getattr(field, "model_name", None), 

156 "model_version": getattr(field, "model_version", None), 

157 } 

158 

159 return self._record_to_document(record) 

160 

161 def _doc_to_record(self, doc: dict[str, Any]) -> Record: 

162 """Convert an Elasticsearch document to a Record.""" 

163 doc_id = doc.get("_id") 

164 record = self._document_to_record(doc, doc_id) 

165 

166 # Add score if present 

167 if "_score" in doc: 

168 record.metadata["_score"] = doc["_score"] 

169 

170 return record 

171 

172 async def create(self, record: Record) -> str: 

173 """Create a new record.""" 

174 self._check_connection() 

175 doc = self._record_to_doc(record) 

176 

177 # Create document with explicit ID if record has one 

178 kwargs = { 

179 "index": self.index_name, 

180 "document": doc, 

181 "refresh": self.refresh 

182 } 

183 if record.id: 

184 kwargs["id"] = record.id 

185 

186 response = await self._client.index(**kwargs) 

187 

188 return response["_id"] 

189 

190 async def create_batch(self, records: list[Record]) -> list[str]: 

191 """Create multiple records in batch.""" 

192 self._check_connection() 

193 

194 ids = [] 

195 operations = [] 

196 

197 for record in records: 

198 doc = self._record_to_doc(record) 

199 operations.append({"index": {"_index": self.index_name}}) 

200 operations.append(doc) 

201 

202 if operations: 

203 response = await self._client.bulk( 

204 operations=operations, 

205 refresh=self.refresh 

206 ) 

207 

208 # Extract IDs from response 

209 for item in response.get("items", []): 

210 if "index" in item and "_id" in item["index"]: 

211 ids.append(item["index"]["_id"]) 

212 

213 return ids 

214 

215 async def read(self, id: str) -> Record | None: 

216 """Read a record by ID.""" 

217 self._check_connection() 

218 

219 try: 

220 response = await self._client.get( 

221 index=self.index_name, 

222 id=id 

223 ) 

224 return self._doc_to_record(response) 

225 except Exception as e: 

226 # Log the error for debugging 

227 logger.debug(f"Error reading document {id}: {e}") 

228 return None 

229 

230 async def update(self, id: str, record: Record) -> bool: 

231 """Update an existing record.""" 

232 self._check_connection() 

233 doc = self._record_to_doc(record) 

234 

235 try: 

236 await self._client.update( 

237 index=self.index_name, 

238 id=id, 

239 doc=doc, 

240 refresh=self.refresh 

241 ) 

242 return True 

243 except Exception: 

244 return False 

245 

246 async def delete(self, id: str) -> bool: 

247 """Delete a record by ID.""" 

248 self._check_connection() 

249 

250 try: 

251 await self._client.delete( 

252 index=self.index_name, 

253 id=id, 

254 refresh=self.refresh 

255 ) 

256 return True 

257 except Exception: 

258 return False 

259 

260 async def exists(self, id: str) -> bool: 

261 """Check if a record exists.""" 

262 self._check_connection() 

263 

264 return await self._client.exists( 

265 index=self.index_name, 

266 id=id 

267 ) 

268 

269 async def upsert(self, id: str, record: Record) -> str: 

270 """Update or insert a record with a specific ID.""" 

271 self._check_connection() 

272 doc = self._record_to_doc(record) 

273 

274 await self._client.index( 

275 index=self.index_name, 

276 id=id, 

277 document=doc, 

278 refresh=self.refresh 

279 ) 

280 

281 return id 

282 

283 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

284 """Update multiple records efficiently using the bulk API. 

285  

286 Uses AsyncElasticsearch's bulk API for efficient batch updates. 

287  

288 Args: 

289 updates: List of (id, record) tuples to update 

290  

291 Returns: 

292 List of success flags for each update 

293 """ 

294 if not updates: 

295 return [] 

296 

297 self._check_connection() 

298 

299 # Build bulk operations for AsyncElasticsearch 

300 operations: list[dict[str, Any]] = [] 

301 for record_id, record in updates: 

302 # Add update operation 

303 operations.append({ 

304 "update": { 

305 "_index": self.index_name, 

306 "_id": record_id 

307 } 

308 }) 

309 # Add document data 

310 doc = self._record_to_doc(record) 

311 operations.append({ 

312 "doc": doc, 

313 "doc_as_upsert": False # Don't create if doesn't exist 

314 }) 

315 

316 try: 

317 # Execute bulk update using AsyncElasticsearch 

318 response = await self._client.bulk( 

319 operations=operations, 

320 refresh=self.refresh 

321 ) 

322 

323 # Process the response to determine which updates succeeded 

324 results = [] 

325 if response.get("items"): 

326 for item in response["items"]: 

327 if "update" in item: 

328 update_result = item["update"] 

329 # Check if update was successful (status 200) or not found (404) 

330 results.append(update_result.get("status") == 200) 

331 else: 

332 results.append(False) 

333 else: 

334 # If no items in response, mark all as failed 

335 results = [False] * len(updates) 

336 

337 return results 

338 

339 except Exception as e: 

340 # If bulk operation fails, mark all as failed 

341 import logging 

342 logging.error(f"Bulk update failed: {e}") 

343 return [False] * len(updates) 

344 

345 async def search(self, query: Query) -> list[Record]: 

346 """Search for records matching the query.""" 

347 self._check_connection() 

348 

349 # Build Elasticsearch query 

350 es_query = {"bool": {"must": []}} 

351 

352 for filter in query.filters: 

353 field_path = f"data.{filter.field}" 

354 

355 if filter.operator == Operator.EQ: 

356 # For string values, use keyword field for exact matching 

357 if isinstance(filter.value, str): 

358 field_path = f"{field_path}.keyword" 

359 es_query["bool"]["must"].append({"term": {field_path: filter.value}}) 

360 elif filter.operator == Operator.NEQ: 

361 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", []) 

362 es_query["bool"]["must_not"].append({"term": {field_path: filter.value}}) 

363 elif filter.operator == Operator.GT: 

364 es_query["bool"]["must"].append({"range": {field_path: {"gt": filter.value}}}) 

365 elif filter.operator == Operator.LT: 

366 es_query["bool"]["must"].append({"range": {field_path: {"lt": filter.value}}}) 

367 elif filter.operator == Operator.GTE: 

368 es_query["bool"]["must"].append({"range": {field_path: {"gte": filter.value}}}) 

369 elif filter.operator == Operator.LTE: 

370 es_query["bool"]["must"].append({"range": {field_path: {"lte": filter.value}}}) 

371 elif filter.operator == Operator.LIKE: 

372 es_query["bool"]["must"].append({"wildcard": {field_path: f"*{filter.value}*"}}) 

373 elif filter.operator == Operator.IN: 

374 es_query["bool"]["must"].append({"terms": {field_path: filter.value}}) 

375 elif filter.operator == Operator.NOT_IN: 

376 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", []) 

377 es_query["bool"]["must_not"].append({"terms": {field_path: filter.value}}) 

378 elif filter.operator == Operator.BETWEEN: 

379 # Use Elasticsearch's native range query for efficient BETWEEN 

380 if isinstance(filter.value, (list, tuple)) and len(filter.value) == 2: 

381 lower, upper = filter.value 

382 es_query["bool"]["must"].append({ 

383 "range": { 

384 field_path: { 

385 "gte": lower, 

386 "lte": upper 

387 } 

388 } 

389 }) 

390 elif filter.operator == Operator.NOT_BETWEEN: 

391 # NOT BETWEEN using must_not with range 

392 if isinstance(filter.value, (list, tuple)) and len(filter.value) == 2: 

393 lower, upper = filter.value 

394 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", []) 

395 es_query["bool"]["must_not"].append({ 

396 "range": { 

397 field_path: { 

398 "gte": lower, 

399 "lte": upper 

400 } 

401 } 

402 }) 

403 

404 # If no filters, use match_all 

405 if not es_query["bool"]["must"] and "must_not" not in es_query["bool"]: 

406 es_query = {"match_all": {}} 

407 

408 # Build sort 

409 sort = [] 

410 if query.sort_specs: 

411 for sort_spec in query.sort_specs: 

412 direction = "desc" if sort_spec.order == SortOrder.DESC else "asc" 

413 sort.append({f"data.{sort_spec.field}": {"order": direction}}) 

414 

415 # Build request body 

416 body = {"query": es_query} 

417 if sort: 

418 body["sort"] = sort 

419 

420 # Add size and from for pagination 

421 size = query.limit_value if query.limit_value else 10000 

422 from_param = query.offset_value if query.offset_value else 0 

423 

424 # Execute search 

425 response = await self._client.search( 

426 index=self.index_name, 

427 query=es_query, 

428 sort=sort if sort else None, 

429 size=size, 

430 from_=from_param 

431 ) 

432 

433 # Convert to records 

434 records = [] 

435 for hit in response["hits"]["hits"]: 

436 record = self._doc_to_record(hit) 

437 

438 # Apply field projection if specified 

439 if query.fields: 

440 record = record.project(query.fields) 

441 

442 records.append(record) 

443 

444 return records 

445 

446 async def _count_all(self) -> int: 

447 """Count all records in the database.""" 

448 self._check_connection() 

449 

450 response = await self._client.count(index=self.index_name) 

451 return response["count"] 

452 

453 async def clear(self) -> int: 

454 """Clear all records from the database.""" 

455 self._check_connection() 

456 

457 # Get count before deletion 

458 count = await self._count_all() 

459 

460 # Delete by query - delete all documents 

461 response = await self._client.delete_by_query( 

462 index=self.index_name, 

463 query={"match_all": {}}, 

464 refresh=self.refresh 

465 ) 

466 

467 return response.get("deleted", count) 

468 

469 async def stream_read( 

470 self, 

471 query: Query | None = None, 

472 config: StreamConfig | None = None 

473 ) -> AsyncIterator[Record]: 

474 """Stream records from Elasticsearch using scroll API.""" 

475 self._check_connection() 

476 config = config or StreamConfig() 

477 

478 # Build query 

479 es_query = {"match_all": {}} 

480 if query and query.filters: 

481 es_query = {"bool": {"must": []}} 

482 for filter in query.filters: 

483 field_path = f"data.{filter.field}" 

484 if filter.operator == Operator.EQ: 

485 es_query["bool"]["must"].append({"term": {field_path: filter.value}}) 

486 

487 # Initial search with scroll 

488 response = await self._client.search( 

489 index=self.index_name, 

490 query=es_query, 

491 scroll="2m", 

492 size=config.batch_size 

493 ) 

494 

495 scroll_id = response["_scroll_id"] 

496 hits = response["hits"]["hits"] 

497 

498 try: 

499 while hits: 

500 for hit in hits: 

501 record = self._doc_to_record(hit) 

502 if query and query.fields: 

503 record = record.project(query.fields) 

504 yield record 

505 

506 # Get next batch 

507 response = await self._client.scroll( 

508 scroll_id=scroll_id, 

509 scroll="2m" 

510 ) 

511 hits = response["hits"]["hits"] 

512 finally: 

513 # Clear scroll 

514 await self._client.clear_scroll(scroll_id=scroll_id) 

515 

516 async def stream_write( 

517 self, 

518 records: AsyncIterator[Record], 

519 config: StreamConfig | None = None 

520 ) -> StreamResult: 

521 """Stream records into Elasticsearch using bulk API.""" 

522 self._check_connection() 

523 config = config or StreamConfig() 

524 result = StreamResult() 

525 start_time = time.time() 

526 quitting = False 

527 

528 batch = [] 

529 async for record in records: 

530 batch.append(record) 

531 

532 if len(batch) >= config.batch_size: 

533 # Write batch with graceful fallback 

534 async def batch_func(b): 

535 await self._write_batch(b) 

536 return [r.id for r in b] 

537 

538 continue_processing = await async_process_batch_with_fallback( 

539 batch, 

540 batch_func, 

541 self.create, 

542 result, 

543 config 

544 ) 

545 

546 if not continue_processing: 

547 quitting = True 

548 break 

549 

550 batch = [] 

551 

552 # Write remaining batch 

553 if batch and not quitting: 

554 async def batch_func(b): 

555 await self._write_batch(b) 

556 return [r.id for r in b] 

557 

558 await async_process_batch_with_fallback( 

559 batch, 

560 batch_func, 

561 self.create, 

562 result, 

563 config 

564 ) 

565 

566 result.duration = time.time() - start_time 

567 return result 

568 

569 async def _write_batch(self, records: list[Record]) -> None: 

570 """Write a batch of records using bulk API.""" 

571 if not records: 

572 return 

573 

574 # Build bulk operations 

575 operations = [] 

576 for record in records: 

577 doc = self._record_to_doc(record) 

578 operations.append({"index": {"_index": self.index_name}}) 

579 operations.append(doc) 

580 

581 # Execute bulk 

582 await self._client.bulk( 

583 operations=operations, 

584 refresh=self.refresh 

585 ) 

586 

587 async def vector_search( 

588 self, 

589 query_vector: np.ndarray | list[float], 

590 vector_field: str = "embedding", 

591 k: int = 10, 

592 metric: DistanceMetric = DistanceMetric.COSINE, 

593 filter: Query | None = None, 

594 include_source: bool = True, 

595 score_threshold: float | None = None, 

596 ) -> list[VectorSearchResult]: 

597 """Search for similar vectors using Elasticsearch KNN. 

598  

599 Args: 

600 query_vector: The vector to search for 

601 vector_field: Name of the vector field to search 

602 k: Number of results to return 

603 metric: Distance metric to use 

604 filter: Optional query filter to apply before vector search 

605 include_source: Whether to include source document in results 

606 score_threshold: Optional minimum similarity score 

607  

608 Returns: 

609 List of search results ordered by similarity 

610 """ 

611 self._check_connection() 

612 

613 # Import vector utilities 

614 from ..vector.elasticsearch_utils import ( 

615 build_knn_query, 

616 ) 

617 

618 # Build filter query if provided 

619 filter_query = self._build_filter_query(filter) if filter else None 

620 

621 # Build KNN query 

622 query = build_knn_query( 

623 query_vector=query_vector, 

624 field_name=vector_field, 

625 k=k, 

626 filter_query=filter_query, 

627 ) 

628 

629 # Execute search 

630 try: 

631 response = await self._client.search( 

632 index=self.index_name, 

633 **query, # Unpack the query dict directly 

634 size=k, 

635 _source=include_source, 

636 ) 

637 except Exception as e: 

638 self._handle_elasticsearch_error(e, "vector search") 

639 return [] 

640 

641 # Process results 

642 results = [] 

643 for hit in response.get("hits", {}).get("hits", []): 

644 score = hit.get("_score", 0.0) 

645 

646 # Apply score threshold if specified 

647 if score_threshold is not None and score < score_threshold: 

648 continue 

649 

650 # Convert document to record if source included 

651 record = None 

652 if include_source: 

653 record = self._doc_to_record(hit) 

654 

655 # Set the storage ID on the record if we have one 

656 if record and not record.has_storage_id(): 

657 record.storage_id = hit["_id"] 

658 

659 # Skip if no record (shouldn't happen if include_source is True) 

660 if record is None: 

661 continue 

662 

663 results.append(VectorSearchResult( 

664 record=record, 

665 score=score, 

666 vector_field=vector_field, 

667 metadata={ 

668 "index": self.index_name, 

669 "metric": metric.value, 

670 "doc_id": hit["_id"], 

671 }, 

672 )) 

673 

674 return results 

675 

676 async def bulk_embed_and_store( 

677 self, 

678 records: list[Record], 

679 text_field: str | list[str], 

680 vector_field: str = "embedding", 

681 embedding_fn: Any | None = None, 

682 batch_size: int = 100, 

683 model_name: str | None = None, 

684 model_version: str | None = None, 

685 ) -> list[str]: 

686 """Embed text fields and store vectors with records. 

687  

688 Args: 

689 records: Records to process 

690 text_field: Field name(s) containing text to embed 

691 vector_field: Field name to store vectors in 

692 embedding_fn: Function to generate embeddings 

693 batch_size: Number of records to process at once 

694 model_name: Name of the embedding model 

695 model_version: Version of the embedding model 

696  

697 Returns: 

698 List of record IDs that were processed 

699 """ 

700 # This is a stub implementation 

701 # Full implementation would require an actual embedding function 

702 logger.warning("bulk_embed_and_store is not fully implemented for Elasticsearch") 

703 return [] 

704 

705 async def create_vector_index( 

706 self, 

707 vector_field: str = "embedding", 

708 dimensions: int | None = None, 

709 metric: DistanceMetric = DistanceMetric.COSINE, 

710 index_type: str = "auto", 

711 **kwargs: Any, 

712 ) -> bool: 

713 """Create or update index mapping for vector field. 

714  

715 Args: 

716 vector_field: Name of the vector field to index 

717 dimensions: Number of dimensions 

718 metric: Distance metric for the index 

719 index_type: Type of index (ignored for ES, always uses HNSW) 

720 **kwargs: Additional index parameters 

721  

722 Returns: 

723 True if index was created/updated successfully 

724 """ 

725 self._check_connection() 

726 

727 if not dimensions: 

728 if vector_field not in self.vector_fields: 

729 raise ValueError(f"Unknown dimensions for field '{vector_field}'") 

730 dimensions = self.vector_fields[vector_field] 

731 

732 # Import vector utilities 

733 from ..vector.elasticsearch_utils import ( 

734 get_similarity_for_metric, 

735 get_vector_mapping, 

736 ) 

737 

738 # Get similarity function for metric 

739 similarity = get_similarity_for_metric(metric) 

740 

741 # Build mapping for the vector field 

742 mapping = get_vector_mapping(dimensions, similarity) 

743 

744 # Update index mapping 

745 try: 

746 await self._client.indices.put_mapping( 

747 index=self.index_name, 

748 properties={ 

749 f"data.{vector_field}": mapping 

750 } 

751 ) 

752 

753 # Track the vector field 

754 self.vector_fields[vector_field] = dimensions 

755 self.vector_enabled = True 

756 

757 logger.info(f"Created vector mapping for field '{vector_field}' with {dimensions} dimensions") 

758 return True 

759 

760 except Exception as e: 

761 self._handle_elasticsearch_error(e, "create vector index") 

762 return False