Coverage for src/dataknobs_data/backends/elasticsearch_async.py: 14%
316 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
1"""Native async Elasticsearch backend implementation with connection pooling."""
3from __future__ import annotations
5import logging
6import time
7from typing import TYPE_CHECKING, Any, cast
9from dataknobs_config import ConfigurableBase
11from ..database import AsyncDatabase
12from ..pooling import ConnectionPoolManager
13from ..pooling.elasticsearch import (
14 ElasticsearchPoolConfig,
15 close_elasticsearch_client,
16 create_async_elasticsearch_client,
17 validate_elasticsearch_client,
18)
19from ..query import Operator, Query, SortOrder
20from ..streaming import StreamConfig, StreamResult, async_process_batch_with_fallback
21from ..vector.mixins import VectorOperationsMixin
22from ..vector.types import DistanceMetric, VectorSearchResult
23from .elasticsearch_mixins import (
24 ElasticsearchBaseConfig,
25 ElasticsearchErrorHandler,
26 ElasticsearchIndexManager,
27 ElasticsearchQueryBuilder,
28 ElasticsearchRecordSerializer,
29 ElasticsearchVectorSupport,
30)
32if TYPE_CHECKING:
33 import numpy as np
34 from collections.abc import AsyncIterator, Callable, Awaitable
35 from ..records import Record
37logger = logging.getLogger(__name__)
39# Global pool manager for Elasticsearch clients
40_client_manager = ConnectionPoolManager()
43class AsyncElasticsearchDatabase(
44 AsyncDatabase,
45 ConfigurableBase,
46 VectorOperationsMixin,
47 ElasticsearchBaseConfig,
48 ElasticsearchIndexManager,
49 ElasticsearchVectorSupport,
50 ElasticsearchErrorHandler,
51 ElasticsearchRecordSerializer,
52 ElasticsearchQueryBuilder,
53):
54 """Native async Elasticsearch database backend with connection pooling."""
56 def __init__(self, config: dict[str, Any] | None = None):
57 """Initialize async Elasticsearch database."""
58 super().__init__(config)
60 # Initialize vector support
61 self.vector_fields = {} # field_name -> dimensions
62 self.vector_enabled = False
64 config = config or {}
65 self._pool_config = ElasticsearchPoolConfig.from_dict(config)
66 self.index_name = self._pool_config.index
67 self.refresh = config.get("refresh", True)
68 self._client = None
69 self._connected = False
71 @classmethod
72 def from_config(cls, config: dict) -> AsyncElasticsearchDatabase:
73 """Create from config dictionary."""
74 return cls(config)
76 async def connect(self) -> None:
77 """Connect to the Elasticsearch database."""
78 if self._connected:
79 return
81 # Get or create client for current event loop
82 from ..pooling import BasePoolConfig
83 self._client = await _client_manager.get_pool(
84 self._pool_config,
85 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_async_elasticsearch_client),
86 validate_elasticsearch_client,
87 close_elasticsearch_client
88 )
90 # Ensure index exists
91 await self._ensure_index()
92 self._connected = True
94 async def close(self) -> None:
95 """Close the database connection."""
96 if self._connected:
97 # Note: The client is managed by the pool manager, so we don't close it here
98 # Just mark as disconnected
99 self._client = None
100 self._connected = False
102 def _initialize(self) -> None:
103 """Initialize is handled in connect."""
104 pass
106 async def _ensure_index(self) -> None:
107 """Ensure the index exists with proper mappings."""
108 if not self._client:
109 raise RuntimeError("Database not connected. Call connect() first.")
111 # Check if index exists
112 if not await self._client.indices.exists(index=self.index_name): # type: ignore[unreachable]
113 # Get mappings with vector field support
114 mappings = self.get_index_mappings(self.vector_fields)
116 # Get settings optimized for KNN if we have vector fields
117 settings = self.get_knn_index_settings() if self.vector_fields else {
118 "number_of_shards": 1,
119 "number_of_replicas": 0,
120 }
122 await self._client.indices.create(
123 index=self.index_name,
124 mappings=mappings,
125 settings=settings
126 )
128 if self.vector_fields:
129 self.vector_enabled = True
130 logger.info(f"Created index '{self.index_name}' with vector support")
132 def _check_connection(self) -> None:
133 """Check if database is connected."""
134 if not self._connected or not self._client:
135 raise RuntimeError("Database not connected. Call connect() first.")
137 def _record_to_doc(self, record: Record) -> dict[str, Any]:
138 """Convert a Record to an Elasticsearch document."""
139 # Update vector tracking if needed
140 if self._has_vector_fields(record):
141 self._update_vector_tracking(record)
143 # Add vector field metadata to record metadata
144 if "vector_fields" not in record.metadata:
145 record.metadata["vector_fields"] = {}
147 for field_name in self.vector_fields:
148 if field_name in record.fields:
149 field = record.fields[field_name]
150 if hasattr(field, "source_field"):
151 record.metadata["vector_fields"][field_name] = {
152 "type": "vector",
153 "dimensions": self.vector_fields[field_name],
154 "source_field": field.source_field,
155 "model": getattr(field, "model_name", None),
156 "model_version": getattr(field, "model_version", None),
157 }
159 return self._record_to_document(record)
161 def _doc_to_record(self, doc: dict[str, Any]) -> Record:
162 """Convert an Elasticsearch document to a Record."""
163 doc_id = doc.get("_id")
164 record = self._document_to_record(doc, doc_id)
166 # Add score if present
167 if "_score" in doc:
168 record.metadata["_score"] = doc["_score"]
170 return record
172 async def create(self, record: Record) -> str:
173 """Create a new record."""
174 self._check_connection()
175 doc = self._record_to_doc(record)
177 # Create document with explicit ID if record has one
178 kwargs = {
179 "index": self.index_name,
180 "document": doc,
181 "refresh": self.refresh
182 }
183 if record.id:
184 kwargs["id"] = record.id
186 response = await self._client.index(**kwargs)
188 return response["_id"]
190 async def create_batch(self, records: list[Record]) -> list[str]:
191 """Create multiple records in batch."""
192 self._check_connection()
194 ids = []
195 operations = []
197 for record in records:
198 doc = self._record_to_doc(record)
199 operations.append({"index": {"_index": self.index_name}})
200 operations.append(doc)
202 if operations:
203 response = await self._client.bulk(
204 operations=operations,
205 refresh=self.refresh
206 )
208 # Extract IDs from response
209 for item in response.get("items", []):
210 if "index" in item and "_id" in item["index"]:
211 ids.append(item["index"]["_id"])
213 return ids
215 async def read(self, id: str) -> Record | None:
216 """Read a record by ID."""
217 self._check_connection()
219 try:
220 response = await self._client.get(
221 index=self.index_name,
222 id=id
223 )
224 return self._doc_to_record(response)
225 except Exception as e:
226 # Log the error for debugging
227 logger.debug(f"Error reading document {id}: {e}")
228 return None
230 async def update(self, id: str, record: Record) -> bool:
231 """Update an existing record."""
232 self._check_connection()
233 doc = self._record_to_doc(record)
235 try:
236 await self._client.update(
237 index=self.index_name,
238 id=id,
239 doc=doc,
240 refresh=self.refresh
241 )
242 return True
243 except Exception:
244 return False
246 async def delete(self, id: str) -> bool:
247 """Delete a record by ID."""
248 self._check_connection()
250 try:
251 await self._client.delete(
252 index=self.index_name,
253 id=id,
254 refresh=self.refresh
255 )
256 return True
257 except Exception:
258 return False
260 async def exists(self, id: str) -> bool:
261 """Check if a record exists."""
262 self._check_connection()
264 return await self._client.exists(
265 index=self.index_name,
266 id=id
267 )
269 async def upsert(self, id: str, record: Record) -> str:
270 """Update or insert a record with a specific ID."""
271 self._check_connection()
272 doc = self._record_to_doc(record)
274 await self._client.index(
275 index=self.index_name,
276 id=id,
277 document=doc,
278 refresh=self.refresh
279 )
281 return id
283 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
284 """Update multiple records efficiently using the bulk API.
286 Uses AsyncElasticsearch's bulk API for efficient batch updates.
288 Args:
289 updates: List of (id, record) tuples to update
291 Returns:
292 List of success flags for each update
293 """
294 if not updates:
295 return []
297 self._check_connection()
299 # Build bulk operations for AsyncElasticsearch
300 operations: list[dict[str, Any]] = []
301 for record_id, record in updates:
302 # Add update operation
303 operations.append({
304 "update": {
305 "_index": self.index_name,
306 "_id": record_id
307 }
308 })
309 # Add document data
310 doc = self._record_to_doc(record)
311 operations.append({
312 "doc": doc,
313 "doc_as_upsert": False # Don't create if doesn't exist
314 })
316 try:
317 # Execute bulk update using AsyncElasticsearch
318 response = await self._client.bulk(
319 operations=operations,
320 refresh=self.refresh
321 )
323 # Process the response to determine which updates succeeded
324 results = []
325 if response.get("items"):
326 for item in response["items"]:
327 if "update" in item:
328 update_result = item["update"]
329 # Check if update was successful (status 200) or not found (404)
330 results.append(update_result.get("status") == 200)
331 else:
332 results.append(False)
333 else:
334 # If no items in response, mark all as failed
335 results = [False] * len(updates)
337 return results
339 except Exception as e:
340 # If bulk operation fails, mark all as failed
341 import logging
342 logging.error(f"Bulk update failed: {e}")
343 return [False] * len(updates)
345 async def search(self, query: Query) -> list[Record]:
346 """Search for records matching the query."""
347 self._check_connection()
349 # Build Elasticsearch query
350 es_query = {"bool": {"must": []}}
352 for filter in query.filters:
353 field_path = f"data.{filter.field}"
355 if filter.operator == Operator.EQ:
356 # For string values, use keyword field for exact matching
357 if isinstance(filter.value, str):
358 field_path = f"{field_path}.keyword"
359 es_query["bool"]["must"].append({"term": {field_path: filter.value}})
360 elif filter.operator == Operator.NEQ:
361 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", [])
362 es_query["bool"]["must_not"].append({"term": {field_path: filter.value}})
363 elif filter.operator == Operator.GT:
364 es_query["bool"]["must"].append({"range": {field_path: {"gt": filter.value}}})
365 elif filter.operator == Operator.LT:
366 es_query["bool"]["must"].append({"range": {field_path: {"lt": filter.value}}})
367 elif filter.operator == Operator.GTE:
368 es_query["bool"]["must"].append({"range": {field_path: {"gte": filter.value}}})
369 elif filter.operator == Operator.LTE:
370 es_query["bool"]["must"].append({"range": {field_path: {"lte": filter.value}}})
371 elif filter.operator == Operator.LIKE:
372 es_query["bool"]["must"].append({"wildcard": {field_path: f"*{filter.value}*"}})
373 elif filter.operator == Operator.IN:
374 es_query["bool"]["must"].append({"terms": {field_path: filter.value}})
375 elif filter.operator == Operator.NOT_IN:
376 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", [])
377 es_query["bool"]["must_not"].append({"terms": {field_path: filter.value}})
378 elif filter.operator == Operator.BETWEEN:
379 # Use Elasticsearch's native range query for efficient BETWEEN
380 if isinstance(filter.value, (list, tuple)) and len(filter.value) == 2:
381 lower, upper = filter.value
382 es_query["bool"]["must"].append({
383 "range": {
384 field_path: {
385 "gte": lower,
386 "lte": upper
387 }
388 }
389 })
390 elif filter.operator == Operator.NOT_BETWEEN:
391 # NOT BETWEEN using must_not with range
392 if isinstance(filter.value, (list, tuple)) and len(filter.value) == 2:
393 lower, upper = filter.value
394 es_query["bool"]["must_not"] = es_query["bool"].get("must_not", [])
395 es_query["bool"]["must_not"].append({
396 "range": {
397 field_path: {
398 "gte": lower,
399 "lte": upper
400 }
401 }
402 })
404 # If no filters, use match_all
405 if not es_query["bool"]["must"] and "must_not" not in es_query["bool"]:
406 es_query = {"match_all": {}}
408 # Build sort
409 sort = []
410 if query.sort_specs:
411 for sort_spec in query.sort_specs:
412 direction = "desc" if sort_spec.order == SortOrder.DESC else "asc"
413 sort.append({f"data.{sort_spec.field}": {"order": direction}})
415 # Build request body
416 body = {"query": es_query}
417 if sort:
418 body["sort"] = sort
420 # Add size and from for pagination
421 size = query.limit_value if query.limit_value else 10000
422 from_param = query.offset_value if query.offset_value else 0
424 # Execute search
425 response = await self._client.search(
426 index=self.index_name,
427 query=es_query,
428 sort=sort if sort else None,
429 size=size,
430 from_=from_param
431 )
433 # Convert to records
434 records = []
435 for hit in response["hits"]["hits"]:
436 record = self._doc_to_record(hit)
438 # Apply field projection if specified
439 if query.fields:
440 record = record.project(query.fields)
442 records.append(record)
444 return records
446 async def _count_all(self) -> int:
447 """Count all records in the database."""
448 self._check_connection()
450 response = await self._client.count(index=self.index_name)
451 return response["count"]
453 async def clear(self) -> int:
454 """Clear all records from the database."""
455 self._check_connection()
457 # Get count before deletion
458 count = await self._count_all()
460 # Delete by query - delete all documents
461 response = await self._client.delete_by_query(
462 index=self.index_name,
463 query={"match_all": {}},
464 refresh=self.refresh
465 )
467 return response.get("deleted", count)
469 async def stream_read(
470 self,
471 query: Query | None = None,
472 config: StreamConfig | None = None
473 ) -> AsyncIterator[Record]:
474 """Stream records from Elasticsearch using scroll API."""
475 self._check_connection()
476 config = config or StreamConfig()
478 # Build query
479 es_query = {"match_all": {}}
480 if query and query.filters:
481 es_query = {"bool": {"must": []}}
482 for filter in query.filters:
483 field_path = f"data.{filter.field}"
484 if filter.operator == Operator.EQ:
485 es_query["bool"]["must"].append({"term": {field_path: filter.value}})
487 # Initial search with scroll
488 response = await self._client.search(
489 index=self.index_name,
490 query=es_query,
491 scroll="2m",
492 size=config.batch_size
493 )
495 scroll_id = response["_scroll_id"]
496 hits = response["hits"]["hits"]
498 try:
499 while hits:
500 for hit in hits:
501 record = self._doc_to_record(hit)
502 if query and query.fields:
503 record = record.project(query.fields)
504 yield record
506 # Get next batch
507 response = await self._client.scroll(
508 scroll_id=scroll_id,
509 scroll="2m"
510 )
511 hits = response["hits"]["hits"]
512 finally:
513 # Clear scroll
514 await self._client.clear_scroll(scroll_id=scroll_id)
516 async def stream_write(
517 self,
518 records: AsyncIterator[Record],
519 config: StreamConfig | None = None
520 ) -> StreamResult:
521 """Stream records into Elasticsearch using bulk API."""
522 self._check_connection()
523 config = config or StreamConfig()
524 result = StreamResult()
525 start_time = time.time()
526 quitting = False
528 batch = []
529 async for record in records:
530 batch.append(record)
532 if len(batch) >= config.batch_size:
533 # Write batch with graceful fallback
534 async def batch_func(b):
535 await self._write_batch(b)
536 return [r.id for r in b]
538 continue_processing = await async_process_batch_with_fallback(
539 batch,
540 batch_func,
541 self.create,
542 result,
543 config
544 )
546 if not continue_processing:
547 quitting = True
548 break
550 batch = []
552 # Write remaining batch
553 if batch and not quitting:
554 async def batch_func(b):
555 await self._write_batch(b)
556 return [r.id for r in b]
558 await async_process_batch_with_fallback(
559 batch,
560 batch_func,
561 self.create,
562 result,
563 config
564 )
566 result.duration = time.time() - start_time
567 return result
569 async def _write_batch(self, records: list[Record]) -> None:
570 """Write a batch of records using bulk API."""
571 if not records:
572 return
574 # Build bulk operations
575 operations = []
576 for record in records:
577 doc = self._record_to_doc(record)
578 operations.append({"index": {"_index": self.index_name}})
579 operations.append(doc)
581 # Execute bulk
582 await self._client.bulk(
583 operations=operations,
584 refresh=self.refresh
585 )
587 async def vector_search(
588 self,
589 query_vector: np.ndarray | list[float],
590 vector_field: str = "embedding",
591 k: int = 10,
592 metric: DistanceMetric = DistanceMetric.COSINE,
593 filter: Query | None = None,
594 include_source: bool = True,
595 score_threshold: float | None = None,
596 ) -> list[VectorSearchResult]:
597 """Search for similar vectors using Elasticsearch KNN.
599 Args:
600 query_vector: The vector to search for
601 vector_field: Name of the vector field to search
602 k: Number of results to return
603 metric: Distance metric to use
604 filter: Optional query filter to apply before vector search
605 include_source: Whether to include source document in results
606 score_threshold: Optional minimum similarity score
608 Returns:
609 List of search results ordered by similarity
610 """
611 self._check_connection()
613 # Import vector utilities
614 from ..vector.elasticsearch_utils import (
615 build_knn_query,
616 )
618 # Build filter query if provided
619 filter_query = self._build_filter_query(filter) if filter else None
621 # Build KNN query
622 query = build_knn_query(
623 query_vector=query_vector,
624 field_name=vector_field,
625 k=k,
626 filter_query=filter_query,
627 )
629 # Execute search
630 try:
631 response = await self._client.search(
632 index=self.index_name,
633 **query, # Unpack the query dict directly
634 size=k,
635 _source=include_source,
636 )
637 except Exception as e:
638 self._handle_elasticsearch_error(e, "vector search")
639 return []
641 # Process results
642 results = []
643 for hit in response.get("hits", {}).get("hits", []):
644 score = hit.get("_score", 0.0)
646 # Apply score threshold if specified
647 if score_threshold is not None and score < score_threshold:
648 continue
650 # Convert document to record if source included
651 record = None
652 if include_source:
653 record = self._doc_to_record(hit)
655 # Set the storage ID on the record if we have one
656 if record and not record.has_storage_id():
657 record.storage_id = hit["_id"]
659 # Skip if no record (shouldn't happen if include_source is True)
660 if record is None:
661 continue
663 results.append(VectorSearchResult(
664 record=record,
665 score=score,
666 vector_field=vector_field,
667 metadata={
668 "index": self.index_name,
669 "metric": metric.value,
670 "doc_id": hit["_id"],
671 },
672 ))
674 return results
676 async def bulk_embed_and_store(
677 self,
678 records: list[Record],
679 text_field: str | list[str],
680 vector_field: str = "embedding",
681 embedding_fn: Any | None = None,
682 batch_size: int = 100,
683 model_name: str | None = None,
684 model_version: str | None = None,
685 ) -> list[str]:
686 """Embed text fields and store vectors with records.
688 Args:
689 records: Records to process
690 text_field: Field name(s) containing text to embed
691 vector_field: Field name to store vectors in
692 embedding_fn: Function to generate embeddings
693 batch_size: Number of records to process at once
694 model_name: Name of the embedding model
695 model_version: Version of the embedding model
697 Returns:
698 List of record IDs that were processed
699 """
700 # This is a stub implementation
701 # Full implementation would require an actual embedding function
702 logger.warning("bulk_embed_and_store is not fully implemented for Elasticsearch")
703 return []
705 async def create_vector_index(
706 self,
707 vector_field: str = "embedding",
708 dimensions: int | None = None,
709 metric: DistanceMetric = DistanceMetric.COSINE,
710 index_type: str = "auto",
711 **kwargs: Any,
712 ) -> bool:
713 """Create or update index mapping for vector field.
715 Args:
716 vector_field: Name of the vector field to index
717 dimensions: Number of dimensions
718 metric: Distance metric for the index
719 index_type: Type of index (ignored for ES, always uses HNSW)
720 **kwargs: Additional index parameters
722 Returns:
723 True if index was created/updated successfully
724 """
725 self._check_connection()
727 if not dimensions:
728 if vector_field not in self.vector_fields:
729 raise ValueError(f"Unknown dimensions for field '{vector_field}'")
730 dimensions = self.vector_fields[vector_field]
732 # Import vector utilities
733 from ..vector.elasticsearch_utils import (
734 get_similarity_for_metric,
735 get_vector_mapping,
736 )
738 # Get similarity function for metric
739 similarity = get_similarity_for_metric(metric)
741 # Build mapping for the vector field
742 mapping = get_vector_mapping(dimensions, similarity)
744 # Update index mapping
745 try:
746 await self._client.indices.put_mapping(
747 index=self.index_name,
748 properties={
749 f"data.{vector_field}": mapping
750 }
751 )
753 # Track the vector field
754 self.vector_fields[vector_field] = dimensions
755 self.vector_enabled = True
757 logger.info(f"Created vector mapping for field '{vector_field}' with {dimensions} dimensions")
758 return True
760 except Exception as e:
761 self._handle_elasticsearch_error(e, "create vector index")
762 return False