Coverage for src / dataknobs_data / backends / elasticsearch_async.py: 12%

381 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 16:34 -0700

1"""Native async Elasticsearch backend implementation with connection pooling.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6import time 

7from typing import TYPE_CHECKING, Any, cast 

8 

9from dataknobs_config import ConfigurableBase 

10 

11from ..database import AsyncDatabase 

12from ..pooling import ConnectionPoolManager 

13from ..pooling.elasticsearch import ( 

14 ElasticsearchPoolConfig, 

15 close_elasticsearch_client, 

16 create_async_elasticsearch_client, 

17 validate_elasticsearch_client, 

18) 

19from ..query import Operator, Query, SortOrder 

20from ..streaming import StreamConfig, StreamResult, async_process_batch_with_fallback 

21from ..vector.mixins import VectorOperationsMixin 

22from ..vector.types import DistanceMetric, VectorSearchResult 

23from .elasticsearch_mixins import ( 

24 ElasticsearchBaseConfig, 

25 ElasticsearchErrorHandler, 

26 ElasticsearchIndexManager, 

27 ElasticsearchQueryBuilder, 

28 ElasticsearchRecordSerializer, 

29 ElasticsearchVectorSupport, 

30) 

31 

32if TYPE_CHECKING: 

33 import numpy as np 

34 from collections.abc import AsyncIterator, Callable, Awaitable 

35 from ..records import Record 

36 

37logger = logging.getLogger(__name__) 

38 

39# Global pool manager for Elasticsearch clients 

40_client_manager = ConnectionPoolManager() 

41 

42 

43class AsyncElasticsearchDatabase( 

44 AsyncDatabase, 

45 ConfigurableBase, 

46 VectorOperationsMixin, 

47 ElasticsearchBaseConfig, 

48 ElasticsearchIndexManager, 

49 ElasticsearchVectorSupport, 

50 ElasticsearchErrorHandler, 

51 ElasticsearchRecordSerializer, 

52 ElasticsearchQueryBuilder, 

53): 

54 """Native async Elasticsearch database backend with connection pooling.""" 

55 

56 def __init__(self, config: dict[str, Any] | None = None): 

57 """Initialize async Elasticsearch database.""" 

58 super().__init__(config) 

59 

60 # Initialize vector support 

61 self.vector_fields = {} # field_name -> dimensions 

62 self.vector_enabled = False 

63 

64 config = config or {} 

65 self._pool_config = ElasticsearchPoolConfig.from_dict(config) 

66 self.index_name = self._pool_config.index 

67 self.refresh = config.get("refresh", True) 

68 self._client = None 

69 self._connected = False 

70 

71 @classmethod 

72 def from_config(cls, config: dict) -> AsyncElasticsearchDatabase: 

73 """Create from config dictionary.""" 

74 return cls(config) 

75 

76 async def connect(self) -> None: 

77 """Connect to the Elasticsearch database.""" 

78 if self._connected: 

79 return 

80 

81 # Get or create client for current event loop 

82 from ..pooling import BasePoolConfig 

83 self._client = await _client_manager.get_pool( 

84 self._pool_config, 

85 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_async_elasticsearch_client), 

86 validate_elasticsearch_client, 

87 close_elasticsearch_client 

88 ) 

89 

90 # Ensure index exists 

91 await self._ensure_index() 

92 self._connected = True 

93 

94 async def close(self) -> None: 

95 """Close the database connection.""" 

96 if self._connected: 

97 # Note: The client is managed by the pool manager, so we don't close it here 

98 # Just mark as disconnected 

99 self._client = None 

100 self._connected = False 

101 

102 def _initialize(self) -> None: 

103 """Initialize is handled in connect.""" 

104 pass 

105 

106 async def _ensure_index(self) -> None: 

107 """Ensure the index exists with proper mappings.""" 

108 if not self._client: 

109 raise RuntimeError("Database not connected. Call connect() first.") 

110 

111 # Check if index exists 

112 if not await self._client.indices.exists(index=self.index_name): # type: ignore[unreachable] 

113 # Get mappings with vector field support 

114 mappings = self.get_index_mappings(self.vector_fields) 

115 

116 # Get settings optimized for KNN if we have vector fields 

117 settings = self.get_knn_index_settings() if self.vector_fields else { 

118 "number_of_shards": 1, 

119 "number_of_replicas": 0, 

120 } 

121 

122 await self._client.indices.create( 

123 index=self.index_name, 

124 mappings=mappings, 

125 settings=settings 

126 ) 

127 

128 if self.vector_fields: 

129 self.vector_enabled = True 

130 logger.info(f"Created index '{self.index_name}' with vector support") 

131 

132 def _check_connection(self) -> None: 

133 """Check if database is connected.""" 

134 if not self._connected or not self._client: 

135 raise RuntimeError("Database not connected. Call connect() first.") 

136 

137 def _record_to_doc(self, record: Record) -> dict[str, Any]: 

138 """Convert a Record to an Elasticsearch document.""" 

139 # Update vector tracking if needed 

140 if self._has_vector_fields(record): 

141 self._update_vector_tracking(record) 

142 

143 # Add vector field metadata to record metadata 

144 if "vector_fields" not in record.metadata: 

145 record.metadata["vector_fields"] = {} 

146 

147 for field_name in self.vector_fields: 

148 if field_name in record.fields: 

149 field = record.fields[field_name] 

150 if hasattr(field, "source_field"): 

151 record.metadata["vector_fields"][field_name] = { 

152 "type": "vector", 

153 "dimensions": self.vector_fields[field_name], 

154 "source_field": field.source_field, 

155 "model": getattr(field, "model_name", None), 

156 "model_version": getattr(field, "model_version", None), 

157 } 

158 

159 return self._record_to_document(record) 

160 

161 def _doc_to_record(self, doc: dict[str, Any]) -> Record: 

162 """Convert an Elasticsearch document to a Record.""" 

163 doc_id = doc.get("_id") 

164 record = self._document_to_record(doc, doc_id) 

165 

166 # Add score if present 

167 if "_score" in doc: 

168 record.metadata["_score"] = doc["_score"] 

169 

170 return record 

171 

172 async def create(self, record: Record) -> str: 

173 """Create a new record.""" 

174 self._check_connection() 

175 doc = self._record_to_doc(record) 

176 

177 # Create document with explicit ID if record has one 

178 kwargs = { 

179 "index": self.index_name, 

180 "document": doc, 

181 "refresh": self.refresh 

182 } 

183 if record.id: 

184 kwargs["id"] = record.id 

185 

186 response = await self._client.index(**kwargs) 

187 

188 return response["_id"] 

189 

190 async def create_batch(self, records: list[Record]) -> list[str]: 

191 """Create multiple records in batch.""" 

192 self._check_connection() 

193 

194 ids = [] 

195 operations = [] 

196 

197 for record in records: 

198 doc = self._record_to_doc(record) 

199 operations.append({"index": {"_index": self.index_name}}) 

200 operations.append(doc) 

201 

202 if operations: 

203 response = await self._client.bulk( 

204 operations=operations, 

205 refresh=self.refresh 

206 ) 

207 

208 # Extract IDs from response 

209 for item in response.get("items", []): 

210 if "index" in item and "_id" in item["index"]: 

211 ids.append(item["index"]["_id"]) 

212 

213 return ids 

214 

215 async def read(self, id: str) -> Record | None: 

216 """Read a record by ID.""" 

217 self._check_connection() 

218 

219 try: 

220 response = await self._client.get( 

221 index=self.index_name, 

222 id=id 

223 ) 

224 return self._doc_to_record(response) 

225 except Exception as e: 

226 # Log the error for debugging 

227 logger.debug(f"Error reading document {id}: {e}") 

228 return None 

229 

230 async def update(self, id: str, record: Record) -> bool: 

231 """Update an existing record.""" 

232 self._check_connection() 

233 doc = self._record_to_doc(record) 

234 

235 try: 

236 await self._client.update( 

237 index=self.index_name, 

238 id=id, 

239 doc=doc, 

240 refresh=self.refresh 

241 ) 

242 return True 

243 except Exception: 

244 return False 

245 

246 async def delete(self, id: str) -> bool: 

247 """Delete a record by ID.""" 

248 self._check_connection() 

249 

250 try: 

251 await self._client.delete( 

252 index=self.index_name, 

253 id=id, 

254 refresh=self.refresh 

255 ) 

256 return True 

257 except Exception: 

258 return False 

259 

260 async def exists(self, id: str) -> bool: 

261 """Check if a record exists.""" 

262 self._check_connection() 

263 

264 return await self._client.exists( 

265 index=self.index_name, 

266 id=id 

267 ) 

268 

269 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str: 

270 """Update or insert a record. 

