Coverage for src/dataknobs_data/backends/elasticsearch.py: 9%
487 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
1"""Elasticsearch backend implementation for the data package."""
3from __future__ import annotations
5import logging
6import uuid
7from typing import TYPE_CHECKING, Any
9from dataknobs_config import ConfigurableBase
11from dataknobs_utils.elasticsearch_utils import SimplifiedElasticsearchIndex
13from ..database import SyncDatabase
14from ..exceptions import DatabaseError
15from ..query import Operator, Query, SortOrder
16from ..query_logic import ComplexQuery
17from ..streaming import StreamConfig, StreamingMixin, StreamResult
18from ..vector.types import DistanceMetric, VectorSearchResult
19from .elasticsearch_mixins import (
20 ElasticsearchBaseConfig,
21 ElasticsearchErrorHandler,
22 ElasticsearchIndexManager,
23 ElasticsearchQueryBuilder,
24 ElasticsearchRecordSerializer,
25 ElasticsearchVectorSupport,
26)
27from .vector_config_mixin import VectorConfigMixin
29if TYPE_CHECKING:
30 import numpy as np
31 from collections.abc import Iterator
32 from ..records import Record
34logger = logging.getLogger(__name__)
37class SyncElasticsearchDatabase(
38 SyncDatabase,
39 StreamingMixin,
40 ConfigurableBase,
41 VectorConfigMixin,
42 ElasticsearchBaseConfig,
43 ElasticsearchIndexManager,
44 ElasticsearchVectorSupport,
45 ElasticsearchErrorHandler,
46 ElasticsearchRecordSerializer,
47 ElasticsearchQueryBuilder,
48):
49 """Synchronous Elasticsearch database backend."""
51 def __init__(self, config: dict[str, Any] | None = None):
52 """Initialize Elasticsearch database.
54 Args:
55 config: Configuration with the following optional keys:
56 - host: Elasticsearch host (default: localhost)
57 - port: Elasticsearch port (default: 9200)
58 - index: Index name (default: "records")
59 - refresh: Whether to refresh after write operations (default: True)
60 - settings: Index settings dict
61 - mappings: Index mappings dict
62 """
63 super().__init__(config)
65 # Parse vector configuration using the mixin
66 self._parse_vector_config(config)
68 # Initialize vector support
69 self.vector_fields = {} # field_name -> dimensions
71 self.es_index = None # Will be initialized in connect()
72 self._connected = False
74 @classmethod
75 def from_config(cls, config: dict) -> SyncElasticsearchDatabase:
76 """Create from config dictionary."""
77 return cls(config)
79 def connect(self) -> None:
80 """Connect to the Elasticsearch database."""
81 if self._connected:
82 return # Already connected
84 # Initialize the Elasticsearch connection and index
85 config = self.config.copy()
87 # Extract configuration
88 self.host = config.pop("host", "localhost")
89 self.port = config.pop("port", 9200)
90 self.index_name = config.pop("index", "records")
91 self.refresh = config.pop("refresh", True)
93 # If vector is enabled but no vector fields defined yet, set up default
94 if self._vector_enabled and not self.vector_fields:
95 # Set a default embedding field with configurable dimensions
96 default_dimensions = config.pop("vector_dimensions", 1536) # Common default
97 default_field = config.pop("default_vector_field", "embedding")
98 self.vector_fields[default_field] = default_dimensions
100 # Get mappings with vector field support
101 base_mappings = self.get_index_mappings(self.vector_fields)
103 # Allow custom mappings to override
104 custom_mappings = config.pop("mappings", None)
105 if custom_mappings:
106 mappings = custom_mappings
107 else:
108 mappings = base_mappings
110 # Get settings optimized for KNN if we have vector fields
111 settings = config.pop("settings", None)
112 if not settings:
113 settings = self.get_knn_index_settings() if (self.vector_fields or self._vector_enabled) else {
114 "number_of_shards": 1,
115 "number_of_replicas": 0,
116 }
118 # Initialize the Elasticsearch index
119 self.es_index = SimplifiedElasticsearchIndex(
120 index_name=self.index_name,
121 host=self.host,
122 port=self.port,
123 settings=settings,
124 mappings=mappings,
125 )
127 # Ensure index exists
128 if not self.es_index.exists():
129 self.es_index.create()
131 # Create an Elasticsearch client for bulk operations
132 from elasticsearch import Elasticsearch
133 self.es_client = Elasticsearch([f"http://{self.host}:{self.port}"])
135 self._connected = True
137 def close(self) -> None:
138 """Close the database connection."""
139 if self.es_index:
140 # ElasticsearchIndex manages its own connections
141 self._connected = False # type: ignore[unreachable]
143 def _initialize(self) -> None:
144 """Initialize method - connection setup moved to connect()."""
145 # Configuration parsing stays here if needed
146 pass
148 def _check_connection(self) -> None:
149 """Check if database is connected."""
150 if not self._connected or not self.es_index:
151 raise RuntimeError("Database not connected. Call connect() first.")
153 def _record_to_doc(self, record: Record, id: str | None = None) -> dict[str, Any]:
154 """Convert a Record to an Elasticsearch document."""
155 # Create a copy of the record to avoid modifying the original
156 record_copy = record.copy(deep=True)
158 # Update vector tracking if needed
159 if self._has_vector_fields(record_copy):
160 self._update_vector_tracking(record_copy)
162 # Add vector field metadata to copied record metadata
163 if "vector_fields" not in record_copy.metadata:
164 record_copy.metadata["vector_fields"] = {}
166 for field_name in self.vector_fields:
167 if field_name in record_copy.fields:
168 field = record_copy.fields[field_name]
169 if hasattr(field, "source_field"):
170 record_copy.metadata["vector_fields"][field_name] = {
171 "type": "vector",
172 "dimensions": self.vector_fields[field_name],
173 "source_field": field.source_field,
174 "model": getattr(field, "model_name", None),
175 "model_version": getattr(field, "model_version", None),
176 }
178 doc = self._record_to_document(record_copy)
179 if id:
180 doc["id"] = id
181 elif not doc.get("id"):
182 doc["id"] = str(uuid.uuid4())
184 return doc
186 def _doc_to_record(self, doc: dict[str, Any]) -> Record:
187 """Convert an Elasticsearch document to a Record."""
188 # Handle both direct documents and search results
189 if "_source" in doc:
190 source_doc = doc
191 else:
192 source_doc = {"_source": doc}
194 record = self._document_to_record(source_doc)
196 # Add score if present
197 if "_score" in doc:
198 record.metadata["_score"] = doc.get("_score")
200 return record
202 def create(self, record: Record) -> str:
203 """Create a new record."""
204 # Use record's ID if it has one, otherwise generate a new one
205 id = record.id if record.id else str(uuid.uuid4())
206 doc = self._record_to_doc(record, id)
208 # Index the document
209 response = self.es_index.index(
210 body=doc,
211 doc_id=id,
212 refresh=self.refresh,
213 )
215 if not response.get("_id"):
216 raise DatabaseError(f"Failed to create record: {response}")
218 return response["_id"]
220 def read(self, id: str) -> Record | None:
221 """Read a record by ID."""
222 response = self.es_index.get(doc_id=id)
224 if not response:
225 return None
227 doc = response.get("_source", {})
228 return self._doc_to_record(doc)
230 def update(self, id: str, record: Record) -> bool:
231 """Update an existing record."""
232 doc = self._record_to_doc(record, id)
234 # Update the document
235 success = self.es_index.update(
236 doc_id=id,
237 body={"doc": doc},
238 refresh=self.refresh,
239 )
241 return success
243 def delete(self, id: str) -> bool:
244 """Delete a record by ID."""
245 success = self.es_index.delete(doc_id=id)
247 # Refresh if needed
248 if success and self.refresh:
249 self.es_index.refresh()
251 return success
253 def exists(self, id: str) -> bool:
254 """Check if a record exists."""
255 return self.es_index.exists(doc_id=id)
257 def upsert(self, id: str, record: Record) -> str:
258 """Update or insert a record with a specific ID."""
259 doc = self._record_to_doc(record, id)
260 response = self.es_index.index(body=doc, doc_id=id, refresh=self.refresh)
262 if response.get("_id"):
263 return id
264 else:
265 raise DatabaseError(f"Failed to upsert record {id}: {response}")
267 def create_batch(self, records: list[Record]) -> list[str]:
268 """Create multiple records efficiently using the bulk API.
270 Uses Elasticsearch's bulk API for efficient batch creation.
272 Args:
273 records: List of records to create
275 Returns:
276 List of created record IDs
277 """
278 if not records:
279 return []
281 # Build bulk operations
282 bulk_operations = []
283 ids = []
285 for record in records:
286 # Generate ID
287 record_id = str(uuid.uuid4())
288 ids.append(record_id)
290 # Create action dict for bulk operation
291 doc = self._record_to_doc(record, record_id)
292 action = {
293 "_op_type": "index",
294 "_index": self.es_index.index_name,
295 "_id": record_id,
296 "_source": doc
297 }
298 bulk_operations.append(action)
300 # Execute bulk create
301 from elasticsearch import helpers
303 try:
304 # Use the bulk helper for creation
305 # Note: helpers.BulkIndexError may be raised if raise_on_error=True
306 success_count, errors = helpers.bulk(
307 self.es_client,
308 bulk_operations,
309 refresh=self.refresh,
310 raise_on_error=False,
311 stats_only=False
312 )
313 # Process results to return actual IDs
314 if errors:
315 # Some operations failed - need to check which ones
316 error_dict = {}
317 for err in errors:
318 # Error dict can have 'index', 'create', 'update', or 'delete' keys
319 for op_type in ['index', 'create']:
320 if op_type in err:
321 error_dict[err[op_type].get('_id')] = err
322 break
324 result_ids = []
325 for record_id in ids:
326 if record_id not in error_dict:
327 result_ids.append(record_id)
328 # Skip failed records
329 return result_ids
330 else:
331 # All succeeded
332 return ids
334 except Exception as e:
335 # Check if this is a BulkIndexError from the helpers module
336 if hasattr(e, 'errors'):
337 # Extract which operations succeeded
338 failed_ids = {err.get('index', {}).get('_id') for err in e.errors}
339 result_ids = []
340 for record_id in ids:
341 if record_id not in failed_ids:
342 result_ids.append(record_id)
343 # Skip failed records
344 return result_ids
345 else:
346 # Complete failure - return empty list
347 return []
349 def read_batch(self, ids: list[str]) -> list[Record | None]:
350 """Read multiple records in batch."""
351 records = []
352 for id in ids:
353 record = self.read(id)
354 records.append(record)
355 return records
357 def delete_batch(self, ids: list[str]) -> list[bool]:
358 """Delete multiple records efficiently using the bulk API.
360 Uses Elasticsearch's bulk API for efficient batch deletion.
362 Args:
363 ids: List of record IDs to delete
365 Returns:
366 List of success flags for each deletion
367 """
368 if not ids:
369 return []
371 # Build bulk operations
372 bulk_operations = []
373 for record_id in ids:
374 # Create action dict for bulk delete
375 action = {
376 "_op_type": "delete",
377 "_index": self.es_index.index_name,
378 "_id": record_id
379 }
380 bulk_operations.append(action)
382 # Execute bulk delete
383 from elasticsearch import helpers
385 try:
386 # Use the bulk helper for deletion
387 success_count, errors = helpers.bulk(
388 self.es_client,
389 bulk_operations,
390 refresh=self.refresh,
391 raise_on_error=False,
392 stats_only=False
393 )
395 # Process results to determine which deletes succeeded
396 results = []
397 if errors:
398 error_dict = {}
399 for err in errors:
400 if 'delete' in err:
401 error_dict[err['delete'].get('_id')] = err
403 for record_id in ids:
404 if record_id in error_dict:
405 # Check if error was "not found" (404) - that's still a successful delete
406 error = error_dict[record_id]
407 status = error.get('delete', {}).get('status')
408 results.append(status == 200 or status == 404)
409 else:
410 results.append(True)
411 else:
412 # All operations completed (either deleted or not found)
413 results = [True] * len(ids)
415 return results
417 except Exception as e:
418 # Check if this is a BulkIndexError from the helpers module
419 if hasattr(e, 'errors'):
420 # Extract which operations failed
421 results = []
422 failed_ids = {err.get('delete', {}).get('_id') for err in e.errors}
424 for record_id in ids:
425 results.append(record_id not in failed_ids)
427 return results
428 else:
429 # If bulk operation completely fails, mark all as failed
430 return [False] * len(ids)
432 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
433 """Update multiple records efficiently using the bulk API.
435 Uses Elasticsearch's bulk API for efficient batch updates.
437 Args:
438 updates: List of (id, record) tuples to update
440 Returns:
441 List of success flags for each update
442 """
443 if not updates:
444 return []
446 # Build bulk operations
447 bulk_operations = []
448 for record_id, record in updates:
449 # Create action dict for bulk update
450 doc = self._record_to_doc(record, record_id)
451 action = {
452 "_op_type": "update",
453 "_index": self.es_index.index_name,
454 "_id": record_id,
455 "doc": doc,
456 "doc_as_upsert": False # Don't create if doesn't exist
457 }
458 bulk_operations.append(action)
460 # Execute bulk update
461 from elasticsearch import helpers
463 try:
464 # Use the bulk helper for the update
465 success_count, errors = helpers.bulk(
466 self.es_client,
467 bulk_operations,
468 refresh=self.refresh,
469 raise_on_error=False,
470 stats_only=False
471 )
473 # Process results to determine which updates succeeded
474 results = []
475 error_dict = {}
476 if errors:
477 for err in errors:
478 if 'update' in err:
479 error_dict[err['update']['_id']] = err
481 for record_id, _ in updates:
482 # Check if this ID had an error
483 if record_id in error_dict:
484 error = error_dict[record_id]
485 # If error is 404 (not found), mark as failed
486 status = error.get('update', {}).get('status')
487 results.append(status == 200) # Only 200 is success for update
488 else:
489 results.append(True)
491 return results
493 except Exception as e:
494 # Check if this is a BulkIndexError from the helpers module
495 if hasattr(e, 'errors'):
496 # Extract which operations failed
497 results = []
498 failed_ids = {err['update']['_id'] for err in e.errors}
500 for record_id, _ in updates:
501 results.append(record_id not in failed_ids)
503 return results
504 else:
505 # If bulk operation completely fails, mark all as failed
506 return [False] * len(updates)
508 def _build_complex_es_query(self, condition: Any) -> dict[str, Any]:
509 """Build Elasticsearch query from complex boolean logic conditions.
511 Args:
512 condition: The Condition object (LogicCondition or FilterCondition)
514 Returns:
515 Elasticsearch query dict
516 """
517 from ..query_logic import FilterCondition, LogicCondition, LogicOperator
519 # Handle FilterCondition (leaf node)
520 if isinstance(condition, FilterCondition):
521 return self._build_filter_es_query(condition.filter)
523 # Handle LogicCondition (branch node)
524 elif isinstance(condition, LogicCondition):
525 if condition.operator == LogicOperator.AND:
526 # Build AND query with must clauses
527 must_clauses = []
528 for sub_condition in condition.conditions:
529 sub_query = self._build_complex_es_query(sub_condition)
530 if sub_query:
531 must_clauses.append(sub_query)
533 if not must_clauses:
534 return {"match_all": {}}
535 elif len(must_clauses) == 1:
536 return must_clauses[0]
537 else:
538 return {"bool": {"must": must_clauses}}
540 elif condition.operator == LogicOperator.OR:
541 # Build OR query with should clauses
542 should_clauses = []
543 for sub_condition in condition.conditions:
544 sub_query = self._build_complex_es_query(sub_condition)
545 if sub_query:
546 should_clauses.append(sub_query)
548 if not should_clauses:
549 return {"match_all": {}}
550 elif len(should_clauses) == 1:
551 return should_clauses[0]
552 else:
553 return {"bool": {"should": should_clauses, "minimum_should_match": 1}}
555 elif condition.operator == LogicOperator.NOT:
556 # Build NOT query with must_not
557 if condition.conditions:
558 sub_query = self._build_complex_es_query(condition.conditions[0])
559 if sub_query:
560 return {"bool": {"must_not": sub_query}}
562 return {"match_all": {}}
564 return {"match_all": {}}
566 def _build_filter_es_query(self, filter_obj: Any) -> dict[str, Any]:
567 """Build Elasticsearch query for a single filter.
569 Args:
570 filter_obj: The Filter object
572 Returns:
573 Elasticsearch query dict for the filter
574 """
575 field_path = f"data.{filter_obj.field}"
577 # For string fields in exact match queries, use .keyword suffix
578 if filter_obj.operator in [Operator.EQ, Operator.NEQ, Operator.IN, Operator.NOT_IN]:
579 if isinstance(filter_obj.value, str) or (
580 isinstance(filter_obj.value, list) and
581 filter_obj.value and
582 isinstance(filter_obj.value[0], str)
583 ):
584 field_path = f"{field_path}.keyword"
585 elif filter_obj.operator == Operator.LIKE:
586 # Wildcard needs .keyword for proper matching
587 if isinstance(filter_obj.value, str):
588 field_path = f"{field_path}.keyword"
590 if filter_obj.operator == Operator.EQ:
591 value = str(filter_obj.value).lower() if isinstance(filter_obj.value, bool) else filter_obj.value
592 return {"term": {field_path: value}}
593 elif filter_obj.operator == Operator.NEQ:
594 value = str(filter_obj.value).lower() if isinstance(filter_obj.value, bool) else filter_obj.value
595 return {"bool": {"must_not": {"term": {field_path: value}}}}
596 elif filter_obj.operator == Operator.GT:
597 return {"range": {field_path: {"gt": filter_obj.value}}}
598 elif filter_obj.operator == Operator.GTE:
599 return {"range": {field_path: {"gte": filter_obj.value}}}
600 elif filter_obj.operator == Operator.LT:
601 return {"range": {field_path: {"lt": filter_obj.value}}}
602 elif filter_obj.operator == Operator.LTE:
603 return {"range": {field_path: {"lte": filter_obj.value}}}
604 elif filter_obj.operator == Operator.LIKE:
605 pattern = filter_obj.value.replace("%", "*").replace("_", "?")
606 return {"wildcard": {field_path: pattern}}
607 elif filter_obj.operator == Operator.IN:
608 return {"terms": {field_path: filter_obj.value}}
609 elif filter_obj.operator == Operator.NOT_IN:
610 return {"bool": {"must_not": {"terms": {field_path: filter_obj.value}}}}
611 elif filter_obj.operator == Operator.EXISTS:
612 return {"exists": {"field": field_path}}
613 elif filter_obj.operator == Operator.NOT_EXISTS:
614 return {"bool": {"must_not": {"exists": {"field": field_path}}}}
615 elif filter_obj.operator == Operator.REGEX:
616 return {"regexp": {field_path: filter_obj.value}}
617 elif filter_obj.operator == Operator.BETWEEN:
618 if isinstance(filter_obj.value, (list, tuple)) and len(filter_obj.value) == 2:
619 lower, upper = filter_obj.value
620 return {"range": {field_path: {"gte": lower, "lte": upper}}}
621 elif filter_obj.operator == Operator.NOT_BETWEEN:
622 if isinstance(filter_obj.value, (list, tuple)) and len(filter_obj.value) == 2:
623 lower, upper = filter_obj.value
624 return {"bool": {"must_not": {"range": {field_path: {"gte": lower, "lte":
625upper}}}}}
627 return {"match_all": {}}
630 def search(self, query: Query | ComplexQuery) -> list[Record]:
631 """Search for records matching a query."""
632 # Handle ComplexQuery with native Elasticsearch bool queries
633 if isinstance(query, ComplexQuery):
634 if query.condition:
635 es_query = self._build_complex_es_query(query.condition)
636 else:
637 es_query = {"match_all": {}}
638 else:
639 # Build Elasticsearch query from simple Query object
640 es_query = {"bool": {"must": []}}
642 # Apply filters
643 for filter_obj in query.filters:
644 field_path = f"data.{filter_obj.field}"
646 # For string fields in exact match queries, use .keyword suffix
647 # LIKE and REGEX need to use the text field, not keyword
648 if filter_obj.operator in [Operator.EQ, Operator.NEQ, Operator.IN, Operator.NOT_IN]:
649 if isinstance(filter_obj.value, str) or (
650 isinstance(filter_obj.value, list) and
651 filter_obj.value and
652 isinstance(filter_obj.value[0], str)
653 ):
654 field_path = f"{field_path}.keyword"
655 elif filter_obj.operator == Operator.LIKE:
656 # Wildcard needs .keyword for proper matching
657 if isinstance(filter_obj.value, str):
658 field_path = f"{field_path}.keyword"
660 if filter_obj.operator == Operator.EQ:
661 # Handle boolean values correctly
662 value = str(filter_obj.value).lower() if isinstance(filter_obj.value, bool) else filter_obj.value
663 es_query["bool"]["must"].append({"term": {field_path: value}})
664 elif filter_obj.operator == Operator.NEQ:
665 value = str(filter_obj.value).lower() if isinstance(filter_obj.value, bool) else filter_obj.value
666 es_query["bool"]["must"].append({"bool": {"must_not": {"term": {field_path: value}}}})
667 elif filter_obj.operator == Operator.GT:
668 es_query["bool"]["must"].append({"range": {field_path: {"gt": filter_obj.value}}})
669 elif filter_obj.operator == Operator.GTE:
670 es_query["bool"]["must"].append({"range": {field_path: {"gte": filter_obj.value}}})
671 elif filter_obj.operator == Operator.LT:
672 es_query["bool"]["must"].append({"range": {field_path: {"lt": filter_obj.value}}})
673 elif filter_obj.operator == Operator.LTE:
674 es_query["bool"]["must"].append({"range": {field_path: {"lte": filter_obj.value}}})
675 elif filter_obj.operator == Operator.LIKE:
676 # Convert SQL LIKE pattern to Elasticsearch wildcard
677 # Wildcard queries should use the keyword field for exact matching
678 pattern = filter_obj.value.replace("%", "*").replace("_", "?")
679 # Use the base field path for LIKE (already has .keyword added above if string)
680 es_query["bool"]["must"].append({"wildcard": {field_path: pattern}})
681 elif filter_obj.operator == Operator.IN:
682 es_query["bool"]["must"].append({"terms": {field_path: filter_obj.value}})
683 elif filter_obj.operator == Operator.NOT_IN:
684 es_query["bool"]["must"].append({"bool": {"must_not": {"terms": {field_path: filter_obj.value}}}})
685 elif filter_obj.operator == Operator.EXISTS:
686 es_query["bool"]["must"].append({"exists": {"field": field_path}})
687 elif filter_obj.operator == Operator.NOT_EXISTS:
688 es_query["bool"]["must"].append({"bool": {"must_not": {"exists": {"field": field_path}}}})
689 elif filter_obj.operator == Operator.REGEX:
690 es_query["bool"]["must"].append({"regexp": {field_path: filter_obj.value}})
691 elif filter_obj.operator == Operator.BETWEEN:
692 # Use Elasticsearch's native range query for efficient BETWEEN
693 if isinstance(filter_obj.value, (list, tuple)) and len(filter_obj.value) == 2:
694 lower, upper = filter_obj.value
695 es_query["bool"]["must"].append({
696 "range": {
697 field_path: {
698 "gte": lower,
699 "lte": upper
700 }
701 }
702 })
703 elif filter_obj.operator == Operator.NOT_BETWEEN:
704 # NOT BETWEEN using bool must_not with range
705 if isinstance(filter_obj.value, (list, tuple)) and len(filter_obj.value) == 2:
706 lower, upper = filter_obj.value
707 es_query["bool"]["must"].append({
708 "bool": {
709 "must_not": {
710 "range": {
711 field_path: {
712 "gte": lower,
713 "lte": upper
714 }
715 }
716 }
717 }
718 })
720 # If no filters, match all
721 if not es_query["bool"]["must"]:
722 es_query = {"match_all": {}}
724 # Build sort
725 sort = []
726 if query.sort_specs:
727 for sort_spec in query.sort_specs:
728 field_path = f"data.{sort_spec.field}"
729 # Don't add .keyword if user already specified it or for common numeric fields
730 # This is a heuristic - ideally we'd check the mapping
731 numeric_fields = ['age', 'salary', 'balance', 'count', 'score', 'amount', 'price', 'index', 'id', 'number', 'total', 'quantity']
732 if (not sort_spec.field.endswith('.keyword') and
733 not sort_spec.field.endswith('.raw') and
734 sort_spec.field.lower() not in numeric_fields):
735 # Likely a text field, add .keyword for sorting
736 field_path = f"data.{sort_spec.field}.keyword"
737 order = "desc" if sort_spec.order == SortOrder.DESC else "asc"
738 sort.append({field_path: {"order": order}})
740 # Build search body
741 search_body = {"query": es_query}
742 if sort:
743 search_body["sort"] = sort
744 if query.limit_value:
745 search_body["size"] = query.limit_value
746 if query.offset_value:
747 search_body["from"] = query.offset_value
749 # Execute search
750 response = self.es_index.search(body=search_body)
752 # Check if the response is valid (has the expected structure)
753 # An empty result set is still a valid response
754 if not hasattr(response, 'json') or response.json is None:
755 raise DatabaseError(f"Invalid search response: {response}")
757 # Check for actual errors in the response
758 if 'error' in response.json:
759 raise DatabaseError(f"Failed to search records: {response.json['error']}")
761 # Parse results
762 records = []
763 hits = response.json.get("hits", {}).get("hits", [])
764 for hit in hits:
765 doc = hit.get("_source", {})
766 records.append(self._doc_to_record(doc))
768 # Apply field projection if specified
769 if query.fields:
770 for record in records:
771 # Keep only specified fields
772 field_names = list(record.fields.keys())
773 for field_name in field_names:
774 if field_name not in query.fields:
775 del record.fields[field_name]
777 return records
779 def _count_all(self) -> int:
780 """Count all records in the database."""
781 self._check_connection()
782 return self.es_index.count()
784 def count(self, query: Query | None = None) -> int:
785 """Count records matching a query using efficient Elasticsearch count.
787 Args:
788 query: Optional search query (counts all if None)
790 Returns:
791 Number of matching records
792 """
793 if not query or not query.filters:
794 return self._count_all()
796 # Build Elasticsearch query from Query object (same as search)
797 es_query = {"bool": {"must": []}}
799 for filter_obj in query.filters:
800 field_path = f"data.{filter_obj.field}"
802 # For string fields in exact match queries, use .keyword suffix
803 # LIKE and REGEX need different handling
804 if filter_obj.operator in [Operator.EQ, Operator.NEQ, Operator.IN, Operator.NOT_IN]:
805 if isinstance(filter_obj.value, str) or (
806 isinstance(filter_obj.value, list) and
807 filter_obj.value and
808 isinstance(filter_obj.value[0], str)
809 ):
810 field_path = f"{field_path}.keyword"
811 elif filter_obj.operator == Operator.LIKE:
812 # Wildcard needs .keyword for proper matching
813 if isinstance(filter_obj.value, str):
814 field_path = f"{field_path}.keyword"
816 if filter_obj.operator == Operator.EQ:
817 # Handle boolean values correctly
818 value = str(filter_obj.value).lower() if isinstance(filter_obj.value, bool) else filter_obj.value
819 es_query["bool"]["must"].append({"term": {field_path: value}})
820 elif filter_obj.operator == Operator.NEQ:
821 value = str(filter_obj.value).lower() if isinstance(filter_obj.value, bool) else filter_obj.value
822 es_query["bool"]["must"].append({"bool": {"must_not": {"term": {field_path: value}}}})
823 elif filter_obj.operator == Operator.GT:
824 es_query["bool"]["must"].append({"range": {field_path: {"gt": filter_obj.value}}})
825 elif filter_obj.operator == Operator.GTE:
826 es_query["bool"]["must"].append({"range": {field_path: {"gte": filter_obj.value}}})
827 elif filter_obj.operator == Operator.LT:
828 es_query["bool"]["must"].append({"range": {field_path: {"lt": filter_obj.value}}})
829 elif filter_obj.operator == Operator.LTE:
830 es_query["bool"]["must"].append({"range": {field_path: {"lte": filter_obj.value}}})
831 elif filter_obj.operator == Operator.LIKE:
832 pattern = filter_obj.value.replace("%", "*").replace("_", "?")
833 es_query["bool"]["must"].append({"wildcard": {field_path: pattern}})
834 elif filter_obj.operator == Operator.IN:
835 es_query["bool"]["must"].append({"terms": {field_path: filter_obj.value}})
836 elif filter_obj.operator == Operator.NOT_IN:
837 es_query["bool"]["must"].append({"bool": {"must_not": {"terms": {field_path: filter_obj.value}}}})
838 elif filter_obj.operator == Operator.EXISTS:
839 es_query["bool"]["must"].append({"exists": {"field": field_path}})
840 elif filter_obj.operator == Operator.NOT_EXISTS:
841 es_query["bool"]["must"].append({"bool": {"must_not": {"exists": {"field": field_path}}}})
842 elif filter_obj.operator == Operator.REGEX:
843 es_query["bool"]["must"].append({"regexp": {field_path: filter_obj.value}})
845 # If no filters were added, use match_all
846 if not es_query["bool"]["must"]:
847 es_query = {"match_all": {}}
849 # Count with the query
850 return self.es_index.count(body={"query": es_query})
852 def clear(self) -> int:
853 """Clear all records from the database."""
854 self._check_connection()
855 # Get count before deletion
856 count = self._count_all()
858 # Delete by query - delete all documents
859 response = self.es_index.delete_by_query(
860 body={"query": {"match_all": {}}}
861 )
863 # Refresh if needed
864 if self.refresh:
865 self.es_index.refresh()
867 return response.get("deleted", count)
869 def stream_read(
870 self,
871 query: Query | None = None,
872 config: StreamConfig | None = None
873 ) -> Iterator[Record]:
874 """Stream records from Elasticsearch."""
875 config = config or StreamConfig()
877 # Use search to get all matching records
878 if query:
879 records = self.search(query)
880 else:
881 records = self.search(Query())
883 # Yield records in batches for consistency
884 for i in range(0, len(records), config.batch_size):
885 batch = records[i:i + config.batch_size]
886 for record in batch:
887 yield record
889 def stream_write(
890 self,
891 records: Iterator[Record],
892 config: StreamConfig | None = None
893 ) -> StreamResult:
894 """Stream records into Elasticsearch."""
895 # Use the default implementation from mixin
896 return self._default_stream_write(records, config)
898 def vector_search(
899 self,
900 query_vector: np.ndarray | list[float],
901 field_name: str = "embedding",
902 k: int = 10,
903 metric: DistanceMetric = DistanceMetric.COSINE,
904 filter: Query | None = None,
905 include_source: bool = True,
906 score_threshold: float | None = None,
907 ) -> list[VectorSearchResult]:
908 """Search for similar vectors using Elasticsearch KNN.
910 Note: This is a synchronous wrapper around the async implementation.
911 For production use, consider using the async version for better performance.
913 Args:
914 query_vector: The vector to search for
915 vector_field: Name of the vector field to search
916 k: Number of results to return
917 metric: Distance metric to use
918 filter: Optional query filter to apply before vector search
919 include_source: Whether to include source document in results
920 score_threshold: Optional minimum similarity score
922 Returns:
923 List of search results ordered by similarity
924 """
925 self._check_connection()
927 # Import vector utilities
928 from ..vector.elasticsearch_utils import (
929 build_knn_query,
930 )
932 # Build filter query if provided
933 filter_query = self._build_filter_query(filter) if filter else None
935 # Build KNN query
936 query = build_knn_query(
937 query_vector=query_vector,
938 field_name=field_name,
939 k=k,
940 filter_query=filter_query,
941 )
943 # Execute search using the es_client
944 try:
945 response = self.es_client.search(
946 index=self.index_name,
947 **query,
948 size=k,
949 source=include_source,
950 )
951 except Exception as e:
952 self._handle_elasticsearch_error(e, "vector search")
953 return []
955 # Process results
956 results = []
957 for hit in response.get("hits", {}).get("hits", []):
958 score = hit.get("_score", 0.0)
960 # Apply score threshold if specified
961 if score_threshold is not None and score < score_threshold:
962 continue
964 # Convert document to record if source included
965 record = None
966 if include_source:
967 record = self._doc_to_record(hit)
968 # Set the storage ID on the record if we have one
969 if not record.has_storage_id():
970 record.storage_id = hit["_id"]
972 # Skip if no record (shouldn't happen if include_source is True)
973 if record is None:
974 continue
976 results.append(VectorSearchResult(
977 record=record,
978 score=score,
979 vector_field=field_name,
980 metadata={
981 "index": self.index_name,
982 "metric": metric.value,
983 "doc_id": hit["_id"],
984 },
985 ))
987 return results
989 def create_vector_index(
990 self,
991 vector_field: str = "embedding",
992 dimensions: int | None = None,
993 metric: DistanceMetric = DistanceMetric.COSINE,
994 index_type: str = "auto",
995 **kwargs: Any,
996 ) -> bool:
997 """Create or update index mapping for vector field.
999 Args:
1000 vector_field: Name of the vector field to index
1001 dimensions: Number of dimensions
1002 metric: Distance metric for the index
1003 index_type: Type of index (ignored for ES, always uses HNSW)
1004 **kwargs: Additional index parameters
1006 Returns:
1007 True if index was created/updated successfully
1008 """
1009 self._check_connection()
1011 if not dimensions:
1012 if vector_field not in self.vector_fields:
1013 raise ValueError(f"Unknown dimensions for field '{vector_field}'")
1014 dimensions = self.vector_fields[vector_field]
1016 # Import vector utilities
1017 from ..vector.elasticsearch_utils import (
1018 get_similarity_for_metric,
1019 get_vector_mapping,
1020 )
1022 # Get similarity function for metric
1023 similarity = get_similarity_for_metric(metric)
1025 # Build mapping for the vector field
1026 mapping = get_vector_mapping(dimensions, similarity)
1028 # Update index mapping using the es_client
1029 try:
1030 self.es_client.indices.put_mapping(
1031 index=self.index_name,
1032 properties={
1033 f"data.{vector_field}": mapping
1034 }
1035 )
1037 # Track the vector field
1038 self.vector_fields[vector_field] = dimensions
1039 self._vector_enabled = True
1041 logger.info(f"Created vector mapping for field '{vector_field}' with {dimensions} dimensions")
1042 return True
1044 except Exception as e:
1045 self._handle_elasticsearch_error(e, "create vector index")
1046 return False
1049# Import the native async implementation