Coverage for src/dataknobs_data/vector/migration.py: 18%

497 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-29 14:14 -0600

1"""Migration tools for adding vector support to existing data.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import logging 

7from dataclasses import dataclass, field 

8from datetime import datetime 

9from typing import TYPE_CHECKING, Any 

10 

11import numpy as np 

12 

13from ..fields import FieldType 

14from ..query import Query 

15from ..records import Record 

16from ..schema import FieldSchema 

17from .sync import SyncConfig, VectorTextSynchronizer 

18from .types import VectorMetadata 

19 

20if TYPE_CHECKING: 

21 from collections.abc import Callable, Coroutine 

22 

23 from ..database import Database 

24 

25logger = logging.getLogger(__name__) 

26 

27 

28@dataclass 

29class MigrationConfig: 

30 """Configuration for vector migration.""" 

31 

32 batch_size: int = 100 

33 max_workers: int = 4 

34 checkpoint_interval: int = 1000 

35 enable_rollback: bool = True 

36 verify_migration: bool = True 

37 retry_failed: bool = True 

38 max_retries: int = 3 

39 max_consecutive_failures: int = 5 # Fail fast after this many consecutive failures 

40 

41 def validate(self) -> None: 

42 """Validate configuration parameters.""" 

43 if self.batch_size <= 0: 

44 raise ValueError(f"Batch size must be positive, got {self.batch_size}") 

45 if self.max_workers <= 0: 

46 raise ValueError(f"Max workers must be positive, got {self.max_workers}") 

47 

48 

49@dataclass 

50class MigrationStatus: 

51 """Status of a migration operation.""" 

52 

53 total_records: int = 0 

54 migrated_records: int = 0 

55 verified_records: int = 0 

56 failed_records: int = 0 

57 rollback_records: int = 0 

58 errors: list[dict[str, Any]] = field(default_factory=list) 

59 checkpoints: list[dict[str, Any]] = field(default_factory=list) 

60 start_time: datetime | None = None 

61 end_time: datetime | None = None 

62 

63 @property 

64 def total_processed(self) -> int: 

65 """Total number of processed records (migrated + failed).""" 

66 return self.migrated_records + self.failed_records 

67 

68 @property 

69 def failed_count(self) -> int: 

70 """Alias for failed_records for compatibility.""" 

71 return self.failed_records 

72 

73 @property 

74 def success_rate(self) -> float: 

75 """Calculate the success rate of the migration.""" 

76 if self.total_records == 0: 

77 return 0.0 

78 return self.migrated_records / self.total_records 

79 

80 @property 

81 def duration(self) -> float | None: 

82 """Calculate the duration of the migration in seconds.""" 

83 if self.start_time and self.end_time: 

84 return (self.end_time - self.start_time).total_seconds() 

85 return None 

86 

87 @property 

88 def records_per_second(self) -> float: 

89 """Calculate the migration speed.""" 

90 duration = self.duration 

91 if duration and duration > 0: 

92 return self.migrated_records / duration 

93 return 0.0 

94 

95 def add_checkpoint(self, name: str, record_id: str | None = None) -> None: 

96 """Add a checkpoint to the migration.""" 

97 self.checkpoints.append({ 

98 "name": name, 

99 "record_id": record_id, 

100 "timestamp": datetime.utcnow().isoformat(), 

101 "migrated": self.migrated_records, 

102 "failed": self.failed_records, 

103 }) 

104 

105 def to_dict(self) -> dict[str, Any]: 

106 """Convert to dictionary representation.""" 

107 return { 

108 "total_records": self.total_records, 

109 "migrated_records": self.migrated_records, 

110 "verified_records": self.verified_records, 

111 "failed_records": self.failed_records, 

112 "rollback_records": self.rollback_records, 

113 "success_rate": self.success_rate, 

114 "duration": self.duration, 

115 "records_per_second": self.records_per_second, 

116 "errors": self.errors, 

117 "checkpoints": self.checkpoints, 

118 "start_time": self.start_time.isoformat() if self.start_time else None, 

119 "end_time": self.end_time.isoformat() if self.end_time else None, 

120 } 

121 

122 

123class VectorMigration: 

124 """Manages migration of existing data to include vector embeddings.""" 

125 

126 def __init__( 

127 self, 

128 source_db: Database, 

129 target_db: Database | None = None, 

130 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]] = None, 

131 text_fields: list[str] | None = None, 

132 vector_field: str = "embedding", 

133 field_separator: str = " ", 

134 batch_size: int = 100, 

135 max_retries: int = 3, 

136 retry_delay: float = 1.0, 

137 model_name: str | None = None, 

138 model_version: str | None = None, 

139 config: MigrationConfig | None = None, 

140 ): 

141 """Initialize the migration manager with simplified API. 

