Coverage for src/dataknobs_data/backends/elasticsearch.py: 9%
526 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:23 -0700
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:23 -0700
1"""Elasticsearch backend implementation for the data package."""
3from __future__ import annotations
5import logging
6import uuid
7from typing import TYPE_CHECKING, Any
9from dataknobs_config import ConfigurableBase
11from dataknobs_utils.elasticsearch_utils import SimplifiedElasticsearchIndex
13from ..database import SyncDatabase
14from ..exceptions import DatabaseError
15from ..query import Operator, Query, SortOrder
16from ..query_logic import ComplexQuery
17from ..streaming import StreamConfig, StreamingMixin, StreamResult
18from ..vector.types import DistanceMetric, VectorSearchResult
19from .elasticsearch_mixins import (
20 ElasticsearchBaseConfig,
21 ElasticsearchErrorHandler,
22 ElasticsearchIndexManager,
23 ElasticsearchQueryBuilder,
24 ElasticsearchRecordSerializer,
25 ElasticsearchVectorSupport,
26)
27from .vector_config_mixin import VectorConfigMixin
29if TYPE_CHECKING:
30 import numpy as np
31 from collections.abc import Iterator
32 from ..records import Record
34logger = logging.getLogger(__name__)
37class SyncElasticsearchDatabase(
38 SyncDatabase,
39 StreamingMixin,
40 ConfigurableBase,
41 VectorConfigMixin,
42 ElasticsearchBaseConfig,
43 ElasticsearchIndexManager,
44 ElasticsearchVectorSupport,
45 ElasticsearchErrorHandler,
46 ElasticsearchRecordSerializer,
47 ElasticsearchQueryBuilder,
48):
49 """Synchronous Elasticsearch database backend."""
51 def __init__(self, config: dict[str, Any] | None = None):
52 """Initialize Elasticsearch database.
54 Args:
55 config: Configuration with the following optional keys:
56 - host: Elasticsearch host (default: localhost)
57 - port: Elasticsearch port (default: 9200)
58 - index: Index name (default: "records")
59 - refresh: Whether to refresh after write operations (default: True)
60 - settings: Index settings dict
61 - mappings: Index mappings dict
62 """
63 super().__init__(config)
65 # Parse vector configuration using the mixin
66 self._parse_vector_config(config)
68 # Initialize vector support
69 self.vector_fields = {} # field_name -> dimensions
71 self.es_index = None # Will be initialized in connect()
72 self._connected = False
74 @classmethod
75 def from_config(cls, config: dict) -> SyncElasticsearchDatabase:
76 """Create from config dictionary."""
77 return cls(config)
79 def connect(self) -> None:
80 """Connect to the Elasticsearch database."""
81 if self._connected:
82 return # Already connected
84 # Initialize the Elasticsearch connection and index
85 config = self.config.copy()
87 # Extract configuration
88 self.host = config.pop("host", "localhost")
89 self.port = config.pop("port", 9200)
90 self.index_name = config.pop("index", "records")
91 self.refresh = config.pop("refresh", True)
93 # If vector is enabled but no vector fields defined yet, set up default
94 if self._vector_enabled and not self.vector_fields:
95 # Set a default embedding field with configurable dimensions
96 default_dimensions = config.pop("vector_dimensions", 1536) # Common default
97 default_field = config.pop("default_vector_field", "embedding")
98 self.vector_fields[default_field] = default_dimensions
100 # Get mappings with vector field support
101 base_mappings = self.get_index_mappings(self.vector_fields)
103 # Allow custom mappings to override
104 custom_mappings = config.pop("mappings", None)
105 if custom_mappings:
106 mappings = custom_mappings
107 else:
108 mappings = base_mappings
110 # Get settings optimized for KNN if we have vector fields
111 settings = config.pop("settings", None)
112 if not settings:
113 settings = self.get_knn_index_settings() if (self.vector_fields or self._vector_enabled) else {
114 "number_of_shards": 1,
115 "number_of_replicas": 0,
116 }
118 # Initialize the Elasticsearch index
119 self.es_index = SimplifiedElasticsearchIndex(
120 index_name=self.index_name,
121 host=self.host,
122 port=self.port,
123 settings=settings,
124 mappings=mappings,
125 )
127 # Ensure index exists
128 if not self.es_index.exists():
129 self.es_index.create()
131 # Create an Elasticsearch client for bulk operations
132 from elasticsearch import Elasticsearch
133 self.es_client = Elasticsearch([f"http://{self.host}:{self.port}"])
135 self._connected = True
137 def close(self) -> None:
138 """Close the database connection."""
139 if self.es_index:
140 # ElasticsearchIndex manages its own connections
141 self._connected = False # type: ignore[unreachable]
143 def _initialize(self) -> None:
144 """Initialize method - connection setup moved to connect()."""
145 # Configuration parsing stays here if needed
146 pass
148 def _check_connection(self) -> None:
149 """Check if database is connected."""
150 if not self._connected or not self.es_index:
151 raise RuntimeError("Database not connected. Call connect() first.")
153 def _record_to_doc(self, record: Record, id: str | None = None) -> dict[str, Any]:
154 """Convert a Record to an Elasticsearch document."""
155 # Create a copy of the record to avoid modifying the original
156 record_copy = record.copy(deep=True)
158 # Update vector tracking if needed
159 if self._has_vector_fields(record_copy):
160 self._update_vector_tracking(record_copy)
162 # Add vector field metadata to copied record metadata
163 if "vector_fields" not in record_copy.metadata:
164 record_copy.metadata["vector_fields"] = {}
166 for field_name in self.vector_fields:
167 if field_name in record_copy.fields:
168 field = record_copy.fields[field_name]
169 if hasattr(field, "source_field"):
170 record_copy.metadata["vector_fields"][field_name] = {
171 "type": "vector",
172 "dimensions": self.vector_fields[field_name],
173 "source_field": field.source_field,
174 "model": getattr(field, "model_name", None),
175 "model_version": getattr(field, "model_version", None),
176 }
178 doc = self._record_to_document(record_copy)
179 if id:
180 doc["id"] = id
181 elif not doc.get("id"):
182 doc["id"] = str(uuid.uuid4())
184 return doc
186 def _doc_to_record(self, doc: dict[str, Any]) -> Record:
187 """Convert an Elasticsearch document to a Record."""
188 # Handle both direct documents and search results
189 if "_source" in doc:
190 source_doc = doc
191 else:
192 source_doc = {"_source": doc}
194 record = self._document_to_record(source_doc)
196 # Add score if present
197 if "_score" in doc:
198 record.metadata["_score"] = doc.get("_score")
200 return record
202 def create(self, record: Record) -> str:
203 """Create a new record."""
204 # Use record's ID if it has one, otherwise generate a new one
205 id = record.id if record.id else str(uuid.uuid4())
206 doc = self._record_to_doc(record, id)
208 # Index the document
209 response = self.es_index.index(
210 body=doc,
211 doc_id=id,
212 refresh=self.refresh,
213 )
215 if not response.get("_id"):
216 raise DatabaseError(f"Failed to create record: {response}")
218 return response["_id"]
220 def read(self, id: str) -> Record | None:
221 """Read a record by ID."""
222 response = self.es_index.get(doc_id=id)
224 if not response:
225 return None
227 doc = response.get("_source", {})
228 return self._doc_to_record(doc)
230 def update(self, id: str, record: Record) -> bool:
231 """Update an existing record."""
232 doc = self._record_to_doc(record, id)
234 # Update the document
235 success = self.es_index.update(
236 doc_id=id,
237 body={"doc": doc},
238 refresh=self.refresh,
239 )
241 return success
243 def delete(self, id: str) -> bool:
244 """Delete a record by ID."""
245 success = self.es_index.delete(doc_id=id)
247 # Refresh if needed
248 if success and self.refresh:
249 self.es_index.refresh()
251 return success
253 def exists(self, id: str) -> bool:
254 """Check if a record exists."""
255 return self.es_index.exists(doc_id=id)
257 def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
258 """Update or insert a record.
260 Can be called as:
261 - upsert(id, record) - explicit ID and record
262 - upsert(record) - extract ID from record using Record's built-in logic
263 """
264 # Determine ID and record based on arguments
265 if isinstance(id_or_record, str):
266 id = id_or_record
267 if record is None:
268 raise ValueError("Record required when ID is provided")
269 else:
270 record = id_or_record
271 id = record.id
272 if id is None:
273 import uuid # type: ignore[unreachable]
274 id = str(uuid.uuid4())
275 record.storage_id = id
277 doc = self._record_to_doc(record, id)
278 response = self.es_index.index(body=doc, doc_id=id, refresh=self.refresh)
280 if response.get("_id"):
281 return id
282 else:
283 raise DatabaseError(f"Failed to upsert record {id}: {response}")
285 def create_batch(self, records: list[Record]) -> list[str]:
286 """Create multiple records efficiently using the bulk API.
288 Uses Elasticsearch's bulk API for efficient batch creation.
290 Args:
291 records: List of records to create
293 Returns:
294 List of created record IDs
295 """
296 if not records:
297 return []
299 # Build bulk operations
300 bulk_operations = []
301 ids = []
303 for record in records:
304 # Generate ID
305 record_id = str(uuid.uuid4())
306 ids.append(record_id)
308 # Create action dict for bulk operation
309 doc = self._record_to_doc(record, record_id)
310 action = {
311 "_op_type": "index",
312 "_index": self.es_index.index_name,
313 "_id": record_id,
314 "_source": doc
315 }
316 bulk_operations.append(action)
318 # Execute bulk create
319 from elasticsearch import helpers
321 try:
322 # Use the bulk helper for creation
323 # Note: helpers.BulkIndexError may be raised if raise_on_error=True
324 _success_count, errors = helpers.bulk(
325 self.es_client,
326 bulk_operations,
327 refresh=self.refresh,
328 raise_on_error=False,
329 stats_only=False
330 )
331 # Process results to return actual IDs
332 if errors:
333 # Some operations failed - need to check which ones
334 error_dict = {}
335 for err in errors:
336 # Error dict can have 'index', 'create', 'update', or 'delete' keys
337 for op_type in ['index', 'create']:
338 if op_type in err:
339 error_dict[err[op_type].get('_id')] = err
340 break
342 result_ids = []
343 for record_id in ids:
344 if record_id not in error_dict:
345 result_ids.append(record_id)
346 # Skip failed records
347 return result_ids
348 else:
349 # All succeeded
350 return ids
352 except Exception as e:
353 # Check if this is a BulkIndexError from the helpers module
354 if hasattr(e, 'errors'):
355 # Extract which operations succeeded
356 failed_ids = {err.get('index', {}).get('_id') for err in e.errors}
357 result_ids = []
358 for record_id in ids:
359 if record_id not in failed_ids:
360 result_ids.append(record_id)
361 # Skip failed records
362 return result_ids
363 else:
364 # Complete failure - return empty list
365 return []
367 def read_batch(self, ids: list[str]) -> list[Record | None]:
368 """Read multiple records in batch."""
369 records = []
370 for id in ids:
371 record = self.read(id)
372 records.append(record)
373 return records
375 def delete_batch(self, ids: list[str]) -> list[bool]:
376 """Delete multiple records efficiently using the bulk API.
378 Uses Elasticsearch's bulk API for efficient batch deletion.
380 Args:
381 ids: List of record IDs to delete
383 Returns:
384 List of success flags for each deletion
385 """
386 if not ids:
387 return []
389 # Build bulk operations
390 bulk_operations = []
391 for record_id in ids:
392 # Create action dict for bulk delete
393 action = {
394 "_op_type": "delete",
395 "_index": self.es_index.index_name,
396 "_id": record_id
397 }
398 bulk_operations.append(action)
400 # Execute bulk delete
401 from elasticsearch import helpers
403 try:
404 # Use the bulk helper for deletion
405 _success_count, errors = helpers.bulk(
406 self.es_client,
407 bulk_operations,
408 refresh=self.refresh,
409 raise_on_error=False,
410 stats_only=False
411 )
413 # Process results to determine which deletes succeeded
414 results = []
415 if errors:
416 error_dict = {}
417 for err in errors:
418 if 'delete' in err:
419 error_dict[err['delete'].get('_id')] = err
421 for record_id in ids:
422 if record_id in error_dict:
423 # Check if error was "not found" (404) - that's still a successful delete
424 error = error_dict[record_id]
425 status = error.get('delete', {}).get('status')
426 results.append(status == 200 or status == 404)
427 else:
428 results.append(True)
429 else:
430 # All operations completed (either deleted or not found)
431 results = [True] * len(ids)
433 return results
435 except Exception as e:
436 # Check if this is a BulkIndexError from the helpers module
437 if hasattr(e, 'errors'):
438 # Extract which operations failed
439 results = []
440 failed_ids = {err.get('delete', {}).get('_id') for err in e.errors}
442 for record_id in ids:
443 results.append(record_id not in failed_ids)
445 return results
446 else:
447 # If bulk operation completely fails, mark all as failed
448 return [False] * len(ids)
450 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
451 """Update multiple records efficiently using the bulk API.
453 Uses Elasticsearch's bulk API for efficient batch updates.
455 Args:
456 updates: List of (id, record) tuples to update
458 Returns:
459 List of success flags for each update
460 """
461 if not updates:
462 return []
464 # Build bulk operations
465 bulk_operations = []
466 for record_id, record in updates:
467 # Create action dict for bulk update
468 doc = self._record_to_doc(record, record_id)
469 action = {
470 "_op_type": "update",
471 "_index": self.es_index.index_name,
472 "_id": record_id,
473 "doc": doc,
474 "doc_as_upsert": False # Don't create if doesn't exist
475 }
476 bulk_operations.append(action)
478 # Execute bulk update
479 from elasticsearch import helpers
481 try:
482 # Use the bulk helper for the update
483 _success_count, errors = helpers.bulk(
484 self.es_client,
485 bulk_operations,
486 refresh=self.refresh,
487 raise_on_error=False,
488 stats_only=False
489 )
491 # Process results to determine which updates succeeded
492 results = []
493 error_dict = {}
494 if errors:
495 for err in errors:
496 if 'update' in err:
497 error_dict[err['update']['_id']] = err
499 for record_id, _ in updates:
500 # Check if this ID had an error
501 if record_id in error_dict:
502 error = error_dict[record_id]
503 # If error is 404 (not found), mark as failed
504 status = error.get('update', {}).get('status')
505 results.append(status == 200) # Only 200 is success for update
506 else:
507 results.append(True)
509 return results
511 except Exception as e:
512 # Check if this is a BulkIndexError from the helpers module
513 if hasattr(e, 'errors'):
514 # Extract which operations failed
515 results = []
516 failed_ids = {err['update']['_id'] for err in e.errors}
518 for record_id, _ in updates:
519 results.append(record_id not in failed_ids)
521 return results
522 else:
523 # If bulk operation completely fails, mark all as failed
524 return [False] * len(updates)
526 def _build_complex_es_query(self, condition: Any) -> dict[str, Any]:
527 """Build Elasticsearch query from complex boolean logic conditions.
529 Args:
530 condition: The Condition object (LogicCondition or FilterCondition)
532 Returns:
533 Elasticsearch query dict
534 """
535 from ..query_logic import FilterCondition, LogicCondition, LogicOperator
537 # Handle FilterCondition (leaf node)
538 if isinstance(condition, FilterCondition):
539 return self._build_filter_es_query(condition.filter)
541 # Handle LogicCondition (branch node)
542 elif isinstance(condition, LogicCondition):
543 if condition.operator == LogicOperator.AND:
544 # Build AND query with must clauses
545 must_clauses = []
546 for sub_condition in condition.conditions:
547 sub_query = self._build_complex_es_query(sub_condition)
548 if sub_query:
549 must_clauses.append(sub_query)
551 if not must_clauses:
552 return {"match_all": {}}
553 elif len(must_clauses) == 1:
554 return must_clauses[0]
555 else:
556 return {"bool": {"must": must_clauses}}
558 elif condition.operator == LogicOperator.OR:
559 # Build OR query with should clauses
560 should_clauses = []
561 for sub_condition in condition.conditions:
562 sub_query = self._build_complex_es_query(sub_condition)
563 if sub_query:
564 should_clauses.append(sub_query)
566 if not should_clauses:
567 return {"match_all": {}}
568 elif len(should_clauses) == 1:
569 return should_clauses[0]
570 else:
571 return {"bool": {"should": should_clauses, "minimum_should_match": 1}}
573 elif condition.operator == LogicOperator.NOT:
574 # Build NOT query with must_not
575 if condition.conditions:
576 sub_query = self._build_complex_es_query(condition.conditions[0])
577 if sub_query:
578 return {"bool": {"must_not": sub_query}}
580 return {"match_all": {}}
582 return {"match_all": {}}
584 def _build_filter_es_query(self, filter_obj: Any) -> dict[str, Any]:
585 """Build Elasticsearch query for a single filter.
587 Args:
588 filter_obj: The Filter object
590 Returns:
591 Elasticsearch query dict for the filter
592 """
593 # Special handling for 'id' field - use _id in Elasticsearch
594 if filter_obj.field == 'id':
595 field_path = "_id"
596 # _id field doesn't need .keyword suffix
597 else:
598 field_path = f"data.{filter_obj.field}"
600 # For string fields in exact match queries, use .keyword suffix
601 if filter_obj.operator in [Operator.EQ, Operator.NEQ, Operator.IN, Operator.NOT_IN]:
602 if isinstance(filter_obj.value, str) or (
603 isinstance(filter_obj.value, list) and
604 filter_obj.value and
605 isinstance(filter_obj.value[0], str)
606 ):
607 field_path = f"{field_path}.keyword"
608 elif filter_obj.operator in [Operator.LIKE, Operator.NOT_LIKE]:
609 # Wildcard needs .keyword for proper matching
610 if isinstance(filter_obj.value, str):
611 field_path = f"{field_path}.keyword"
613 if filter_obj.operator == Operator.EQ:
614 value = str(filter_obj.value).lower() if isinstance(filter_obj.value, bool) else filter_obj.value
615 return {"term": {field_path: value}}
616 elif filter_obj.operator == Operator.NEQ:
617 value = str(filter_obj.value).lower() if isinstance(filter_obj.value, bool) else filter_obj.value
618 return {"bool": {"must_not": {"term": {field_path: value}}}}
619 elif filter_obj.operator == Operator.GT:
620 return {"range": {field_path: {"gt": filter_obj.value}}}
621 elif filter_obj.operator == Operator.GTE:
622 return {"range": {field_path: {"gte": filter_obj.value}}}
623 elif filter_obj.operator == Operator.LT:
624 return {"range": {field_path: {"lt": filter_obj.value}}}
625 elif filter_obj.operator == Operator.LTE:
626 return {"range": {field_path: {"lte": filter_obj.value}}}
627 elif filter_obj.operator == Operator.LIKE:
628 pattern = filter_obj.value.replace("%", "*").replace("_", "?")
629 return {"wildcard": {field_path: pattern}}
630 elif filter_obj.operator == Operator.NOT_LIKE:
631 pattern = filter_obj.value.replace("%", "*").replace("_", "?")
632 return {"bool": {"must_not": {"wildcard": {field_path: pattern}}}}
633 elif filter_obj.operator == Operator.IN:
634 # Special handling for _id field - use ids query instead of terms
635 if filter_obj.field == 'id':
636 return {"ids": {"values": filter_obj.value}}
637 else:
638 return {"terms": {field_path: filter_obj.value}}
639 elif filter_obj.operator == Operator.NOT_IN:
640 # Special handling for _id field
641 if filter_obj.field == 'id':
642 return {"bool": {"must_not": {"ids": {"values": filter_obj.value}}}}
643 else:
644 return {"bool": {"must_not": {"terms": {field_path: filter_obj.value}}}}
645 elif filter_obj.operator == Operator.EXISTS:
646 return {"exists": {"field": field_path}}
647 elif filter_obj.operator == Operator.NOT_EXISTS:
648 return {"bool": {"must_not": {"exists": {"field": field_path}}}}
649 elif filter_obj.operator == Operator.REGEX:
650 return {"regexp": {field_path: filter_obj.value}}
651 elif filter_obj.operator == Operator.BETWEEN:
652 if isinstance(filter_obj.value, (list, tuple)) and len(filter_obj.value) == 2:
653 lower, upper = filter_obj.value
654 return {"range": {field_path: {"gte": lower, "lte": upper}}}
655 elif filter_obj.operator == Operator.NOT_BETWEEN:
656 if isinstance(filter_obj.value, (list, tuple)) and len(filter_obj.value) == 2:
657 lower, upper = filter_obj.value
658 return {"bool": {"must_not": {"range": {field_path: {"gte": lower, "lte":
659upper}}}}}
661 return {"match_all": {}}
664 def search(self, query: Query | ComplexQuery) -> list[Record]:
665 """Search for records matching a query."""
666 # Handle ComplexQuery with native Elasticsearch bool queries
667 if isinstance(query, ComplexQuery):
668 if query.condition:
669 es_query = self._build_complex_es_query(query.condition)
670 else:
671 es_query = {"match_all": {}}
672 else:
673 # Build Elasticsearch query from simple Query object
674 es_query = {"bool": {"must": []}}
676 # Apply filters
677 for filter_obj in query.filters:
678 # Special handling for 'id' field - use _id in Elasticsearch
679 if filter_obj.field == 'id':
680 field_path = "_id"
681 # _id field doesn't need .keyword suffix
682 else:
683 field_path = f"data.{filter_obj.field}"
685 # For string fields in exact match queries, use .keyword suffix
686 # LIKE and REGEX need to use the text field, not keyword
687 if filter_obj.operator in [Operator.EQ, Operator.NEQ, Operator.IN, Operator.NOT_IN]:
688 if isinstance(filter_obj.value, str) or (
689 isinstance(filter_obj.value, list) and
690 filter_obj.value and
691 isinstance(filter_obj.value[0], str)
692 ):
693 field_path = f"{field_path}.keyword"
694 elif filter_obj.operator in [Operator.LIKE, Operator.NOT_LIKE]:
695 # Wildcard needs .keyword for proper matching
696 if isinstance(filter_obj.value, str):
697 field_path = f"{field_path}.keyword"
699 if filter_obj.operator == Operator.EQ:
700 # Handle boolean values correctly
701 value = str(filter_obj.value).lower() if isinstance(filter_obj.value, bool) else filter_obj.value
702 es_query["bool"]["must"].append({"term": {field_path: value}})
703 elif filter_obj.operator == Operator.NEQ:
704 value = str(filter_obj.value).lower() if isinstance(filter_obj.value, bool) else filter_obj.value
705 es_query["bool"]["must"].append({"bool": {"must_not": {"term": {field_path: value}}}})
706 elif filter_obj.operator == Operator.GT:
707 es_query["bool"]["must"].append({"range": {field_path: {"gt": filter_obj.value}}})
708 elif filter_obj.operator == Operator.GTE:
709 es_query["bool"]["must"].append({"range": {field_path: {"gte": filter_obj.value}}})
710 elif filter_obj.operator == Operator.LT:
711 es_query["bool"]["must"].append({"range": {field_path: {"lt": filter_obj.value}}})
712 elif filter_obj.operator == Operator.LTE:
713 es_query["bool"]["must"].append({"range": {field_path: {"lte": filter_obj.value}}})
714 elif filter_obj.operator == Operator.LIKE:
715 # Convert SQL LIKE pattern to Elasticsearch wildcard
716 # Wildcard queries should use the keyword field for exact matching
717 pattern = filter_obj.value.replace("%", "*").replace("_", "?")
718 # Use the base field path for LIKE (already has .keyword added above if string)
719 es_query["bool"]["must"].append({"wildcard": {field_path: pattern}})
720 elif filter_obj.operator == Operator.NOT_LIKE:
721 pattern = filter_obj.value.replace("%", "*").replace("_", "?")
722 es_query["bool"]["must"].append({"bool": {"must_not": {"wildcard": {field_path: pattern}}}})
723 elif filter_obj.operator == Operator.IN:
724 # Special handling for _id field - use ids query instead of terms
725 if filter_obj.field == 'id':
726 es_query["bool"]["must"].append({"ids": {"values": filter_obj.value}})
727 else:
728 es_query["bool"]["must"].append({"terms": {field_path: filter_obj.value}})
729 elif filter_obj.operator == Operator.NOT_IN:
730 # Special handling for _id field
731 if filter_obj.field == 'id':
732 es_query["bool"]["must"].append({"bool": {"must_not": {"ids": {"values": filter_obj.value}}}})
733 else:
734 es_query["bool"]["must"].append({"bool": {"must_not": {"terms": {field_path: filter_obj.value}}}})
735 elif filter_obj.operator == Operator.EXISTS:
736 es_query["bool"]["must"].append({"exists": {"field": field_path}})
737 elif filter_obj.operator == Operator.NOT_EXISTS:
738 es_query["bool"]["must"].append({"bool": {"must_not": {"exists": {"field": field_path}}}})
739 elif filter_obj.operator == Operator.REGEX:
740 es_query["bool"]["must"].append({"regexp": {field_path: filter_obj.value}})
741 elif filter_obj.operator == Operator.BETWEEN:
742 # Use Elasticsearch's native range query for efficient BETWEEN
743 if isinstance(filter_obj.value, (list, tuple)) and len(filter_obj.value) == 2:
744 lower, upper = filter_obj.value
745 es_query["bool"]["must"].append({
746 "range": {
747 field_path: {
748 "gte": lower,
749 "lte": upper
750 }
751 }
752 })
753 elif filter_obj.operator == Operator.NOT_BETWEEN:
754 # NOT BETWEEN using bool must_not with range
755 if isinstance(filter_obj.value, (list, tuple)) and len(filter_obj.value) == 2:
756 lower, upper = filter_obj.value
757 es_query["bool"]["must"].append({
758 "bool": {
759 "must_not": {
760 "range": {
761 field_path: {
762 "gte": lower,
763 "lte": upper
764 }
765 }
766 }
767 }
768 })
770 # If no filters, match all
771 if not es_query["bool"]["must"]:
772 es_query = {"match_all": {}}
774 # Build sort
775 sort = []
776 if query.sort_specs:
777 for sort_spec in query.sort_specs:
778 # Special handling for 'id' field - sort by the id field in source data
779 # We can't sort by _id directly as it requires fielddata which is disabled by default
780 # The id field is already of type keyword, so no .keyword suffix needed
781 if sort_spec.field == 'id':
782 field_path = "id"
783 else:
784 field_path = f"data.{sort_spec.field}"
785 # Don't add .keyword if user already specified it or for common numeric fields
786 # This is a heuristic - ideally we'd check the mapping
787 numeric_fields = ['age', 'salary', 'balance', 'count', 'score', 'amount', 'price', 'index', 'number', 'total', 'quantity']
788 if (not sort_spec.field.endswith('.keyword') and
789 not sort_spec.field.endswith('.raw') and
790 sort_spec.field.lower() not in numeric_fields):
791 # Likely a text field, add .keyword for sorting
792 field_path = f"data.{sort_spec.field}.keyword"
793 order = "desc" if sort_spec.order == SortOrder.DESC else "asc"
794 sort.append({field_path: {"order": order}})
796 # Build search body
797 search_body = {"query": es_query}
798 if sort:
799 search_body["sort"] = sort
800 if query.limit_value:
801 search_body["size"] = query.limit_value
802 if query.offset_value:
803 search_body["from"] = query.offset_value
805 # Execute search
806 response = self.es_index.search(body=search_body)
808 # Check if the response is valid (has the expected structure)
809 # An empty result set is still a valid response
810 if not hasattr(response, 'json') or response.json is None:
811 raise DatabaseError(f"Invalid search response: {response}")
813 # Check for actual errors in the response
814 if 'error' in response.json:
815 raise DatabaseError(f"Failed to search records: {response.json['error']}")
817 # Parse results
818 records = []
819 hits = response.json.get("hits", {}).get("hits", [])
820 for hit in hits:
821 doc = hit.get("_source", {})
822 records.append(self._doc_to_record(doc))
824 # Apply field projection if specified
825 if query.fields:
826 for record in records:
827 # Keep only specified fields
828 field_names = list(record.fields.keys())
829 for field_name in field_names:
830 if field_name not in query.fields:
831 del record.fields[field_name]
833 return records
835 def _count_all(self) -> int:
836 """Count all records in the database."""
837 self._check_connection()
838 return self.es_index.count()
840 def count(self, query: Query | None = None) -> int:
841 """Count records matching a query using efficient Elasticsearch count.
843 Args:
844 query: Optional search query (counts all if None)
846 Returns:
847 Number of matching records
848 """
849 if not query or not query.filters:
850 return self._count_all()
852 # Build Elasticsearch query from Query object (same as search)
853 es_query = {"bool": {"must": []}}
855 for filter_obj in query.filters:
856 # Special handling for 'id' field - use _id in Elasticsearch
857 if filter_obj.field == 'id':
858 field_path = "_id"
859 # _id field doesn't need .keyword suffix
860 else:
861 field_path = f"data.{filter_obj.field}"
863 # For string fields in exact match queries, use .keyword suffix
864 # LIKE and REGEX need different handling
865 if filter_obj.operator in [Operator.EQ, Operator.NEQ, Operator.IN, Operator.NOT_IN]:
866 if isinstance(filter_obj.value, str) or (
867 isinstance(filter_obj.value, list) and
868 filter_obj.value and
869 isinstance(filter_obj.value[0], str)
870 ):
871 field_path = f"{field_path}.keyword"
872 elif filter_obj.operator in [Operator.LIKE, Operator.NOT_LIKE]:
873 # Wildcard needs .keyword for proper matching
874 if isinstance(filter_obj.value, str):
875 field_path = f"{field_path}.keyword"
877 if filter_obj.operator == Operator.EQ:
878 # Handle boolean values correctly
879 value = str(filter_obj.value).lower() if isinstance(filter_obj.value, bool) else filter_obj.value
880 es_query["bool"]["must"].append({"term": {field_path: value}})
881 elif filter_obj.operator == Operator.NEQ:
882 value = str(filter_obj.value).lower() if isinstance(filter_obj.value, bool) else filter_obj.value
883 es_query["bool"]["must"].append({"bool": {"must_not": {"term": {field_path: value}}}})
884 elif filter_obj.operator == Operator.GT:
885 es_query["bool"]["must"].append({"range": {field_path: {"gt": filter_obj.value}}})
886 elif filter_obj.operator == Operator.GTE:
887 es_query["bool"]["must"].append({"range": {field_path: {"gte": filter_obj.value}}})
888 elif filter_obj.operator == Operator.LT:
889 es_query["bool"]["must"].append({"range": {field_path: {"lt": filter_obj.value}}})
890 elif filter_obj.operator == Operator.LTE:
891 es_query["bool"]["must"].append({"range": {field_path: {"lte": filter_obj.value}}})
892 elif filter_obj.operator == Operator.LIKE:
893 pattern = filter_obj.value.replace("%", "*").replace("_", "?")
894 es_query["bool"]["must"].append({"wildcard": {field_path: pattern}})
895 elif filter_obj.operator == Operator.NOT_LIKE:
896 pattern = filter_obj.value.replace("%", "*").replace("_", "?")
897 es_query["bool"]["must"].append({"bool": {"must_not": {"wildcard": {field_path: pattern}}}})
898 elif filter_obj.operator == Operator.IN:
899 # Special handling for _id field - use ids query instead of terms
900 if filter_obj.field == 'id':
901 es_query["bool"]["must"].append({"ids": {"values": filter_obj.value}})
902 else:
903 es_query["bool"]["must"].append({"terms": {field_path: filter_obj.value}})
904 elif filter_obj.operator == Operator.NOT_IN:
905 # Special handling for _id field
906 if filter_obj.field == 'id':
907 es_query["bool"]["must"].append({"bool": {"must_not": {"ids": {"values": filter_obj.value}}}})
908 else:
909 es_query["bool"]["must"].append({"bool": {"must_not": {"terms": {field_path: filter_obj.value}}}})
910 elif filter_obj.operator == Operator.EXISTS:
911 es_query["bool"]["must"].append({"exists": {"field": field_path}})
912 elif filter_obj.operator == Operator.NOT_EXISTS:
913 es_query["bool"]["must"].append({"bool": {"must_not": {"exists": {"field": field_path}}}})
914 elif filter_obj.operator == Operator.REGEX:
915 es_query["bool"]["must"].append({"regexp": {field_path: filter_obj.value}})
917 # If no filters were added, use match_all
918 if not es_query["bool"]["must"]:
919 es_query = {"match_all": {}}
921 # Count with the query
922 return self.es_index.count(body={"query": es_query})
924 def clear(self) -> int:
925 """Clear all records from the database."""
926 self._check_connection()
927 # Get count before deletion
928 count = self._count_all()
930 # Delete by query - delete all documents
931 response = self.es_index.delete_by_query(
932 body={"query": {"match_all": {}}}
933 )
935 # Refresh if needed
936 if self.refresh:
937 self.es_index.refresh()
939 return response.get("deleted", count)
941 def stream_read(
942 self,
943 query: Query | None = None,
944 config: StreamConfig | None = None
945 ) -> Iterator[Record]:
946 """Stream records from Elasticsearch."""
947 config = config or StreamConfig()
949 # Use search to get all matching records
950 if query:
951 records = self.search(query)
952 else:
953 records = self.search(Query())
955 # Yield records in batches for consistency
956 for i in range(0, len(records), config.batch_size):
957 batch = records[i:i + config.batch_size]
958 for record in batch:
959 yield record
961 def stream_write(
962 self,
963 records: Iterator[Record],
964 config: StreamConfig | None = None
965 ) -> StreamResult:
966 """Stream records into Elasticsearch."""
967 # Use the default implementation from mixin
968 return self._default_stream_write(records, config)
970 def vector_search(
971 self,
972 query_vector: np.ndarray | list[float],
973 field_name: str = "embedding",
974 k: int = 10,
975 metric: DistanceMetric = DistanceMetric.COSINE,
976 filter: Query | None = None,
977 include_source: bool = True,
978 score_threshold: float | None = None,
979 ) -> list[VectorSearchResult]:
980 """Search for similar vectors using Elasticsearch KNN.
982 Note: This is a synchronous wrapper around the async implementation.
983 For production use, consider using the async version for better performance.
985 Args:
986 query_vector: The vector to search for
987 vector_field: Name of the vector field to search
988 k: Number of results to return
989 metric: Distance metric to use
990 filter: Optional query filter to apply before vector search
991 include_source: Whether to include source document in results
992 score_threshold: Optional minimum similarity score
994 Returns:
995 List of search results ordered by similarity
996 """
997 self._check_connection()
999 # Import vector utilities
1000 from ..vector.elasticsearch_utils import (
1001 build_knn_query,
1002 )
1004 # Build filter query if provided
1005 filter_query = self._build_filter_query(filter) if filter else None
1007 # Build KNN query
1008 query = build_knn_query(
1009 query_vector=query_vector,
1010 field_name=field_name,
1011 k=k,
1012 filter_query=filter_query,
1013 )
1015 # Execute search using the es_client
1016 try:
1017 response = self.es_client.search(
1018 index=self.index_name,
1019 **query,
1020 size=k,
1021 source=include_source,
1022 )
1023 except Exception as e:
1024 self._handle_elasticsearch_error(e, "vector search")
1025 return []
1027 # Process results
1028 results = []
1029 for hit in response.get("hits", {}).get("hits", []):
1030 score = hit.get("_score", 0.0)
1032 # Apply score threshold if specified
1033 if score_threshold is not None and score < score_threshold:
1034 continue
1036 # Convert document to record if source included
1037 record = None
1038 if include_source:
1039 record = self._doc_to_record(hit)
1040 # Set the storage ID on the record if we have one
1041 if not record.has_storage_id():
1042 record.storage_id = hit["_id"]
1044 # Skip if no record (shouldn't happen if include_source is True)
1045 if record is None:
1046 continue
1048 results.append(VectorSearchResult(
1049 record=record,
1050 score=score,
1051 vector_field=field_name,
1052 metadata={
1053 "index": self.index_name,
1054 "metric": metric.value,
1055 "doc_id": hit["_id"],
1056 },
1057 ))
1059 return results
1061 def create_vector_index(
1062 self,
1063 vector_field: str = "embedding",
1064 dimensions: int | None = None,
1065 metric: DistanceMetric = DistanceMetric.COSINE,
1066 index_type: str = "auto",
1067 **kwargs: Any,
1068 ) -> bool:
1069 """Create or update index mapping for vector field.
1071 Args:
1072 vector_field: Name of the vector field to index
1073 dimensions: Number of dimensions
1074 metric: Distance metric for the index
1075 index_type: Type of index (ignored for ES, always uses HNSW)
1076 **kwargs: Additional index parameters
1078 Returns:
1079 True if index was created/updated successfully
1080 """
1081 self._check_connection()
1083 if not dimensions:
1084 if vector_field not in self.vector_fields:
1085 raise ValueError(f"Unknown dimensions for field '{vector_field}'")
1086 dimensions = self.vector_fields[vector_field]
1088 # Import vector utilities
1089 from ..vector.elasticsearch_utils import (
1090 get_similarity_for_metric,
1091 get_vector_mapping,
1092 )
1094 # Get similarity function for metric
1095 similarity = get_similarity_for_metric(metric)
1097 # Build mapping for the vector field
1098 mapping = get_vector_mapping(dimensions, similarity)
1100 # Update index mapping using the es_client
1101 try:
1102 self.es_client.indices.put_mapping(
1103 index=self.index_name,
1104 properties={
1105 f"data.{vector_field}": mapping
1106 }
1107 )
1109 # Track the vector field
1110 self.vector_fields[vector_field] = dimensions
1111 self._vector_enabled = True
1113 logger.info(f"Created vector mapping for field '{vector_field}' with {dimensions} dimensions")
1114 return True
1116 except Exception as e:
1117 self._handle_elasticsearch_error(e, "create vector index")
1118 return False
1121# Import the native async implementation