Coverage for src/dataknobs_data/backends/elasticsearch_mixins.py: 21%

141 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-29 14:14 -0600

1"""Shared mixins for Elasticsearch backend implementations.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6from typing import TYPE_CHECKING, Any 

7 

8import numpy as np 

9 

10from ..fields import Field, FieldType, VectorField 

11from ..records import Record 

12 

13if TYPE_CHECKING: 

14 from ..query import Query 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19class ElasticsearchBaseConfig: 

20 """Mixin for parsing Elasticsearch configuration.""" 

21 

22 def _parse_elasticsearch_config(self, config: dict[str, Any]) -> tuple[str, int, str, dict]: 

23 """Parse Elasticsearch configuration. 

24  

25 Args: 

26 config: Configuration dictionary 

27  

28 Returns: 

29 Tuple of (host, port, index_name, extra_config) 

30 """ 

31 host = config.get("host", "localhost") 

32 port = config.get("port", 9200) 

33 index = config.get("index", "records") 

34 

35 # Extract other config options 

36 extra_config = { 

37 "refresh": config.get("refresh", True), 

38 "settings": config.get("settings", { 

39 "number_of_shards": 1, 

40 "number_of_replicas": 0, 

41 }), 

42 "mappings": config.get("mappings"), 

43 } 

44 

45 return host, port, index, extra_config 

46 

47 

48class ElasticsearchIndexManager: 

49 """Mixin for Elasticsearch index management.""" 

50 

51 @staticmethod 

52 def get_index_mappings(vector_fields: dict[str, int] | None = None) -> dict: 

53 """Get index mappings with vector field support. 

54  

55 Args: 

56 vector_fields: Dict mapping vector field names to dimensions 

57  

58 Returns: 

59 Elasticsearch mappings dictionary 

60 """ 

61 mappings = { 

62 "properties": { 

63 "id": {"type": "keyword"}, 

64 "data": { 

65 "type": "object", 

66 "properties": {} 

67 }, 

68 "metadata": {"type": "object", "enabled": True}, 

69 "created_at": {"type": "date"}, 

70 "updated_at": {"type": "date"}, 

71 } 

72 } 

73 

74 # Add vector field mappings if specified 

75 if vector_fields: 

76 for field_name, dimensions in vector_fields.items(): 

77 # Use dense_vector type for vector fields nested under data 

78 data_props = mappings["properties"]["data"]["properties"] # type: ignore[index] 

79 data_props[field_name] = { 

80 "type": "dense_vector", 

81 "dims": dimensions, 

82 "index": True, 

83 "similarity": "cosine" # Default similarity 

84 } 

85 

86 return mappings 

87 

88 @staticmethod 

89 def get_knn_index_settings() -> dict: 

90 """Get index settings optimized for KNN search. 

91  

92 Returns: 

93 Index settings dictionary 

94 """ 

95 return { 

96 "number_of_shards": 1, 

97 "number_of_replicas": 0, 

98 # Note: "knn" setting is not needed for standard Elasticsearch 

99 # KNN is enabled by having dense_vector fields with index=true 

100 } 

101 

102 

103class ElasticsearchVectorSupport: 

104 """Mixin for vector field detection and tracking.""" 

105 

106 def __init__(self): 

107 """Initialize vector support tracking.""" 

108 self.vector_fields: dict[str, int] = {} # field_name -> dimensions 

109 self.vector_enabled = False 

110 

111 def _detect_vector_fields(self, record: Record) -> dict[str, int]: 

112 """Detect vector fields in a record. 

113  

114 Args: 

115 record: Record to examine 

116  

117 Returns: 

118 Dict mapping field names to dimensions 

119 """ 

120 vector_fields = {} 

121 

122 for field_name, field_obj in record.fields.items(): 

123 if field_obj.type in (FieldType.VECTOR, FieldType.SPARSE_VECTOR): 

124 if isinstance(field_obj, VectorField) and field_obj.value is not None: 

125 # Get dimensions from the vector value 

126 if isinstance(field_obj.value, (list, np.ndarray)): 

127 dims = len(field_obj.value) if isinstance(field_obj.value, list) else field_obj.value.shape[0] 

128 vector_fields[field_name] = dims 

129 logger.debug(f"Detected vector field '{field_name}' with {dims} dimensions") 

130 

131 return vector_fields 

132 

133 def _has_vector_fields(self, record: Record) -> bool: 

134 """Check if a record has vector fields. 

