Coverage for src / dataknobs_data / backends / elasticsearch_async.py: 12%
381 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-26 16:34 -0700
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-26 16:34 -0700
1"""Native async Elasticsearch backend implementation with connection pooling."""
3from __future__ import annotations
5import logging
6import time
7from typing import TYPE_CHECKING, Any, cast
9from dataknobs_config import ConfigurableBase
11from ..database import AsyncDatabase
12from ..pooling import ConnectionPoolManager
13from ..pooling.elasticsearch import (
14 ElasticsearchPoolConfig,
15 close_elasticsearch_client,
16 create_async_elasticsearch_client,
17 validate_elasticsearch_client,
18)
19from ..query import Operator, Query, SortOrder
20from ..streaming import StreamConfig, StreamResult, async_process_batch_with_fallback
21from ..vector.mixins import VectorOperationsMixin
22from ..vector.types import DistanceMetric, VectorSearchResult
23from .elasticsearch_mixins import (
24 ElasticsearchBaseConfig,
25 ElasticsearchErrorHandler,
26 ElasticsearchIndexManager,
27 ElasticsearchQueryBuilder,
28 ElasticsearchRecordSerializer,
29 ElasticsearchVectorSupport,
30)
32if TYPE_CHECKING:
33 import numpy as np
34 from collections.abc import AsyncIterator, Callable, Awaitable
35 from ..records import Record
37logger = logging.getLogger(__name__)
39# Global pool manager for Elasticsearch clients
40_client_manager = ConnectionPoolManager()
43class AsyncElasticsearchDatabase(
44 AsyncDatabase,
45 ConfigurableBase,
46 VectorOperationsMixin,
47 ElasticsearchBaseConfig,
48 ElasticsearchIndexManager,
49 ElasticsearchVectorSupport,
50 ElasticsearchErrorHandler,
51 ElasticsearchRecordSerializer,
52 ElasticsearchQueryBuilder,
53):
54 """Native async Elasticsearch database backend with connection pooling."""
56 def __init__(self, config: dict[str, Any] | None = None):
57 """Initialize async Elasticsearch database."""
58 super().__init__(config)
60 # Initialize vector support
61 self.vector_fields = {} # field_name -> dimensions
62 self.vector_enabled = False
64 config = config or {}
65 self._pool_config = ElasticsearchPoolConfig.from_dict(config)
66 self.index_name = self._pool_config.index
67 self.refresh = config.get("refresh", True)
68 self._client = None
69 self._connected = False
71 @classmethod
72 def from_config(cls, config: dict) -> AsyncElasticsearchDatabase:
73 """Create from config dictionary."""
74 return cls(config)
76 async def connect(self) -> None:
77 """Connect to the Elasticsearch database."""
78 if self._connected:
79 return
81 # Get or create client for current event loop
82 from ..pooling import BasePoolConfig
83 self._client = await _client_manager.get_pool(
84 self._pool_config,
85 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_async_elasticsearch_client),
86 validate_elasticsearch_client,
87 close_elasticsearch_client
88 )
90 # Ensure index exists
91 await self._ensure_index()
92 self._connected = True
94 async def close(self) -> None:
95 """Close the database connection."""
96 if self._connected:
97 # Note: The client is managed by the pool manager, so we don't close it here
98 # Just mark as disconnected
99 self._client = None
100 self._connected = False
102 def _initialize(self) -> None:
103 """Initialize is handled in connect."""
104 pass
106 async def _ensure_index(self) -> None:
107 """Ensure the index exists with proper mappings."""
108 if not self._client:
109 raise RuntimeError("Database not connected. Call connect() first.")
111 # Check if index exists
112 if not await self._client.indices.exists(index=self.index_name): # type: ignore[unreachable]
113 # Get mappings with vector field support
114 mappings = self.get_index_mappings(self.vector_fields)
116 # Get settings optimized for KNN if we have vector fields
117 settings = self.get_knn_index_settings() if self.vector_fields else {
118 "number_of_shards": 1,
119 "number_of_replicas": 0,
120 }
122 await self._client.indices.create(
123 index=self.index_name,
124 mappings=mappings,
125 settings=settings
126 )
128 if self.vector_fields:
129 self.vector_enabled = True
130 logger.info(f"Created index '{self.index_name}' with vector support")
132 def _check_connection(self) -> None:
133 """Check if database is connected."""
134 if not self._connected or not self._client:
135 raise RuntimeError("Database not connected. Call connect() first.")
137 def _record_to_doc(self, record: Record) -> dict[str, Any]:
138 """Convert a Record to an Elasticsearch document."""
139 # Update vector tracking if needed
140 if self._has_vector_fields(record):
141 self._update_vector_tracking(record)
143 # Add vector field metadata to record metadata
144 if "vector_fields" not in record.metadata:
145 record.metadata["vector_fields"] = {}
147 for field_name in self.vector_fields:
148 if field_name in record.fields:
149 field = record.fields[field_name]
150 if hasattr(field, "source_field"):
151 record.metadata["vector_fields"][field_name] = {
152 "type": "vector",
153 "dimensions": self.vector_fields[field_name],
154 "source_field": field.source_field,
155 "model": getattr(field, "model_name", None),
156 "model_version": getattr(field, "model_version", None),
157 }
159 return self._record_to_document(record)
161 def _doc_to_record(self, doc: dict[str, Any]) -> Record:
162 """Convert an Elasticsearch document to a Record."""
163 doc_id = doc.get("_id")
164 record = self._document_to_record(doc, doc_id)
166 # Add score if present
167 if "_score" in doc:
168 record.metadata["_score"] = doc["_score"]
170 return record
172 async def create(self, record: Record) -> str:
173 """Create a new record."""
174 self._check_connection()
175 doc = self._record_to_doc(record)
177 # Create document with explicit ID if record has one
178 kwargs = {
179 "index": self.index_name,
180 "document": doc,
181 "refresh": self.refresh
182 }
183 if record.id:
184 kwargs["id"] = record.id
186 response = await self._client.index(**kwargs)
188 return response["_id"]
190 async def create_batch(self, records: list[Record]) -> list[str]:
191 """Create multiple records in batch."""
192 self._check_connection()
194 ids = []
195 operations = []
197 for record in records:
198 doc = self._record_to_doc(record)
199 operations.append({"index": {"_index": self.index_name}})
200 operations.append(doc)
202 if operations:
203 response = await self._client.bulk(
204 operations=operations,
205 refresh=self.refresh
206 )
208 # Extract IDs from response
209 for item in response.get("items", []):
210 if "index" in item and "_id" in item["index"]:
211 ids.append(item["index"]["_id"])
213 return ids
215 async def read(self, id: str) -> Record | None:
216 """Read a record by ID."""
217 self._check_connection()
219 try:
220 response = await self._client.get(
221 index=self.index_name,
222 id=id
223 )
224 return self._doc_to_record(response)
225 except Exception as e:
226 # Log the error for debugging
227 logger.debug(f"Error reading document {id}: {e}")
228 return None
230 async def update(self, id: str, record: Record) -> bool:
231 """Update an existing record."""
232 self._check_connection()
233 doc = self._record_to_doc(record)
235 try:
236 await self._client.update(
237 index=self.index_name,
238 id=id,
239 doc=doc,
240 refresh=self.refresh
241 )
242 return True
243 except Exception:
244 return False
246 async def delete(self, id: str) -> bool:
247 """Delete a record by ID."""
248 self._check_connection()
250 try:
251 await self._client.delete(
252 index=self.index_name,
253 id=id,
254 refresh=self.refresh
255 )
256 return True
257 except Exception:
258 return False
260 async def exists(self, id: str) -> bool:
261 """Check if a record exists."""
262 self._check_connection()
264 return await self._client.exists(
265 index=self.index_name,
266 id=id
267 )
269 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
270 """Update or insert a record.
272 Can be called as:
273 - upsert(id, record) - explicit ID and record
274 - upsert(record) - extract ID from record using Record's built-in logic
275 """
276 self._check_connection()
278 # Determine ID and record based on arguments
279 if isinstance(id_or_record, str):
280 id = id_or_record
281 if record is None:
282 raise ValueError("Record required when ID is provided")
283 else:
284 record = id_or_record
285 id = record.id
286 if id is None:
287 import uuid # type: ignore[unreachable]
288 id = str(uuid.uuid4())
289 record.storage_id = id
291 doc = self._record_to_doc(record)
293 await self._client.index(
294 index=self.index_name,
295 id=id,
296 document=doc,
297 refresh=self.refresh
298 )
300 return id
302 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
303 """Update multiple records efficiently using the bulk API.
305 Uses AsyncElasticsearch's bulk API for efficient batch updates.
307 Args:
308 updates: List of (id, record) tuples to update
310 Returns:
311 List of success flags for each update
312 """
313 if not updates:
314 return []
316 self._check_connection()
318 # Build bulk operations for AsyncElasticsearch
319 operations: list[dict[str, Any]] = []
320 for record_id, record in updates:
321 # Add update operation
322 operations.append({
323 "update": {
324 "_index": self.index_name,
325 "_id": record_id
326 }
327 })
328 # Add document data
329 doc = self._record_to_doc(record)
330 operations.append({
331 "doc": doc,
332 "doc_as_upsert": False # Don't create if doesn't exist
333 })
335 try:
336 # Execute bulk update using AsyncElasticsearch
337 response = await self._client.bulk(
338 operations=operations,
339 refresh=self.refresh
340 )
342 # Process the response to determine which updates succeeded
343 results = []
344 if response.get("items"):
345 for item in response["items"]:
346 if "update" in item:
347 update_result = item["update"]
348 # Check if update was successful (status 200) or not found (404)
349 results.append(update_result.get("status") == 200)
350 else:
351 results.append(False)
352 else:
353 # If no items in response, mark all as failed
354 results = [False] * len(updates)
356 return results
358 except Exception as e:
359 # If bulk operation fails, mark all as failed
360 import logging
361 logging.error(f"Bulk update failed: {e}")
362 return [False] * len(updates)
364 async def search(self, query: Query) -> list[Record]:
365 """Search for records matching the query."""
366 self._check_connection()
368 # Build Elasticsearch query
369 es_query = {"bool": {"must": []}}
371 for filter in query.filters:
372 field_path = f"data.{filter.field}"
374 if filter.operator == Operator.EQ:
375 # For string values, use keyword field for exact matching
376 if isinstance(filter.value, str):
377 field_path = f"{field_path}.keyword"
378 es_query["bool"]["must"].append({"term": {field_path: filter.value}})
379 elif filter.operator == Operator.NEQ:
380 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", [])
381 es_query["bool"]["must_not"].append({"term": {field_path: filter.value}})
382 elif filter.operator == Operator.GT:
383 es_query["bool"]["must"].append({"range": {field_path: {"gt": filter.value}}})
384 elif filter.operator == Operator.LT:
385 es_query["bool"]["must"].append({"range": {field_path: {"lt": filter.value}}})
386 elif filter.operator == Operator.GTE:
387 es_query["bool"]["must"].append({"range": {field_path: {"gte": filter.value}}})
388 elif filter.operator == Operator.LTE:
389 es_query["bool"]["must"].append({"range": {field_path: {"lte": filter.value}}})
390 elif filter.operator == Operator.LIKE:
391 es_query["bool"]["must"].append({"wildcard": {field_path: f"*{filter.value}*"}})
392 elif filter.operator == Operator.IN:
393 es_query["bool"]["must"].append({"terms": {field_path: filter.value}})
394 elif filter.operator == Operator.NOT_IN:
395 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", [])
396 es_query["bool"]["must_not"].append({"terms": {field_path: filter.value}})
397 elif filter.operator == Operator.BETWEEN:
398 # Use Elasticsearch's native range query for efficient BETWEEN
399 if isinstance(filter.value, (list, tuple)) and len(filter.value) == 2:
400 lower, upper = filter.value
401 es_query["bool"]["must"].append({
402 "range": {
403 field_path: {
404 "gte": lower,
405 "lte": upper
406 }
407 }
408 })
409 elif filter.operator == Operator.NOT_BETWEEN:
410 # NOT BETWEEN using must_not with range
411 if isinstance(filter.value, (list, tuple)) and len(filter.value) == 2:
412 lower, upper = filter.value
413 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", [])
414 es_query["bool"]["must_not"].append({
415 "range": {
416 field_path: {
417 "gte": lower,
418 "lte": upper
419 }
420 }
421 })
423 # If no filters, use match_all
424 if not es_query["bool"]["must"] and "must_not" not in es_query["bool"]:
425 es_query = {"match_all": {}}
427 # Build sort
428 sort = []
429 if query.sort_specs:
430 for sort_spec in query.sort_specs:
431 direction = "desc" if sort_spec.order == SortOrder.DESC else "asc"
432 sort.append({f"data.{sort_spec.field}": {"order": direction}})
434 # Build request body
435 body = {"query": es_query}
436 if sort:
437 body["sort"] = sort
439 # Add size and from for pagination
440 size = query.limit_value if query.limit_value else 10000
441 from_param = query.offset_value if query.offset_value else 0
443 # Execute search
444 response = await self._client.search(
445 index=self.index_name,
446 query=es_query,
447 sort=sort if sort else None,
448 size=size,
449 from_=from_param
450 )
452 # Convert to records
453 records = []
454 for hit in response["hits"]["hits"]:
455 record = self._doc_to_record(hit)
457 # Apply field projection if specified
458 if query.fields:
459 record = record.project(query.fields)
461 records.append(record)
463 return records
465 async def _count_all(self) -> int:
466 """Count all records in the database."""
467 self._check_connection()
469 response = await self._client.count(index=self.index_name)
470 return response["count"]
472 async def clear(self) -> int:
473 """Clear all records from the database."""
474 self._check_connection()
476 # Get count before deletion
477 count = await self._count_all()
479 # Delete by query - delete all documents
480 response = await self._client.delete_by_query(
481 index=self.index_name,
482 query={"match_all": {}},
483 refresh=self.refresh
484 )
486 return response.get("deleted", count)
488 async def stream_read(
489 self,
490 query: Query | None = None,
491 config: StreamConfig | None = None
492 ) -> AsyncIterator[Record]:
493 """Stream records from Elasticsearch using scroll API."""
494 self._check_connection()
495 config = config or StreamConfig()
497 # Build query
498 es_query = {"match_all": {}}
499 if query and query.filters:
500 es_query = {"bool": {"must": []}}
501 for filter in query.filters:
502 field_path = f"data.{filter.field}"
503 if filter.operator == Operator.EQ:
504 es_query["bool"]["must"].append({"term": {field_path: filter.value}})
506 # Initial search with scroll
507 response = await self._client.search(
508 index=self.index_name,
509 query=es_query,
510 scroll="2m",
511 size=config.batch_size
512 )
514 scroll_id = response["_scroll_id"]
515 hits = response["hits"]["hits"]
517 try:
518 while hits:
519 for hit in hits:
520 record = self._doc_to_record(hit)
521 if query and query.fields:
522 record = record.project(query.fields)
523 yield record
525 # Get next batch
526 response = await self._client.scroll(
527 scroll_id=scroll_id,
528 scroll="2m"
529 )
530 hits = response["hits"]["hits"]
531 finally:
532 # Clear scroll
533 await self._client.clear_scroll(scroll_id=scroll_id)
535 async def stream_write(
536 self,
537 records: AsyncIterator[Record],
538 config: StreamConfig | None = None
539 ) -> StreamResult:
540 """Stream records into Elasticsearch using bulk API."""
541 self._check_connection()
542 config = config or StreamConfig()
543 result = StreamResult()
544 start_time = time.time()
545 quitting = False
547 batch = []
548 async for record in records:
549 batch.append(record)
551 if len(batch) >= config.batch_size:
552 # Write batch with graceful fallback
553 async def batch_func(b):
554 await self._write_batch(b)
555 return [r.id for r in b]
557 continue_processing = await async_process_batch_with_fallback(
558 batch,
559 batch_func,
560 self.create,
561 result,
562 config
563 )
565 if not continue_processing:
566 quitting = True
567 break
569 batch = []
571 # Write remaining batch
572 if batch and not quitting:
573 async def batch_func(b):
574 await self._write_batch(b)
575 return [r.id for r in b]
577 await async_process_batch_with_fallback(
578 batch,
579 batch_func,
580 self.create,
581 result,
582 config
583 )
585 result.duration = time.time() - start_time
586 return result
588 async def _write_batch(self, records: list[Record]) -> None:
589 """Write a batch of records using bulk API."""
590 if not records:
591 return
593 # Build bulk operations
594 operations = []
595 for record in records:
596 doc = self._record_to_doc(record)
597 operations.append({"index": {"_index": self.index_name}})
598 operations.append(doc)
600 # Execute bulk
601 await self._client.bulk(
602 operations=operations,
603 refresh=self.refresh
604 )
606 async def vector_search(
607 self,
608 query_vector: np.ndarray | list[float],
609 vector_field: str = "embedding",
610 k: int = 10,
611 metric: DistanceMetric = DistanceMetric.COSINE,
612 filter: Query | None = None,
613 include_source: bool = True,
614 score_threshold: float | None = None,
615 ) -> list[VectorSearchResult]:
616 """Search for similar vectors using Elasticsearch KNN.
618 Args:
619 query_vector: The vector to search for
620 vector_field: Name of the vector field to search
621 k: Number of results to return
622 metric: Distance metric to use
623 filter: Optional query filter to apply before vector search
624 include_source: Whether to include source document in results
625 score_threshold: Optional minimum similarity score
627 Returns:
628 List of search results ordered by similarity
629 """
630 self._check_connection()
632 # Import vector utilities
633 from ..vector.elasticsearch_utils import (
634 build_knn_query,
635 )
637 # Build filter query if provided
638 filter_query = self._build_filter_query(filter) if filter else None
640 # Build KNN query
641 query = build_knn_query(
642 query_vector=query_vector,
643 field_name=vector_field,
644 k=k,
645 filter_query=filter_query,
646 )
648 # Execute search
649 try:
650 response = await self._client.search(
651 index=self.index_name,
652 **query, # Unpack the query dict directly
653 size=k,
654 _source=include_source,
655 )
656 except Exception as e:
657 self._handle_elasticsearch_error(e, "vector search")
658 return []
660 # Process results
661 results = []
662 for hit in response.get("hits", {}).get("hits", []):
663 score = hit.get("_score", 0.0)
665 # Apply score threshold if specified
666 if score_threshold is not None and score < score_threshold:
667 continue
669 # Convert document to record if source included
670 record = None
671 if include_source:
672 record = self._doc_to_record(hit)
674 # Set the storage ID on the record if we have one
675 if record and not record.has_storage_id():
676 record.storage_id = hit["_id"]
678 # Skip if no record (shouldn't happen if include_source is True)
679 if record is None:
680 continue
682 results.append(VectorSearchResult(
683 record=record,
684 score=score,
685 vector_field=vector_field,
686 metadata={
687 "index": self.index_name,
688 "metric": metric.value,
689 "doc_id": hit["_id"],
690 },
691 ))
693 return results
695 async def bulk_embed_and_store(
696 self,
697 records: list[Record],
698 text_field: str | list[str],
699 vector_field: str = "embedding",
700 embedding_fn: Any | None = None,
701 batch_size: int = 100,
702 model_name: str | None = None,
703 model_version: str | None = None,
704 ) -> list[str]:
705 """Embed text fields and store vectors with records.
707 Args:
708 records: Records to process
709 text_field: Field name(s) containing text to embed
710 vector_field: Field name to store vectors in
711 embedding_fn: Function to generate embeddings
712 batch_size: Number of records to process at once
713 model_name: Name of the embedding model
714 model_version: Version of the embedding model
716 Returns:
717 List of record IDs that were processed
718 """
719 # This is a stub implementation
720 # Full implementation would require an actual embedding function
721 logger.warning("bulk_embed_and_store is not fully implemented for Elasticsearch")
722 return []
724 async def create_vector_index(
725 self,
726 vector_field: str = "embedding",
727 dimensions: int | None = None,
728 metric: DistanceMetric = DistanceMetric.COSINE,
729 index_type: str = "auto",
730 **kwargs: Any,
731 ) -> bool:
732 """Create or update index mapping for vector field.
734 Args:
735 vector_field: Name of the vector field to index
736 dimensions: Number of dimensions
737 metric: Distance metric for the index
738 index_type: Type of index (ignored for ES, always uses HNSW)
739 **kwargs: Additional index parameters
741 Returns:
742 True if index was created/updated successfully
743 """
744 self._check_connection()
746 if not dimensions:
747 if vector_field not in self.vector_fields:
748 raise ValueError(f"Unknown dimensions for field '{vector_field}'")
749 dimensions = self.vector_fields[vector_field]
751 # Import vector utilities
752 from ..vector.elasticsearch_utils import (
753 get_similarity_for_metric,
754 get_vector_mapping,
755 )
757 # Get similarity function for metric
758 similarity = get_similarity_for_metric(metric)
760 # Build mapping for the vector field
761 mapping = get_vector_mapping(dimensions, similarity)
763 # Update index mapping
764 try:
765 await self._client.indices.put_mapping(
766 index=self.index_name,
767 properties={
768 f"data.{vector_field}": mapping
769 }
770 )
772 # Track the vector field
773 self.vector_fields[vector_field] = dimensions
774 self.vector_enabled = True
776 logger.info(f"Created vector mapping for field '{vector_field}' with {dimensions} dimensions")
777 return True
779 except Exception as e:
780 self._handle_elasticsearch_error(e, "create vector index")
781 return False
783 async def _supports_native_hybrid(self) -> bool:
784 """Check if this Elasticsearch backend supports native hybrid search.
786 Elasticsearch 8.x supports native RRF hybrid search.
788 Returns:
789 True since Elasticsearch supports native hybrid search
790 """
791 return True
793 async def hybrid_search(
794 self,
795 query_text: str,
796 query_vector: np.ndarray | list[float],
797 text_fields: list[str] | None = None,
798 vector_field: str = "embedding",
799 k: int = 10,
800 config: Any = None, # HybridSearchConfig
801 filter: Query | None = None,
802 metric: DistanceMetric = DistanceMetric.COSINE,
803 ) -> list[Any]: # list[HybridSearchResult]
804 """Perform native Elasticsearch hybrid search using RRF.
806 Uses Elasticsearch's native RRF (Reciprocal Rank Fusion) for combining
807 BM25 text search with KNN vector search. This is more efficient than
808 client-side fusion as it's executed in a single request.
810 Args:
811 query_text: Text query for BM25 matching
812 query_vector: Vector for KNN similarity search
813 text_fields: Fields to search for text matching
814 vector_field: Name of the vector field to search
815 k: Number of results to return
816 config: Hybrid search configuration (weights, fusion strategy)
817 filter: Optional additional filters to apply
818 metric: Distance metric for vector search
820 Returns:
821 List of HybridSearchResult ordered by RRF score (descending)
822 """
823 from ..vector.hybrid import (
824 FusionStrategy,
825 HybridSearchConfig,
826 HybridSearchResult,
827 )
828 from ..vector.elasticsearch_utils import build_knn_query
830 self._check_connection()
832 config = config or HybridSearchConfig()
834 # If not using native strategy, fall back to parent implementation
835 if config.fusion_strategy != FusionStrategy.NATIVE:
836 # Import parent class to call its implementation
837 from ..vector.mixins import VectorOperationsMixin
838 return await VectorOperationsMixin.hybrid_search(
839 self,
840 query_text=query_text,
841 query_vector=query_vector,
842 text_fields=text_fields,
843 vector_field=vector_field,
844 k=k,
845 config=config,
846 filter=filter,
847 metric=metric,
848 )
850 # Use config.text_fields if provided, otherwise use parameter
851 search_text_fields = config.text_fields or text_fields or ["content", "title", "text"]
853 # Build filter query if provided
854 filter_query = self._build_filter_query(filter) if filter else None
856 # Build KNN query
857 knn_query = build_knn_query(
858 query_vector=query_vector,
859 field_name=vector_field,
860 k=k,
861 filter_query=filter_query,
862 )
864 # Build text search query with multi_match
865 text_query: dict[str, Any] = {
866 "multi_match": {
867 "query": query_text,
868 "fields": [f"data.{f}" for f in search_text_fields],
869 "type": "best_fields",
870 "operator": "or",
871 }
872 }
874 # Build RRF query combining both searches
875 # Note: RRF requires Elasticsearch 8.8+ with appropriate license
876 # For older versions, we need to use sub_searches
877 try:
878 # Try native RRF (ES 8.8+)
879 body: dict[str, Any] = {
880 "retriever": {
881 "rrf": {
882 "retrievers": [
883 {
884 "standard": {
885 "query": text_query
886 }
887 },
888 {
889 "knn": {
890 "field": f"data.{vector_field}",
891 "query_vector": query_vector.tolist() if hasattr(query_vector, 'tolist') else list(query_vector),
892 "k": k,
893 "num_candidates": k * 3,
894 }
895 }
896 ],
897 "rank_constant": config.rrf_k,
898 "rank_window_size": k * 3,
899 }
900 },
901 "size": k,
902 }
904 if filter_query:
905 body["post_filter"] = filter_query
907 response = await self._client.search(
908 index=self.index_name,
909 body=body,
910 )
911 except Exception as e:
912 # Fall back to client-side fusion if native RRF not available
913 logger.warning(f"Native RRF not available ({e}), falling back to client-side fusion")
914 from ..vector.mixins import VectorOperationsMixin
915 return await VectorOperationsMixin.hybrid_search(
916 self,
917 query_text=query_text,
918 query_vector=query_vector,
919 text_fields=text_fields,
920 vector_field=vector_field,
921 k=k,
922 config=HybridSearchConfig(
923 text_weight=config.text_weight,
924 vector_weight=config.vector_weight,
925 fusion_strategy=FusionStrategy.RRF,
926 rrf_k=config.rrf_k,
927 text_fields=config.text_fields,
928 ),
929 filter=filter,
930 metric=metric,
931 )
933 # Process results
934 results: list[HybridSearchResult] = []
935 hits = response.get("hits", {}).get("hits", [])
937 for i, hit in enumerate(hits):
938 record = self._doc_to_record(hit)
939 if record:
940 if not record.has_storage_id():
941 record.storage_id = hit["_id"]
943 # RRF doesn't provide individual scores, just the fused score
944 combined_score = hit.get("_score", 1.0 / (config.rrf_k + i + 1))
946 results.append(HybridSearchResult(
947 record=record,
948 combined_score=combined_score,
949 text_score=None, # Not available with native RRF
950 vector_score=None, # Not available with native RRF
951 text_rank=None,
952 vector_rank=None,
953 metadata={
954 "fusion_strategy": "native_rrf",
955 "index": self.index_name,
956 "doc_id": hit["_id"],
957 },
958 ))
960 return results
962 async def _text_search_for_hybrid(
963 self,
964 query_text: str,
965 text_fields: list[str] | None,
966 k: int,
967 filter: Query | None = None,
968 ) -> list[tuple[Record, float]]:
969 """Perform BM25 text search for hybrid search fusion.
971 Uses Elasticsearch's native BM25 scoring for text relevance.
973 Args:
974 query_text: Text to search for
975 text_fields: Fields to search in
976 k: Maximum results to return
977 filter: Additional filters
979 Returns:
980 List of (record, score) tuples ordered by BM25 relevance
981 """
982 self._check_connection()
984 search_fields = text_fields or ["content", "title", "text"]
986 # Build multi_match query
987 query: dict[str, Any] = {
988 "multi_match": {
989 "query": query_text,
990 "fields": [f"data.{f}" for f in search_fields],
991 "type": "best_fields",
992 "operator": "or",
993 }
994 }
996 # Build filter if provided
997 filter_query = self._build_filter_query(filter) if filter else None
999 body: dict[str, Any] = {
1000 "query": query if not filter_query else {
1001 "bool": {
1002 "must": query,
1003 "filter": filter_query,
1004 }
1005 },
1006 "size": k,
1007 }
1009 try:
1010 response = await self._client.search(
1011 index=self.index_name,
1012 body=body,
1013 )
1014 except Exception as e:
1015 self._handle_elasticsearch_error(e, "text search for hybrid")
1016 return []
1018 # Process results
1019 results: list[tuple[Record, float]] = []
1020 hits = response.get("hits", {}).get("hits", [])
1021 max_score = response.get("hits", {}).get("max_score", 1.0) or 1.0
1023 for hit in hits:
1024 record = self._doc_to_record(hit)
1025 if record:
1026 if not record.has_storage_id():
1027 record.storage_id = hit["_id"]
1029 # Normalize BM25 score to 0-1 range
1030 score = hit.get("_score", 0.0) / max_score if max_score > 0 else 0.0
1031 results.append((record, score))
1033 return results