142  

143 Args: 

144 source_db: Source database to migrate from 

145 target_db: Target database (None to migrate in-place) 

146 embedding_fn: Function to generate embeddings 

147 text_fields: Fields to concatenate for embedding 

148 vector_field: Name of the vector field to create 

149 field_separator: Separator for concatenating text fields 

150 batch_size: Batch size for processing 

151 max_retries: Maximum retry attempts 

152 retry_delay: Delay between retries 

153 model_name: Name of the embedding model 

154 model_version: Version of the embedding model 

155 config: Advanced configuration (overrides other params) 

156 """ 

157 self.source_db = source_db 

158 self.target_db = target_db or source_db 

159 self.embedding_fn = embedding_fn 

160 self.embedding_function = embedding_fn # Alias for compatibility 

161 self.text_fields = text_fields or [] 

162 self.vector_field = vector_field 

163 self.field_separator = field_separator 

164 self.batch_size = batch_size 

165 self.max_retries = max_retries 

166 self.retry_delay = retry_delay 

167 self.model_name = model_name 

168 self.model_version = model_version 

169 

170 # Use config if provided, otherwise create from params 

171 if config: 

172 self.config = config 

173 else: 

174 self.config = MigrationConfig( 

175 batch_size=batch_size, 

176 max_retries=max_retries, 

177 ) 

178 self.config.validate() 

179 

180 # Migration status 

181 self.status = MigrationStatus() 

182 

183 # Track rollback data if enabled 

184 self._rollback_data: dict[str, dict[str, Any]] = {} 

185 

186 async def run( 

187 self, 

188 progress_callback: Callable[[MigrationStatus], None] | None = None 

189 ) -> MigrationStatus: 

190 """Run the complete migration. 

191  

192 Args: 

193 progress_callback: Optional callback for progress updates 

194  

195 Returns: 

196 Migration status 

197 """ 

198 self.status = MigrationStatus(start_time=datetime.utcnow()) 

199 

200 try: 

201 # Get all records from source 

202 all_records = await self.source_db.search(Query()) 

203 self.status.total_records = len(all_records) 

204 

205 # Process in batches 

206 for i in range(0, len(all_records), self.batch_size): 

207 batch = all_records[i:i + self.batch_size] 

208 

209 for record in batch: 

210 try: 

211 # Concatenate text fields 

212 text_parts = [] 

213 for field in self.text_fields: 

214 value = record.get_value(field) 

215 if value: 

216 text_parts.append(str(value)) 

217 

218 if text_parts: 

219 text = self.field_separator.join(text_parts) 

220 

221 # Generate embedding 

222 if asyncio.iscoroutinefunction(self.embedding_fn): 

223 embedding = await self.embedding_fn(text) 

224 else: 

225 embedding = await asyncio.to_thread(self.embedding_fn, text) 

226 

227 # Create VectorField 

228 from ..fields import VectorField 

229 vector_field_obj = VectorField( 

230 value=embedding, 

231 name=self.vector_field, 

232 source_field=self.text_fields[0] if len(self.text_fields) == 1 else None, 

233 model_name=self.model_name, 

234 model_version=self.model_version, 

235 ) 

236 

237 # Add to record 

238 record.fields[self.vector_field] = vector_field_obj 

239 

240 # Create in target database 

241 await self.target_db.create(record) 

242 self.status.migrated_records += 1 

243 

244 except Exception as e: 

245 logger.error(f"Failed to migrate record {record.id}: {e}") 

246 self.status.failed_records += 1 

247 self.status.errors.append({"record_id": record.id, "error": str(e)}) 

248 

249 if progress_callback: 

250 progress_callback(self.status) 

251 

252 self.status.end_time = datetime.utcnow() 

253 return self.status 

254 

255 except Exception as e: 

256 logger.error(f"Migration failed: {e}") 

257 self.status.failed_records = self.status.total_records - self.status.migrated_records 

258 self.status.end_time = datetime.utcnow() 

259 return self.status 

260 

261 async def start(self) -> None: 

262 """Start migration (for compatibility).""" 

263 # Migration runs synchronously in run() method 

264 pass 

265 

266 async def wait_for_completion(self, progress_callback: Callable[[MigrationStatus], None] | None = None) -> MigrationStatus: 

267 """Wait for migration completion (for compatibility).""" 

268 # Since run() is synchronous, just return current status 

269 return self.status 

270 

271 async def add_vectors_to_existing( 

272 self, 

273 vector_fields: dict[str, str], # vector_field -> source_field mapping 

274 filter_query: dict[str, Any] | None = None, 

275 progress_callback: Callable[[MigrationStatus], None] | None = None, 

276 ) -> MigrationStatus: 

277 """Add vector fields to existing records. 