271  

272 Can be called as: 

273 - upsert(id, record) - explicit ID and record 

274 - upsert(record) - extract ID from record using Record's built-in logic 

275 """ 

276 self._check_connection() 

277 

278 # Determine ID and record based on arguments 

279 if isinstance(id_or_record, str): 

280 id = id_or_record 

281 if record is None: 

282 raise ValueError("Record required when ID is provided") 

283 else: 

284 record = id_or_record 

285 id = record.id 

286 if id is None: 

287 import uuid # type: ignore[unreachable] 

288 id = str(uuid.uuid4()) 

289 record.storage_id = id 

290 

291 doc = self._record_to_doc(record) 

292 

293 await self._client.index( 

294 index=self.index_name, 

295 id=id, 

296 document=doc, 

297 refresh=self.refresh 

298 ) 

299 

300 return id 

301 

302 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

303 """Update multiple records efficiently using the bulk API. 

304  

305 Uses AsyncElasticsearch's bulk API for efficient batch updates. 

306  

307 Args: 

308 updates: List of (id, record) tuples to update 

309  

310 Returns: 

311 List of success flags for each update 

312 """ 

313 if not updates: 

314 return [] 

315 

316 self._check_connection() 

317 

318 # Build bulk operations for AsyncElasticsearch 

319 operations: list[dict[str, Any]] = [] 

320 for record_id, record in updates: 

321 # Add update operation 

322 operations.append({ 

323 "update": { 

324 "_index": self.index_name, 

325 "_id": record_id 

326 } 

327 }) 

328 # Add document data 

329 doc = self._record_to_doc(record) 

330 operations.append({ 

331 "doc": doc, 

332 "doc_as_upsert": False # Don't create if doesn't exist 

333 }) 

334 

335 try: 

336 # Execute bulk update using AsyncElasticsearch 

337 response = await self._client.bulk( 

338 operations=operations, 

339 refresh=self.refresh 

340 ) 

341 

342 # Process the response to determine which updates succeeded 

343 results = [] 

344 if response.get("items"): 

345 for item in response["items"]: 

346 if "update" in item: 

347 update_result = item["update"] 

348 # Check if update was successful (status 200) or not found (404) 

349 results.append(update_result.get("status") == 200) 

350 else: 

351 results.append(False) 

352 else: 

353 # If no items in response, mark all as failed 

354 results = [False] * len(updates) 

355 

356 return results 

357 

358 except Exception as e: 

359 # If bulk operation fails, mark all as failed 

360 import logging 

361 logging.error(f"Bulk update failed: {e}") 

362 return [False] * len(updates) 

363 

364 async def search(self, query: Query) -> list[Record]: 

365 """Search for records matching the query.""" 

366 self._check_connection() 

367 

368 # Build Elasticsearch query 

369 es_query = {"bool": {"must": []}} 

370 

371 for filter in query.filters: 

372 field_path = f"data.{filter.field}" 

373 

374 if filter.operator == Operator.EQ: 

375 # For string values, use keyword field for exact matching 

376 if isinstance(filter.value, str): 

377 field_path = f"{field_path}.keyword" 

378 es_query["bool"]["must"].append({"term": {field_path: filter.value}}) 

379 elif filter.operator == Operator.NEQ: 

380 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", []) 

381 es_query["bool"]["must_not"].append({"term": {field_path: filter.value}}) 

382 elif filter.operator == Operator.GT: 

383 es_query["bool"]["must"].append({"range": {field_path: {"gt": filter.value}}}) 

384 elif filter.operator == Operator.LT: 

385 es_query["bool"]["must"].append({"range": {field_path: {"lt": filter.value}}}) 

386 elif filter.operator == Operator.GTE: 

387 es_query["bool"]["must"].append({"range": {field_path: {"gte": filter.value}}}) 

388 elif filter.operator == Operator.LTE: 

389 es_query["bool"]["must"].append({"range": {field_path: {"lte": filter.value}}}) 

390 elif filter.operator == Operator.LIKE: 

391 es_query["bool"]["must"].append({"wildcard": {field_path: f"*{filter.value}*"}}) 

392 elif filter.operator == Operator.IN: 

393 es_query["bool"]["must"].append({"terms": {field_path: filter.value}}) 

394 elif filter.operator == Operator.NOT_IN: 

395 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", []) 

396 es_query["bool"]["must_not"].append({"terms": {field_path: filter.value}}) 

397 elif filter.operator == Operator.BETWEEN: 

398 # Use Elasticsearch's native range query for efficient BETWEEN 

399 if isinstance(filter.value, (list, tuple)) and len(filter.value) == 2: 

400 lower, upper = filter.value 

401 es_query["bool"]["must"].append({ 

402 "range": { 

403 field_path: { 

404 "gte": lower, 

405 "lte": upper 

406 } 

407 } 

408 }) 

409 elif filter.operator == Operator.NOT_BETWEEN: 

410 # NOT BETWEEN using must_not with range 

411 if isinstance(filter.value, (list, tuple)) and len(filter.value) == 2: 

412 lower, upper = filter.value 

413 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", []) 

414 es_query["bool"]["must_not"].append({ 

415 "range": { 

416 field_path: { 

417 "gte": lower, 

418 "lte": upper 

419 } 

420 } 

421 }) 

422 

423 # If no filters, use match_all 

424 if not es_query["bool"]["must"] and "must_not" not in es_query["bool"]: 

425 es_query = {"match_all": {}} 

426 

427 # Build sort 

428 sort = [] 

429 if query.sort_specs: 

430 for sort_spec in query.sort_specs: 

431 direction = "desc" if sort_spec.order == SortOrder.DESC else "asc" 

432 sort.append({f"data.{sort_spec.field}": {"order": direction}}) 

433 

434 # Build request body 

435 body = {"query": es_query} 

436 if sort: 

437 body["sort"] = sort 

438 

439 # Add size and from for pagination 

440 size = query.limit_value if query.limit_value else 10000 

441 from_param = query.offset_value if query.offset_value else 0 

442 

443 # Execute search 

444 response = await self._client.search( 

445 index=self.index_name, 

446 query=es_query, 

447 sort=sort if sort else None, 

448 size=size, 

449 from_=from_param 

450 ) 

451 

452 # Convert to records 

453 records = [] 

454 for hit in response["hits"]["hits"]: 

455 record = self._doc_to_record(hit) 

456 

457 # Apply field projection if specified 

458 if query.fields: 

459 record = record.project(query.fields) 

460 

461 records.append(record) 

462 

463 return records 

464 

465 async def _count_all(self) -> int: 

466 """Count all records in the database.""" 

467 self._check_connection() 

468 

469 response = await self._client.count(index=self.index_name) 

470 return response["count"] 

471 

472 async def clear(self) -> int: 

473 """Clear all records from the database.""" 

474 self._check_connection() 

475 

476 # Get count before deletion 

477 count = await self._count_all() 

478 

479 # Delete by query - delete all documents 

480 response = await self._client.delete_by_query( 

481 index=self.index_name, 

482 query={"match_all": {}}, 

483 refresh=self.refresh 

484 ) 

485 

486 return response.get("deleted", count) 

487 

488 async def stream_read( 

489 self, 

490 query: Query | None = None, 

491 config: StreamConfig | None = None 

492 ) -> AsyncIterator[Record]: 

493 """Stream records from Elasticsearch using scroll API.""" 

494 self._check_connection() 

495 config = config or StreamConfig() 

496 

497 # Build query 

498 es_query = {"match_all": {}} 

499 if query and query.filters: 

500 es_query = {"bool": {"must": []}} 

501 for filter in query.filters: 

502 field_path = f"data.{filter.field}" 

503 if filter.operator == Operator.EQ: 

504 es_query["bool"]["must"].append({"term": {field_path: filter.value}}) 

505 

506 # Initial search with scroll 

507 response = await self._client.search( 

508 index=self.index_name, 

509 query=es_query, 

510 scroll="2m", 

511 size=config.batch_size 

512 ) 

513 

514 scroll_id = response["_scroll_id"] 

515 hits = response["hits"]["hits"] 

516 

517 try: 

518 while hits: 

519 for hit in hits: 

520 record = self._doc_to_record(hit) 

521 if query and query.fields: 

522 record = record.project(query.fields) 

523 yield record 

524 

525 # Get next batch 

526 response = await self._client.scroll( 

527 scroll_id=scroll_id, 

528 scroll="2m" 

529 ) 

530 hits = response["hits"]["hits"] 

531 finally: 

532 # Clear scroll 

533 await self._client.clear_scroll(scroll_id=scroll_id) 

534 

535 async def stream_write( 

536 self, 

537 records: AsyncIterator[Record], 

538 config: StreamConfig | None = None 

539 ) -> StreamResult: 

540 """Stream records into Elasticsearch using bulk API.""" 

541 self._check_connection() 

542 config = config or StreamConfig() 

543 result = StreamResult() 

544 start_time = time.time() 

545 quitting = False 

546 

547 batch = [] 

548 async for record in records: 

549 batch.append(record) 

550 

551 if len(batch) >= config.batch_size: 

552 # Write batch with graceful fallback 

553 async def batch_func(b): 

554 await self._write_batch(b) 

555 return [r.id for r in b] 

556 

557 continue_processing = await async_process_batch_with_fallback( 

558 batch, 

559 batch_func, 

560 self.create, 

561 result, 

562 config 

563 ) 

564 

565 if not continue_processing: 

566 quitting = True 

567 break 

568 

569 batch = [] 

570 

571 # Write remaining batch 

572 if batch and not quitting: 

573 async def batch_func(b): 

574 await self._write_batch(b) 

575 return [r.id for r in b] 

576 

577 await async_process_batch_with_fallback( 

578 batch, 

579 batch_func, 

580 self.create, 

581 result, 

582 config 

583 ) 

584 

585 result.duration = time.time() - start_time 

586 return result 

587 

588 async def _write_batch(self, records: list[Record]) -> None: 

589 """Write a batch of records using bulk API.""" 

590 if not records: 

591 return 

592 

593 # Build bulk operations 

594 operations = [] 

595 for record in records: 

596 doc = self._record_to_doc(record) 

597 operations.append({"index": {"_index": self.index_name}}) 

598 operations.append(doc) 

599 

600 # Execute bulk 

601 await self._client.bulk( 

602 operations=operations, 

603 refresh=self.refresh 

604 ) 

605 

606 async def vector_search( 

607 self, 

608 query_vector: np.ndarray | list[float], 

609 vector_field: str = "embedding", 

610 k: int = 10, 

611 metric: DistanceMetric = DistanceMetric.COSINE, 

612 filter: Query | None = None, 

613 include_source: bool = True, 

614 score_threshold: float | None = None, 

615 ) -> list[VectorSearchResult]: 

616 """Search for similar vectors using Elasticsearch KNN. 

