Coverage for src/dataknobs_data/backends/elasticsearch_async.py: 13%

326 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:23 -0700

1"""Native async Elasticsearch backend implementation with connection pooling.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6import time 

7from typing import TYPE_CHECKING, Any, cast 

8 

9from dataknobs_config import ConfigurableBase 

10 

11from ..database import AsyncDatabase 

12from ..pooling import ConnectionPoolManager 

13from ..pooling.elasticsearch import ( 

14 ElasticsearchPoolConfig, 

15 close_elasticsearch_client, 

16 create_async_elasticsearch_client, 

17 validate_elasticsearch_client, 

18) 

19from ..query import Operator, Query, SortOrder 

20from ..streaming import StreamConfig, StreamResult, async_process_batch_with_fallback 

21from ..vector.mixins import VectorOperationsMixin 

22from ..vector.types import DistanceMetric, VectorSearchResult 

23from .elasticsearch_mixins import ( 

24 ElasticsearchBaseConfig, 

25 ElasticsearchErrorHandler, 

26 ElasticsearchIndexManager, 

27 ElasticsearchQueryBuilder, 

28 ElasticsearchRecordSerializer, 

29 ElasticsearchVectorSupport, 

30) 

31 

32if TYPE_CHECKING: 

33 import numpy as np 

34 from collections.abc import AsyncIterator, Callable, Awaitable 

35 from ..records import Record 

36 

37logger = logging.getLogger(__name__) 

38 

39# Global pool manager for Elasticsearch clients 

40_client_manager = ConnectionPoolManager() 

41 

42 

43class AsyncElasticsearchDatabase( 

44 AsyncDatabase, 

45 ConfigurableBase, 

46 VectorOperationsMixin, 

47 ElasticsearchBaseConfig, 

48 ElasticsearchIndexManager, 

49 ElasticsearchVectorSupport, 

50 ElasticsearchErrorHandler, 

51 ElasticsearchRecordSerializer, 

52 ElasticsearchQueryBuilder, 

53): 

54 """Native async Elasticsearch database backend with connection pooling.""" 

55 

56 def __init__(self, config: dict[str, Any] | None = None): 

57 """Initialize async Elasticsearch database.""" 

58 super().__init__(config) 

59 

60 # Initialize vector support 

61 self.vector_fields = {} # field_name -> dimensions 

62 self.vector_enabled = False 

63 

64 config = config or {} 

65 self._pool_config = ElasticsearchPoolConfig.from_dict(config) 

66 self.index_name = self._pool_config.index 

67 self.refresh = config.get("refresh", True) 

68 self._client = None 

69 self._connected = False 

70 

71 @classmethod 

72 def from_config(cls, config: dict) -> AsyncElasticsearchDatabase: 

73 """Create from config dictionary.""" 

74 return cls(config) 

75 

76 async def connect(self) -> None: 

77 """Connect to the Elasticsearch database.""" 

78 if self._connected: 

79 return 

80 

81 # Get or create client for current event loop 

82 from ..pooling import BasePoolConfig 

83 self._client = await _client_manager.get_pool( 

84 self._pool_config, 

85 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_async_elasticsearch_client), 

86 validate_elasticsearch_client, 

87 close_elasticsearch_client 

88 ) 

89 

90 # Ensure index exists 

91 await self._ensure_index() 

92 self._connected = True 

93 

94 async def close(self) -> None: 

95 """Close the database connection.""" 

96 if self._connected: 

97 # Note: The client is managed by the pool manager, so we don't close it here 

98 # Just mark as disconnected 

99 self._client = None 

100 self._connected = False 

101 

102 def _initialize(self) -> None: 

103 """Initialize is handled in connect.""" 

104 pass 

105 

106 async def _ensure_index(self) -> None: 

107 """Ensure the index exists with proper mappings.""" 

108 if not self._client: 

109 raise RuntimeError("Database not connected. Call connect() first.") 

110 

111 # Check if index exists 

112 if not await self._client.indices.exists(index=self.index_name): # type: ignore[unreachable] 

113 # Get mappings with vector field support 

114 mappings = self.get_index_mappings(self.vector_fields) 

115 

116 # Get settings optimized for KNN if we have vector fields 

117 settings = self.get_knn_index_settings() if self.vector_fields else { 

118 "number_of_shards": 1, 

119 "number_of_replicas": 0, 

120 } 

121 

122 await self._client.indices.create( 

123 index=self.index_name, 

124 mappings=mappings, 

125 settings=settings 

126 ) 

127 

128 if self.vector_fields: 

129 self.vector_enabled = True 

130 logger.info(f"Created index '{self.index_name}' with vector support") 

131 

132 def _check_connection(self) -> None: 

133 """Check if database is connected.""" 

134 if not self._connected or not self._client: 

135 raise RuntimeError("Database not connected. Call connect() first.") 

136 

137 def _record_to_doc(self, record: Record) -> dict[str, Any]: 

138 """Convert a Record to an Elasticsearch document.""" 

139 # Update vector tracking if needed 

140 if self._has_vector_fields(record): 

141 self._update_vector_tracking(record) 

142 

143 # Add vector field metadata to record metadata 

144 if "vector_fields" not in record.metadata: 

145 record.metadata["vector_fields"] = {} 

146 

147 for field_name in self.vector_fields: 

148 if field_name in record.fields: 

149 field = record.fields[field_name] 

150 if hasattr(field, "source_field"): 

151 record.metadata["vector_fields"][field_name] = { 

152 "type": "vector", 

153 "dimensions": self.vector_fields[field_name], 

154 "source_field": field.source_field, 

155 "model": getattr(field, "model_name", None), 

156 "model_version": getattr(field, "model_version", None), 

157 } 

158 

159 return self._record_to_document(record) 

160 

161 def _doc_to_record(self, doc: dict[str, Any]) -> Record: 

162 """Convert an Elasticsearch document to a Record.""" 

163 doc_id = doc.get("_id") 

164 record = self._document_to_record(doc, doc_id) 

165 

166 # Add score if present 

167 if "_score" in doc: 

168 record.metadata["_score"] = doc["_score"] 

169 

170 return record 

171 

172 async def create(self, record: Record) -> str: 

173 """Create a new record.""" 

174 self._check_connection() 

175 doc = self._record_to_doc(record) 

176 

177 # Create document with explicit ID if record has one 

178 kwargs = { 

179 "index": self.index_name, 

180 "document": doc, 

181 "refresh": self.refresh 

182 } 

183 if record.id: 

184 kwargs["id"] = record.id 

185 

186 response = await self._client.index(**kwargs) 

187 

188 return response["_id"] 

189 

190 async def create_batch(self, records: list[Record]) -> list[str]: 

191 """Create multiple records in batch.""" 

192 self._check_connection() 

193 

194 ids = [] 

195 operations = [] 

196 

197 for record in records: 

198 doc = self._record_to_doc(record) 

199 operations.append({"index": {"_index": self.index_name}}) 

200 operations.append(doc) 

201 

202 if operations: 

203 response = await self._client.bulk( 

204 operations=operations, 

205 refresh=self.refresh 

206 ) 

207 

208 # Extract IDs from response 

209 for item in response.get("items", []): 

210 if "index" in item and "_id" in item["index"]: 

211 ids.append(item["index"]["_id"]) 

212 

213 return ids 

214 

215 async def read(self, id: str) -> Record | None: 

216 """Read a record by ID.""" 

217 self._check_connection() 

218 

219 try: 

220 response = await self._client.get( 

221 index=self.index_name, 

222 id=id 

223 ) 

224 return self._doc_to_record(response) 

225 except Exception as e: 

226 # Log the error for debugging 

227 logger.debug(f"Error reading document {id}: {e}") 

228 return None 

229 

230 async def update(self, id: str, record: Record) -> bool: 

231 """Update an existing record.""" 

232 self._check_connection() 

233 doc = self._record_to_doc(record) 

234 

235 try: 

236 await self._client.update( 

237 index=self.index_name, 

238 id=id, 

239 doc=doc, 

240 refresh=self.refresh 

241 ) 

242 return True 

243 except Exception: 

244 return False 

245 

246 async def delete(self, id: str) -> bool: 

247 """Delete a record by ID.""" 

248 self._check_connection() 

249 

250 try: 

251 await self._client.delete( 

252 index=self.index_name, 

253 id=id, 

254 refresh=self.refresh 

255 ) 

256 return True 

257 except Exception: 

258 return False 

259 

260 async def exists(self, id: str) -> bool: 

261 """Check if a record exists.""" 

262 self._check_connection() 

263 

264 return await self._client.exists( 

265 index=self.index_name, 

266 id=id 

267 ) 

268 

269 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str: 

270 """Update or insert a record. 

