Coverage for src/dataknobs_data/backends/elasticsearch_mixins.py: 21%
141 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
1"""Shared mixins for Elasticsearch backend implementations."""
3from __future__ import annotations
5import logging
6from typing import TYPE_CHECKING, Any
8import numpy as np
10from ..fields import Field, FieldType, VectorField
11from ..records import Record
13if TYPE_CHECKING:
14 from ..query import Query
16logger = logging.getLogger(__name__)
19class ElasticsearchBaseConfig:
20 """Mixin for parsing Elasticsearch configuration."""
22 def _parse_elasticsearch_config(self, config: dict[str, Any]) -> tuple[str, int, str, dict]:
23 """Parse Elasticsearch configuration.
25 Args:
26 config: Configuration dictionary
28 Returns:
29 Tuple of (host, port, index_name, extra_config)
30 """
31 host = config.get("host", "localhost")
32 port = config.get("port", 9200)
33 index = config.get("index", "records")
35 # Extract other config options
36 extra_config = {
37 "refresh": config.get("refresh", True),
38 "settings": config.get("settings", {
39 "number_of_shards": 1,
40 "number_of_replicas": 0,
41 }),
42 "mappings": config.get("mappings"),
43 }
45 return host, port, index, extra_config
48class ElasticsearchIndexManager:
49 """Mixin for Elasticsearch index management."""
51 @staticmethod
52 def get_index_mappings(vector_fields: dict[str, int] | None = None) -> dict:
53 """Get index mappings with vector field support.
55 Args:
56 vector_fields: Dict mapping vector field names to dimensions
58 Returns:
59 Elasticsearch mappings dictionary
60 """
61 mappings = {
62 "properties": {
63 "id": {"type": "keyword"},
64 "data": {
65 "type": "object",
66 "properties": {}
67 },
68 "metadata": {"type": "object", "enabled": True},
69 "created_at": {"type": "date"},
70 "updated_at": {"type": "date"},
71 }
72 }
74 # Add vector field mappings if specified
75 if vector_fields:
76 for field_name, dimensions in vector_fields.items():
77 # Use dense_vector type for vector fields nested under data
78 data_props = mappings["properties"]["data"]["properties"] # type: ignore[index]
79 data_props[field_name] = {
80 "type": "dense_vector",
81 "dims": dimensions,
82 "index": True,
83 "similarity": "cosine" # Default similarity
84 }
86 return mappings
88 @staticmethod
89 def get_knn_index_settings() -> dict:
90 """Get index settings optimized for KNN search.
92 Returns:
93 Index settings dictionary
94 """
95 return {
96 "number_of_shards": 1,
97 "number_of_replicas": 0,
98 # Note: "knn" setting is not needed for standard Elasticsearch
99 # KNN is enabled by having dense_vector fields with index=true
100 }
103class ElasticsearchVectorSupport:
104 """Mixin for vector field detection and tracking."""
106 def __init__(self):
107 """Initialize vector support tracking."""
108 self.vector_fields: dict[str, int] = {} # field_name -> dimensions
109 self.vector_enabled = False
111 def _detect_vector_fields(self, record: Record) -> dict[str, int]:
112 """Detect vector fields in a record.
114 Args:
115 record: Record to examine
117 Returns:
118 Dict mapping field names to dimensions
119 """
120 vector_fields = {}
122 for field_name, field_obj in record.fields.items():
123 if field_obj.type in (FieldType.VECTOR, FieldType.SPARSE_VECTOR):
124 if isinstance(field_obj, VectorField) and field_obj.value is not None:
125 # Get dimensions from the vector value
126 if isinstance(field_obj.value, (list, np.ndarray)):
127 dims = len(field_obj.value) if isinstance(field_obj.value, list) else field_obj.value.shape[0]
128 vector_fields[field_name] = dims
129 logger.debug(f"Detected vector field '{field_name}' with {dims} dimensions")
131 return vector_fields
133 def _has_vector_fields(self, record: Record) -> bool:
134 """Check if a record has vector fields.
136 Args:
137 record: Record to check
139 Returns:
140 True if record has vector fields
141 """
142 return len(self._detect_vector_fields(record)) > 0
144 def _update_vector_tracking(self, record: Record) -> None:
145 """Update tracking of vector fields from a record.
147 Args:
148 record: Record to examine
149 """
150 detected = self._detect_vector_fields(record)
151 for field_name, dims in detected.items():
152 if field_name not in self.vector_fields:
153 self.vector_fields[field_name] = dims
154 logger.info(f"Tracking new vector field '{field_name}' with {dims} dimensions")
157class ElasticsearchErrorHandler:
158 """Mixin for consistent error handling."""
160 @staticmethod
161 def _handle_elasticsearch_error(error: Exception, operation: str) -> None:
162 """Handle Elasticsearch errors consistently.
164 Args:
165 error: The exception that occurred
166 operation: Description of the operation that failed
167 """
168 from elasticsearch import (
169 ConnectionError,
170 NotFoundError,
171 RequestError,
172 TransportError,
173 )
175 if isinstance(error, ConnectionError):
176 logger.error(f"Connection error during {operation}: {error}")
177 raise RuntimeError(f"Failed to connect to Elasticsearch: {error}") from error
178 elif isinstance(error, NotFoundError):
179 logger.warning(f"Resource not found during {operation}: {error}")
180 raise ValueError(f"Resource not found: {error}") from error
181 elif isinstance(error, RequestError):
182 logger.error(f"Bad request during {operation}: {error}")
183 raise ValueError(f"Invalid request: {error}") from error
184 elif isinstance(error, TransportError):
185 logger.error(f"Transport error during {operation}: {error}")
186 raise RuntimeError(f"Elasticsearch transport error: {error}") from error
187 else:
188 logger.error(f"Unexpected error during {operation}: {error}")
189 raise error
192class ElasticsearchRecordSerializer:
193 """Mixin for record serialization with vector field handling."""
195 @staticmethod
196 def _record_to_document(record: Record) -> dict[str, Any]:
197 """Convert a record to an Elasticsearch document.
199 Args:
200 record: Record to convert
202 Returns:
203 Document dictionary for Elasticsearch
204 """
205 # Serialize the record data
206 data_dict = {}
208 for field_name, field_obj in record.fields.items():
209 if isinstance(field_obj, VectorField) and field_obj.value is not None:
210 # Convert numpy arrays to lists for JSON serialization
211 if isinstance(field_obj.value, np.ndarray):
212 data_dict[field_name] = field_obj.value.tolist()
213 else:
214 data_dict[field_name] = field_obj.value
215 else:
216 data_dict[field_name] = field_obj.value
218 # Create the document
219 doc = {
220 "data": data_dict,
221 "metadata": record.metadata,
222 }
224 # Add timestamps if they exist as attributes
225 if hasattr(record, "created_at") and record.created_at:
226 doc["created_at"] = record.created_at.isoformat()
227 if hasattr(record, "updated_at") and record.updated_at:
228 doc["updated_at"] = record.updated_at.isoformat()
230 # Add ID if present
231 if record.id:
232 doc["id"] = record.id
234 return doc
236 @staticmethod
237 def _document_to_record(doc: dict[str, Any], doc_id: str | None = None) -> Record:
238 """Convert an Elasticsearch document to a record.
240 Args:
241 doc: Document from Elasticsearch
242 doc_id: Document ID from Elasticsearch
244 Returns:
245 Record instance
246 """
247 # Get the source data
248 source = doc.get("_source", doc)
250 # Extract data and metadata
251 data = source.get("data", {})
252 metadata = source.get("metadata", {})
254 # Create fields
255 fields = {}
256 for field_name, value in data.items():
257 # Check if this is a vector field based on metadata
258 field_meta = metadata.get("vector_fields", {}).get(field_name, {})
260 if field_meta.get("type") == "vector" or (
261 isinstance(value, list) and len(value) > 0 and
262 all(isinstance(v, (int, float)) for v in value)
263 ):
264 # This looks like a vector field
265 vector_value = np.array(value, dtype=np.float32) if value else np.array([], dtype=np.float32)
266 fields[field_name] = VectorField(
267 name=field_name,
268 value=vector_value,
269 source_field=field_meta.get("source_field"),
270 model_name=field_meta.get("model"),
271 model_version=field_meta.get("model_version"),
272 )
273 else:
274 # Regular field - infer type from value
275 field_type = FieldType.STRING # default
276 if isinstance(value, bool):
277 field_type = FieldType.BOOLEAN
278 elif isinstance(value, int):
279 field_type = FieldType.INTEGER
280 elif isinstance(value, float):
281 field_type = FieldType.FLOAT
282 elif isinstance(value, dict) or (isinstance(value, (list, tuple)) and not all(isinstance(v, (int, float)) for v in value)):
283 field_type = FieldType.JSON
285 fields[field_name] = Field(
286 name=field_name,
287 value=value,
288 type=field_type,
289 )
291 # Create the record - pass fields as OrderedDict since they're Field objects
292 from collections import OrderedDict
293 record = Record(data=OrderedDict(fields), metadata=metadata)
295 # Set ID from document
296 if doc_id:
297 record.id = doc_id
298 elif "_id" in doc:
299 record.id = doc["_id"]
300 elif "id" in source:
301 record.id = source["id"]
303 # Set timestamps if available (as attributes, not fields)
304 if source.get("created_at"):
305 from datetime import datetime
306 record.created_at = datetime.fromisoformat(source["created_at"])
308 if source.get("updated_at"):
309 from datetime import datetime
310 record.updated_at = datetime.fromisoformat(source["updated_at"])
312 return record
315class ElasticsearchQueryBuilder:
316 """Mixin for building Elasticsearch queries."""
318 @staticmethod
319 def _build_filter_query(filter_query: Query | None) -> dict[str, Any] | None:
320 """Build Elasticsearch filter query from Query object.
322 Args:
323 filter_query: Query object to convert
325 Returns:
326 Elasticsearch query dict or None
327 """
328 if not filter_query:
329 return None
331 # TODO: Implement full query translation
332 # For now, just support simple field equality
333 from ..query import Operator
335 must_clauses = []
337 if filter_query.filters:
338 for filter_item in filter_query.filters:
339 field_path = f"data.{filter_item.field}"
341 if filter_item.operator == Operator.EQ:
342 # Use match query for text fields to handle analyzed text
343 must_clauses.append({
344 "match": {field_path: filter_item.value}
345 })
346 elif filter_item.operator == Operator.IN:
347 must_clauses.append({
348 "terms": {field_path: filter_item.value}
349 })
350 elif filter_item.operator == Operator.GT:
351 must_clauses.append({
352 "range": {field_path: {"gt": filter_item.value}}
353 })
354 elif filter_item.operator == Operator.LT:
355 must_clauses.append({
356 "range": {field_path: {"lt": filter_item.value}}
357 })
359 if must_clauses:
360 return {"bool": {"must": must_clauses}}
362 return None