617  

618 Args: 

619 query_vector: The vector to search for 

620 vector_field: Name of the vector field to search 

621 k: Number of results to return 

622 metric: Distance metric to use 

623 filter: Optional query filter to apply before vector search 

624 include_source: Whether to include source document in results 

625 score_threshold: Optional minimum similarity score 

626  

627 Returns: 

628 List of search results ordered by similarity 

629 """ 

630 self._check_connection() 

631 

632 # Import vector utilities 

633 from ..vector.elasticsearch_utils import ( 

634 build_knn_query, 

635 ) 

636 

637 # Build filter query if provided 

638 filter_query = self._build_filter_query(filter) if filter else None 

639 

640 # Build KNN query 

641 query = build_knn_query( 

642 query_vector=query_vector, 

643 field_name=vector_field, 

644 k=k, 

645 filter_query=filter_query, 

646 ) 

647 

648 # Execute search 

649 try: 

650 response = await self._client.search( 

651 index=self.index_name, 

652 **query, # Unpack the query dict directly 

653 size=k, 

654 _source=include_source, 

655 ) 

656 except Exception as e: 

657 self._handle_elasticsearch_error(e, "vector search") 

658 return [] 

659 

660 # Process results 

661 results = [] 

662 for hit in response.get("hits", {}).get("hits", []): 

663 score = hit.get("_score", 0.0) 

664 

665 # Apply score threshold if specified 

666 if score_threshold is not None and score < score_threshold: 

667 continue 

668 

669 # Convert document to record if source included 

670 record = None 

671 if include_source: 

672 record = self._doc_to_record(hit) 

673 

674 # Set the storage ID on the record if we have one 

675 if record and not record.has_storage_id(): 

676 record.storage_id = hit["_id"] 

677 

678 # Skip if no record (shouldn't happen if include_source is True) 

679 if record is None: 

680 continue 

681 

682 results.append(VectorSearchResult( 

683 record=record, 

684 score=score, 

685 vector_field=vector_field, 

686 metadata={ 

687 "index": self.index_name, 

688 "metric": metric.value, 

689 "doc_id": hit["_id"], 

690 }, 

691 )) 

692 

693 return results 

694 

695 async def bulk_embed_and_store( 

696 self, 

697 records: list[Record], 

698 text_field: str | list[str], 

699 vector_field: str = "embedding", 

700 embedding_fn: Any | None = None, 

701 batch_size: int = 100, 

702 model_name: str | None = None, 

703 model_version: str | None = None, 

704 ) -> list[str]: 

705 """Embed text fields and store vectors with records. 

