Coverage for src/dataknobs_data/backends/elasticsearch_async.py: 13%
326 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
1"""Native async Elasticsearch backend implementation with connection pooling."""
3from __future__ import annotations
5import logging
6import time
7from typing import TYPE_CHECKING, Any, cast
9from dataknobs_config import ConfigurableBase
11from ..database import AsyncDatabase
12from ..pooling import ConnectionPoolManager
13from ..pooling.elasticsearch import (
14 ElasticsearchPoolConfig,
15 close_elasticsearch_client,
16 create_async_elasticsearch_client,
17 validate_elasticsearch_client,
18)
19from ..query import Operator, Query, SortOrder
20from ..streaming import StreamConfig, StreamResult, async_process_batch_with_fallback
21from ..vector.mixins import VectorOperationsMixin
22from ..vector.types import DistanceMetric, VectorSearchResult
23from .elasticsearch_mixins import (
24 ElasticsearchBaseConfig,
25 ElasticsearchErrorHandler,
26 ElasticsearchIndexManager,
27 ElasticsearchQueryBuilder,
28 ElasticsearchRecordSerializer,
29 ElasticsearchVectorSupport,
30)
32if TYPE_CHECKING:
33 import numpy as np
34 from collections.abc import AsyncIterator, Callable, Awaitable
35 from ..records import Record
37logger = logging.getLogger(__name__)
39# Global pool manager for Elasticsearch clients
40_client_manager = ConnectionPoolManager()
43class AsyncElasticsearchDatabase(
44 AsyncDatabase,
45 ConfigurableBase,
46 VectorOperationsMixin,
47 ElasticsearchBaseConfig,
48 ElasticsearchIndexManager,
49 ElasticsearchVectorSupport,
50 ElasticsearchErrorHandler,
51 ElasticsearchRecordSerializer,
52 ElasticsearchQueryBuilder,
53):
54 """Native async Elasticsearch database backend with connection pooling."""
56 def __init__(self, config: dict[str, Any] | None = None):
57 """Initialize async Elasticsearch database."""
58 super().__init__(config)
60 # Initialize vector support
61 self.vector_fields = {} # field_name -> dimensions
62 self.vector_enabled = False
64 config = config or {}
65 self._pool_config = ElasticsearchPoolConfig.from_dict(config)
66 self.index_name = self._pool_config.index
67 self.refresh = config.get("refresh", True)
68 self._client = None
69 self._connected = False
71 @classmethod
72 def from_config(cls, config: dict) -> AsyncElasticsearchDatabase:
73 """Create from config dictionary."""
74 return cls(config)
76 async def connect(self) -> None:
77 """Connect to the Elasticsearch database."""
78 if self._connected:
79 return
81 # Get or create client for current event loop
82 from ..pooling import BasePoolConfig
83 self._client = await _client_manager.get_pool(
84 self._pool_config,
85 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_async_elasticsearch_client),
86 validate_elasticsearch_client,
87 close_elasticsearch_client
88 )
90 # Ensure index exists
91 await self._ensure_index()
92 self._connected = True
94 async def close(self) -> None:
95 """Close the database connection."""
96 if self._connected:
97 # Note: The client is managed by the pool manager, so we don't close it here
98 # Just mark as disconnected
99 self._client = None
100 self._connected = False
102 def _initialize(self) -> None:
103 """Initialize is handled in connect."""
104 pass
106 async def _ensure_index(self) -> None:
107 """Ensure the index exists with proper mappings."""
108 if not self._client:
109 raise RuntimeError("Database not connected. Call connect() first.")
111 # Check if index exists
112 if not await self._client.indices.exists(index=self.index_name): # type: ignore[unreachable]
113 # Get mappings with vector field support
114 mappings = self.get_index_mappings(self.vector_fields)
116 # Get settings optimized for KNN if we have vector fields
117 settings = self.get_knn_index_settings() if self.vector_fields else {
118 "number_of_shards": 1,
119 "number_of_replicas": 0,
120 }
122 await self._client.indices.create(
123 index=self.index_name,
124 mappings=mappings,
125 settings=settings
126 )
128 if self.vector_fields:
129 self.vector_enabled = True
130 logger.info(f"Created index '{self.index_name}' with vector support")
132 def _check_connection(self) -> None:
133 """Check if database is connected."""
134 if not self._connected or not self._client:
135 raise RuntimeError("Database not connected. Call connect() first.")
137 def _record_to_doc(self, record: Record) -> dict[str, Any]:
138 """Convert a Record to an Elasticsearch document."""
139 # Update vector tracking if needed
140 if self._has_vector_fields(record):
141 self._update_vector_tracking(record)
143 # Add vector field metadata to record metadata
144 if "vector_fields" not in record.metadata:
145 record.metadata["vector_fields"] = {}
147 for field_name in self.vector_fields:
148 if field_name in record.fields:
149 field = record.fields[field_name]
150 if hasattr(field, "source_field"):
151 record.metadata["vector_fields"][field_name] = {
152 "type": "vector",
153 "dimensions": self.vector_fields[field_name],
154 "source_field": field.source_field,
155 "model": getattr(field, "model_name", None),
156 "model_version": getattr(field, "model_version", None),
157 }
159 return self._record_to_document(record)
161 def _doc_to_record(self, doc: dict[str, Any]) -> Record:
162 """Convert an Elasticsearch document to a Record."""
163 doc_id = doc.get("_id")
164 record = self._document_to_record(doc, doc_id)
166 # Add score if present
167 if "_score" in doc:
168 record.metadata["_score"] = doc["_score"]
170 return record
172 async def create(self, record: Record) -> str:
173 """Create a new record."""
174 self._check_connection()
175 doc = self._record_to_doc(record)
177 # Create document with explicit ID if record has one
178 kwargs = {
179 "index": self.index_name,
180 "document": doc,
181 "refresh": self.refresh
182 }
183 if record.id:
184 kwargs["id"] = record.id
186 response = await self._client.index(**kwargs)
188 return response["_id"]
190 async def create_batch(self, records: list[Record]) -> list[str]:
191 """Create multiple records in batch."""
192 self._check_connection()
194 ids = []
195 operations = []
197 for record in records:
198 doc = self._record_to_doc(record)
199 operations.append({"index": {"_index": self.index_name}})
200 operations.append(doc)
202 if operations:
203 response = await self._client.bulk(
204 operations=operations,
205 refresh=self.refresh
206 )
208 # Extract IDs from response
209 for item in response.get("items", []):
210 if "index" in item and "_id" in item["index"]:
211 ids.append(item["index"]["_id"])
213 return ids
215 async def read(self, id: str) -> Record | None:
216 """Read a record by ID."""
217 self._check_connection()
219 try:
220 response = await self._client.get(
221 index=self.index_name,
222 id=id
223 )
224 return self._doc_to_record(response)
225 except Exception as e:
226 # Log the error for debugging
227 logger.debug(f"Error reading document {id}: {e}")
228 return None
230 async def update(self, id: str, record: Record) -> bool:
231 """Update an existing record."""
232 self._check_connection()
233 doc = self._record_to_doc(record)
235 try:
236 await self._client.update(
237 index=self.index_name,
238 id=id,
239 doc=doc,
240 refresh=self.refresh
241 )
242 return True
243 except Exception:
244 return False
246 async def delete(self, id: str) -> bool:
247 """Delete a record by ID."""
248 self._check_connection()
250 try:
251 await self._client.delete(
252 index=self.index_name,
253 id=id,
254 refresh=self.refresh
255 )
256 return True
257 except Exception:
258 return False
260 async def exists(self, id: str) -> bool:
261 """Check if a record exists."""
262 self._check_connection()
264 return await self._client.exists(
265 index=self.index_name,
266 id=id
267 )
269 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
270 """Update or insert a record.
272 Can be called as:
273 - upsert(id, record) - explicit ID and record
274 - upsert(record) - extract ID from record using Record's built-in logic
275 """
276 self._check_connection()
278 # Determine ID and record based on arguments
279 if isinstance(id_or_record, str):
280 id = id_or_record
281 if record is None:
282 raise ValueError("Record required when ID is provided")
283 else:
284 record = id_or_record
285 id = record.id
286 if id is None:
287 import uuid # type: ignore[unreachable]
288 id = str(uuid.uuid4())
289 record.storage_id = id
291 doc = self._record_to_doc(record)
293 await self._client.index(
294 index=self.index_name,
295 id=id,
296 document=doc,
297 refresh=self.refresh
298 )
300 return id
302 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
303 """Update multiple records efficiently using the bulk API.
305 Uses AsyncElasticsearch's bulk API for efficient batch updates.
307 Args:
308 updates: List of (id, record) tuples to update
310 Returns:
311 List of success flags for each update
312 """
313 if not updates:
314 return []
316 self._check_connection()
318 # Build bulk operations for AsyncElasticsearch
319 operations: list[dict[str, Any]] = []
320 for record_id, record in updates:
321 # Add update operation
322 operations.append({
323 "update": {
324 "_index": self.index_name,
325 "_id": record_id
326 }
327 })
328 # Add document data
329 doc = self._record_to_doc(record)
330 operations.append({
331 "doc": doc,
332 "doc_as_upsert": False # Don't create if doesn't exist
333 })
335 try:
336 # Execute bulk update using AsyncElasticsearch
337 response = await self._client.bulk(
338 operations=operations,
339 refresh=self.refresh
340 )
342 # Process the response to determine which updates succeeded
343 results = []
344 if response.get("items"):
345 for item in response["items"]:
346 if "update" in item:
347 update_result = item["update"]
348 # Check if update was successful (status 200) or not found (404)
349 results.append(update_result.get("status") == 200)
350 else:
351 results.append(False)
352 else:
353 # If no items in response, mark all as failed
354 results = [False] * len(updates)
356 return results
358 except Exception as e:
359 # If bulk operation fails, mark all as failed
360 import logging
361 logging.error(f"Bulk update failed: {e}")
362 return [False] * len(updates)
364 async def search(self, query: Query) -> list[Record]:
365 """Search for records matching the query."""
366 self._check_connection()
368 # Build Elasticsearch query
369 es_query = {"bool": {"must": []}}
371 for filter in query.filters:
372 field_path = f"data.{filter.field}"
374 if filter.operator == Operator.EQ:
375 # For string values, use keyword field for exact matching
376 if isinstance(filter.value, str):
377 field_path = f"{field_path}.keyword"
378 es_query["bool"]["must"].append({"term": {field_path: filter.value}})
379 elif filter.operator == Operator.NEQ:
380 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", [])
381 es_query["bool"]["must_not"].append({"term": {field_path: filter.value}})
382 elif filter.operator == Operator.GT:
383 es_query["bool"]["must"].append({"range": {field_path: {"gt": filter.value}}})
384 elif filter.operator == Operator.LT:
385 es_query["bool"]["must"].append({"range": {field_path: {"lt": filter.value}}})
386 elif filter.operator == Operator.GTE:
387 es_query["bool"]["must"].append({"range": {field_path: {"gte": filter.value}}})
388 elif filter.operator == Operator.LTE:
389 es_query["bool"]["must"].append({"range": {field_path: {"lte": filter.value}}})
390 elif filter.operator == Operator.LIKE:
391 es_query["bool"]["must"].append({"wildcard": {field_path: f"*{filter.value}*"}})
392 elif filter.operator == Operator.IN:
393 es_query["bool"]["must"].append({"terms": {field_path: filter.value}})
394 elif filter.operator == Operator.NOT_IN:
395 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", [])
396 es_query["bool"]["must_not"].append({"terms": {field_path: filter.value}})
397 elif filter.operator == Operator.BETWEEN:
398 # Use Elasticsearch's native range query for efficient BETWEEN
399 if isinstance(filter.value, (list, tuple)) and len(filter.value) == 2:
400 lower, upper = filter.value
401 es_query["bool"]["must"].append({
402 "range": {
403 field_path: {
404 "gte": lower,
405 "lte": upper
406 }
407 }
408 })
409 elif filter.operator == Operator.NOT_BETWEEN:
410 # NOT BETWEEN using must_not with range
411 if isinstance(filter.value, (list, tuple)) and len(filter.value) == 2:
412 lower, upper = filter.value
413 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", [])
414 es_query["bool"]["must_not"].append({
415 "range": {
416 field_path: {
417 "gte": lower,
418 "lte": upper
419 }
420 }
421 })
423 # If no filters, use match_all
424 if not es_query["bool"]["must"] and "must_not" not in es_query["bool"]:
425 es_query = {"match_all": {}}
427 # Build sort
428 sort = []
429 if query.sort_specs:
430 for sort_spec in query.sort_specs:
431 direction = "desc" if sort_spec.order == SortOrder.DESC else "asc"
432 sort.append({f"data.{sort_spec.field}": {"order": direction}})
434 # Build request body
435 body = {"query": es_query}
436 if sort:
437 body["sort"] = sort
439 # Add size and from for pagination
440 size = query.limit_value if query.limit_value else 10000
441 from_param = query.offset_value if query.offset_value else 0
443 # Execute search
444 response = await self._client.search(
445 index=self.index_name,
446 query=es_query,
447 sort=sort if sort else None,
448 size=size,
449 from_=from_param
450 )
452 # Convert to records
453 records = []
454 for hit in response["hits"]["hits"]:
455 record = self._doc_to_record(hit)
457 # Apply field projection if specified
458 if query.fields:
459 record = record.project(query.fields)
461 records.append(record)
463 return records
465 async def _count_all(self) -> int:
466 """Count all records in the database."""
467 self._check_connection()
469 response = await self._client.count(index=self.index_name)
470 return response["count"]
472 async def clear(self) -> int:
473 """Clear all records from the database."""
474 self._check_connection()
476 # Get count before deletion
477 count = await self._count_all()
479 # Delete by query - delete all documents
480 response = await self._client.delete_by_query(
481 index=self.index_name,
482 query={"match_all": {}},
483 refresh=self.refresh
484 )
486 return response.get("deleted", count)
488 async def stream_read(
489 self,
490 query: Query | None = None,
491 config: StreamConfig | None = None
492 ) -> AsyncIterator[Record]:
493 """Stream records from Elasticsearch using scroll API."""
494 self._check_connection()
495 config = config or StreamConfig()
497 # Build query
498 es_query = {"match_all": {}}
499 if query and query.filters:
500 es_query = {"bool": {"must": []}}
501 for filter in query.filters:
502 field_path = f"data.{filter.field}"
503 if filter.operator == Operator.EQ:
504 es_query["bool"]["must"].append({"term": {field_path: filter.value}})
506 # Initial search with scroll
507 response = await self._client.search(
508 index=self.index_name,
509 query=es_query,
510 scroll="2m",
511 size=config.batch_size
512 )
514 scroll_id = response["_scroll_id"]
515 hits = response["hits"]["hits"]
517 try:
518 while hits:
519 for hit in hits:
520 record = self._doc_to_record(hit)
521 if query and query.fields:
522 record = record.project(query.fields)
523 yield record
525 # Get next batch
526 response = await self._client.scroll(
527 scroll_id=scroll_id,
528 scroll="2m"
529 )
530 hits = response["hits"]["hits"]
531 finally:
532 # Clear scroll
533 await self._client.clear_scroll(scroll_id=scroll_id)
535 async def stream_write(
536 self,
537 records: AsyncIterator[Record],
538 config: StreamConfig | None = None
539 ) -> StreamResult:
540 """Stream records into Elasticsearch using bulk API."""
541 self._check_connection()
542 config = config or StreamConfig()
543 result = StreamResult()
544 start_time = time.time()
545 quitting = False
547 batch = []
548 async for record in records:
549 batch.append(record)
551 if len(batch) >= config.batch_size:
552 # Write batch with graceful fallback
553 async def batch_func(b):
554 await self._write_batch(b)
555 return [r.id for r in b]
557 continue_processing = await async_process_batch_with_fallback(
558 batch,
559 batch_func,
560 self.create,
561 result,
562 config
563 )
565 if not continue_processing:
566 quitting = True
567 break
569 batch = []
571 # Write remaining batch
572 if batch and not quitting:
573 async def batch_func(b):
574 await self._write_batch(b)
575 return [r.id for r in b]
577 await async_process_batch_with_fallback(
578 batch,
579 batch_func,
580 self.create,
581 result,
582 config
583 )
585 result.duration = time.time() - start_time
586 return result
588 async def _write_batch(self, records: list[Record]) -> None:
589 """Write a batch of records using bulk API."""
590 if not records:
591 return
593 # Build bulk operations
594 operations = []
595 for record in records:
596 doc = self._record_to_doc(record)
597 operations.append({"index": {"_index": self.index_name}})
598 operations.append(doc)
600 # Execute bulk
601 await self._client.bulk(
602 operations=operations,
603 refresh=self.refresh
604 )
606 async def vector_search(
607 self,
608 query_vector: np.ndarray | list[float],
609 vector_field: str = "embedding",
610 k: int = 10,
611 metric: DistanceMetric = DistanceMetric.COSINE,
612 filter: Query | None = None,
613 include_source: bool = True,
614 score_threshold: float | None = None,
615 ) -> list[VectorSearchResult]:
616 """Search for similar vectors using Elasticsearch KNN.
618 Args:
619 query_vector: The vector to search for
620 vector_field: Name of the vector field to search
621 k: Number of results to return
622 metric: Distance metric to use
623 filter: Optional query filter to apply before vector search
624 include_source: Whether to include source document in results
625 score_threshold: Optional minimum similarity score
627 Returns:
628 List of search results ordered by similarity
629 """
630 self._check_connection()
632 # Import vector utilities
633 from ..vector.elasticsearch_utils import (
634 build_knn_query,
635 )
637 # Build filter query if provided
638 filter_query = self._build_filter_query(filter) if filter else None
640 # Build KNN query
641 query = build_knn_query(
642 query_vector=query_vector,
643 field_name=vector_field,
644 k=k,
645 filter_query=filter_query,
646 )
648 # Execute search
649 try:
650 response = await self._client.search(
651 index=self.index_name,
652 **query, # Unpack the query dict directly
653 size=k,
654 _source=include_source,
655 )
656 except Exception as e:
657 self._handle_elasticsearch_error(e, "vector search")
658 return []
660 # Process results
661 results = []
662 for hit in response.get("hits", {}).get("hits", []):
663 score = hit.get("_score", 0.0)
665 # Apply score threshold if specified
666 if score_threshold is not None and score < score_threshold:
667 continue
669 # Convert document to record if source included
670 record = None
671 if include_source:
672 record = self._doc_to_record(hit)
674 # Set the storage ID on the record if we have one
675 if record and not record.has_storage_id():
676 record.storage_id = hit["_id"]
678 # Skip if no record (shouldn't happen if include_source is True)
679 if record is None:
680 continue
682 results.append(VectorSearchResult(
683 record=record,
684 score=score,
685 vector_field=vector_field,
686 metadata={
687 "index": self.index_name,
688 "metric": metric.value,
689 "doc_id": hit["_id"],
690 },
691 ))
693 return results
695 async def bulk_embed_and_store(
696 self,
697 records: list[Record],
698 text_field: str | list[str],
699 vector_field: str = "embedding",
700 embedding_fn: Any | None = None,
701 batch_size: int = 100,
702 model_name: str | None = None,
703 model_version: str | None = None,
704 ) -> list[str]:
705 """Embed text fields and store vectors with records.
707 Args:
708 records: Records to process
709 text_field: Field name(s) containing text to embed
710 vector_field: Field name to store vectors in
711 embedding_fn: Function to generate embeddings
712 batch_size: Number of records to process at once
713 model_name: Name of the embedding model
714 model_version: Version of the embedding model
716 Returns:
717 List of record IDs that were processed
718 """
719 # This is a stub implementation
720 # Full implementation would require an actual embedding function
721 logger.warning("bulk_embed_and_store is not fully implemented for Elasticsearch")
722 return []
724 async def create_vector_index(
725 self,
726 vector_field: str = "embedding",
727 dimensions: int | None = None,
728 metric: DistanceMetric = DistanceMetric.COSINE,
729 index_type: str = "auto",
730 **kwargs: Any,
731 ) -> bool:
732 """Create or update index mapping for vector field.
734 Args:
735 vector_field: Name of the vector field to index
736 dimensions: Number of dimensions
737 metric: Distance metric for the index
738 index_type: Type of index (ignored for ES, always uses HNSW)
739 **kwargs: Additional index parameters
741 Returns:
742 True if index was created/updated successfully
743 """
744 self._check_connection()
746 if not dimensions:
747 if vector_field not in self.vector_fields:
748 raise ValueError(f"Unknown dimensions for field '{vector_field}'")
749 dimensions = self.vector_fields[vector_field]
751 # Import vector utilities
752 from ..vector.elasticsearch_utils import (
753 get_similarity_for_metric,
754 get_vector_mapping,
755 )
757 # Get similarity function for metric
758 similarity = get_similarity_for_metric(metric)
760 # Build mapping for the vector field
761 mapping = get_vector_mapping(dimensions, similarity)
763 # Update index mapping
764 try:
765 await self._client.indices.put_mapping(
766 index=self.index_name,
767 properties={
768 f"data.{vector_field}": mapping
769 }
770 )
772 # Track the vector field
773 self.vector_fields[vector_field] = dimensions
774 self.vector_enabled = True
776 logger.info(f"Created vector mapping for field '{vector_field}' with {dimensions} dimensions")
777 return True
779 except Exception as e:
780 self._handle_elasticsearch_error(e, "create vector index")
781 return False