271  

272 Can be called as: 

273 - upsert(id, record) - explicit ID and record 

274 - upsert(record) - extract ID from record using Record's built-in logic 

275 """ 

276 self._check_connection() 

277 

278 # Determine ID and record based on arguments 

279 if isinstance(id_or_record, str): 

280 id = id_or_record 

281 if record is None: 

282 raise ValueError("Record required when ID is provided") 

283 else: 

284 record = id_or_record 

285 id = record.id 

286 if id is None: 

287 import uuid # type: ignore[unreachable] 

288 id = str(uuid.uuid4()) 

289 record.storage_id = id 

290 

291 doc = self._record_to_doc(record) 

292 

293 await self._client.index( 

294 index=self.index_name, 

295 id=id, 

296 document=doc, 

297 refresh=self.refresh 

298 ) 

299 

300 return id 

301 

302 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

303 """Update multiple records efficiently using the bulk API. 

304  

305 Uses AsyncElasticsearch's bulk API for efficient batch updates. 

306  

307 Args: 

308 updates: List of (id, record) tuples to update 

309  

310 Returns: 

311 List of success flags for each update 

312 """ 

313 if not updates: 

314 return [] 

315 

316 self._check_connection() 

317 

318 # Build bulk operations for AsyncElasticsearch 

319 operations: list[dict[str, Any]] = [] 

320 for record_id, record in updates: 

321 # Add update operation 

322 operations.append({ 

323 "update": { 

324 "_index": self.index_name, 

325 "_id": record_id 

326 } 

327 }) 

328 # Add document data 

329 doc = self._record_to_doc(record) 

330 operations.append({ 

331 "doc": doc, 

332 "doc_as_upsert": False # Don't create if doesn't exist 

333 }) 

334 

335 try: 

336 # Execute bulk update using AsyncElasticsearch 

337 response = await self._client.bulk( 

338 operations=operations, 

339 refresh=self.refresh 

340 ) 

341 

342 # Process the response to determine which updates succeeded 

343 results = [] 

344 if response.get("items"): 

345 for item in response["items"]: 

346 if "update" in item: 

347 update_result = item["update"] 

348 # Check if update was successful (status 200) or not found (404) 

349 results.append(update_result.get("status") == 200) 

350 else: 

351 results.append(False) 

352 else: 

353 # If no items in response, mark all as failed 

354 results = [False] * len(updates) 

355 

356 return results 

357 

358 except Exception as e: 

359 # If bulk operation fails, mark all as failed 

360 import logging 

361 logging.error(f"Bulk update failed: {e}") 

362 return [False] * len(updates) 

363 

364 async def search(self, query: Query) -> list[Record]: 

365 """Search for records matching the query.""" 

366 self._check_connection() 

367 

368 # Build Elasticsearch query 

369 es_query = {"bool": {"must": []}} 

370 

371 for filter in query.filters: 

372 field_path = f"data.{filter.field}" 

373 

374 if filter.operator == Operator.EQ: 

375 # For string values, use keyword field for exact matching 

376 if isinstance(filter.value, str): 

377 field_path = f"{field_path}.keyword" 

378 es_query["bool"]["must"].append({"term": {field_path: filter.value}}) 

379 elif filter.operator == Operator.NEQ: 

380 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", []) 

381 es_query["bool"]["must_not"].append({"term": {field_path: filter.value}}) 

382 elif filter.operator == Operator.GT: 

383 es_query["bool"]["must"].append({"range": {field_path: {"gt": filter.value}}}) 

384 elif filter.operator == Operator.LT: 

385 es_query["bool"]["must"].append({"range": {field_path: {"lt": filter.value}}}) 

386 elif filter.operator == Operator.GTE: 

387 es_query["bool"]["must"].append({"range": {field_path: {"gte": filter.value}}}) 

388 elif filter.operator == Operator.LTE: 

389 es_query["bool"]["must"].append({"range": {field_path: {"lte": filter.value}}}) 

390 elif filter.operator == Operator.LIKE: 

391 es_query["bool"]["must"].append({"wildcard": {field_path: f"*{filter.value}*"}}) 

392 elif filter.operator == Operator.IN: 

393 es_query["bool"]["must"].append({"terms": {field_path: filter.value}}) 

394 elif filter.operator == Operator.NOT_IN: 

395 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", []) 

396 es_query["bool"]["must_not"].append({"terms": {field_path: filter.value}}) 

397 elif filter.operator == Operator.BETWEEN: 

398 # Use Elasticsearch's native range query for efficient BETWEEN 

399 if isinstance(filter.value, (list, tuple)) and len(filter.value) == 2: 

400 lower, upper = filter.value 

401 es_query["bool"]["must"].append({ 

402 "range": { 

403 field_path: { 

404 "gte": lower, 

405 "lte": upper 

406 } 

407 } 

408 }) 

409 elif filter.operator == Operator.NOT_BETWEEN: 

410 # NOT BETWEEN using must_not with range 

411 if isinstance(filter.value, (list, tuple)) and len(filter.value) == 2: 

412 lower, upper = filter.value 

413 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", []) 

414 es_query["bool"]["must_not"].append({ 

415 "range": { 

416 field_path: { 

417 "gte": lower, 

418 "lte": upper 

419 } 

420 } 

421 }) 

422 

423 # If no filters, use match_all 

424 if not es_query["bool"]["must"] and "must_not" not in es_query["bool"]: 

425 es_query = {"match_all": {}} 

426 

427 # Build sort 

428 sort = [] 

429 if query.sort_specs: 

430 for sort_spec in query.sort_specs: 

431 direction = "desc" if sort_spec.order == SortOrder.DESC else "asc" 

432 sort.append({f"data.{sort_spec.field}": {"order": direction}}) 

433 

434 # Build request body 

435 body = {"query": es_query} 

436 if sort: 

437 body["sort"] = sort 

438 

439 # Add size and from for pagination 

440 size = query.limit_value if query.limit_value else 10000 

441 from_param = query.offset_value if query.offset_value else 0 

442 

443 # Execute search 

444 response = await self._client.search( 

445 index=self.index_name, 

446 query=es_query, 

447 sort=sort if sort else None, 

448 size=size, 

449 from_=from_param 

450 ) 

451 

452 # Convert to records 

453 records = [] 

454 for hit in response["hits"]["hits"]: 

455 record = self._doc_to_record(hit) 

456 

457 # Apply field projection if specified 

458 if query.fields: 

459 record = record.project(query.fields) 

460 

461 records.append(record) 

462 

463 return records 

464 

465 async def _count_all(self) -> int: 

466 """Count all records in the database.""" 

467 self._check_connection() 

468 

469 response = await self._client.count(index=self.index_name) 

470 return response["count"] 

471 

472 async def clear(self) -> int: 

473 """Clear all records from the database.""" 

474 self._check_connection() 

475 

476 # Get count before deletion 

477 count = await self._count_all() 

478 

479 # Delete by query - delete all documents 

480 response = await self._client.delete_by_query( 

481 index=self.index_name, 

482 query={"match_all": {}}, 

483 refresh=self.refresh 

484 ) 

485 

486 return response.get("deleted", count) 

487 

488 async def stream_read( 

489 self, 

490 query: Query | None = None, 

491 config: StreamConfig | None = None 

492 ) -> AsyncIterator[Record]: 

493 """Stream records from Elasticsearch using scroll API.""" 

494 self._check_connection() 

495 config = config or StreamConfig() 

496 

497 # Build query 

498 es_query = {"match_all": {}} 

499 if query and query.filters: 

500 es_query = {"bool": {"must": []}} 

501 for filter in query.filters: 

502 field_path = f"data.{filter.field}" 

503 if filter.operator == Operator.EQ: 

504 es_query["bool"]["must"].append({"term": {field_path: filter.value}}) 

505 

506 # Initial search with scroll 

507 response = await self._client.search( 

508 index=self.index_name, 

509 query=es_query, 

510 scroll="2m", 

511 size=config.batch_size 

512 ) 

513 

514 scroll_id = response["_scroll_id"] 

515 hits = response["hits"]["hits"] 

516 

517 try: 

518 while hits: 

519 for hit in hits: 

520 record = self._doc_to_record(hit) 

521 if query and query.fields: 

522 record = record.project(query.fields) 

523 yield record 

524 

525 # Get next batch 

526 response = await self._client.scroll( 

527 scroll_id=scroll_id, 

528 scroll="2m" 

529 ) 

530 hits = response["hits"]["hits"] 

531 finally: 

532 # Clear scroll 

533 await self._client.clear_scroll(scroll_id=scroll_id) 

534 

535 async def stream_write( 

536 self, 

537 records: AsyncIterator[Record], 

538 config: StreamConfig | None = None 

539 ) -> StreamResult: 

540 """Stream records into Elasticsearch using bulk API.""" 

541 self._check_connection() 

542 config = config or StreamConfig() 

543 result = StreamResult() 

544 start_time = time.time() 

545 quitting = False 

546 

547 batch = [] 

548 async for record in records: 

549 batch.append(record) 

550 

551 if len(batch) >= config.batch_size: 

552 # Write batch with graceful fallback 

553 async def batch_func(b): 

554 await self._write_batch(b) 

555 return [r.id for r in b] 

556 

557 continue_processing = await async_process_batch_with_fallback( 

558 batch, 

559 batch_func, 

560 self.create, 

561 result, 

562 config 

563 ) 

564 

565 if not continue_processing: 

566 quitting = True 

567 break 

568 

569 batch = [] 

570 

571 # Write remaining batch 

572 if batch and not quitting: 

573 async def batch_func(b): 

574 await self._write_batch(b) 

575 return [r.id for r in b] 

576 

577 await async_process_batch_with_fallback( 

578 batch, 

579 batch_func, 

580 self.create, 

581 result, 

582 config 

583 ) 

584 

585 result.duration = time.time() - start_time 

586 return result 

587 

588 async def _write_batch(self, records: list[Record]) -> None: 

589 """Write a batch of records using bulk API.""" 

590 if not records: 

591 return 

592 

593 # Build bulk operations 

594 operations = [] 

595 for record in records: 

596 doc = self._record_to_doc(record) 

597 operations.append({"index": {"_index": self.index_name}}) 

598 operations.append(doc) 

599 

600 # Execute bulk 

601 await self._client.bulk( 

602 operations=operations, 

603 refresh=self.refresh 

604 ) 

605 

606 async def vector_search( 

607 self, 

608 query_vector: np.ndarray | list[float], 

609 vector_field: str = "embedding", 

610 k: int = 10, 

611 metric: DistanceMetric = DistanceMetric.COSINE, 

612 filter: Query | None = None, 

613 include_source: bool = True, 

614 score_threshold: float | None = None, 

615 ) -> list[VectorSearchResult]: 

616 """Search for similar vectors using Elasticsearch KNN. 