278  

279 Args: 

280 vector_fields: Mapping of vector field names to source text fields 

281 filter_query: Optional filter to select records to migrate 

282 progress_callback: Callback for progress updates 

283  

284 Returns: 

285 Migration status 

286 """ 

287 if not self.embedding_fn: 

288 raise ValueError("Embedding function required for adding vectors") 

289 

290 status = MigrationStatus(start_time=datetime.utcnow()) 

291 

292 try: 

293 # Get records to migrate 

294 if filter_query: 

295 # Convert filter_query dict to Query object 

296 query = Query() 

297 for field, value in filter_query.items(): 

298 query = query.filter(field, "==", value) 

299 records = await self.source_db.search(query) 

300 else: 

301 records = await self.source_db.all() 

302 

303 status.total_records = len(records) 

304 logger.info(f"Starting migration of {status.total_records} records") 

305 

306 # Create synchronizer with wrapped embedding function 

307 sync_config = SyncConfig( 

308 batch_size=self.config.batch_size, 

309 max_retries=self.config.max_retries, 

310 ) 

311 

312 # Track the last embedding exception 

313 last_embedding_exception = None 

314 

315 # Create wrapper that captures exceptions 

316 async def embedding_wrapper(text: str) -> np.ndarray: 

317 nonlocal last_embedding_exception 

318 try: 

319 if asyncio.iscoroutinefunction(self.embedding_fn): 

320 result = await self.embedding_fn(text) 

321 else: 

322 result = await asyncio.to_thread(self.embedding_fn, text) 

323 return result 

324 except Exception as e: 

325 last_embedding_exception = e 

326 raise 

327 

328 synchronizer = VectorTextSynchronizer( 

329 database=self.target_db, 

330 embedding_fn=embedding_wrapper, 

331 config=sync_config, 

332 model_name=self.model_name, 

333 model_version=self.model_version, 

334 ) 

335 

336 # Process in batches 

337 consecutive_batch_failures = 0 

338 for i in range(0, len(records), self.config.batch_size): 

339 batch = records[i:i + self.config.batch_size] 

340 

341 # Process batch with workers 

342 tasks = [] 

343 for record in batch: 

344 # Store original data for rollback 

345 if self.config.enable_rollback: 

346 # Store original field values for rollback 

347 self._rollback_data[record.id] = { 

348 field_name: record.get_value(field_name) 

349 for field_name in record.fields.keys() 

350 } 

351 

352 # Add vector fields to record 

353 for vector_field, source_field in vector_fields.items(): 

354 if record.get_value(source_field) is None: 

355 continue 

356 

357 # Add vector field schema if needed 

358 if vector_field not in self.target_db.schema.fields: 

359 source_text = record.get_value(source_field) 

360 if source_text: 

361 # Get dimensions from first embedding 

362 sample_embedding = await self._get_embedding(str(source_text)) 

363 if sample_embedding is not None: 

364 dimensions = len(sample_embedding) 

365 # Add schema for vector field 

366 field_schema = FieldSchema( 

367 name=vector_field, 

368 type=FieldType.VECTOR, 

369 metadata={ 

370 "dimensions": dimensions, 

371 "source_field": source_field, 

372 } 

373 ) 

374 self.target_db.add_field_schema(field_schema) 

375 

376 # Create migration task 

377 task = self._migrate_record( 

378 synchronizer, 

379 record, 

380 vector_fields, 

381 status, 

382 ) 

383 tasks.append(task) 

384 

385 # Wait for batch to complete 

386 results = await asyncio.gather(*tasks, return_exceptions=False) 

387 

388 # Check for batch failures and fail fast if needed 

389 batch_failed_count = sum(1 for r in results if r is False) 

390 if batch_failed_count == len(results) and len(results) > 0: 

391 consecutive_batch_failures += 1 

392 # If multiple consecutive batches completely fail, re-raise the last exception 

393 if consecutive_batch_failures >= 2 and self.config.enable_rollback: 

394 if last_embedding_exception: 

395 raise last_embedding_exception 

396 else: 

397 raise Exception("Migration failed: consecutive batch failures") 

398 else: 

399 consecutive_batch_failures = 0 

400 

401 # Checkpoint if needed 

402 if status.migrated_records % self.config.checkpoint_interval == 0: 

403 status.add_checkpoint( 

404 f"Batch {i // self.config.batch_size + 1}", 

405 batch[-1].id if batch else None, 

406 ) 

407 if progress_callback: 

408 progress_callback(status) 

409 

410 # Verify migration if enabled 

411 if self.config.verify_migration: 

412 await self._verify_migration(vector_fields, records, status) 

413 

414 except Exception as e: 

415 logger.error(f"Migration failed: {e}") 

416 if self.config.enable_rollback: 

417 await self._rollback(status) 

418 raise 

419 

420 finally: 

421 status.end_time = datetime.utcnow() 

422 

423 logger.info( 

424 f"Migration completed: {status.migrated_records}/{status.total_records} " 

425 f"migrated, {status.failed_records} failed" 

426 ) 

427 

428 return status 

429 

430 async def _get_embedding(self, text: str) -> np.ndarray | None: 

431 """Get embedding for text.""" 

432 try: 

433 if asyncio.iscoroutinefunction(self.embedding_fn): 

434 result = await self.embedding_fn(text) 

435 else: 

436 result = await asyncio.to_thread(self.embedding_fn, text) 

437 

438 if isinstance(result, np.ndarray): 

439 return result 

440 elif isinstance(result, list): 

441 return np.array(result) 

442 return None 

443 

444 except Exception as e: 

445 logger.error(f"Failed to get embedding: {e}") 

446 return None 

447 

448 async def _migrate_record( 

449 self, 

450 synchronizer: VectorTextSynchronizer, 

451 record: Record, 

452 vector_fields: dict[str, str], 

453 status: MigrationStatus, 

454 ) -> bool: 

455 """Migrate a single record. 