706  

707 Args: 

708 records: Records to process 

709 text_field: Field name(s) containing text to embed 

710 vector_field: Field name to store vectors in 

711 embedding_fn: Function to generate embeddings 

712 batch_size: Number of records to process at once 

713 model_name: Name of the embedding model 

714 model_version: Version of the embedding model 

715  

716 Returns: 

717 List of record IDs that were processed 

718 """ 

719 # This is a stub implementation 

720 # Full implementation would require an actual embedding function 

721 logger.warning("bulk_embed_and_store is not fully implemented for Elasticsearch") 

722 return [] 

723 

724 async def create_vector_index( 

725 self, 

726 vector_field: str = "embedding", 

727 dimensions: int | None = None, 

728 metric: DistanceMetric = DistanceMetric.COSINE, 

729 index_type: str = "auto", 

730 **kwargs: Any, 

731 ) -> bool: 

732 """Create or update index mapping for vector field. 

733  

734 Args: 

735 vector_field: Name of the vector field to index 

736 dimensions: Number of dimensions 

737 metric: Distance metric for the index 

738 index_type: Type of index (ignored for ES, always uses HNSW) 

739 **kwargs: Additional index parameters 

740  

741 Returns: 

742 True if index was created/updated successfully 

743 """ 

744 self._check_connection() 

745 

746 if not dimensions: 

747 if vector_field not in self.vector_fields: 

748 raise ValueError(f"Unknown dimensions for field '{vector_field}'") 

749 dimensions = self.vector_fields[vector_field] 

750 

751 # Import vector utilities 

752 from ..vector.elasticsearch_utils import ( 

753 get_similarity_for_metric, 

754 get_vector_mapping, 

755 ) 

756 

757 # Get similarity function for metric 

758 similarity = get_similarity_for_metric(metric) 

759 

760 # Build mapping for the vector field 

761 mapping = get_vector_mapping(dimensions, similarity) 

762 

763 # Update index mapping 

764 try: 

765 await self._client.indices.put_mapping( 

766 index=self.index_name, 

767 properties={ 

768 f"data.{vector_field}": mapping 

769 } 

770 ) 

771 

772 # Track the vector field 

773 self.vector_fields[vector_field] = dimensions 

774 self.vector_enabled = True 

775 

776 logger.info(f"Created vector mapping for field '{vector_field}' with {dimensions} dimensions") 

777 return True 

778 

779 except Exception as e: 

780 self._handle_elasticsearch_error(e, "create vector index") 

781 return False 

782 

783 async def _supports_native_hybrid(self) -> bool: 

784 """Check if this Elasticsearch backend supports native hybrid search. 