135  

136 Args: 

137 record: Record to check 

138  

139 Returns: 

140 True if record has vector fields 

141 """ 

142 return len(self._detect_vector_fields(record)) > 0 

143 

144 def _update_vector_tracking(self, record: Record) -> None: 

145 """Update tracking of vector fields from a record. 

146  

147 Args: 

148 record: Record to examine 

149 """ 

150 detected = self._detect_vector_fields(record) 

151 for field_name, dims in detected.items(): 

152 if field_name not in self.vector_fields: 

153 self.vector_fields[field_name] = dims 

154 logger.info(f"Tracking new vector field '{field_name}' with {dims} dimensions") 

155 

156 

157class ElasticsearchErrorHandler: 

158 """Mixin for consistent error handling.""" 

159 

160 @staticmethod 

161 def _handle_elasticsearch_error(error: Exception, operation: str) -> None: 

162 """Handle Elasticsearch errors consistently. 

163  

164 Args: 

165 error: The exception that occurred 

166 operation: Description of the operation that failed 

167 """ 

168 from elasticsearch import ( 

169 ConnectionError, 

170 NotFoundError, 

171 RequestError, 

172 TransportError, 

173 ) 

174 

175 if isinstance(error, ConnectionError): 

176 logger.error(f"Connection error during {operation}: {error}") 

177 raise RuntimeError(f"Failed to connect to Elasticsearch: {error}") from error 

178 elif isinstance(error, NotFoundError): 

179 logger.warning(f"Resource not found during {operation}: {error}") 

180 raise ValueError(f"Resource not found: {error}") from error 

181 elif isinstance(error, RequestError): 

182 logger.error(f"Bad request during {operation}: {error}") 

183 raise ValueError(f"Invalid request: {error}") from error 

184 elif isinstance(error, TransportError): 

185 logger.error(f"Transport error during {operation}: {error}") 

186 raise RuntimeError(f"Elasticsearch transport error: {error}") from error 

187 else: 

188 logger.error(f"Unexpected error during {operation}: {error}") 

189 raise error 

190 

191 

192class ElasticsearchRecordSerializer: 

193 """Mixin for record serialization with vector field handling.""" 

194 

195 @staticmethod 

196 def _record_to_document(record: Record) -> dict[str, Any]: 

197 """Convert a record to an Elasticsearch document. 

198  

199 Args: 

200 record: Record to convert 

201  

202 Returns: 

203 Document dictionary for Elasticsearch 

204 """ 

205 # Serialize the record data 

206 data_dict = {} 

207 

208 for field_name, field_obj in record.fields.items(): 

209 if isinstance(field_obj, VectorField) and field_obj.value is not None: 

210 # Convert numpy arrays to lists for JSON serialization 

211 if isinstance(field_obj.value, np.ndarray): 

212 data_dict[field_name] = field_obj.value.tolist() 

213 else: 

214 data_dict[field_name] = field_obj.value 

215 else: 

216 data_dict[field_name] = field_obj.value 

217 

218 # Create the document 

219 doc = { 

220 "data": data_dict, 

221 "metadata": record.metadata, 

222 } 

223 

224 # Add timestamps if they exist as attributes 

225 if hasattr(record, "created_at") and record.created_at: 

226 doc["created_at"] = record.created_at.isoformat() 

227 if hasattr(record, "updated_at") and record.updated_at: 

228 doc["updated_at"] = record.updated_at.isoformat() 

229 

230 # Add ID if present 

231 if record.id: 

232 doc["id"] = record.id 

233 

234 return doc 

235 

236 @staticmethod 

237 def _document_to_record(doc: dict[str, Any], doc_id: str | None = None) -> Record: 

238 """Convert an Elasticsearch document to a record. 

239  

240 Args: 

241 doc: Document from Elasticsearch 

242 doc_id: Document ID from Elasticsearch 