456  

457 Returns: 

458 True if migration succeeded, False otherwise 

459 """ 

460 try: 

461 # Sync vectors 

462 success, updated_fields = await synchronizer.sync_record(record, force=True) 

463 

464 if success and updated_fields: 

465 # Update record in target database 

466 await self.target_db.update(record.id, record) 

467 status.migrated_records += 1 

468 return True 

469 else: 

470 status.failed_records += 1 

471 status.errors.append({ 

472 "record_id": record.id, 

473 "error": "Failed to generate vectors", 

474 }) 

475 return False 

476 

477 except Exception as e: 

478 status.failed_records += 1 

479 status.errors.append({ 

480 "record_id": record.id, 

481 "error": str(e), 

482 }) 

483 logger.error(f"Failed to migrate record {record.id}: {e}") 

484 return False 

485 

486 async def _verify_migration( 

487 self, 

488 vector_fields: dict[str, str], 

489 records: list[Record], 

490 status: MigrationStatus, 

491 ) -> None: 

492 """Verify that migration was successful.""" 

493 logger.info("Verifying migration...") 

494 

495 for record in records: 

496 try: 

497 # Get updated record 

498 migrated = await self.target_db.read(record.id) 

499 

500 # Check vector fields 

501 all_present = True 

502 for vector_field, source_field in vector_fields.items(): 

503 source_value = record.get_value(source_field) 

504 if source_value: 

505 # Check if vector field exists (could be in fields or data) 

506 vector_data = migrated.get_value(vector_field) 

507 if vector_data is None: 

508 all_present = False 

509 break 

510 

511 # For VectorField objects, check the value 

512 from ..fields import VectorField 

513 if isinstance(migrated.fields.get(vector_field), VectorField): 

514 vector_data = migrated.fields[vector_field].value 

515 

516 if not isinstance(vector_data, (list, np.ndarray)): 

517 all_present = False 

518 break 

519 

520 if all_present: 

521 status.verified_records += 1 

522 

523 except Exception as e: 

524 logger.error(f"Failed to verify record {record.id}: {e}") 

525 

526 async def _rollback(self, status: MigrationStatus) -> None: 

527 """Rollback migration on failure.""" 

528 if not self._rollback_data: 

529 return 

530 

531 logger.info(f"Rolling back {len(self._rollback_data)} records...") 

532 

533 for record_id, original_data in self._rollback_data.items(): 

534 try: 

535 # Restore original record 

536 original_record = Record(id=record_id, data=original_data) 

537 await self.target_db.update(record_id, original_record) 

538 status.rollback_records += 1 

539 except Exception as e: 

540 logger.error(f"Failed to rollback record {record_id}: {e}") 

541 

542 async def migrate_between_backends( 

543 self, 

544 field_mapping: dict[str, str] | None = None, 

545 transform_fn: Callable[[Record], Record] | None = None, 

546 progress_callback: Callable[[MigrationStatus], None] | None = None, 

547 ) -> MigrationStatus: 

548 """Migrate vector data between different backends. 

