Coverage for src/dataknobs_data/vector/migration.py: 18%
497 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:23 -0700
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:23 -0700
1"""Migration tools for adding vector support to existing data."""
3from __future__ import annotations
5import asyncio
6import logging
7from dataclasses import dataclass, field
8from datetime import datetime
9from typing import TYPE_CHECKING, Any
11import numpy as np
13from ..fields import FieldType
14from ..query import Query
15from ..records import Record
16from ..schema import FieldSchema
17from .sync import SyncConfig, VectorTextSynchronizer
18from .types import VectorMetadata
20if TYPE_CHECKING:
21 from collections.abc import Callable, Coroutine
23 from ..database import Database
25logger = logging.getLogger(__name__)
28@dataclass
29class MigrationConfig:
30 """Configuration for vector migration."""
32 batch_size: int = 100
33 max_workers: int = 4
34 checkpoint_interval: int = 1000
35 enable_rollback: bool = True
36 verify_migration: bool = True
37 retry_failed: bool = True
38 max_retries: int = 3
39 max_consecutive_failures: int = 5 # Fail fast after this many consecutive failures
41 def validate(self) -> None:
42 """Validate configuration parameters."""
43 if self.batch_size <= 0:
44 raise ValueError(f"Batch size must be positive, got {self.batch_size}")
45 if self.max_workers <= 0:
46 raise ValueError(f"Max workers must be positive, got {self.max_workers}")
49@dataclass
50class MigrationStatus:
51 """Status of a migration operation."""
53 total_records: int = 0
54 migrated_records: int = 0
55 verified_records: int = 0
56 failed_records: int = 0
57 rollback_records: int = 0
58 errors: list[dict[str, Any]] = field(default_factory=list)
59 checkpoints: list[dict[str, Any]] = field(default_factory=list)
60 start_time: datetime | None = None
61 end_time: datetime | None = None
63 @property
64 def total_processed(self) -> int:
65 """Total number of processed records (migrated + failed)."""
66 return self.migrated_records + self.failed_records
68 @property
69 def failed_count(self) -> int:
70 """Alias for failed_records for compatibility."""
71 return self.failed_records
73 @property
74 def success_rate(self) -> float:
75 """Calculate the success rate of the migration."""
76 if self.total_records == 0:
77 return 0.0
78 return self.migrated_records / self.total_records
80 @property
81 def duration(self) -> float | None:
82 """Calculate the duration of the migration in seconds."""
83 if self.start_time and self.end_time:
84 return (self.end_time - self.start_time).total_seconds()
85 return None
87 @property
88 def records_per_second(self) -> float:
89 """Calculate the migration speed."""
90 duration = self.duration
91 if duration and duration > 0:
92 return self.migrated_records / duration
93 return 0.0
95 def add_checkpoint(self, name: str, record_id: str | None = None) -> None:
96 """Add a checkpoint to the migration."""
97 self.checkpoints.append({
98 "name": name,
99 "record_id": record_id,
100 "timestamp": datetime.utcnow().isoformat(),
101 "migrated": self.migrated_records,
102 "failed": self.failed_records,
103 })
105 def to_dict(self) -> dict[str, Any]:
106 """Convert to dictionary representation."""
107 return {
108 "total_records": self.total_records,
109 "migrated_records": self.migrated_records,
110 "verified_records": self.verified_records,
111 "failed_records": self.failed_records,
112 "rollback_records": self.rollback_records,
113 "success_rate": self.success_rate,
114 "duration": self.duration,
115 "records_per_second": self.records_per_second,
116 "errors": self.errors,
117 "checkpoints": self.checkpoints,
118 "start_time": self.start_time.isoformat() if self.start_time else None,
119 "end_time": self.end_time.isoformat() if self.end_time else None,
120 }
123class VectorMigration:
124 """Manages migration of existing data to include vector embeddings."""
126 def __init__(
127 self,
128 source_db: Database,
129 target_db: Database | None = None,
130 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]] = None,
131 text_fields: list[str] | None = None,
132 vector_field: str = "embedding",
133 field_separator: str = " ",
134 batch_size: int = 100,
135 max_retries: int = 3,
136 retry_delay: float = 1.0,
137 model_name: str | None = None,
138 model_version: str | None = None,
139 config: MigrationConfig | None = None,
140 ):
141 """Initialize the migration manager with simplified API.
143 Args:
144 source_db: Source database to migrate from
145 target_db: Target database (None to migrate in-place)
146 embedding_fn: Function to generate embeddings
147 text_fields: Fields to concatenate for embedding
148 vector_field: Name of the vector field to create
149 field_separator: Separator for concatenating text fields
150 batch_size: Batch size for processing
151 max_retries: Maximum retry attempts
152 retry_delay: Delay between retries
153 model_name: Name of the embedding model
154 model_version: Version of the embedding model
155 config: Advanced configuration (overrides other params)
156 """
157 self.source_db = source_db
158 self.target_db = target_db or source_db
159 self.embedding_fn = embedding_fn
160 self.embedding_function = embedding_fn # Alias for compatibility
161 self.text_fields = text_fields or []
162 self.vector_field = vector_field
163 self.field_separator = field_separator
164 self.batch_size = batch_size
165 self.max_retries = max_retries
166 self.retry_delay = retry_delay
167 self.model_name = model_name
168 self.model_version = model_version
170 # Use config if provided, otherwise create from params
171 if config:
172 self.config = config
173 else:
174 self.config = MigrationConfig(
175 batch_size=batch_size,
176 max_retries=max_retries,
177 )
178 self.config.validate()
180 # Migration status
181 self.status = MigrationStatus()
183 # Track rollback data if enabled
184 self._rollback_data: dict[str, dict[str, Any]] = {}
186 async def run(
187 self,
188 progress_callback: Callable[[MigrationStatus], None] | None = None
189 ) -> MigrationStatus:
190 """Run the complete migration.
192 Args:
193 progress_callback: Optional callback for progress updates
195 Returns:
196 Migration status
197 """
198 self.status = MigrationStatus(start_time=datetime.utcnow())
200 try:
201 # Get all records from source
202 all_records = await self.source_db.search(Query())
203 self.status.total_records = len(all_records)
205 # Process in batches
206 for i in range(0, len(all_records), self.batch_size):
207 batch = all_records[i:i + self.batch_size]
209 for record in batch:
210 try:
211 # Concatenate text fields
212 text_parts = []
213 for field in self.text_fields:
214 value = record.get_value(field)
215 if value:
216 text_parts.append(str(value))
218 if text_parts:
219 text = self.field_separator.join(text_parts)
221 # Generate embedding
222 if asyncio.iscoroutinefunction(self.embedding_fn):
223 embedding = await self.embedding_fn(text)
224 else:
225 embedding = await asyncio.to_thread(self.embedding_fn, text)
227 # Create VectorField
228 from ..fields import VectorField
229 vector_field_obj = VectorField(
230 value=embedding,
231 name=self.vector_field,
232 source_field=self.text_fields[0] if len(self.text_fields) == 1 else None,
233 model_name=self.model_name,
234 model_version=self.model_version,
235 )
237 # Add to record
238 record.fields[self.vector_field] = vector_field_obj
240 # Create in target database
241 await self.target_db.create(record)
242 self.status.migrated_records += 1
244 except Exception as e:
245 logger.error(f"Failed to migrate record {record.id}: {e}")
246 self.status.failed_records += 1
247 self.status.errors.append({"record_id": record.id, "error": str(e)})
249 if progress_callback:
250 progress_callback(self.status)
252 self.status.end_time = datetime.utcnow()
253 return self.status
255 except Exception as e:
256 logger.error(f"Migration failed: {e}")
257 self.status.failed_records = self.status.total_records - self.status.migrated_records
258 self.status.end_time = datetime.utcnow()
259 return self.status
261 async def start(self) -> None:
262 """Start migration (for compatibility)."""
263 # Migration runs synchronously in run() method
264 pass
266 async def wait_for_completion(self, progress_callback: Callable[[MigrationStatus], None] | None = None) -> MigrationStatus:
267 """Wait for migration completion (for compatibility)."""
268 # Since run() is synchronous, just return current status
269 return self.status
271 async def add_vectors_to_existing(
272 self,
273 vector_fields: dict[str, str], # vector_field -> source_field mapping
274 filter_query: dict[str, Any] | None = None,
275 progress_callback: Callable[[MigrationStatus], None] | None = None,
276 ) -> MigrationStatus:
277 """Add vector fields to existing records.
279 Args:
280 vector_fields: Mapping of vector field names to source text fields
281 filter_query: Optional filter to select records to migrate
282 progress_callback: Callback for progress updates
284 Returns:
285 Migration status
286 """
287 if not self.embedding_fn:
288 raise ValueError("Embedding function required for adding vectors")
290 status = MigrationStatus(start_time=datetime.utcnow())
292 try:
293 # Get records to migrate
294 if filter_query:
295 # Convert filter_query dict to Query object
296 query = Query()
297 for field, value in filter_query.items():
298 query = query.filter(field, "==", value)
299 records = await self.source_db.search(query)
300 else:
301 records = await self.source_db.all()
303 status.total_records = len(records)
304 logger.info(f"Starting migration of {status.total_records} records")
306 # Create synchronizer with wrapped embedding function
307 sync_config = SyncConfig(
308 batch_size=self.config.batch_size,
309 max_retries=self.config.max_retries,
310 )
312 # Track the last embedding exception
313 last_embedding_exception = None
315 # Create wrapper that captures exceptions
316 async def embedding_wrapper(text: str) -> np.ndarray:
317 nonlocal last_embedding_exception
318 try:
319 if asyncio.iscoroutinefunction(self.embedding_fn):
320 result = await self.embedding_fn(text)
321 else:
322 result = await asyncio.to_thread(self.embedding_fn, text)
323 return result
324 except Exception as e:
325 last_embedding_exception = e
326 raise
328 synchronizer = VectorTextSynchronizer(
329 database=self.target_db,
330 embedding_fn=embedding_wrapper,
331 config=sync_config,
332 model_name=self.model_name,
333 model_version=self.model_version,
334 )
336 # Process in batches
337 consecutive_batch_failures = 0
338 for i in range(0, len(records), self.config.batch_size):
339 batch = records[i:i + self.config.batch_size]
341 # Process batch with workers
342 tasks = []
343 for record in batch:
344 # Store original data for rollback
345 if self.config.enable_rollback:
346 # Store original field values for rollback
347 self._rollback_data[record.id] = {
348 field_name: record.get_value(field_name)
349 for field_name in record.fields.keys()
350 }
352 # Add vector fields to record
353 for vector_field, source_field in vector_fields.items():
354 if record.get_value(source_field) is None:
355 continue
357 # Add vector field schema if needed
358 if vector_field not in self.target_db.schema.fields:
359 source_text = record.get_value(source_field)
360 if source_text:
361 # Get dimensions from first embedding
362 sample_embedding = await self._get_embedding(str(source_text))
363 if sample_embedding is not None:
364 dimensions = len(sample_embedding)
365 # Add schema for vector field
366 field_schema = FieldSchema(
367 name=vector_field,
368 type=FieldType.VECTOR,
369 metadata={
370 "dimensions": dimensions,
371 "source_field": source_field,
372 }
373 )
374 self.target_db.add_field_schema(field_schema)
376 # Create migration task
377 task = self._migrate_record(
378 synchronizer,
379 record,
380 vector_fields,
381 status,
382 )
383 tasks.append(task)
385 # Wait for batch to complete
386 results = await asyncio.gather(*tasks, return_exceptions=False)
388 # Check for batch failures and fail fast if needed
389 batch_failed_count = sum(1 for r in results if r is False)
390 if batch_failed_count == len(results) and len(results) > 0:
391 consecutive_batch_failures += 1
392 # If multiple consecutive batches completely fail, re-raise the last exception
393 if consecutive_batch_failures >= 2 and self.config.enable_rollback:
394 if last_embedding_exception:
395 raise last_embedding_exception
396 else:
397 raise Exception("Migration failed: consecutive batch failures")
398 else:
399 consecutive_batch_failures = 0
401 # Checkpoint if needed
402 if status.migrated_records % self.config.checkpoint_interval == 0:
403 status.add_checkpoint(
404 f"Batch {i // self.config.batch_size + 1}",
405 batch[-1].id if batch else None,
406 )
407 if progress_callback:
408 progress_callback(status)
410 # Verify migration if enabled
411 if self.config.verify_migration:
412 await self._verify_migration(vector_fields, records, status)
414 except Exception as e:
415 logger.error(f"Migration failed: {e}")
416 if self.config.enable_rollback:
417 await self._rollback(status)
418 raise
420 finally:
421 status.end_time = datetime.utcnow()
423 logger.info(
424 f"Migration completed: {status.migrated_records}/{status.total_records} "
425 f"migrated, {status.failed_records} failed"
426 )
428 return status
430 async def _get_embedding(self, text: str) -> np.ndarray | None:
431 """Get embedding for text."""
432 try:
433 if asyncio.iscoroutinefunction(self.embedding_fn):
434 result = await self.embedding_fn(text)
435 else:
436 result = await asyncio.to_thread(self.embedding_fn, text)
438 if isinstance(result, np.ndarray):
439 return result
440 elif isinstance(result, list):
441 return np.array(result)
442 return None
444 except Exception as e:
445 logger.error(f"Failed to get embedding: {e}")
446 return None
448 async def _migrate_record(
449 self,
450 synchronizer: VectorTextSynchronizer,
451 record: Record,
452 vector_fields: dict[str, str],
453 status: MigrationStatus,
454 ) -> bool:
455 """Migrate a single record.
457 Returns:
458 True if migration succeeded, False otherwise
459 """
460 try:
461 # Sync vectors
462 success, updated_fields = await synchronizer.sync_record(record, force=True)
464 if success and updated_fields:
465 # Update record in target database
466 await self.target_db.update(record.id, record)
467 status.migrated_records += 1
468 return True
469 else:
470 status.failed_records += 1
471 status.errors.append({
472 "record_id": record.id,
473 "error": "Failed to generate vectors",
474 })
475 return False
477 except Exception as e:
478 status.failed_records += 1
479 status.errors.append({
480 "record_id": record.id,
481 "error": str(e),
482 })
483 logger.error(f"Failed to migrate record {record.id}: {e}")
484 return False
486 async def _verify_migration(
487 self,
488 vector_fields: dict[str, str],
489 records: list[Record],
490 status: MigrationStatus,
491 ) -> None:
492 """Verify that migration was successful."""
493 logger.info("Verifying migration...")
495 for record in records:
496 try:
497 # Get updated record
498 migrated = await self.target_db.read(record.id)
500 # Check vector fields
501 all_present = True
502 for vector_field, source_field in vector_fields.items():
503 source_value = record.get_value(source_field)
504 if source_value:
505 # Check if vector field exists (could be in fields or data)
506 vector_data = migrated.get_value(vector_field)
507 if vector_data is None:
508 all_present = False
509 break
511 # For VectorField objects, check the value
512 from ..fields import VectorField
513 if isinstance(migrated.fields.get(vector_field), VectorField):
514 vector_data = migrated.fields[vector_field].value
516 if not isinstance(vector_data, (list, np.ndarray)):
517 all_present = False
518 break
520 if all_present:
521 status.verified_records += 1
523 except Exception as e:
524 logger.error(f"Failed to verify record {record.id}: {e}")
526 async def _rollback(self, status: MigrationStatus) -> None:
527 """Rollback migration on failure."""
528 if not self._rollback_data:
529 return
531 logger.info(f"Rolling back {len(self._rollback_data)} records...")
533 for record_id, original_data in self._rollback_data.items():
534 try:
535 # Restore original record
536 original_record = Record(id=record_id, data=original_data)
537 await self.target_db.update(record_id, original_record)
538 status.rollback_records += 1
539 except Exception as e:
540 logger.error(f"Failed to rollback record {record_id}: {e}")
542 async def migrate_between_backends(
543 self,
544 field_mapping: dict[str, str] | None = None,
545 transform_fn: Callable[[Record], Record] | None = None,
546 progress_callback: Callable[[MigrationStatus], None] | None = None,
547 ) -> MigrationStatus:
548 """Migrate vector data between different backends.
550 Args:
551 field_mapping: Optional field name mapping
552 transform_fn: Optional record transformation function
553 progress_callback: Callback for progress updates
555 Returns:
556 Migration status
557 """
558 status = MigrationStatus(start_time=datetime.utcnow())
560 try:
561 # Get all records with vectors
562 records = await self.source_db.all()
563 status.total_records = len(records)
565 logger.info(
566 f"Migrating {status.total_records} records from "
567 f"{self.source_db.__class__.__name__} to "
568 f"{self.target_db.__class__.__name__}"
569 )
571 # Process in batches
572 for i in range(0, len(records), self.config.batch_size):
573 batch = records[i:i + self.config.batch_size]
575 for original_record in batch:
576 try:
577 record = original_record
578 # Apply field mapping
579 if field_mapping:
580 new_data = {}
581 for old_field, new_field in field_mapping.items():
582 old_value = record.get_value(old_field)
583 if old_value is not None:
584 new_data[new_field] = old_value
585 # Update record with new field mapping
586 for field_name, value in new_data.items():
587 record.set_value(field_name, value)
589 # Apply transformation
590 if transform_fn:
591 record = transform_fn(record)
593 # Create in target database
594 await self.target_db.create(record)
595 status.migrated_records += 1
597 except Exception as e:
598 status.failed_records += 1
599 status.errors.append({
600 "record_id": record.id,
601 "error": str(e),
602 })
603 logger.error(f"Failed to migrate record {record.id}: {e}")
605 # Progress update
606 if progress_callback:
607 progress_callback(status)
609 finally:
610 status.end_time = datetime.utcnow()
612 return status
614 @classmethod
615 def from_config(
616 cls,
617 source_db: Database,
618 target_db: Database | None,
619 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]],
620 config: MigrationConfig,
621 text_fields: list[str] | None = None,
622 vector_field: str = "embedding",
623 model_name: str | None = None,
624 model_version: str | None = None,
625 ) -> VectorMigration:
626 """Create migration from a config object for advanced use cases.
628 Args:
629 source_db: Source database
630 target_db: Target database (None for in-place)
631 embedding_fn: Function to generate embeddings
632 config: Migration configuration
633 text_fields: Text field names (optional)
634 vector_field: Name of the vector field
635 model_name: Name of the embedding model
636 model_version: Version of the embedding model
638 Returns:
639 Configured VectorMigration instance
640 """
641 return cls(
642 source_db=source_db,
643 target_db=target_db,
644 embedding_fn=embedding_fn,
645 text_fields=text_fields,
646 vector_field=vector_field,
647 batch_size=config.batch_size,
648 model_name=model_name,
649 model_version=model_version,
650 config=config,
651 )
654class IncrementalVectorizer:
655 """Manages incremental vectorization of large datasets.
657 Examples:
658 import numpy as np
659 from dataknobs_data import database_factory
661 # Create database and embedding function
662 db = database_factory.create(backend="memory")
664 def embedding_fn(text):
665 # In practice, use a real model like sentence-transformers
666 return np.random.rand(384).astype(np.float32)
668 # Simple usage with single field
669 vectorizer = IncrementalVectorizer(
670 db,
671 embedding_fn=embedding_fn,
672 text_fields="content" # Can be string or list
673 )
674 result = await vectorizer.run()
676 # Resume from checkpoint
677 result = await vectorizer.run(resume_from=last_checkpoint)
679 # Process limited batch
680 result = await vectorizer.run_batch(limit=1000)
681 """
683 def __init__(
684 self,
685 database: Database,
686 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]],
687 text_fields: list[str] | str | None = None, # Support multiple fields
688 vector_field: str = "embedding", # Sensible default
689 field_separator: str = " ",
690 batch_size: int = 100,
691 checkpoint_interval: int = 1000,
692 max_workers: int = 4,
693 model_name: str | None = None,
694 model_version: str | None = None,
695 ):
696 """Initialize the incremental vectorizer with simplified parameters.
698 Args:
699 database: Database to vectorize
700 embedding_fn: Function to generate embeddings
701 text_fields: Text field names to concatenate for embeddings
702 vector_field: Name of the vector field to create
703 field_separator: Separator for concatenating multiple text fields
704 batch_size: Size of processing batches
705 checkpoint_interval: Records between checkpoints
706 max_workers: Maximum concurrent workers
707 model_name: Name of the embedding model
708 model_version: Version of the embedding model
709 """
710 self.database = database
711 self.embedding_fn = embedding_fn
712 self.embedding_function = embedding_fn # Alias for compatibility
714 # Handle text fields
715 if isinstance(text_fields, str):
716 text_fields = [text_fields]
717 elif text_fields is None:
718 # Try to auto-detect from database schema
719 text_fields = self._detect_text_fields()
720 self.text_fields = text_fields
722 self.vector_field = vector_field
723 self.field_separator = field_separator
724 self.batch_size = batch_size
725 self.checkpoint_interval = checkpoint_interval
726 self.max_workers = max_workers
727 self.model_name = model_name
728 self.model_version = model_version
730 # Processing state
731 self._queue: asyncio.Queue[Record] = asyncio.Queue()
732 self._processing_task: asyncio.Task | None = None
733 self._workers: list[asyncio.Task] = []
734 self._shutdown_event = asyncio.Event()
735 self._stats = {
736 "processed": 0,
737 "failed": 0,
738 "queued": 0,
739 }
740 self._last_checkpoint: str | None = None
741 self._progress: VectorizationProgress | None = None
743 def _detect_text_fields(self) -> list[str]:
744 """Auto-detect text fields from database schema."""
745 text_fields = []
746 if hasattr(self.database, 'schema') and self.database.schema:
747 for field_name, field_schema in self.database.schema.fields.items():
748 if field_schema.type in (FieldType.STRING, FieldType.TEXT):
749 text_fields.append(field_name)
751 # Default to common field names if no schema
752 if not text_fields:
753 text_fields = ["content", "text", "description"]
755 return text_fields
757 async def _worker(self, worker_id: int) -> None:
758 """Worker task for processing records."""
759 logger.info(f"Worker {worker_id} started")
761 while not self._shutdown_event.is_set():
762 try:
763 # Get record from queue with timeout
764 try:
765 record = await asyncio.wait_for(
766 self._queue.get(),
767 timeout=1.0
768 )
769 except asyncio.TimeoutError:
770 continue
772 # Process record
773 await self._process_record(record)
774 self._stats["processed"] += 1
776 except Exception as e:
777 logger.error(f"Worker {worker_id} error: {e}")
778 self._stats["failed"] += 1
780 logger.info(f"Worker {worker_id} stopped")
782 async def _process_record(self, record: Record) -> None:
783 """Process a single record to add vectors."""
784 try:
785 # Get source text from multiple fields
786 text_parts = []
787 for field in self.text_fields:
788 value = record.get_value(field)
789 if value:
790 text_parts.append(str(value))
792 if not text_parts:
793 return
795 source_text = self.field_separator.join(text_parts)
797 # Check if vector already exists
798 vector_data = record.get_value(self.vector_field)
799 if vector_data is not None:
800 if vector_data and isinstance(vector_data, (list, np.ndarray)):
801 return
803 # Generate embedding
804 if asyncio.iscoroutinefunction(self.embedding_fn):
805 embedding = await self.embedding_fn(str(source_text))
806 else:
807 embedding = await asyncio.to_thread(self.embedding_fn, str(source_text))
809 if embedding is None:
810 return
812 # Update record
813 update_data = {
814 self.vector_field: embedding.tolist() if isinstance(embedding, np.ndarray) else embedding,
815 }
817 # Add metadata
818 if self.model_name:
819 metadata = VectorMetadata(
820 dimensions=len(embedding),
821 source_field=self.field_separator.join(self.text_fields),
822 model_name=self.model_name,
823 model_version=self.model_version,
824 updated_at=datetime.utcnow().isoformat(),
825 )
826 update_data[f"{self.vector_field}_metadata"] = metadata.to_dict()
828 # Update the record with the new vector data
829 for key, value in update_data.items():
830 record.set_value(key, value)
831 await self.database.update(record.id, record)
833 except Exception as e:
834 logger.error(f"Failed to process record {record.id}: {e}")
835 raise
837 async def start(self) -> None:
838 """Start incremental vectorization."""
839 if self._processing_task and not self._processing_task.done():
840 logger.warning("Vectorization already running")
841 return
843 self._shutdown_event.clear()
845 # Start workers
846 self._workers = [
847 asyncio.create_task(self._worker(i))
848 for i in range(self.max_workers)
849 ]
851 # Start queue loader
852 self._processing_task = asyncio.create_task(self._load_queue())
854 logger.info(f"Started incremental vectorization with {self.max_workers} workers")
856 async def _load_queue(self) -> None:
857 """Load records into processing queue."""
858 while not self._shutdown_event.is_set():
859 try:
860 # Get records without vectors that have at least one text field
861 filter_query = {
862 self.vector_field: {"$exists": False},
863 "$or": [
864 {field: {"$exists": True, "$ne": ""}}
865 for field in self.text_fields
866 ],
867 }
869 records = await self.database.filter(filter_query, limit=self.batch_size)
871 if not records:
872 # No more records to process
873 await asyncio.sleep(60) # Check again in a minute
874 continue
876 # Add to queue
877 for record in records:
878 await self._queue.put(record)
879 self._stats["queued"] += 1
881 except Exception as e:
882 logger.error(f"Failed to load queue: {e}")
883 await asyncio.sleep(10)
885 async def stop(self, timeout: float = 30.0) -> None:
886 """Stop incremental vectorization.
888 Args:
889 timeout: Maximum time to wait for graceful shutdown
890 """
891 if not self._processing_task:
892 return
894 logger.info("Stopping incremental vectorization...")
895 self._shutdown_event.set()
897 # Cancel queue loader
898 self._processing_task.cancel()
899 try:
900 await self._processing_task
901 except asyncio.CancelledError:
902 pass
904 # Wait for workers to finish
905 try:
906 await asyncio.wait_for(
907 asyncio.gather(*self._workers),
908 timeout=timeout
909 )
910 except asyncio.TimeoutError:
911 logger.warning("Workers did not stop gracefully, cancelling")
912 for worker in self._workers:
913 worker.cancel()
915 await asyncio.gather(*self._workers, return_exceptions=True)
917 self._workers.clear()
918 self._processing_task = None
920 async def run(
921 self,
922 progress_callback: Callable[[int, int, list], None] | None = None,
923 max_workers: int | None = None,
924 ) -> dict[str, Any]:
925 """Run the complete vectorization.
927 Args:
928 progress_callback: Optional callback (completed, total, current_batch)
929 max_workers: Override default max_workers
931 Returns:
932 Results dictionary
933 """
934 if max_workers:
935 self.max_workers = max_workers
937 # Get all records that need vectors
938 from ..query import Query
939 all_records = await self.database.search(Query())
941 to_process = []
942 for record in all_records:
943 # Check if needs vectorization
944 if self.vector_field not in record.fields:
945 # Check if has text to vectorize
946 has_text = False
947 for field in self.text_fields:
948 if record.get_value(field):
949 has_text = True
950 break
951 if has_text:
952 to_process.append(record)
954 total = len(to_process)
955 processed = 0
956 failed = 0
958 # Process in batches
959 for i in range(0, total, self.batch_size):
960 batch = to_process[i:i + self.batch_size]
962 for record in batch:
963 try:
964 await self._process_record(record)
965 processed += 1
966 except Exception as e:
967 logger.error(f"Failed to process record {record.id}: {e}")
968 failed += 1
970 if progress_callback:
971 if asyncio.iscoroutinefunction(progress_callback):
972 await progress_callback(processed, total, batch)
973 else:
974 progress_callback(processed, total, batch)
976 return {
977 "processed": processed,
978 "failed": failed,
979 "total": total,
980 }
982 async def get_status(self) -> dict[str, Any]:
983 """Get current vectorization status.
985 Returns:
986 Status dictionary
987 """
988 # Count records with and without vectors
989 from ..query import Query
990 all_records = await self.database.search(Query())
992 total = 0
993 completed = 0
995 for record in all_records:
996 # Check if has text fields
997 has_text = False
998 for field_name in self.text_fields:
999 if record.get_value(field_name):
1000 has_text = True
1001 break
1003 if has_text:
1004 total += 1
1005 if self.vector_field in record.fields:
1006 completed += 1
1008 return {
1009 "total": total,
1010 "completed": completed,
1011 "remaining": total - completed,
1012 "percentage": (completed / total * 100) if total > 0 else 0,
1013 }
1015 def get_stats(self) -> dict[str, Any]:
1016 """Get vectorization statistics.
1018 Returns:
1019 Dictionary of statistics
1020 """
1021 return {
1022 **self._stats,
1023 "queue_size": self._queue.qsize(),
1024 "workers": len(self._workers),
1025 "is_running": bool(
1026 self._processing_task and not self._processing_task.done()
1027 ),
1028 }
1030 async def wait_for_completion(self, check_interval: float = 5.0) -> None:
1031 """Wait for all queued records to be processed.
1033 Args:
1034 check_interval: Seconds between queue checks
1035 """
1036 while self._queue.qsize() > 0:
1037 await asyncio.sleep(check_interval)
1039 logger.info("All queued records processed")
1041 async def run_with_checkpoint(self, resume_from: str | None = None) -> VectorizationResult:
1042 """Run the complete vectorization with checkpoint support.
1044 Args:
1045 resume_from: Optional checkpoint ID to resume from
1047 Returns:
1048 Vectorization result with statistics
1049 """
1050 await self.start()
1051 await self.wait_for_completion()
1053 return VectorizationResult(
1054 processed=self._stats["processed"],
1055 failed=self._stats["failed"],
1056 checkpoint=self._last_checkpoint,
1057 )
1059 async def run_batch(self, limit: int | None = None) -> VectorizationResult:
1060 """Process a limited number of records.
1062 Args:
1063 limit: Maximum number of records to process
1065 Returns:
1066 Vectorization result with statistics
1067 """
1068 # Temporarily modify batch size if limit provided
1069 original_batch_size = self.batch_size
1070 if limit:
1071 self.batch_size = min(self.batch_size, limit)
1073 try:
1074 await self.start()
1076 # Wait for limited processing
1077 while self._stats["processed"] < (limit or float('inf')):
1078 if self._queue.empty() and self._processing_task.done():
1079 break
1080 await asyncio.sleep(0.1)
1082 await self.stop()
1084 return VectorizationResult(
1085 processed=self._stats["processed"],
1086 failed=self._stats["failed"],
1087 checkpoint=self._last_checkpoint,
1088 )
1089 finally:
1090 self.batch_size = original_batch_size
1092 @property
1093 def progress(self) -> VectorizationProgress:
1094 """Get current progress."""
1095 return VectorizationProgress(
1096 total_records=self._stats.get("total", 0),
1097 processed_records=self._stats["processed"],
1098 failed_records=self._stats["failed"],
1099 queued_records=self._queue.qsize(),
1100 checkpoint=self._last_checkpoint,
1101 )
1103 async def get_checkpoint(self) -> str:
1104 """Get checkpoint ID for resuming."""
1105 # Save current progress as checkpoint
1106 self._last_checkpoint = f"checkpoint_{self._stats['processed']}"
1107 return self._last_checkpoint
1110@dataclass
1111class VectorizationResult:
1112 """Result of a vectorization operation."""
1113 processed: int
1114 failed: int
1115 checkpoint: str | None = None
1118@dataclass
1119class VectorizationProgress:
1120 """Current progress of vectorization."""
1121 total_records: int
1122 processed_records: int
1123 failed_records: int
1124 queued_records: int
1125 checkpoint: str | None = None