617  

618 Args: 

619 query_vector: The vector to search for 

620 vector_field: Name of the vector field to search 

621 k: Number of results to return 

622 metric: Distance metric to use 

623 filter: Optional query filter to apply before vector search 

624 include_source: Whether to include source document in results 

625 score_threshold: Optional minimum similarity score 

626  

627 Returns: 

628 List of search results ordered by similarity 

629 """ 

630 self._check_connection() 

631 

632 # Import vector utilities 

633 from ..vector.elasticsearch_utils import ( 

634 build_knn_query, 

635 ) 

636 

637 # Build filter query if provided 

638 filter_query = self._build_filter_query(filter) if filter else None 

639 

640 # Build KNN query 

641 query = build_knn_query( 

642 query_vector=query_vector, 

643 field_name=vector_field, 

644 k=k, 

645 filter_query=filter_query, 

646 ) 

647 

648 # Execute search 

649 try: 

650 response = await self._client.search( 

651 index=self.index_name, 

652 **query, # Unpack the query dict directly 

653 size=k, 

654 _source=include_source, 

655 ) 

656 except Exception as e: 

657 self._handle_elasticsearch_error(e, "vector search") 

658 return [] 

659 

660 # Process results 

661 results = [] 

662 for hit in response.get("hits", {}).get("hits", []): 

663 score = hit.get("_score", 0.0) 

664 

665 # Apply score threshold if specified 

666 if score_threshold is not None and score < score_threshold: 

667 continue 

668 

669 # Convert document to record if source included 

670 record = None 

671 if include_source: 

672 record = self._doc_to_record(hit) 

673 

674 # Set the storage ID on the record if we have one 

675 if record and not record.has_storage_id(): 

676 record.storage_id = hit["_id"] 

677 

678 # Skip if no record (shouldn't happen if include_source is True) 

679 if record is None: 

680 continue 

681 

682 results.append(VectorSearchResult( 

683 record=record, 

684 score=score, 

685 vector_field=vector_field, 

686 metadata={ 

687 "index": self.index_name, 

688 "metric": metric.value, 

689 "doc_id": hit["_id"], 

690 }, 

691 )) 

692 

693 return results 

694 

695 async def bulk_embed_and_store( 

696 self, 

697 records: list[Record], 

698 text_field: str | list[str], 

699 vector_field: str = "embedding", 

700 embedding_fn: Any | None = None, 

701 batch_size: int = 100, 

702 model_name: str | None = None, 

703 model_version: str | None = None, 

704 ) -> list[str]: 

705 """Embed text fields and store vectors with records. 