549  

550 Args: 

551 field_mapping: Optional field name mapping 

552 transform_fn: Optional record transformation function 

553 progress_callback: Callback for progress updates 

554  

555 Returns: 

556 Migration status 

557 """ 

558 status = MigrationStatus(start_time=datetime.utcnow()) 

559 

560 try: 

561 # Get all records with vectors 

562 records = await self.source_db.all() 

563 status.total_records = len(records) 

564 

565 logger.info( 

566 f"Migrating {status.total_records} records from " 

567 f"{self.source_db.__class__.__name__} to " 

568 f"{self.target_db.__class__.__name__}" 

569 ) 

570 

571 # Process in batches 

572 for i in range(0, len(records), self.config.batch_size): 

573 batch = records[i:i + self.config.batch_size] 

574 

575 for original_record in batch: 

576 try: 

577 record = original_record 

578 # Apply field mapping 

579 if field_mapping: 

580 new_data = {} 

581 for old_field, new_field in field_mapping.items(): 

582 old_value = record.get_value(old_field) 

583 if old_value is not None: 

584 new_data[new_field] = old_value 

585 # Update record with new field mapping 

586 for field_name, value in new_data.items(): 

587 record.set_value(field_name, value) 

588 

589 # Apply transformation 

590 if transform_fn: 

591 record = transform_fn(record) 

592 

593 # Create in target database 

594 await self.target_db.create(record) 

595 status.migrated_records += 1 

596 

597 except Exception as e: 

598 status.failed_records += 1 

599 status.errors.append({ 

600 "record_id": record.id, 

601 "error": str(e), 

602 }) 

603 logger.error(f"Failed to migrate record {record.id}: {e}") 

604 

605 # Progress update 

606 if progress_callback: 

607 progress_callback(status) 

608 

609 finally: 

610 status.end_time = datetime.utcnow() 

611 

612 return status 

613 

614 @classmethod 

615 def from_config( 

616 cls, 

617 source_db: Database, 

618 target_db: Database | None, 

619 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]], 

620 config: MigrationConfig, 

621 text_fields: list[str] | None = None, 

622 vector_field: str = "embedding", 

623 model_name: str | None = None, 

624 model_version: str | None = None, 

625 ) -> VectorMigration: 

626 """Create migration from a config object for advanced use cases. 

627  

628 Args: 

629 source_db: Source database 

630 target_db: Target database (None for in-place) 

631 embedding_fn: Function to generate embeddings 

632 config: Migration configuration 

633 text_fields: Text field names (optional) 

634 vector_field: Name of the vector field 

635 model_name: Name of the embedding model 

636 model_version: Version of the embedding model 

637  

638 Returns: 

639 Configured VectorMigration instance 

640 """ 

641 return cls( 

642 source_db=source_db, 

643 target_db=target_db, 

644 embedding_fn=embedding_fn, 

645 text_fields=text_fields, 

646 vector_field=vector_field, 

647 batch_size=config.batch_size, 

648 model_name=model_name, 

649 model_version=model_version, 

650 config=config, 

651 ) 

652 

653 

654class IncrementalVectorizer: 

655 """Manages incremental vectorization of large datasets. 

656  

657 Examples: 

658 # Simple usage with single field 

659 vectorizer = IncrementalVectorizer( 

660 db, 

661 embedding_fn=model.encode, 

662 text_fields="content" # Can be string or list 

663 ) 

664 result = await vectorizer.run() 

665  

666 # Resume from checkpoint 

667 result = await vectorizer.run(resume_from=last_checkpoint) 

668  

669 # Process limited batch 

670 result = await vectorizer.run_batch(limit=1000) 

671 """ 

672 

673 def __init__( 

674 self, 

675 database: Database, 

676 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]], 

677 text_fields: list[str] | str | None = None, # Support multiple fields 

678 vector_field: str = "embedding", # Sensible default 

679 field_separator: str = " ", 

680 batch_size: int = 100, 

681 checkpoint_interval: int = 1000, 

682 max_workers: int = 4, 

683 model_name: str | None = None, 

684 model_version: str | None = None, 

685 ): 

686 """Initialize the incremental vectorizer with simplified parameters. 

687  

688 Args: 

689 database: Database to vectorize 

690 embedding_fn: Function to generate embeddings 