785 

786 Elasticsearch 8.x supports native RRF hybrid search. 

787 

788 Returns: 

789 True since Elasticsearch supports native hybrid search 

790 """ 

791 return True 

792 

793 async def hybrid_search( 

794 self, 

795 query_text: str, 

796 query_vector: np.ndarray | list[float], 

797 text_fields: list[str] | None = None, 

798 vector_field: str = "embedding", 

799 k: int = 10, 

800 config: Any = None, # HybridSearchConfig 

801 filter: Query | None = None, 

802 metric: DistanceMetric = DistanceMetric.COSINE, 

803 ) -> list[Any]: # list[HybridSearchResult] 

804 """Perform native Elasticsearch hybrid search using RRF. 

805 

806 Uses Elasticsearch's native RRF (Reciprocal Rank Fusion) for combining 

807 BM25 text search with KNN vector search. This is more efficient than 

808 client-side fusion as it's executed in a single request. 

809 

810 Args: 

811 query_text: Text query for BM25 matching 

812 query_vector: Vector for KNN similarity search 

813 text_fields: Fields to search for text matching 

814 vector_field: Name of the vector field to search 

815 k: Number of results to return 

816 config: Hybrid search configuration (weights, fusion strategy) 

817 filter: Optional additional filters to apply 

818 metric: Distance metric for vector search 

