Coverage for src/dataknobs_data/vector/migration.py: 18%
497 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 07:20 -0600
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 07:20 -0600
1"""Migration tools for adding vector support to existing data."""
3from __future__ import annotations
5import asyncio
6import logging
7from dataclasses import dataclass, field
8from datetime import datetime
9from typing import TYPE_CHECKING, Any
11import numpy as np
13from ..fields import FieldType
14from ..query import Query
15from ..records import Record
16from ..schema import FieldSchema
17from .sync import SyncConfig, VectorTextSynchronizer
18from .types import VectorMetadata
20if TYPE_CHECKING:
21 from collections.abc import Callable, Coroutine
23 from ..database import Database
25logger = logging.getLogger(__name__)
28@dataclass
29class MigrationConfig:
30 """Configuration for vector migration."""
32 batch_size: int = 100
33 max_workers: int = 4
34 checkpoint_interval: int = 1000
35 enable_rollback: bool = True
36 verify_migration: bool = True
37 retry_failed: bool = True
38 max_retries: int = 3
39 max_consecutive_failures: int = 5 # Fail fast after this many consecutive failures
41 def validate(self) -> None:
42 """Validate configuration parameters."""
43 if self.batch_size <= 0:
44 raise ValueError(f"Batch size must be positive, got {self.batch_size}")
45 if self.max_workers <= 0:
46 raise ValueError(f"Max workers must be positive, got {self.max_workers}")
49@dataclass
50class MigrationStatus:
51 """Status of a migration operation."""
53 total_records: int = 0
54 migrated_records: int = 0
55 verified_records: int = 0
56 failed_records: int = 0
57 rollback_records: int = 0
58 errors: list[dict[str, Any]] = field(default_factory=list)
59 checkpoints: list[dict[str, Any]] = field(default_factory=list)
60 start_time: datetime | None = None
61 end_time: datetime | None = None
63 @property
64 def total_processed(self) -> int:
65 """Total number of processed records (migrated + failed)."""
66 return self.migrated_records + self.failed_records
68 @property
69 def failed_count(self) -> int:
70 """Alias for failed_records for compatibility."""
71 return self.failed_records
73 @property
74 def success_rate(self) -> float:
75 """Calculate the success rate of the migration."""
76 if self.total_records == 0:
77 return 0.0
78 return self.migrated_records / self.total_records
80 @property
81 def duration(self) -> float | None:
82 """Calculate the duration of the migration in seconds."""
83 if self.start_time and self.end_time:
84 return (self.end_time - self.start_time).total_seconds()
85 return None
87 @property
88 def records_per_second(self) -> float:
89 """Calculate the migration speed."""
90 duration = self.duration
91 if duration and duration > 0:
92 return self.migrated_records / duration
93 return 0.0
95 def add_checkpoint(self, name: str, record_id: str | None = None) -> None:
96 """Add a checkpoint to the migration."""
97 self.checkpoints.append({
98 "name": name,
99 "record_id": record_id,
100 "timestamp": datetime.utcnow().isoformat(),
101 "migrated": self.migrated_records,
102 "failed": self.failed_records,
103 })
105 def to_dict(self) -> dict[str, Any]:
106 """Convert to dictionary representation."""
107 return {
108 "total_records": self.total_records,
109 "migrated_records": self.migrated_records,
110 "verified_records": self.verified_records,
111 "failed_records": self.failed_records,
112 "rollback_records": self.rollback_records,
113 "success_rate": self.success_rate,
114 "duration": self.duration,
115 "records_per_second": self.records_per_second,
116 "errors": self.errors,
117 "checkpoints": self.checkpoints,
118 "start_time": self.start_time.isoformat() if self.start_time else None,
119 "end_time": self.end_time.isoformat() if self.end_time else None,
120 }
123class VectorMigration:
124 """Manages migration of existing data to include vector embeddings."""
126 def __init__(
127 self,
128 source_db: Database,
129 target_db: Database | None = None,
130 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]] = None,
131 text_fields: list[str] | None = None,
132 vector_field: str = "embedding",
133 field_separator: str = " ",
134 batch_size: int = 100,
135 max_retries: int = 3,
136 retry_delay: float = 1.0,
137 model_name: str | None = None,
138 model_version: str | None = None,
139 config: MigrationConfig | None = None,
140 ):
141 """Initialize the migration manager with simplified API.
143 Args:
144 source_db: Source database to migrate from
145 target_db: Target database (None to migrate in-place)
146 embedding_fn: Function to generate embeddings
147 text_fields: Fields to concatenate for embedding
148 vector_field: Name of the vector field to create
149 field_separator: Separator for concatenating text fields
150 batch_size: Batch size for processing
151 max_retries: Maximum retry attempts
152 retry_delay: Delay between retries
153 model_name: Name of the embedding model
154 model_version: Version of the embedding model
155 config: Advanced configuration (overrides other params)
156 """
157 self.source_db = source_db
158 self.target_db = target_db or source_db
159 self.embedding_fn = embedding_fn
160 self.embedding_function = embedding_fn # Alias for compatibility
161 self.text_fields = text_fields or []
162 self.vector_field = vector_field
163 self.field_separator = field_separator
164 self.batch_size = batch_size
165 self.max_retries = max_retries
166 self.retry_delay = retry_delay
167 self.model_name = model_name
168 self.model_version = model_version
170 # Use config if provided, otherwise create from params
171 if config:
172 self.config = config
173 else:
174 self.config = MigrationConfig(
175 batch_size=batch_size,
176 max_retries=max_retries,
177 )
178 self.config.validate()
180 # Migration status
181 self.status = MigrationStatus()
183 # Track rollback data if enabled
184 self._rollback_data: dict[str, dict[str, Any]] = {}
186 async def run(
187 self,
188 progress_callback: Callable[[MigrationStatus], None] | None = None
189 ) -> MigrationStatus:
190 """Run the complete migration.
192 Args:
193 progress_callback: Optional callback for progress updates
195 Returns:
196 Migration status
197 """
198 self.status = MigrationStatus(start_time=datetime.utcnow())
200 try:
201 # Get all records from source
202 all_records = await self.source_db.search(Query())
203 self.status.total_records = len(all_records)
205 # Process in batches
206 for i in range(0, len(all_records), self.batch_size):
207 batch = all_records[i:i + self.batch_size]
209 for record in batch:
210 try:
211 # Concatenate text fields
212 text_parts = []
213 for field in self.text_fields:
214 value = record.get_value(field)
215 if value:
216 text_parts.append(str(value))
218 if text_parts:
219 text = self.field_separator.join(text_parts)
221 # Generate embedding
222 if asyncio.iscoroutinefunction(self.embedding_fn):
223 embedding = await self.embedding_fn(text)
224 else:
225 embedding = await asyncio.to_thread(self.embedding_fn, text)
227 # Create VectorField
228 from ..fields import VectorField
229 vector_field_obj = VectorField(
230 value=embedding,
231 name=self.vector_field,
232 source_field=self.text_fields[0] if len(self.text_fields) == 1 else None,
233 model_name=self.model_name,
234 model_version=self.model_version,
235 )
237 # Add to record
238 record.fields[self.vector_field] = vector_field_obj
240 # Create in target database
241 await self.target_db.create(record)
242 self.status.migrated_records += 1
244 except Exception as e:
245 logger.error(f"Failed to migrate record {record.id}: {e}")
246 self.status.failed_records += 1
247 self.status.errors.append({"record_id": record.id, "error": str(e)})
249 if progress_callback:
250 progress_callback(self.status)
252 self.status.end_time = datetime.utcnow()
253 return self.status
255 except Exception as e:
256 logger.error(f"Migration failed: {e}")
257 self.status.failed_records = self.status.total_records - self.status.migrated_records
258 self.status.end_time = datetime.utcnow()
259 return self.status
261 async def start(self) -> None:
262 """Start migration (for compatibility)."""
263 # Migration runs synchronously in run() method
264 pass
266 async def wait_for_completion(self, progress_callback: Callable[[MigrationStatus], None] | None = None) -> MigrationStatus:
267 """Wait for migration completion (for compatibility)."""
268 # Since run() is synchronous, just return current status
269 return self.status
271 async def add_vectors_to_existing(
272 self,
273 vector_fields: dict[str, str], # vector_field -> source_field mapping
274 filter_query: dict[str, Any] | None = None,
275 progress_callback: Callable[[MigrationStatus], None] | None = None,
276 ) -> MigrationStatus:
277 """Add vector fields to existing records.
279 Args:
280 vector_fields: Mapping of vector field names to source text fields
281 filter_query: Optional filter to select records to migrate
282 progress_callback: Callback for progress updates
284 Returns:
285 Migration status
286 """
287 if not self.embedding_fn:
288 raise ValueError("Embedding function required for adding vectors")
290 status = MigrationStatus(start_time=datetime.utcnow())
292 try:
293 # Get records to migrate
294 if filter_query:
295 # Convert filter_query dict to Query object
296 query = Query()
297 for field, value in filter_query.items():
298 query = query.filter(field, "==", value)
299 records = await self.source_db.search(query)
300 else:
301 records = await self.source_db.all()
303 status.total_records = len(records)
304 logger.info(f"Starting migration of {status.total_records} records")
306 # Create synchronizer with wrapped embedding function
307 sync_config = SyncConfig(
308 batch_size=self.config.batch_size,
309 max_retries=self.config.max_retries,
310 )
312 # Track the last embedding exception
313 last_embedding_exception = None
315 # Create wrapper that captures exceptions
316 async def embedding_wrapper(text: str) -> np.ndarray:
317 nonlocal last_embedding_exception
318 try:
319 if asyncio.iscoroutinefunction(self.embedding_fn):
320 result = await self.embedding_fn(text)
321 else:
322 result = await asyncio.to_thread(self.embedding_fn, text)
323 return result
324 except Exception as e:
325 last_embedding_exception = e
326 raise
328 synchronizer = VectorTextSynchronizer(
329 database=self.target_db,
330 embedding_fn=embedding_wrapper,
331 config=sync_config,
332 model_name=self.model_name,
333 model_version=self.model_version,
334 )
336 # Process in batches
337 consecutive_batch_failures = 0
338 for i in range(0, len(records), self.config.batch_size):
339 batch = records[i:i + self.config.batch_size]
341 # Process batch with workers
342 tasks = []
343 for record in batch:
344 # Store original data for rollback
345 if self.config.enable_rollback:
346 # Store original field values for rollback
347 self._rollback_data[record.id] = {
348 field_name: record.get_value(field_name)
349 for field_name in record.fields.keys()
350 }
352 # Add vector fields to record
353 for vector_field, source_field in vector_fields.items():
354 if record.get_value(source_field) is None:
355 continue
357 # Add vector field schema if needed
358 if vector_field not in self.target_db.schema.fields:
359 source_text = record.get_value(source_field)
360 if source_text:
361 # Get dimensions from first embedding
362 sample_embedding = await self._get_embedding(str(source_text))
363 if sample_embedding is not None:
364 dimensions = len(sample_embedding)
365 # Add schema for vector field
366 field_schema = FieldSchema(
367 name=vector_field,
368 type=FieldType.VECTOR,
369 metadata={
370 "dimensions": dimensions,
371 "source_field": source_field,
372 }
373 )
374 self.target_db.add_field_schema(field_schema)
376 # Create migration task
377 task = self._migrate_record(
378 synchronizer,
379 record,
380 vector_fields,
381 status,
382 )
383 tasks.append(task)
385 # Wait for batch to complete
386 results = await asyncio.gather(*tasks, return_exceptions=False)
388 # Check for batch failures and fail fast if needed
389 batch_failed_count = sum(1 for r in results if r is False)
390 if batch_failed_count == len(results) and len(results) > 0:
391 consecutive_batch_failures += 1
392 # If multiple consecutive batches completely fail, re-raise the last exception
393 if consecutive_batch_failures >= 2 and self.config.enable_rollback:
394 if last_embedding_exception:
395 raise last_embedding_exception
396 else:
397 raise Exception("Migration failed: consecutive batch failures")
398 else:
399 consecutive_batch_failures = 0
401 # Checkpoint if needed
402 if status.migrated_records % self.config.checkpoint_interval == 0:
403 status.add_checkpoint(
404 f"Batch {i // self.config.batch_size + 1}",
405 batch[-1].id if batch else None,
406 )
407 if progress_callback:
408 progress_callback(status)
410 # Verify migration if enabled
411 if self.config.verify_migration:
412 await self._verify_migration(vector_fields, records, status)
414 except Exception as e:
415 logger.error(f"Migration failed: {e}")
416 if self.config.enable_rollback:
417 await self._rollback(status)
418 raise
420 finally:
421 status.end_time = datetime.utcnow()
423 logger.info(
424 f"Migration completed: {status.migrated_records}/{status.total_records} "
425 f"migrated, {status.failed_records} failed"
426 )
428 return status
430 async def _get_embedding(self, text: str) -> np.ndarray | None:
431 """Get embedding for text."""
432 try:
433 if asyncio.iscoroutinefunction(self.embedding_fn):
434 result = await self.embedding_fn(text)
435 else:
436 result = await asyncio.to_thread(self.embedding_fn, text)
438 if isinstance(result, np.ndarray):
439 return result
440 elif isinstance(result, list):
441 return np.array(result)
442 return None
444 except Exception as e:
445 logger.error(f"Failed to get embedding: {e}")
446 return None
448 async def _migrate_record(
449 self,
450 synchronizer: VectorTextSynchronizer,
451 record: Record,
452 vector_fields: dict[str, str],
453 status: MigrationStatus,
454 ) -> bool:
455 """Migrate a single record.
457 Returns:
458 True if migration succeeded, False otherwise
459 """
460 try:
461 # Sync vectors
462 success, updated_fields = await synchronizer.sync_record(record, force=True)
464 if success and updated_fields:
465 # Update record in target database
466 await self.target_db.update(record.id, record)
467 status.migrated_records += 1
468 return True
469 else:
470 status.failed_records += 1
471 status.errors.append({
472 "record_id": record.id,
473 "error": "Failed to generate vectors",
474 })
475 return False
477 except Exception as e:
478 status.failed_records += 1
479 status.errors.append({
480 "record_id": record.id,
481 "error": str(e),
482 })
483 logger.error(f"Failed to migrate record {record.id}: {e}")
484 return False
486 async def _verify_migration(
487 self,
488 vector_fields: dict[str, str],
489 records: list[Record],
490 status: MigrationStatus,
491 ) -> None:
492 """Verify that migration was successful."""
493 logger.info("Verifying migration...")
495 for record in records:
496 try:
497 # Get updated record
498 migrated = await self.target_db.read(record.id)
500 # Check vector fields
501 all_present = True
502 for vector_field, source_field in vector_fields.items():
503 source_value = record.get_value(source_field)
504 if source_value:
505 # Check if vector field exists (could be in fields or data)
506 vector_data = migrated.get_value(vector_field)
507 if vector_data is None:
508 all_present = False
509 break
511 # For VectorField objects, check the value
512 from ..fields import VectorField
513 if isinstance(migrated.fields.get(vector_field), VectorField):
514 vector_data = migrated.fields[vector_field].value
516 if not isinstance(vector_data, (list, np.ndarray)):
517 all_present = False
518 break
520 if all_present:
521 status.verified_records += 1
523 except Exception as e:
524 logger.error(f"Failed to verify record {record.id}: {e}")
526 async def _rollback(self, status: MigrationStatus) -> None:
527 """Rollback migration on failure."""
528 if not self._rollback_data:
529 return
531 logger.info(f"Rolling back {len(self._rollback_data)} records...")
533 for record_id, original_data in self._rollback_data.items():
534 try:
535 # Restore original record
536 original_record = Record(id=record_id, data=original_data)
537 await self.target_db.update(record_id, original_record)
538 status.rollback_records += 1
539 except Exception as e:
540 logger.error(f"Failed to rollback record {record_id}: {e}")
542 async def migrate_between_backends(
543 self,
544 field_mapping: dict[str, str] | None = None,
545 transform_fn: Callable[[Record], Record] | None = None,
546 progress_callback: Callable[[MigrationStatus], None] | None = None,
547 ) -> MigrationStatus:
548 """Migrate vector data between different backends.
550 Args:
551 field_mapping: Optional field name mapping
552 transform_fn: Optional record transformation function
553 progress_callback: Callback for progress updates
555 Returns:
556 Migration status
557 """
558 status = MigrationStatus(start_time=datetime.utcnow())
560 try:
561 # Get all records with vectors
562 records = await self.source_db.all()
563 status.total_records = len(records)
565 logger.info(
566 f"Migrating {status.total_records} records from "
567 f"{self.source_db.__class__.__name__} to "
568 f"{self.target_db.__class__.__name__}"
569 )
571 # Process in batches
572 for i in range(0, len(records), self.config.batch_size):
573 batch = records[i:i + self.config.batch_size]
575 for original_record in batch:
576 try:
577 record = original_record
578 # Apply field mapping
579 if field_mapping:
580 new_data = {}
581 for old_field, new_field in field_mapping.items():
582 old_value = record.get_value(old_field)
583 if old_value is not None:
584 new_data[new_field] = old_value
585 # Update record with new field mapping
586 for field_name, value in new_data.items():
587 record.set_value(field_name, value)
589 # Apply transformation
590 if transform_fn:
591 record = transform_fn(record)
593 # Create in target database
594 await self.target_db.create(record)
595 status.migrated_records += 1
597 except Exception as e:
598 status.failed_records += 1
599 status.errors.append({
600 "record_id": record.id,
601 "error": str(e),
602 })
603 logger.error(f"Failed to migrate record {record.id}: {e}")
605 # Progress update
606 if progress_callback:
607 progress_callback(status)
609 finally:
610 status.end_time = datetime.utcnow()
612 return status
614 @classmethod
615 def from_config(
616 cls,
617 source_db: Database,
618 target_db: Database | None,
619 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]],
620 config: MigrationConfig,
621 text_fields: list[str] | None = None,
622 vector_field: str = "embedding",
623 model_name: str | None = None,
624 model_version: str | None = None,
625 ) -> VectorMigration:
626 """Create migration from a config object for advanced use cases.
628 Args:
629 source_db: Source database
630 target_db: Target database (None for in-place)
631 embedding_fn: Function to generate embeddings
632 config: Migration configuration
633 text_fields: Text field names (optional)
634 vector_field: Name of the vector field
635 model_name: Name of the embedding model
636 model_version: Version of the embedding model
638 Returns:
639 Configured VectorMigration instance
640 """
641 return cls(
642 source_db=source_db,
643 target_db=target_db,
644 embedding_fn=embedding_fn,
645 text_fields=text_fields,
646 vector_field=vector_field,
647 batch_size=config.batch_size,
648 model_name=model_name,
649 model_version=model_version,
650 config=config,
651 )
654class IncrementalVectorizer:
655 """Manages incremental vectorization of large datasets.
657 Examples:
658 # Simple usage with single field
659 vectorizer = IncrementalVectorizer(
660 db,
661 embedding_fn=model.encode,
662 text_fields="content" # Can be string or list
663 )
664 result = await vectorizer.run()
666 # Resume from checkpoint
667 result = await vectorizer.run(resume_from=last_checkpoint)
669 # Process limited batch
670 result = await vectorizer.run_batch(limit=1000)
671 """
673 def __init__(
674 self,
675 database: Database,
676 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]],
677 text_fields: list[str] | str | None = None, # Support multiple fields
678 vector_field: str = "embedding", # Sensible default
679 field_separator: str = " ",
680 batch_size: int = 100,
681 checkpoint_interval: int = 1000,
682 max_workers: int = 4,
683 model_name: str | None = None,
684 model_version: str | None = None,
685 ):
686 """Initialize the incremental vectorizer with simplified parameters.
688 Args:
689 database: Database to vectorize
690 embedding_fn: Function to generate embeddings
691 text_fields: Text field names to concatenate for embeddings
692 vector_field: Name of the vector field to create
693 field_separator: Separator for concatenating multiple text fields
694 batch_size: Size of processing batches
695 checkpoint_interval: Records between checkpoints
696 max_workers: Maximum concurrent workers
697 model_name: Name of the embedding model
698 model_version: Version of the embedding model
699 """
700 self.database = database
701 self.embedding_fn = embedding_fn
702 self.embedding_function = embedding_fn # Alias for compatibility
704 # Handle text fields
705 if isinstance(text_fields, str):
706 text_fields = [text_fields]
707 elif text_fields is None:
708 # Try to auto-detect from database schema
709 text_fields = self._detect_text_fields()
710 self.text_fields = text_fields
712 self.vector_field = vector_field
713 self.field_separator = field_separator
714 self.batch_size = batch_size
715 self.checkpoint_interval = checkpoint_interval
716 self.max_workers = max_workers
717 self.model_name = model_name
718 self.model_version = model_version
720 # Processing state
721 self._queue: asyncio.Queue[Record] = asyncio.Queue()
722 self._processing_task: asyncio.Task | None = None
723 self._workers: list[asyncio.Task] = []
724 self._shutdown_event = asyncio.Event()
725 self._stats = {
726 "processed": 0,
727 "failed": 0,
728 "queued": 0,
729 }
730 self._last_checkpoint: str | None = None
731 self._progress: VectorizationProgress | None = None
733 def _detect_text_fields(self) -> list[str]:
734 """Auto-detect text fields from database schema."""
735 text_fields = []
736 if hasattr(self.database, 'schema') and self.database.schema:
737 for field_name, field_schema in self.database.schema.fields.items():
738 if field_schema.type in (FieldType.STRING, FieldType.TEXT):
739 text_fields.append(field_name)
741 # Default to common field names if no schema
742 if not text_fields:
743 text_fields = ["content", "text", "description"]
745 return text_fields
747 async def _worker(self, worker_id: int) -> None:
748 """Worker task for processing records."""
749 logger.info(f"Worker {worker_id} started")
751 while not self._shutdown_event.is_set():
752 try:
753 # Get record from queue with timeout
754 try:
755 record = await asyncio.wait_for(
756 self._queue.get(),
757 timeout=1.0
758 )
759 except asyncio.TimeoutError:
760 continue
762 # Process record
763 await self._process_record(record)
764 self._stats["processed"] += 1
766 except Exception as e:
767 logger.error(f"Worker {worker_id} error: {e}")
768 self._stats["failed"] += 1
770 logger.info(f"Worker {worker_id} stopped")
772 async def _process_record(self, record: Record) -> None:
773 """Process a single record to add vectors."""
774 try:
775 # Get source text from multiple fields
776 text_parts = []
777 for field in self.text_fields:
778 value = record.get_value(field)
779 if value:
780 text_parts.append(str(value))
782 if not text_parts:
783 return
785 source_text = self.field_separator.join(text_parts)
787 # Check if vector already exists
788 vector_data = record.get_value(self.vector_field)
789 if vector_data is not None:
790 if vector_data and isinstance(vector_data, (list, np.ndarray)):
791 return
793 # Generate embedding
794 if asyncio.iscoroutinefunction(self.embedding_fn):
795 embedding = await self.embedding_fn(str(source_text))
796 else:
797 embedding = await asyncio.to_thread(self.embedding_fn, str(source_text))
799 if embedding is None:
800 return
802 # Update record
803 update_data = {
804 self.vector_field: embedding.tolist() if isinstance(embedding, np.ndarray) else embedding,
805 }
807 # Add metadata
808 if self.model_name:
809 metadata = VectorMetadata(
810 dimensions=len(embedding),
811 source_field=self.field_separator.join(self.text_fields),
812 model_name=self.model_name,
813 model_version=self.model_version,
814 updated_at=datetime.utcnow().isoformat(),
815 )
816 update_data[f"{self.vector_field}_metadata"] = metadata.to_dict()
818 # Update the record with the new vector data
819 for key, value in update_data.items():
820 record.set_value(key, value)
821 await self.database.update(record.id, record)
823 except Exception as e:
824 logger.error(f"Failed to process record {record.id}: {e}")
825 raise
827 async def start(self) -> None:
828 """Start incremental vectorization."""
829 if self._processing_task and not self._processing_task.done():
830 logger.warning("Vectorization already running")
831 return
833 self._shutdown_event.clear()
835 # Start workers
836 self._workers = [
837 asyncio.create_task(self._worker(i))
838 for i in range(self.max_workers)
839 ]
841 # Start queue loader
842 self._processing_task = asyncio.create_task(self._load_queue())
844 logger.info(f"Started incremental vectorization with {self.max_workers} workers")
846 async def _load_queue(self) -> None:
847 """Load records into processing queue."""
848 while not self._shutdown_event.is_set():
849 try:
850 # Get records without vectors that have at least one text field
851 filter_query = {
852 self.vector_field: {"$exists": False},
853 "$or": [
854 {field: {"$exists": True, "$ne": ""}}
855 for field in self.text_fields
856 ],
857 }
859 records = await self.database.filter(filter_query, limit=self.batch_size)
861 if not records:
862 # No more records to process
863 await asyncio.sleep(60) # Check again in a minute
864 continue
866 # Add to queue
867 for record in records:
868 await self._queue.put(record)
869 self._stats["queued"] += 1
871 except Exception as e:
872 logger.error(f"Failed to load queue: {e}")
873 await asyncio.sleep(10)
875 async def stop(self, timeout: float = 30.0) -> None:
876 """Stop incremental vectorization.
878 Args:
879 timeout: Maximum time to wait for graceful shutdown
880 """
881 if not self._processing_task:
882 return
884 logger.info("Stopping incremental vectorization...")
885 self._shutdown_event.set()
887 # Cancel queue loader
888 self._processing_task.cancel()
889 try:
890 await self._processing_task
891 except asyncio.CancelledError:
892 pass
894 # Wait for workers to finish
895 try:
896 await asyncio.wait_for(
897 asyncio.gather(*self._workers),
898 timeout=timeout
899 )
900 except asyncio.TimeoutError:
901 logger.warning("Workers did not stop gracefully, cancelling")
902 for worker in self._workers:
903 worker.cancel()
905 await asyncio.gather(*self._workers, return_exceptions=True)
907 self._workers.clear()
908 self._processing_task = None
910 async def run(
911 self,
912 progress_callback: Callable[[int, int, list], None] | None = None,
913 max_workers: int | None = None,
914 ) -> dict[str, Any]:
915 """Run the complete vectorization.
917 Args:
918 progress_callback: Optional callback (completed, total, current_batch)
919 max_workers: Override default max_workers
921 Returns:
922 Results dictionary
923 """
924 if max_workers:
925 self.max_workers = max_workers
927 # Get all records that need vectors
928 from ..query import Query
929 all_records = await self.database.search(Query())
931 to_process = []
932 for record in all_records:
933 # Check if needs vectorization
934 if self.vector_field not in record.fields:
935 # Check if has text to vectorize
936 has_text = False
937 for field in self.text_fields:
938 if record.get_value(field):
939 has_text = True
940 break
941 if has_text:
942 to_process.append(record)
944 total = len(to_process)
945 processed = 0
946 failed = 0
948 # Process in batches
949 for i in range(0, total, self.batch_size):
950 batch = to_process[i:i + self.batch_size]
952 for record in batch:
953 try:
954 await self._process_record(record)
955 processed += 1
956 except Exception as e:
957 logger.error(f"Failed to process record {record.id}: {e}")
958 failed += 1
960 if progress_callback:
961 if asyncio.iscoroutinefunction(progress_callback):
962 await progress_callback(processed, total, batch)
963 else:
964 progress_callback(processed, total, batch)
966 return {
967 "processed": processed,
968 "failed": failed,
969 "total": total,
970 }
972 async def get_status(self) -> dict[str, Any]:
973 """Get current vectorization status.
975 Returns:
976 Status dictionary
977 """
978 # Count records with and without vectors
979 from ..query import Query
980 all_records = await self.database.search(Query())
982 total = 0
983 completed = 0
985 for record in all_records:
986 # Check if has text fields
987 has_text = False
988 for field_name in self.text_fields:
989 if record.get_value(field_name):
990 has_text = True
991 break
993 if has_text:
994 total += 1
995 if self.vector_field in record.fields:
996 completed += 1
998 return {
999 "total": total,
1000 "completed": completed,
1001 "remaining": total - completed,
1002 "percentage": (completed / total * 100) if total > 0 else 0,
1003 }
1005 def get_stats(self) -> dict[str, Any]:
1006 """Get vectorization statistics.
1008 Returns:
1009 Dictionary of statistics
1010 """
1011 return {
1012 **self._stats,
1013 "queue_size": self._queue.qsize(),
1014 "workers": len(self._workers),
1015 "is_running": bool(
1016 self._processing_task and not self._processing_task.done()
1017 ),
1018 }
1020 async def wait_for_completion(self, check_interval: float = 5.0) -> None:
1021 """Wait for all queued records to be processed.
1023 Args:
1024 check_interval: Seconds between queue checks
1025 """
1026 while self._queue.qsize() > 0:
1027 await asyncio.sleep(check_interval)
1029 logger.info("All queued records processed")
1031 async def run_with_checkpoint(self, resume_from: str | None = None) -> VectorizationResult:
1032 """Run the complete vectorization with checkpoint support.
1034 Args:
1035 resume_from: Optional checkpoint ID to resume from
1037 Returns:
1038 Vectorization result with statistics
1039 """
1040 await self.start()
1041 await self.wait_for_completion()
1043 return VectorizationResult(
1044 processed=self._stats["processed"],
1045 failed=self._stats["failed"],
1046 checkpoint=self._last_checkpoint,
1047 )
1049 async def run_batch(self, limit: int | None = None) -> VectorizationResult:
1050 """Process a limited number of records.
1052 Args:
1053 limit: Maximum number of records to process
1055 Returns:
1056 Vectorization result with statistics
1057 """
1058 # Temporarily modify batch size if limit provided
1059 original_batch_size = self.batch_size
1060 if limit:
1061 self.batch_size = min(self.batch_size, limit)
1063 try:
1064 await self.start()
1066 # Wait for limited processing
1067 while self._stats["processed"] < (limit or float('inf')):
1068 if self._queue.empty() and self._processing_task.done():
1069 break
1070 await asyncio.sleep(0.1)
1072 await self.stop()
1074 return VectorizationResult(
1075 processed=self._stats["processed"],
1076 failed=self._stats["failed"],
1077 checkpoint=self._last_checkpoint,
1078 )
1079 finally:
1080 self.batch_size = original_batch_size
1082 @property
1083 def progress(self) -> VectorizationProgress:
1084 """Get current progress."""
1085 return VectorizationProgress(
1086 total_records=self._stats.get("total", 0),
1087 processed_records=self._stats["processed"],
1088 failed_records=self._stats["failed"],
1089 queued_records=self._queue.qsize(),
1090 checkpoint=self._last_checkpoint,
1091 )
1093 async def get_checkpoint(self) -> str:
1094 """Get checkpoint ID for resuming."""
1095 # Save current progress as checkpoint
1096 self._last_checkpoint = f"checkpoint_{self._stats['processed']}"
1097 return self._last_checkpoint
1100@dataclass
1101class VectorizationResult:
1102 """Result of a vectorization operation."""
1103 processed: int
1104 failed: int
1105 checkpoint: str | None = None
1108@dataclass
1109class VectorizationProgress:
1110 """Current progress of vectorization."""
1111 total_records: int
1112 processed_records: int
1113 failed_records: int
1114 queued_records: int
1115 checkpoint: str | None = None