706  

707 Args: 

708 records: Records to process 

709 text_field: Field name(s) containing text to embed 

710 vector_field: Field name to store vectors in 

711 embedding_fn: Function to generate embeddings 

712 batch_size: Number of records to process at once 

713 model_name: Name of the embedding model 

714 model_version: Version of the embedding model 

715  

716 Returns: 

717 List of record IDs that were processed 

718 """ 

719 # This is a stub implementation 

720 # Full implementation would require an actual embedding function 

721 logger.warning("bulk_embed_and_store is not fully implemented for Elasticsearch") 

722 return [] 

723 

724 async def create_vector_index( 

725 self, 

726 vector_field: str = "embedding", 

727 dimensions: int | None = None, 

728 metric: DistanceMetric = DistanceMetric.COSINE, 

729 index_type: str = "auto", 

730 **kwargs: Any, 

731 ) -> bool: 

732 """Create or update index mapping for vector field. 

733  

734 Args: 

735 vector_field: Name of the vector field to index 

736 dimensions: Number of dimensions 

737 metric: Distance metric for the index 

738 index_type: Type of index (ignored for ES, always uses HNSW) 

739 **kwargs: Additional index parameters 

740  

741 Returns: 

742 True if index was created/updated successfully 

743 """ 

744 self._check_connection() 

745 

746 if not dimensions: 

747 if vector_field not in self.vector_fields: 

748 raise ValueError(f"Unknown dimensions for field '{vector_field}'") 

749 dimensions = self.vector_fields[vector_field] 

750 

751 # Import vector utilities 

752 from ..vector.elasticsearch_utils import ( 

753 get_similarity_for_metric, 

754 get_vector_mapping, 

755 ) 

756 

757 # Get similarity function for metric 

758 similarity = get_similarity_for_metric(metric) 

759 

760 # Build mapping for the vector field 

761 mapping = get_vector_mapping(dimensions, similarity) 

762 

763 # Update index mapping 

764 try: 

765 await self._client.indices.put_mapping( 

766 index=self.index_name, 

767 properties={ 

768 f"data.{vector_field}": mapping 

769 } 

770 ) 

771 

772 # Track the vector field 

773 self.vector_fields[vector_field] = dimensions 

774 self.vector_enabled = True 

775 

776 logger.info(f"Created vector mapping for field '{vector_field}' with {dimensions} dimensions") 

777 return True 

778 

779 except Exception as e: 

780 self._handle_elasticsearch_error(e, "create vector index") 

781 return False