Coverage for src / dataknobs_data / vector / sync.py: 18%
276 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-26 15:45 -0700
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-26 15:45 -0700
1"""Synchronization tools for keeping vectors up to date with text changes."""
3from __future__ import annotations
5import asyncio
6import hashlib
7import logging
8from collections import defaultdict
9from dataclasses import dataclass, field
10from datetime import datetime
11from typing import TYPE_CHECKING, Any
13import numpy as np
15from ..fields import VectorField
16from ..records import Record
18if TYPE_CHECKING:
19 from collections.abc import Callable, Coroutine
21 from ..database import Database
23logger = logging.getLogger(__name__)
26@dataclass
27class SyncConfig:
28 """Configuration for vector synchronization."""
30 auto_embed_on_create: bool = True
31 auto_update_on_text_change: bool = True
32 batch_size: int = 100
33 track_model_version: bool = True
34 embedding_timeout: float = 30.0
35 max_retries: int = 3
36 retry_delay: float = 1.0
38 def validate(self) -> None:
39 """Validate configuration parameters."""
40 if self.batch_size <= 0:
41 raise ValueError(f"Batch size must be positive, got {self.batch_size}")
42 if self.embedding_timeout <= 0:
43 raise ValueError(f"Embedding timeout must be positive, got {self.embedding_timeout}")
44 if self.max_retries < 0:
45 raise ValueError(f"Max retries cannot be negative, got {self.max_retries}")
48@dataclass
49class SyncStatus:
50 """Status of a synchronization operation."""
52 total_records: int = 0
53 processed_records: int = 0
54 updated_records: int = 0
55 failed_records: int = 0
56 skipped_records: int = 0
57 errors: list[dict[str, Any]] = field(default_factory=list)
58 start_time: datetime | None = None
59 end_time: datetime | None = None
61 @property
62 def success_rate(self) -> float:
63 """Calculate the success rate of the sync operation."""
64 if self.processed_records == 0:
65 return 0.0
66 return (self.processed_records - self.failed_records) / self.processed_records
68 @property
69 def duration(self) -> float | None:
70 """Calculate the duration of the sync operation in seconds."""
71 if self.start_time and self.end_time:
72 return (self.end_time - self.start_time).total_seconds()
73 return None
75 def to_dict(self) -> dict[str, Any]:
76 """Convert to dictionary representation."""
77 return {
78 "total_records": self.total_records,
79 "processed_records": self.processed_records,
80 "updated_records": self.updated_records,
81 "failed_records": self.failed_records,
82 "skipped_records": self.skipped_records,
83 "success_rate": self.success_rate,
84 "duration": self.duration,
85 "errors": self.errors,
86 "start_time": self.start_time.isoformat() if self.start_time else None,
87 "end_time": self.end_time.isoformat() if self.end_time else None,
88 }
91class VectorTextSynchronizer:
92 """Synchronizes vector embeddings with their source text fields."""
94 def __init__(
95 self,
96 database: Database,
97 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]],
98 text_fields: list[str] | str | None = None,
99 vector_field: str = "embedding",
100 field_separator: str = " ",
101 auto_sync: bool = True,
102 batch_size: int = 100,
103 model_name: str | None = None,
104 model_version: str | None = None,
105 config: SyncConfig | None = None,
106 ):
107 """Initialize the synchronizer with simplified API.
109 Args:
110 database: The database to synchronize
111 embedding_fn: Function to generate embeddings from text
112 text_fields: Fields to concatenate for embedding (if None, uses all text fields)
113 vector_field: Name of the vector field to store embeddings
114 field_separator: Separator for concatenating text fields
115 auto_sync: Whether to auto-sync on create/update
116 batch_size: Batch size for bulk operations
117 model_name: Name of the embedding model
118 model_version: Version of the embedding model
119 config: Advanced configuration object (overrides other params)
120 """
121 self.database = database
122 self.embedding_fn = embedding_fn
123 self.embedding_function = embedding_fn # Alias for compatibility
125 # Handle text_fields
126 if isinstance(text_fields, str):
127 text_fields = [text_fields]
128 self.text_fields = text_fields or []
130 self.vector_field = vector_field
131 self.field_separator = field_separator
132 self.auto_sync = auto_sync
133 self.batch_size = batch_size
134 self.model_name = model_name
135 self.model_version = model_version
137 # Use config if provided, otherwise create from params
138 if config:
139 self.config = config
140 else:
141 self.config = SyncConfig(
142 auto_embed_on_create=auto_sync,
143 auto_update_on_text_change=auto_sync,
144 batch_size=batch_size,
145 )
146 self.config.validate()
148 # Track vector fields and their source fields
149 self._vector_fields: dict[str, dict[str, Any]] = {}
150 self._source_fields: dict[str, list[str]] = defaultdict(list)
151 self._initialize_field_mappings()
153 def _initialize_field_mappings(self) -> None:
154 """Initialize mappings between vector fields and source fields."""
155 # Use schema if available
156 for field_name, field_schema in self.database.schema.fields.items():
157 if field_schema.is_vector_field():
158 self._vector_fields[field_name] = {
159 "dimensions": field_schema.get_dimensions() or 384,
160 "source_field": field_schema.get_source_field(),
161 }
162 source = field_schema.get_source_field()
163 if source:
164 self._source_fields[source].append(field_name)
166 def _compute_content_hash(self, content: str) -> str:
167 """Compute a hash of the content for change detection."""
168 return hashlib.md5(content.encode()).hexdigest()
170 def _has_current_vector(self, record: Record, vector_field: str) -> bool:
171 """Check if a record has a current vector for the given field.
173 Args:
174 record: The record to check
175 vector_field: Name of the vector field
177 Returns:
178 True if the vector is current, False otherwise
179 """
180 # Check if field exists
181 field_obj = record.fields.get(vector_field)
182 if not field_obj:
183 return False
185 # Get the vector value
186 vector_value = None
187 if isinstance(field_obj, VectorField):
188 vector_value = field_obj.value
189 if vector_value is None:
190 return False
192 # For VectorField, check model version if tracking is enabled
193 if self.config.track_model_version and self.model_version:
194 stored_version = field_obj.model_version
195 if stored_version != self.model_version:
196 return False
197 else:
198 # Plain value (list or array)
199 vector_value = field_obj.value
200 if vector_value is None:
201 return False
202 if not isinstance(vector_value, (list, np.ndarray)):
203 return False
205 # For plain values, check metadata and content hash separately
206 if self.config.track_model_version and self.model_version:
207 metadata_field = f"{vector_field}_metadata"
208 metadata = record.get_value(metadata_field)
209 if not metadata or not isinstance(metadata, dict):
210 return False
211 stored_version = metadata.get("model_version")
212 if stored_version != self.model_version:
213 return False
215 # Check content hash if source field exists
216 field_info = self._vector_fields.get(vector_field)
217 if field_info and field_info.get("source_field"):
218 source_content = record.get_value(field_info["source_field"], "")
219 if source_content:
220 # For VectorField objects, we don't check content hash
221 # as they're considered immutable once created
222 if isinstance(field_obj, VectorField):
223 # VectorField with matching version is considered current
224 return True
226 # For plain values, check the content hash field
227 hash_field = f"{vector_field}_content_hash"
228 stored_hash = record.get_value(hash_field)
229 current_hash = self._compute_content_hash(str(source_content))
230 if stored_hash != current_hash:
231 return False
233 return True
235 def _needs_update(self, record: Record, vector_field: str) -> bool:
236 """Check if a vector field needs to be updated.
238 Args:
239 record: The record to check
240 vector_field: Name of the vector field
242 Returns:
243 True if the vector needs updating, False otherwise
244 """
245 return not self._has_current_vector(record, vector_field)
247 async def _embed_text(self, text: str) -> np.ndarray | None:
248 """Generate embedding for text with error handling.
250 Args:
251 text: Text to embed
253 Returns:
254 Embedding vector or None if failed
255 """
256 if not text:
257 return None
259 for attempt in range(self.config.max_retries):
260 try:
261 if asyncio.iscoroutinefunction(self.embedding_fn):
262 result = await asyncio.wait_for(
263 self.embedding_fn(text),
264 timeout=self.config.embedding_timeout
265 )
266 else:
267 result = await asyncio.to_thread(self.embedding_fn, text)
269 if isinstance(result, np.ndarray):
270 return result
271 elif isinstance(result, list):
272 return np.array(result)
273 else:
274 logger.error(f"Embedding function returned unexpected type: {type(result)}")
275 return None
277 except asyncio.TimeoutError:
278 logger.warning(f"Embedding timeout on attempt {attempt + 1}")
279 if attempt < self.config.max_retries - 1:
280 await asyncio.sleep(self.config.retry_delay)
281 except Exception as e:
282 logger.error(f"Embedding error on attempt {attempt + 1}: {e}")
283 if attempt < self.config.max_retries - 1:
284 await asyncio.sleep(self.config.retry_delay)
286 return None
288 async def sync_record(
289 self,
290 record_or_id: Record | str,
291 force: bool = False
292 ) -> tuple[bool, list[str]]:
293 """Synchronize vectors for a single record.
295 Args:
296 record_or_id: The record or record ID to synchronize
297 force: Force update even if vectors appear current
299 Returns:
300 Tuple of (success, list of updated fields)
301 """
302 # Get record if ID provided
303 if isinstance(record_or_id, str):
304 record = await self.database.read(record_or_id)
305 if not record:
306 return False, []
307 record_id = record_or_id
308 else:
309 record = record_or_id
310 record_id = record.id
312 updated_fields = []
313 failed_fields = []
315 # If text_fields are specified, use them for the default vector field
316 if self.text_fields:
317 text_parts = []
318 for field in self.text_fields:
319 value = record.get_value(field)
320 if value:
321 text_parts.append(str(value))
323 if text_parts:
324 text = self.field_separator.join(text_parts)
325 embedding = await self._embed_text(text)
326 if embedding is not None:
327 from ..fields import VectorField
328 # Compute content hash for change tracking
329 content_hash = self._compute_content_hash(text)
330 vector_field_obj = VectorField(
331 value=embedding,
332 name=self.vector_field,
333 source_field=self.text_fields[0] if len(self.text_fields) == 1 else None,
334 model_name=self.model_name,
335 model_version=self.model_version,
336 metadata={"content_hash": content_hash}
337 )
338 record.fields[self.vector_field] = vector_field_obj
339 updated_fields.append(self.vector_field)
340 else:
341 # Embedding generation failed
342 failed_fields.append(self.vector_field)
344 # Also process vector fields defined in schema with source fields
345 for vector_field_name, field_info in self._vector_fields.items():
346 source_field = field_info.get("source_field")
347 if source_field and (force or self._needs_update(record, vector_field_name)):
348 source_value = record.get_value(source_field)
349 if source_value:
350 source_text = str(source_value)
351 embedding = await self._embed_text(source_text)
352 if embedding is not None:
353 from ..fields import VectorField
354 # Compute content hash for change tracking
355 content_hash = self._compute_content_hash(source_text)
356 vector_field_obj = VectorField(
357 value=embedding,
358 name=vector_field_name,
359 source_field=source_field,
360 model_name=self.model_name,
361 model_version=self.model_version,
362 metadata={"content_hash": content_hash}
363 )
364 record.fields[vector_field_name] = vector_field_obj
365 updated_fields.append(vector_field_name)
366 else:
367 # Embedding generation failed
368 failed_fields.append(vector_field_name)
370 # Save to database if any fields were updated
371 if updated_fields:
372 # Use storage_id if available, otherwise fall back to record.id
373 update_id = record.storage_id if record.has_storage_id() else record_id
374 await self.database.update(update_id, record)
376 # Return success=False if there were failures and no successes
377 success = len(failed_fields) == 0 or len(updated_fields) > 0
378 return success, updated_fields
380 async def sync_all(
381 self,
382 batch_size: int | None = None,
383 force: bool = False,
384 progress_callback: Callable[[int, int], None] | None = None,
385 ) -> dict[str, Any]:
386 """Synchronize all records in the database.
388 Args:
389 batch_size: Batch size for processing (uses self.batch_size if None)
390 force: Force update even if vectors appear current
391 progress_callback: Callback for progress updates (done, total)
393 Returns:
394 Dictionary with sync results
395 """
396 from ..query import Query
398 batch_size = batch_size or self.batch_size
400 # Get all records
401 all_records = await self.database.search(Query())
402 total = len(all_records)
404 processed = 0
405 updated = 0
406 failed = 0
408 # Process in batches
409 for i in range(0, total, batch_size):
410 batch = all_records[i:i + batch_size]
412 for record in batch:
413 success, fields = await self.sync_record(record, force=force)
415 processed += 1
416 if success and fields:
417 updated += 1
418 elif not success:
419 failed += 1
421 if progress_callback:
422 progress_callback(processed, total)
424 return {
425 "processed": processed,
426 "updated": updated,
427 "failed": failed,
428 "total": total,
429 }
431 async def bulk_sync(
432 self,
433 records: list[Record] | None = None,
434 force: bool = False,
435 progress_callback: Callable[[SyncStatus], None] | None = None,
436 ) -> SyncStatus:
437 """Synchronize vectors for multiple records in batches.
439 Args:
440 records: Records to sync (None for all records in database)
441 force: Force update even if vectors appear current
442 progress_callback: Callback for progress updates
444 Returns:
445 Synchronization status
446 """
447 status = SyncStatus(start_time=datetime.utcnow())
449 try:
450 # Get records if not provided
451 if records is None:
452 records = await self.database.all()
454 status.total_records = len(records)
456 # Process in batches
457 for i in range(0, len(records), self.config.batch_size):
458 batch = records[i:i + self.config.batch_size]
460 for record in batch:
461 try:
462 success, updated_fields = await self.sync_record(record, force)
463 status.processed_records += 1
465 if updated_fields:
466 # sync_record already updates the database
467 status.updated_records += 1
468 elif success:
469 status.skipped_records += 1
470 else:
471 status.failed_records += 1
473 except Exception as e:
474 status.failed_records += 1
475 status.errors.append({
476 "record_id": record.id,
477 "error": str(e),
478 })
479 logger.error(f"Failed to sync record {record.id}: {e}")
481 # Call progress callback
482 if progress_callback:
483 progress_callback(status)
485 finally:
486 status.end_time = datetime.utcnow()
488 logger.info(
489 f"Sync completed: {status.updated_records} updated, "
490 f"{status.skipped_records} skipped, {status.failed_records} failed"
491 )
493 return status
495 async def sync_on_update(
496 self,
497 record_id: str,
498 old_data: dict[str, Any],
499 new_data: dict[str, Any],
500 ) -> bool:
501 """Handle record updates and sync vectors if needed.
503 Args:
504 record_id: ID of the updated record
505 old_data: Previous data
506 new_data: New data
508 Returns:
509 True if sync was performed, False otherwise
510 """
511 if not self.config.auto_update_on_text_change:
512 return False
514 # Check if any source fields changed
515 fields_to_update = set()
516 for source_field, vector_fields in self._source_fields.items():
517 old_value = old_data.get(source_field)
518 new_value = new_data.get(source_field)
520 if old_value != new_value:
521 fields_to_update.update(vector_fields)
523 if not fields_to_update:
524 return False
526 # Create record and sync
527 record = Record(id=record_id, data=new_data)
528 _success, updated_fields = await self.sync_record(record, force=True)
530 if updated_fields:
531 # Update only the vector fields
532 update_data = {
533 field: record.get_value(field)
534 for field in updated_fields
535 if record.get_value(field) is not None
536 }
538 # Include metadata fields
539 for field in updated_fields:
540 metadata_field = f"{field}_metadata"
541 metadata_value = record.get_value(metadata_field)
542 if metadata_value is not None:
543 update_data[metadata_field] = metadata_value
545 hash_field = f"{field}_content_hash"
546 hash_value = record.get_value(hash_field)
547 if hash_value is not None:
548 update_data[hash_field] = hash_value
550 # Get the existing record and update it
551 existing_record = await self.database.read(record_id)
552 if existing_record:
553 for key, value in update_data.items():
554 existing_record.set_value(key, value)
555 await self.database.update(record_id, existing_record)
556 return True
558 return False
560 async def sync_on_create(self, record: Record) -> bool:
561 """Handle record creation and sync vectors if needed.
563 Args:
564 record: The newly created record
566 Returns:
567 True if sync was performed, False otherwise
568 """
569 if not self.config.auto_embed_on_create:
570 return False
572 _success, updated_fields = await self.sync_record(record)
574 if updated_fields:
575 # Update the record with vector data
576 await self.database.update(record.id, record)
577 return True
579 return False
581 @classmethod
582 def from_config(
583 cls,
584 database: Database,
585 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]],
586 config: SyncConfig,
587 text_fields: list[str] | None = None,
588 vector_field: str = "embedding",
589 model_name: str | None = None,
590 model_version: str | None = None,
591 ) -> VectorTextSynchronizer:
592 """Create synchronizer from a config object for advanced use cases.
594 Args:
595 database: The database to synchronize
596 embedding_fn: Function to generate embeddings from text
597 config: Synchronization configuration
598 text_fields: Text field names (optional)
599 vector_field: Name of the vector field
600 model_name: Name of the embedding model
601 model_version: Version of the embedding model
603 Returns:
604 Configured VectorTextSynchronizer instance
605 """
606 return cls(
607 database=database,
608 embedding_fn=embedding_fn,
609 text_fields=text_fields,
610 vector_field=vector_field,
611 auto_sync=config.auto_embed_on_create,
612 batch_size=config.batch_size,
613 model_name=model_name,
614 model_version=model_version,
615 config=config,
616 )