691 text_fields: Text field names to concatenate for embeddings 

692 vector_field: Name of the vector field to create 

693 field_separator: Separator for concatenating multiple text fields 

694 batch_size: Size of processing batches 

695 checkpoint_interval: Records between checkpoints 

696 max_workers: Maximum concurrent workers 

697 model_name: Name of the embedding model 

698 model_version: Version of the embedding model 

699 """ 

700 self.database = database 

701 self.embedding_fn = embedding_fn 

702 self.embedding_function = embedding_fn # Alias for compatibility 

703 

704 # Handle text fields 

705 if isinstance(text_fields, str): 

706 text_fields = [text_fields] 

707 elif text_fields is None: 

708 # Try to auto-detect from database schema 

709 text_fields = self._detect_text_fields() 

710 self.text_fields = text_fields 

711 

712 self.vector_field = vector_field 

713 self.field_separator = field_separator 

714 self.batch_size = batch_size 

715 self.checkpoint_interval = checkpoint_interval 

716 self.max_workers = max_workers 

717 self.model_name = model_name 

718 self.model_version = model_version 

719 

720 # Processing state 

721 self._queue: asyncio.Queue[Record] = asyncio.Queue() 

722 self._processing_task: asyncio.Task | None = None 

723 self._workers: list[asyncio.Task] = [] 

724 self._shutdown_event = asyncio.Event() 

725 self._stats = { 

726 "processed": 0, 

727 "failed": 0, 

728 "queued": 0, 

729 } 

730 self._last_checkpoint: str | None = None 

731 self._progress: VectorizationProgress | None = None 

732 

733 def _detect_text_fields(self) -> list[str]: 

734 """Auto-detect text fields from database schema.""" 

735 text_fields = [] 

736 if hasattr(self.database, 'schema') and self.database.schema: 

737 for field_name, field_schema in self.database.schema.fields.items(): 

738 if field_schema.type in (FieldType.STRING, FieldType.TEXT): 

739 text_fields.append(field_name) 

740 

741 # Default to common field names if no schema 

742 if not text_fields: 

743 text_fields = ["content", "text", "description"] 

744 

745 return text_fields 

746 

747 async def _worker(self, worker_id: int) -> None: 

748 """Worker task for processing records.""" 

749 logger.info(f"Worker {worker_id} started") 

750 

751 while not self._shutdown_event.is_set(): 

752 try: 

753 # Get record from queue with timeout 

754 try: 

755 record = await asyncio.wait_for( 

756 self._queue.get(), 

757 timeout=1.0 

758 ) 

759 except asyncio.TimeoutError: 

760 continue 

761 

762 # Process record 

763 await self._process_record(record) 

764 self._stats["processed"] += 1 

765 

766 except Exception as e: 

767 logger.error(f"Worker {worker_id} error: {e}") 

768 self._stats["failed"] += 1 

769 

770 logger.info(f"Worker {worker_id} stopped") 

771 

772 async def _process_record(self, record: Record) -> None: 

773 """Process a single record to add vectors.""" 

774 try: 

775 # Get source text from multiple fields 

776 text_parts = [] 

777 for field in self.text_fields: 

778 value = record.get_value(field) 

779 if value: 

780 text_parts.append(str(value)) 

781 

782 if not text_parts: 

783 return 

784 

785 source_text = self.field_separator.join(text_parts) 

786 

787 # Check if vector already exists 

788 vector_data = record.get_value(self.vector_field) 

789 if vector_data is not None: 

790 if vector_data and isinstance(vector_data, (list, np.ndarray)): 

791 return 

792 

793 # Generate embedding 

794 if asyncio.iscoroutinefunction(self.embedding_fn): 

795 embedding = await self.embedding_fn(str(source_text)) 

796 else: 

797 embedding = await asyncio.to_thread(self.embedding_fn, str(source_text)) 

798 

799 if embedding is None: 

800 return 

801 

802 # Update record 

803 update_data = { 

804 self.vector_field: embedding.tolist() if isinstance(embedding, np.ndarray) else embedding, 

805 } 

806 

807 # Add metadata 

808 if self.model_name: 

809 metadata = VectorMetadata( 

810 dimensions=len(embedding), 

811 source_field=self.field_separator.join(self.text_fields), 

812 model_name=self.model_name, 

813 model_version=self.model_version, 

814 updated_at=datetime.utcnow().isoformat(), 

815 ) 

816 update_data[f"{self.vector_field}_metadata"] = metadata.to_dict() 

817 

818 # Update the record with the new vector data 

819 for key, value in update_data.items(): 

820 record.set_value(key, value) 

821 await self.database.update(record.id, record) 

822 

823 except Exception as e: 

824 logger.error(f"Failed to process record {record.id}: {e}") 

825 raise 

826 

827 async def start(self) -> None: 

828 """Start incremental vectorization.""" 

829 if self._processing_task and not self._processing_task.done(): 

830 logger.warning("Vectorization already running") 

831 return 

832 

833 self._shutdown_event.clear() 

834 

835 # Start workers 

836 self._workers = [ 

837 asyncio.create_task(self._worker(i)) 

838 for i in range(self.max_workers) 

839 ] 

840 

841 # Start queue loader 

842 self._processing_task = asyncio.create_task(self._load_queue()) 

843 

844 logger.info(f"Started incremental vectorization with {self.max_workers} workers") 

845 

846 async def _load_queue(self) -> None: 

847 """Load records into processing queue.""" 

848 while not self._shutdown_event.is_set(): 

849 try: 

850 # Get records without vectors that have at least one text field 

851 filter_query = { 

852 self.vector_field: {"$exists": False}, 

853 "$or": [ 

854 {field: {"$exists": True, "$ne": ""}} 

855 for field in self.text_fields 

856 ], 

857 } 

858 

859 records = await self.database.filter(filter_query, limit=self.batch_size) 

860 

861 if not records: 

862 # No more records to process 

863 await asyncio.sleep(60) # Check again in a minute 

864 continue 

865 

866 # Add to queue 

867 for record in records: 

868 await self._queue.put(record) 

869 self._stats["queued"] += 1 

870 

871 except Exception as e: 

872 logger.error(f"Failed to load queue: {e}") 

873 await asyncio.sleep(10) 

874 

875 async def stop(self, timeout: float = 30.0) -> None: 

876 """Stop incremental vectorization. 