243  

244 Returns: 

245 Record instance 

246 """ 

247 # Get the source data 

248 source = doc.get("_source", doc) 

249 

250 # Extract data and metadata 

251 data = source.get("data", {}) 

252 metadata = source.get("metadata", {}) 

253 

254 # Create fields 

255 fields = {} 

256 for field_name, value in data.items(): 

257 # Check if this is a vector field based on metadata 

258 field_meta = metadata.get("vector_fields", {}).get(field_name, {}) 

259 

260 if field_meta.get("type") == "vector" or ( 

261 isinstance(value, list) and len(value) > 0 and 

262 all(isinstance(v, (int, float)) for v in value) 

263 ): 

264 # This looks like a vector field 

265 vector_value = np.array(value, dtype=np.float32) if value else np.array([], dtype=np.float32) 

266 fields[field_name] = VectorField( 

267 name=field_name, 

268 value=vector_value, 

269 source_field=field_meta.get("source_field"), 

270 model_name=field_meta.get("model"), 

271 model_version=field_meta.get("model_version"), 

272 ) 

273 else: 

274 # Regular field - infer type from value 

275 field_type = FieldType.STRING # default 

276 if isinstance(value, bool): 

277 field_type = FieldType.BOOLEAN 

278 elif isinstance(value, int): 

279 field_type = FieldType.INTEGER 

280 elif isinstance(value, float): 

281 field_type = FieldType.FLOAT 

282 elif isinstance(value, dict) or (isinstance(value, (list, tuple)) and not all(isinstance(v, (int, float)) for v in value)): 

283 field_type = FieldType.JSON 

284 

285 fields[field_name] = Field( 

286 name=field_name, 

287 value=value, 

288 type=field_type, 

289 ) 

290 

291 # Create the record - pass fields as OrderedDict since they're Field objects 

292 from collections import OrderedDict 

293 record = Record(data=OrderedDict(fields), metadata=metadata) 

294 

295 # Set ID from document 

296 if doc_id: 

297 record.id = doc_id 

298 elif "_id" in doc: 

299 record.id = doc["_id"] 

300 elif "id" in source: 

301 record.id = source["id"] 

302 

303 # Set timestamps if available (as attributes, not fields) 

304 if source.get("created_at"): 

305 from datetime import datetime 

306 record.created_at = datetime.fromisoformat(source["created_at"]) 

307 

308 if source.get("updated_at"): 

309 from datetime import datetime 

310 record.updated_at = datetime.fromisoformat(source["updated_at"]) 

311 

312 return record 

313 

314 

315class ElasticsearchQueryBuilder: 

316 """Mixin for building Elasticsearch queries.""" 

317 

318 @staticmethod 

319 def _build_filter_query(filter_query: Query | None) -> dict[str, Any] | None: 

320 """Build Elasticsearch filter query from Query object. 

321  

322 Args: 

323 filter_query: Query object to convert 

324  

325 Returns: 

326 Elasticsearch query dict or None 

327 """ 

328 if not filter_query: 

329 return None 

330 

331 # TODO: Implement full query translation 

332 # For now, just support simple field equality 

333 from ..query import Operator 

334 

335 must_clauses = [] 

336 

337 if filter_query.filters: 

338 for filter_item in filter_query.filters: 

339 field_path = f"data.{filter_item.field}" 

340 

341 if filter_item.operator == Operator.EQ: 

342 # Use match query for text fields to handle analyzed text 

343 must_clauses.append({ 

344 "match": {field_path: filter_item.value} 

345 }) 

346 elif filter_item.operator == Operator.IN: 

347 must_clauses.append({ 

348 "terms": {field_path: filter_item.value} 

349 }) 

350 elif filter_item.operator == Operator.GT: 

351 must_clauses.append({ 

352 "range": {field_path: {"gt": filter_item.value}} 

353 }) 

354 elif filter_item.operator == Operator.LT: 

355 must_clauses.append({ 

356 "range": {field_path: {"lt": filter_item.value}} 

357 }) 

358 

359 if must_clauses: 

360 return {"bool": {"must": must_clauses}} 

361 

362 return None