819 

820 Returns: 

821 List of HybridSearchResult ordered by RRF score (descending) 

822 """ 

823 from ..vector.hybrid import ( 

824 FusionStrategy, 

825 HybridSearchConfig, 

826 HybridSearchResult, 

827 ) 

828 from ..vector.elasticsearch_utils import build_knn_query 

829 

830 self._check_connection() 

831 

832 config = config or HybridSearchConfig() 

833 

834 # If not using native strategy, fall back to parent implementation 

835 if config.fusion_strategy != FusionStrategy.NATIVE: 

836 # Import parent class to call its implementation 

837 from ..vector.mixins import VectorOperationsMixin 

838 return await VectorOperationsMixin.hybrid_search( 

839 self, 

840 query_text=query_text, 

841 query_vector=query_vector, 

842 text_fields=text_fields, 

843 vector_field=vector_field, 

844 k=k, 

845 config=config, 

846 filter=filter, 

847 metric=metric, 

848 ) 

849 

850 # Use config.text_fields if provided, otherwise use parameter 

851 search_text_fields = config.text_fields or text_fields or ["content", "title", "text"] 

852 

853 # Build filter query if provided 

854 filter_query = self._build_filter_query(filter) if filter else None 

855 

856 # Build KNN query 

857 knn_query = build_knn_query( 

858 query_vector=query_vector, 

859 field_name=vector_field, 

860 k=k, 

861 filter_query=filter_query, 

862 ) 

863 

864 # Build text search query with multi_match 

865 text_query: dict[str, Any] = { 

866 "multi_match": { 

867 "query": query_text, 

868 "fields": [f"data.{f}" for f in search_text_fields], 

869 "type": "best_fields", 

870 "operator": "or", 

871 } 

872 } 

873 

874 # Build RRF query combining both searches 

875 # Note: RRF requires Elasticsearch 8.8+ with appropriate license 

876 # For older versions, we need to use sub_searches 

877 try: 

878 # Try native RRF (ES 8.8+) 

879 body: dict[str, Any] = { 

880 "retriever": { 

881 "rrf": { 

882 "retrievers": [ 

883 { 

884 "standard": { 

885 "query": text_query 

886 } 

887 }, 

888 { 

889 "knn": { 

890 "field": f"data.{vector_field}", 

891 "query_vector": query_vector.tolist() if hasattr(query_vector, 'tolist') else list(query_vector), 

892 "k": k, 

893 "num_candidates": k * 3, 

894 } 

895 } 

896 ], 

897 "rank_constant": config.rrf_k, 

898 "rank_window_size": k * 3, 

899 } 

900 }, 

901 "size": k, 

902 } 

903 

904 if filter_query: 

905 body["post_filter"] = filter_query 

906 

907 response = await self._client.search( 

908 index=self.index_name, 

909 body=body, 

910 ) 

911 except Exception as e: 

912 # Fall back to client-side fusion if native RRF not available 

913 logger.warning(f"Native RRF not available ({e}), falling back to client-side fusion") 

914 from ..vector.mixins import VectorOperationsMixin 

915 return await VectorOperationsMixin.hybrid_search( 

916 self, 

917 query_text=query_text, 

918 query_vector=query_vector, 

919 text_fields=text_fields, 

920 vector_field=vector_field, 

921 k=k, 

922 config=HybridSearchConfig( 

923 text_weight=config.text_weight, 

924 vector_weight=config.vector_weight, 

925 fusion_strategy=FusionStrategy.RRF, 

926 rrf_k=config.rrf_k, 

927 text_fields=config.text_fields, 

928 ), 

929 filter=filter, 

930 metric=metric, 

931 ) 

932 

933 # Process results 

934 results: list[HybridSearchResult] = [] 

935 hits = response.get("hits", {}).get("hits", []) 

936 

937 for i, hit in enumerate(hits): 

938 record = self._doc_to_record(hit) 

939 if record: 

940 if not record.has_storage_id(): 

941 record.storage_id = hit["_id"] 

942 

943 # RRF doesn't provide individual scores, just the fused score 

944 combined_score = hit.get("_score", 1.0 / (config.rrf_k + i + 1)) 

945 

946 results.append(HybridSearchResult( 

947 record=record, 

948 combined_score=combined_score, 

949 text_score=None, # Not available with native RRF 

950 vector_score=None, # Not available with native RRF 

951 text_rank=None, 

952 vector_rank=None, 

953 metadata={ 

954 "fusion_strategy": "native_rrf", 

955 "index": self.index_name, 

956 "doc_id": hit["_id"], 

957 }, 

958 )) 

959 

960 return results 

961 

962 async def _text_search_for_hybrid( 

963 self, 

964 query_text: str, 

965 text_fields: list[str] | None, 

966 k: int, 

967 filter: Query | None = None, 

968 ) -> list[tuple[Record, float]]: 

969 """Perform BM25 text search for hybrid search fusion. 