877  

878 Args: 

879 timeout: Maximum time to wait for graceful shutdown 

880 """ 

881 if not self._processing_task: 

882 return 

883 

884 logger.info("Stopping incremental vectorization...") 

885 self._shutdown_event.set() 

886 

887 # Cancel queue loader 

888 self._processing_task.cancel() 

889 try: 

890 await self._processing_task 

891 except asyncio.CancelledError: 

892 pass 

893 

894 # Wait for workers to finish 

895 try: 

896 await asyncio.wait_for( 

897 asyncio.gather(*self._workers), 

898 timeout=timeout 

899 ) 

900 except asyncio.TimeoutError: 

901 logger.warning("Workers did not stop gracefully, cancelling") 

902 for worker in self._workers: 

903 worker.cancel() 

904 

905 await asyncio.gather(*self._workers, return_exceptions=True) 

906 

907 self._workers.clear() 

908 self._processing_task = None 

909 

910 async def run( 

911 self, 

912 progress_callback: Callable[[int, int, list], None] | None = None, 

913 max_workers: int | None = None, 

914 ) -> dict[str, Any]: 

915 """Run the complete vectorization. 

916  

917 Args: 

918 progress_callback: Optional callback (completed, total, current_batch) 

919 max_workers: Override default max_workers 

920  

921 Returns: 

922 Results dictionary 

923 """ 

924 if max_workers: 

925 self.max_workers = max_workers 

926 

927 # Get all records that need vectors 

928 from ..query import Query 

929 all_records = await self.database.search(Query()) 

930 

931 to_process = [] 

932 for record in all_records: 

933 # Check if needs vectorization 

934 if self.vector_field not in record.fields: 

935 # Check if has text to vectorize 

936 has_text = False 

937 for field in self.text_fields: 

938 if record.get_value(field): 

939 has_text = True 

940 break 

941 if has_text: 

942 to_process.append(record) 

943 

944 total = len(to_process) 

945 processed = 0 

946 failed = 0 

947 

948 # Process in batches 

949 for i in range(0, total, self.batch_size): 

950 batch = to_process[i:i + self.batch_size] 

951 

952 for record in batch: 

953 try: 

954 await self._process_record(record) 

955 processed += 1 

956 except Exception as e: 

957 logger.error(f"Failed to process record {record.id}: {e}") 

958 failed += 1 

959 

960 if progress_callback: 

961 if asyncio.iscoroutinefunction(progress_callback): 

962 await progress_callback(processed, total, batch) 

963 else: 

964 progress_callback(processed, total, batch) 

965 

966 return { 

967 "processed": processed, 

968 "failed": failed, 

969 "total": total, 

970 } 

971 

972 async def get_status(self) -> dict[str, Any]: 

973 """Get current vectorization status. 

974  

975 Returns: 

976 Status dictionary 

977 """ 

978 # Count records with and without vectors 

979 from ..query import Query 

980 all_records = await self.database.search(Query()) 

981 

982 total = 0 

983 completed = 0 

984 

985 for record in all_records: 

986 # Check if has text fields 

987 has_text = False 

988 for field_name in self.text_fields: 

989 if record.get_value(field_name): 

990 has_text = True 

991 break 

992 

993 if has_text: 

994 total += 1 

995 if self.vector_field in record.fields: 

996 completed += 1 

997 

998 return { 

999 "total": total, 

1000 "completed": completed, 

1001 "remaining": total - completed, 

1002 "percentage": (completed / total * 100) if total > 0 else 0, 

1003 } 

1004 

1005 def get_stats(self) -> dict[str, Any]: 

1006 """Get vectorization statistics. 

1007  

1008 Returns: 

1009 Dictionary of statistics 

1010 """ 

1011 return { 

1012 **self._stats, 

1013 "queue_size": self._queue.qsize(), 

1014 "workers": len(self._workers), 

1015 "is_running": bool( 

1016 self._processing_task and not self._processing_task.done() 

1017 ), 

1018 } 

1019 

1020 async def wait_for_completion(self, check_interval: float = 5.0) -> None: 

1021 """Wait for all queued records to be processed. 

1022  

1023 Args: 

1024 check_interval: Seconds between queue checks 

1025 """ 

1026 while self._queue.qsize() > 0: 

1027 await asyncio.sleep(check_interval) 

1028 

1029 logger.info("All queued records processed") 

1030 

1031 async def run_with_checkpoint(self, resume_from: str | None = None) -> VectorizationResult: 

1032 """Run the complete vectorization with checkpoint support. 

1033  

1034 Args: 

1035 resume_from: Optional checkpoint ID to resume from 

1036  

1037 Returns: 

1038 Vectorization result with statistics 

1039 """ 

1040 await self.start() 

1041 await self.wait_for_completion() 

1042 

1043 return VectorizationResult( 

1044 processed=self._stats["processed"], 

1045 failed=self._stats["failed"], 

1046 checkpoint=self._last_checkpoint, 

1047 ) 

1048 

1049 async def run_batch(self, limit: int | None = None) -> VectorizationResult: 

1050 """Process a limited number of records. 

1051  

1052 Args: 

1053 limit: Maximum number of records to process 

1054  

1055 Returns: 

1056 Vectorization result with statistics 

1057 """ 

1058 # Temporarily modify batch size if limit provided 

1059 original_batch_size = self.batch_size 

1060 if limit: 

1061 self.batch_size = min(self.batch_size, limit) 

1062 

1063 try: 

1064 await self.start() 

1065 

1066 # Wait for limited processing 

1067 while self._stats["processed"] < (limit or float('inf')): 

1068 if self._queue.empty() and self._processing_task.done(): 

1069 break 

1070 await asyncio.sleep(0.1) 

1071 

1072 await self.stop() 

1073 

1074 return VectorizationResult( 

1075 processed=self._stats["processed"], 

1076 failed=self._stats["failed"], 

1077 checkpoint=self._last_checkpoint, 

1078 ) 

1079 finally: 

1080 self.batch_size = original_batch_size 

1081 

1082 @property 

1083 def progress(self) -> VectorizationProgress: 

1084 """Get current progress.""" 

1085 return VectorizationProgress( 

1086 total_records=self._stats.get("total", 0), 

1087 processed_records=self._stats["processed"], 

1088 failed_records=self._stats["failed"], 

1089 queued_records=self._queue.qsize(), 

1090 checkpoint=self._last_checkpoint, 

1091 ) 

1092 

1093 async def get_checkpoint(self) -> str: 

1094 """Get checkpoint ID for resuming.""" 

1095 # Save current progress as checkpoint 

1096 self._last_checkpoint = f"checkpoint_{self._stats['processed']}" 

1097 return self._last_checkpoint 

1098 

1099 

1100@dataclass 

1101class VectorizationResult: 

1102 """Result of a vectorization operation.""" 

1103 processed: int 

1104 failed: int 

1105 checkpoint: str | None = None 

1106 

1107 

1108@dataclass 

1109class VectorizationProgress: 

1110 """Current progress of vectorization.""" 

1111 total_records: int 

1112 processed_records: int 

1113 failed_records: int 

1114 queued_records: int 

1115 checkpoint: str | None = None