970 

971 Uses Elasticsearch's native BM25 scoring for text relevance. 

972 

973 Args: 

974 query_text: Text to search for 

975 text_fields: Fields to search in 

976 k: Maximum results to return 

977 filter: Additional filters 

978 

979 Returns: 

980 List of (record, score) tuples ordered by BM25 relevance 

981 """ 

982 self._check_connection() 

983 

984 search_fields = text_fields or ["content", "title", "text"] 

985 

986 # Build multi_match query 

987 query: dict[str, Any] = { 

988 "multi_match": { 

989 "query": query_text, 

990 "fields": [f"data.{f}" for f in search_fields], 

991 "type": "best_fields", 

992 "operator": "or", 

993 } 

994 } 

995 

996 # Build filter if provided 

997 filter_query = self._build_filter_query(filter) if filter else None 

998 

999 body: dict[str, Any] = { 

1000 "query": query if not filter_query else { 

1001 "bool": { 

1002 "must": query, 

1003 "filter": filter_query, 

1004 } 

1005 }, 

1006 "size": k, 

1007 } 

1008 

1009 try: 

1010 response = await self._client.search( 

1011 index=self.index_name, 

1012 body=body, 

1013 ) 

1014 except Exception as e: 

1015 self._handle_elasticsearch_error(e, "text search for hybrid") 

1016 return [] 

1017 

1018 # Process results 

1019 results: list[tuple[Record, float]] = [] 

1020 hits = response.get("hits", {}).get("hits", []) 

1021 max_score = response.get("hits", {}).get("max_score", 1.0) or 1.0 

1022 

1023 for hit in hits: 

1024 record = self._doc_to_record(hit) 

1025 if record: 

1026 if not record.has_storage_id(): 

1027 record.storage_id = hit["_id"] 

1028 

1029 # Normalize BM25 score to 0-1 range 

1030 score = hit.get("_score", 0.0) / max_score if max_score > 0 else 0.0 

1031 results.append((record, score)) 

1032 

1033 return results