Coverage for src/dataknobs_data/vector/migration.py: 18%

497 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:23 -0700

1"""Migration tools for adding vector support to existing data.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import logging 

7from dataclasses import dataclass, field 

8from datetime import datetime 

9from typing import TYPE_CHECKING, Any 

10 

11import numpy as np 

12 

13from ..fields import FieldType 

14from ..query import Query 

15from ..records import Record 

16from ..schema import FieldSchema 

17from .sync import SyncConfig, VectorTextSynchronizer 

18from .types import VectorMetadata 

19 

20if TYPE_CHECKING: 

21 from collections.abc import Callable, Coroutine 

22 

23 from ..database import Database 

24 

25logger = logging.getLogger(__name__) 

26 

27 

28@dataclass 

29class MigrationConfig: 

30 """Configuration for vector migration.""" 

31 

32 batch_size: int = 100 

33 max_workers: int = 4 

34 checkpoint_interval: int = 1000 

35 enable_rollback: bool = True 

36 verify_migration: bool = True 

37 retry_failed: bool = True 

38 max_retries: int = 3 

39 max_consecutive_failures: int = 5 # Fail fast after this many consecutive failures 

40 

41 def validate(self) -> None: 

42 """Validate configuration parameters.""" 

43 if self.batch_size <= 0: 

44 raise ValueError(f"Batch size must be positive, got {self.batch_size}") 

45 if self.max_workers <= 0: 

46 raise ValueError(f"Max workers must be positive, got {self.max_workers}") 

47 

48 

49@dataclass 

50class MigrationStatus: 

51 """Status of a migration operation.""" 

52 

53 total_records: int = 0 

54 migrated_records: int = 0 

55 verified_records: int = 0 

56 failed_records: int = 0 

57 rollback_records: int = 0 

58 errors: list[dict[str, Any]] = field(default_factory=list) 

59 checkpoints: list[dict[str, Any]] = field(default_factory=list) 

60 start_time: datetime | None = None 

61 end_time: datetime | None = None 

62 

63 @property 

64 def total_processed(self) -> int: 

65 """Total number of processed records (migrated + failed).""" 

66 return self.migrated_records + self.failed_records 

67 

68 @property 

69 def failed_count(self) -> int: 

70 """Alias for failed_records for compatibility.""" 

71 return self.failed_records 

72 

73 @property 

74 def success_rate(self) -> float: 

75 """Calculate the success rate of the migration.""" 

76 if self.total_records == 0: 

77 return 0.0 

78 return self.migrated_records / self.total_records 

79 

80 @property 

81 def duration(self) -> float | None: 

82 """Calculate the duration of the migration in seconds.""" 

83 if self.start_time and self.end_time: 

84 return (self.end_time - self.start_time).total_seconds() 

85 return None 

86 

87 @property 

88 def records_per_second(self) -> float: 

89 """Calculate the migration speed.""" 

90 duration = self.duration 

91 if duration and duration > 0: 

92 return self.migrated_records / duration 

93 return 0.0 

94 

95 def add_checkpoint(self, name: str, record_id: str | None = None) -> None: 

96 """Add a checkpoint to the migration.""" 

97 self.checkpoints.append({ 

98 "name": name, 

99 "record_id": record_id, 

100 "timestamp": datetime.utcnow().isoformat(), 

101 "migrated": self.migrated_records, 

102 "failed": self.failed_records, 

103 }) 

104 

105 def to_dict(self) -> dict[str, Any]: 

106 """Convert to dictionary representation.""" 

107 return { 

108 "total_records": self.total_records, 

109 "migrated_records": self.migrated_records, 

110 "verified_records": self.verified_records, 

111 "failed_records": self.failed_records, 

112 "rollback_records": self.rollback_records, 

113 "success_rate": self.success_rate, 

114 "duration": self.duration, 

115 "records_per_second": self.records_per_second, 

116 "errors": self.errors, 

117 "checkpoints": self.checkpoints, 

118 "start_time": self.start_time.isoformat() if self.start_time else None, 

119 "end_time": self.end_time.isoformat() if self.end_time else None, 

120 } 

121 

122 

123class VectorMigration: 

124 """Manages migration of existing data to include vector embeddings.""" 

125 

126 def __init__( 

127 self, 

128 source_db: Database, 

129 target_db: Database | None = None, 

130 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]] = None, 

131 text_fields: list[str] | None = None, 

132 vector_field: str = "embedding", 

133 field_separator: str = " ", 

134 batch_size: int = 100, 

135 max_retries: int = 3, 

136 retry_delay: float = 1.0, 

137 model_name: str | None = None, 

138 model_version: str | None = None, 

139 config: MigrationConfig | None = None, 

140 ): 

141 """Initialize the migration manager with simplified API. 

142  

143 Args: 

144 source_db: Source database to migrate from 

145 target_db: Target database (None to migrate in-place) 

146 embedding_fn: Function to generate embeddings 

147 text_fields: Fields to concatenate for embedding 

148 vector_field: Name of the vector field to create 

149 field_separator: Separator for concatenating text fields 

150 batch_size: Batch size for processing 

151 max_retries: Maximum retry attempts 

152 retry_delay: Delay between retries 

153 model_name: Name of the embedding model 

154 model_version: Version of the embedding model 

155 config: Advanced configuration (overrides other params) 

156 """ 

157 self.source_db = source_db 

158 self.target_db = target_db or source_db 

159 self.embedding_fn = embedding_fn 

160 self.embedding_function = embedding_fn # Alias for compatibility 

161 self.text_fields = text_fields or [] 

162 self.vector_field = vector_field 

163 self.field_separator = field_separator 

164 self.batch_size = batch_size 

165 self.max_retries = max_retries 

166 self.retry_delay = retry_delay 

167 self.model_name = model_name 

168 self.model_version = model_version 

169 

170 # Use config if provided, otherwise create from params 

171 if config: 

172 self.config = config 

173 else: 

174 self.config = MigrationConfig( 

175 batch_size=batch_size, 

176 max_retries=max_retries, 

177 ) 

178 self.config.validate() 

179 

180 # Migration status 

181 self.status = MigrationStatus() 

182 

183 # Track rollback data if enabled 

184 self._rollback_data: dict[str, dict[str, Any]] = {} 

185 

186 async def run( 

187 self, 

188 progress_callback: Callable[[MigrationStatus], None] | None = None 

189 ) -> MigrationStatus: 

190 """Run the complete migration. 

191  

192 Args: 

193 progress_callback: Optional callback for progress updates 

194  

195 Returns: 

196 Migration status 

197 """ 

198 self.status = MigrationStatus(start_time=datetime.utcnow()) 

199 

200 try: 

201 # Get all records from source 

202 all_records = await self.source_db.search(Query()) 

203 self.status.total_records = len(all_records) 

204 

205 # Process in batches 

206 for i in range(0, len(all_records), self.batch_size): 

207 batch = all_records[i:i + self.batch_size] 

208 

209 for record in batch: 

210 try: 

211 # Concatenate text fields 

212 text_parts = [] 

213 for field in self.text_fields: 

214 value = record.get_value(field) 

215 if value: 

216 text_parts.append(str(value)) 

217 

218 if text_parts: 

219 text = self.field_separator.join(text_parts) 

220 

221 # Generate embedding 

222 if asyncio.iscoroutinefunction(self.embedding_fn): 

223 embedding = await self.embedding_fn(text) 

224 else: 

225 embedding = await asyncio.to_thread(self.embedding_fn, text) 

226 

227 # Create VectorField 

228 from ..fields import VectorField 

229 vector_field_obj = VectorField( 

230 value=embedding, 

231 name=self.vector_field, 

232 source_field=self.text_fields[0] if len(self.text_fields) == 1 else None, 

233 model_name=self.model_name, 

234 model_version=self.model_version, 

235 ) 

236 

237 # Add to record 

238 record.fields[self.vector_field] = vector_field_obj 

239 

240 # Create in target database 

241 await self.target_db.create(record) 

242 self.status.migrated_records += 1 

243 

244 except Exception as e: 

245 logger.error(f"Failed to migrate record {record.id}: {e}") 

246 self.status.failed_records += 1 

247 self.status.errors.append({"record_id": record.id, "error": str(e)}) 

248 

249 if progress_callback: 

250 progress_callback(self.status) 

251 

252 self.status.end_time = datetime.utcnow() 

253 return self.status 

254 

255 except Exception as e: 

256 logger.error(f"Migration failed: {e}") 

257 self.status.failed_records = self.status.total_records - self.status.migrated_records 

258 self.status.end_time = datetime.utcnow() 

259 return self.status 

260 

261 async def start(self) -> None: 

262 """Start migration (for compatibility).""" 

263 # Migration runs synchronously in run() method 

264 pass 

265 

266 async def wait_for_completion(self, progress_callback: Callable[[MigrationStatus], None] | None = None) -> MigrationStatus: 

267 """Wait for migration completion (for compatibility).""" 

268 # Since run() is synchronous, just return current status 

269 return self.status 

270 

271 async def add_vectors_to_existing( 

272 self, 

273 vector_fields: dict[str, str], # vector_field -> source_field mapping 

274 filter_query: dict[str, Any] | None = None, 

275 progress_callback: Callable[[MigrationStatus], None] | None = None, 

276 ) -> MigrationStatus: 

277 """Add vector fields to existing records. 

278  

279 Args: 

280 vector_fields: Mapping of vector field names to source text fields 

281 filter_query: Optional filter to select records to migrate 

282 progress_callback: Callback for progress updates 

283  

284 Returns: 

285 Migration status 

286 """ 

287 if not self.embedding_fn: 

288 raise ValueError("Embedding function required for adding vectors") 

289 

290 status = MigrationStatus(start_time=datetime.utcnow()) 

291 

292 try: 

293 # Get records to migrate 

294 if filter_query: 

295 # Convert filter_query dict to Query object 

296 query = Query() 

297 for field, value in filter_query.items(): 

298 query = query.filter(field, "==", value) 

299 records = await self.source_db.search(query) 

300 else: 

301 records = await self.source_db.all() 

302 

303 status.total_records = len(records) 

304 logger.info(f"Starting migration of {status.total_records} records") 

305 

306 # Create synchronizer with wrapped embedding function 

307 sync_config = SyncConfig( 

308 batch_size=self.config.batch_size, 

309 max_retries=self.config.max_retries, 

310 ) 

311 

312 # Track the last embedding exception 

313 last_embedding_exception = None 

314 

315 # Create wrapper that captures exceptions 

316 async def embedding_wrapper(text: str) -> np.ndarray: 

317 nonlocal last_embedding_exception 

318 try: 

319 if asyncio.iscoroutinefunction(self.embedding_fn): 

320 result = await self.embedding_fn(text) 

321 else: 

322 result = await asyncio.to_thread(self.embedding_fn, text) 

323 return result 

324 except Exception as e: 

325 last_embedding_exception = e 

326 raise 

327 

328 synchronizer = VectorTextSynchronizer( 

329 database=self.target_db, 

330 embedding_fn=embedding_wrapper, 

331 config=sync_config, 

332 model_name=self.model_name, 

333 model_version=self.model_version, 

334 ) 

335 

336 # Process in batches 

337 consecutive_batch_failures = 0 

338 for i in range(0, len(records), self.config.batch_size): 

339 batch = records[i:i + self.config.batch_size] 

340 

341 # Process batch with workers 

342 tasks = [] 

343 for record in batch: 

344 # Store original data for rollback 

345 if self.config.enable_rollback: 

346 # Store original field values for rollback 

347 self._rollback_data[record.id] = { 

348 field_name: record.get_value(field_name) 

349 for field_name in record.fields.keys() 

350 } 

351 

352 # Add vector fields to record 

353 for vector_field, source_field in vector_fields.items(): 

354 if record.get_value(source_field) is None: 

355 continue 

356 

357 # Add vector field schema if needed 

358 if vector_field not in self.target_db.schema.fields: 

359 source_text = record.get_value(source_field) 

360 if source_text: 

361 # Get dimensions from first embedding 

362 sample_embedding = await self._get_embedding(str(source_text)) 

363 if sample_embedding is not None: 

364 dimensions = len(sample_embedding) 

365 # Add schema for vector field 

366 field_schema = FieldSchema( 

367 name=vector_field, 

368 type=FieldType.VECTOR, 

369 metadata={ 

370 "dimensions": dimensions, 

371 "source_field": source_field, 

372 } 

373 ) 

374 self.target_db.add_field_schema(field_schema) 

375 

376 # Create migration task 

377 task = self._migrate_record( 

378 synchronizer, 

379 record, 

380 vector_fields, 

381 status, 

382 ) 

383 tasks.append(task) 

384 

385 # Wait for batch to complete 

386 results = await asyncio.gather(*tasks, return_exceptions=False) 

387 

388 # Check for batch failures and fail fast if needed 

389 batch_failed_count = sum(1 for r in results if r is False) 

390 if batch_failed_count == len(results) and len(results) > 0: 

391 consecutive_batch_failures += 1 

392 # If multiple consecutive batches completely fail, re-raise the last exception 

393 if consecutive_batch_failures >= 2 and self.config.enable_rollback: 

394 if last_embedding_exception: 

395 raise last_embedding_exception 

396 else: 

397 raise Exception("Migration failed: consecutive batch failures") 

398 else: 

399 consecutive_batch_failures = 0 

400 

401 # Checkpoint if needed 

402 if status.migrated_records % self.config.checkpoint_interval == 0: 

403 status.add_checkpoint( 

404 f"Batch {i // self.config.batch_size + 1}", 

405 batch[-1].id if batch else None, 

406 ) 

407 if progress_callback: 

408 progress_callback(status) 

409 

410 # Verify migration if enabled 

411 if self.config.verify_migration: 

412 await self._verify_migration(vector_fields, records, status) 

413 

414 except Exception as e: 

415 logger.error(f"Migration failed: {e}") 

416 if self.config.enable_rollback: 

417 await self._rollback(status) 

418 raise 

419 

420 finally: 

421 status.end_time = datetime.utcnow() 

422 

423 logger.info( 

424 f"Migration completed: {status.migrated_records}/{status.total_records} " 

425 f"migrated, {status.failed_records} failed" 

426 ) 

427 

428 return status 

429 

430 async def _get_embedding(self, text: str) -> np.ndarray | None: 

431 """Get embedding for text.""" 

432 try: 

433 if asyncio.iscoroutinefunction(self.embedding_fn): 

434 result = await self.embedding_fn(text) 

435 else: 

436 result = await asyncio.to_thread(self.embedding_fn, text) 

437 

438 if isinstance(result, np.ndarray): 

439 return result 

440 elif isinstance(result, list): 

441 return np.array(result) 

442 return None 

443 

444 except Exception as e: 

445 logger.error(f"Failed to get embedding: {e}") 

446 return None 

447 

448 async def _migrate_record( 

449 self, 

450 synchronizer: VectorTextSynchronizer, 

451 record: Record, 

452 vector_fields: dict[str, str], 

453 status: MigrationStatus, 

454 ) -> bool: 

455 """Migrate a single record. 

456  

457 Returns: 

458 True if migration succeeded, False otherwise 

459 """ 

460 try: 

461 # Sync vectors 

462 success, updated_fields = await synchronizer.sync_record(record, force=True) 

463 

464 if success and updated_fields: 

465 # Update record in target database 

466 await self.target_db.update(record.id, record) 

467 status.migrated_records += 1 

468 return True 

469 else: 

470 status.failed_records += 1 

471 status.errors.append({ 

472 "record_id": record.id, 

473 "error": "Failed to generate vectors", 

474 }) 

475 return False 

476 

477 except Exception as e: 

478 status.failed_records += 1 

479 status.errors.append({ 

480 "record_id": record.id, 

481 "error": str(e), 

482 }) 

483 logger.error(f"Failed to migrate record {record.id}: {e}") 

484 return False 

485 

486 async def _verify_migration( 

487 self, 

488 vector_fields: dict[str, str], 

489 records: list[Record], 

490 status: MigrationStatus, 

491 ) -> None: 

492 """Verify that migration was successful.""" 

493 logger.info("Verifying migration...") 

494 

495 for record in records: 

496 try: 

497 # Get updated record 

498 migrated = await self.target_db.read(record.id) 

499 

500 # Check vector fields 

501 all_present = True 

502 for vector_field, source_field in vector_fields.items(): 

503 source_value = record.get_value(source_field) 

504 if source_value: 

505 # Check if vector field exists (could be in fields or data) 

506 vector_data = migrated.get_value(vector_field) 

507 if vector_data is None: 

508 all_present = False 

509 break 

510 

511 # For VectorField objects, check the value 

512 from ..fields import VectorField 

513 if isinstance(migrated.fields.get(vector_field), VectorField): 

514 vector_data = migrated.fields[vector_field].value 

515 

516 if not isinstance(vector_data, (list, np.ndarray)): 

517 all_present = False 

518 break 

519 

520 if all_present: 

521 status.verified_records += 1 

522 

523 except Exception as e: 

524 logger.error(f"Failed to verify record {record.id}: {e}") 

525 

526 async def _rollback(self, status: MigrationStatus) -> None: 

527 """Rollback migration on failure.""" 

528 if not self._rollback_data: 

529 return 

530 

531 logger.info(f"Rolling back {len(self._rollback_data)} records...") 

532 

533 for record_id, original_data in self._rollback_data.items(): 

534 try: 

535 # Restore original record 

536 original_record = Record(id=record_id, data=original_data) 

537 await self.target_db.update(record_id, original_record) 

538 status.rollback_records += 1 

539 except Exception as e: 

540 logger.error(f"Failed to rollback record {record_id}: {e}") 

541 

542 async def migrate_between_backends( 

543 self, 

544 field_mapping: dict[str, str] | None = None, 

545 transform_fn: Callable[[Record], Record] | None = None, 

546 progress_callback: Callable[[MigrationStatus], None] | None = None, 

547 ) -> MigrationStatus: 

548 """Migrate vector data between different backends. 

549  

550 Args: 

551 field_mapping: Optional field name mapping 

552 transform_fn: Optional record transformation function 

553 progress_callback: Callback for progress updates 

554  

555 Returns: 

556 Migration status 

557 """ 

558 status = MigrationStatus(start_time=datetime.utcnow()) 

559 

560 try: 

561 # Get all records with vectors 

562 records = await self.source_db.all() 

563 status.total_records = len(records) 

564 

565 logger.info( 

566 f"Migrating {status.total_records} records from " 

567 f"{self.source_db.__class__.__name__} to " 

568 f"{self.target_db.__class__.__name__}" 

569 ) 

570 

571 # Process in batches 

572 for i in range(0, len(records), self.config.batch_size): 

573 batch = records[i:i + self.config.batch_size] 

574 

575 for original_record in batch: 

576 try: 

577 record = original_record 

578 # Apply field mapping 

579 if field_mapping: 

580 new_data = {} 

581 for old_field, new_field in field_mapping.items(): 

582 old_value = record.get_value(old_field) 

583 if old_value is not None: 

584 new_data[new_field] = old_value 

585 # Update record with new field mapping 

586 for field_name, value in new_data.items(): 

587 record.set_value(field_name, value) 

588 

589 # Apply transformation 

590 if transform_fn: 

591 record = transform_fn(record) 

592 

593 # Create in target database 

594 await self.target_db.create(record) 

595 status.migrated_records += 1 

596 

597 except Exception as e: 

598 status.failed_records += 1 

599 status.errors.append({ 

600 "record_id": record.id, 

601 "error": str(e), 

602 }) 

603 logger.error(f"Failed to migrate record {record.id}: {e}") 

604 

605 # Progress update 

606 if progress_callback: 

607 progress_callback(status) 

608 

609 finally: 

610 status.end_time = datetime.utcnow() 

611 

612 return status 

613 

614 @classmethod 

615 def from_config( 

616 cls, 

617 source_db: Database, 

618 target_db: Database | None, 

619 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]], 

620 config: MigrationConfig, 

621 text_fields: list[str] | None = None, 

622 vector_field: str = "embedding", 

623 model_name: str | None = None, 

624 model_version: str | None = None, 

625 ) -> VectorMigration: 

626 """Create migration from a config object for advanced use cases. 

627  

628 Args: 

629 source_db: Source database 

630 target_db: Target database (None for in-place) 

631 embedding_fn: Function to generate embeddings 

632 config: Migration configuration 

633 text_fields: Text field names (optional) 

634 vector_field: Name of the vector field 

635 model_name: Name of the embedding model 

636 model_version: Version of the embedding model 

637  

638 Returns: 

639 Configured VectorMigration instance 

640 """ 

641 return cls( 

642 source_db=source_db, 

643 target_db=target_db, 

644 embedding_fn=embedding_fn, 

645 text_fields=text_fields, 

646 vector_field=vector_field, 

647 batch_size=config.batch_size, 

648 model_name=model_name, 

649 model_version=model_version, 

650 config=config, 

651 ) 

652 

653 

654class IncrementalVectorizer: 

655 """Manages incremental vectorization of large datasets. 

656 

657 Examples: 

658 import numpy as np 

659 from dataknobs_data import database_factory 

660 

661 # Create database and embedding function 

662 db = database_factory.create(backend="memory") 

663 

664 def embedding_fn(text): 

665 # In practice, use a real model like sentence-transformers 

666 return np.random.rand(384).astype(np.float32) 

667 

668 # Simple usage with single field 

669 vectorizer = IncrementalVectorizer( 

670 db, 

671 embedding_fn=embedding_fn, 

672 text_fields="content" # Can be string or list 

673 ) 

674 result = await vectorizer.run() 

675 

676 # Resume from checkpoint 

677 result = await vectorizer.run(resume_from=last_checkpoint) 

678 

679 # Process limited batch 

680 result = await vectorizer.run_batch(limit=1000) 

681 """ 

682 

683 def __init__( 

684 self, 

685 database: Database, 

686 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]], 

687 text_fields: list[str] | str | None = None, # Support multiple fields 

688 vector_field: str = "embedding", # Sensible default 

689 field_separator: str = " ", 

690 batch_size: int = 100, 

691 checkpoint_interval: int = 1000, 

692 max_workers: int = 4, 

693 model_name: str | None = None, 

694 model_version: str | None = None, 

695 ): 

696 """Initialize the incremental vectorizer with simplified parameters. 

697  

698 Args: 

699 database: Database to vectorize 

700 embedding_fn: Function to generate embeddings 

701 text_fields: Text field names to concatenate for embeddings 

702 vector_field: Name of the vector field to create 

703 field_separator: Separator for concatenating multiple text fields 

704 batch_size: Size of processing batches 

705 checkpoint_interval: Records between checkpoints 

706 max_workers: Maximum concurrent workers 

707 model_name: Name of the embedding model 

708 model_version: Version of the embedding model 

709 """ 

710 self.database = database 

711 self.embedding_fn = embedding_fn 

712 self.embedding_function = embedding_fn # Alias for compatibility 

713 

714 # Handle text fields 

715 if isinstance(text_fields, str): 

716 text_fields = [text_fields] 

717 elif text_fields is None: 

718 # Try to auto-detect from database schema 

719 text_fields = self._detect_text_fields() 

720 self.text_fields = text_fields 

721 

722 self.vector_field = vector_field 

723 self.field_separator = field_separator 

724 self.batch_size = batch_size 

725 self.checkpoint_interval = checkpoint_interval 

726 self.max_workers = max_workers 

727 self.model_name = model_name 

728 self.model_version = model_version 

729 

730 # Processing state 

731 self._queue: asyncio.Queue[Record] = asyncio.Queue() 

732 self._processing_task: asyncio.Task | None = None 

733 self._workers: list[asyncio.Task] = [] 

734 self._shutdown_event = asyncio.Event() 

735 self._stats = { 

736 "processed": 0, 

737 "failed": 0, 

738 "queued": 0, 

739 } 

740 self._last_checkpoint: str | None = None 

741 self._progress: VectorizationProgress | None = None 

742 

743 def _detect_text_fields(self) -> list[str]: 

744 """Auto-detect text fields from database schema.""" 

745 text_fields = [] 

746 if hasattr(self.database, 'schema') and self.database.schema: 

747 for field_name, field_schema in self.database.schema.fields.items(): 

748 if field_schema.type in (FieldType.STRING, FieldType.TEXT): 

749 text_fields.append(field_name) 

750 

751 # Default to common field names if no schema 

752 if not text_fields: 

753 text_fields = ["content", "text", "description"] 

754 

755 return text_fields 

756 

757 async def _worker(self, worker_id: int) -> None: 

758 """Worker task for processing records.""" 

759 logger.info(f"Worker {worker_id} started") 

760 

761 while not self._shutdown_event.is_set(): 

762 try: 

763 # Get record from queue with timeout 

764 try: 

765 record = await asyncio.wait_for( 

766 self._queue.get(), 

767 timeout=1.0 

768 ) 

769 except asyncio.TimeoutError: 

770 continue 

771 

772 # Process record 

773 await self._process_record(record) 

774 self._stats["processed"] += 1 

775 

776 except Exception as e: 

777 logger.error(f"Worker {worker_id} error: {e}") 

778 self._stats["failed"] += 1 

779 

780 logger.info(f"Worker {worker_id} stopped") 

781 

782 async def _process_record(self, record: Record) -> None: 

783 """Process a single record to add vectors.""" 

784 try: 

785 # Get source text from multiple fields 

786 text_parts = [] 

787 for field in self.text_fields: 

788 value = record.get_value(field) 

789 if value: 

790 text_parts.append(str(value)) 

791 

792 if not text_parts: 

793 return 

794 

795 source_text = self.field_separator.join(text_parts) 

796 

797 # Check if vector already exists 

798 vector_data = record.get_value(self.vector_field) 

799 if vector_data is not None: 

800 if vector_data and isinstance(vector_data, (list, np.ndarray)): 

801 return 

802 

803 # Generate embedding 

804 if asyncio.iscoroutinefunction(self.embedding_fn): 

805 embedding = await self.embedding_fn(str(source_text)) 

806 else: 

807 embedding = await asyncio.to_thread(self.embedding_fn, str(source_text)) 

808 

809 if embedding is None: 

810 return 

811 

812 # Update record 

813 update_data = { 

814 self.vector_field: embedding.tolist() if isinstance(embedding, np.ndarray) else embedding, 

815 } 

816 

817 # Add metadata 

818 if self.model_name: 

819 metadata = VectorMetadata( 

820 dimensions=len(embedding), 

821 source_field=self.field_separator.join(self.text_fields), 

822 model_name=self.model_name, 

823 model_version=self.model_version, 

824 updated_at=datetime.utcnow().isoformat(), 

825 ) 

826 update_data[f"{self.vector_field}_metadata"] = metadata.to_dict() 

827 

828 # Update the record with the new vector data 

829 for key, value in update_data.items(): 

830 record.set_value(key, value) 

831 await self.database.update(record.id, record) 

832 

833 except Exception as e: 

834 logger.error(f"Failed to process record {record.id}: {e}") 

835 raise 

836 

837 async def start(self) -> None: 

838 """Start incremental vectorization.""" 

839 if self._processing_task and not self._processing_task.done(): 

840 logger.warning("Vectorization already running") 

841 return 

842 

843 self._shutdown_event.clear() 

844 

845 # Start workers 

846 self._workers = [ 

847 asyncio.create_task(self._worker(i)) 

848 for i in range(self.max_workers) 

849 ] 

850 

851 # Start queue loader 

852 self._processing_task = asyncio.create_task(self._load_queue()) 

853 

854 logger.info(f"Started incremental vectorization with {self.max_workers} workers") 

855 

856 async def _load_queue(self) -> None: 

857 """Load records into processing queue.""" 

858 while not self._shutdown_event.is_set(): 

859 try: 

860 # Get records without vectors that have at least one text field 

861 filter_query = { 

862 self.vector_field: {"$exists": False}, 

863 "$or": [ 

864 {field: {"$exists": True, "$ne": ""}} 

865 for field in self.text_fields 

866 ], 

867 } 

868 

869 records = await self.database.filter(filter_query, limit=self.batch_size) 

870 

871 if not records: 

872 # No more records to process 

873 await asyncio.sleep(60) # Check again in a minute 

874 continue 

875 

876 # Add to queue 

877 for record in records: 

878 await self._queue.put(record) 

879 self._stats["queued"] += 1 

880 

881 except Exception as e: 

882 logger.error(f"Failed to load queue: {e}") 

883 await asyncio.sleep(10) 

884 

885 async def stop(self, timeout: float = 30.0) -> None: 

886 """Stop incremental vectorization. 

887  

888 Args: 

889 timeout: Maximum time to wait for graceful shutdown 

890 """ 

891 if not self._processing_task: 

892 return 

893 

894 logger.info("Stopping incremental vectorization...") 

895 self._shutdown_event.set() 

896 

897 # Cancel queue loader 

898 self._processing_task.cancel() 

899 try: 

900 await self._processing_task 

901 except asyncio.CancelledError: 

902 pass 

903 

904 # Wait for workers to finish 

905 try: 

906 await asyncio.wait_for( 

907 asyncio.gather(*self._workers), 

908 timeout=timeout 

909 ) 

910 except asyncio.TimeoutError: 

911 logger.warning("Workers did not stop gracefully, cancelling") 

912 for worker in self._workers: 

913 worker.cancel() 

914 

915 await asyncio.gather(*self._workers, return_exceptions=True) 

916 

917 self._workers.clear() 

918 self._processing_task = None 

919 

920 async def run( 

921 self, 

922 progress_callback: Callable[[int, int, list], None] | None = None, 

923 max_workers: int | None = None, 

924 ) -> dict[str, Any]: 

925 """Run the complete vectorization. 

926  

927 Args: 

928 progress_callback: Optional callback (completed, total, current_batch) 

929 max_workers: Override default max_workers 

930  

931 Returns: 

932 Results dictionary 

933 """ 

934 if max_workers: 

935 self.max_workers = max_workers 

936 

937 # Get all records that need vectors 

938 from ..query import Query 

939 all_records = await self.database.search(Query()) 

940 

941 to_process = [] 

942 for record in all_records: 

943 # Check if needs vectorization 

944 if self.vector_field not in record.fields: 

945 # Check if has text to vectorize 

946 has_text = False 

947 for field in self.text_fields: 

948 if record.get_value(field): 

949 has_text = True 

950 break 

951 if has_text: 

952 to_process.append(record) 

953 

954 total = len(to_process) 

955 processed = 0 

956 failed = 0 

957 

958 # Process in batches 

959 for i in range(0, total, self.batch_size): 

960 batch = to_process[i:i + self.batch_size] 

961 

962 for record in batch: 

963 try: 

964 await self._process_record(record) 

965 processed += 1 

966 except Exception as e: 

967 logger.error(f"Failed to process record {record.id}: {e}") 

968 failed += 1 

969 

970 if progress_callback: 

971 if asyncio.iscoroutinefunction(progress_callback): 

972 await progress_callback(processed, total, batch) 

973 else: 

974 progress_callback(processed, total, batch) 

975 

976 return { 

977 "processed": processed, 

978 "failed": failed, 

979 "total": total, 

980 } 

981 

982 async def get_status(self) -> dict[str, Any]: 

983 """Get current vectorization status. 

984  

985 Returns: 

986 Status dictionary 

987 """ 

988 # Count records with and without vectors 

989 from ..query import Query 

990 all_records = await self.database.search(Query()) 

991 

992 total = 0 

993 completed = 0 

994 

995 for record in all_records: 

996 # Check if has text fields 

997 has_text = False 

998 for field_name in self.text_fields: 

999 if record.get_value(field_name): 

1000 has_text = True 

1001 break 

1002 

1003 if has_text: 

1004 total += 1 

1005 if self.vector_field in record.fields: 

1006 completed += 1 

1007 

1008 return { 

1009 "total": total, 

1010 "completed": completed, 

1011 "remaining": total - completed, 

1012 "percentage": (completed / total * 100) if total > 0 else 0, 

1013 } 

1014 

1015 def get_stats(self) -> dict[str, Any]: 

1016 """Get vectorization statistics. 

1017  

1018 Returns: 

1019 Dictionary of statistics 

1020 """ 

1021 return { 

1022 **self._stats, 

1023 "queue_size": self._queue.qsize(), 

1024 "workers": len(self._workers), 

1025 "is_running": bool( 

1026 self._processing_task and not self._processing_task.done() 

1027 ), 

1028 } 

1029 

1030 async def wait_for_completion(self, check_interval: float = 5.0) -> None: 

1031 """Wait for all queued records to be processed. 

1032  

1033 Args: 

1034 check_interval: Seconds between queue checks 

1035 """ 

1036 while self._queue.qsize() > 0: 

1037 await asyncio.sleep(check_interval) 

1038 

1039 logger.info("All queued records processed") 

1040 

1041 async def run_with_checkpoint(self, resume_from: str | None = None) -> VectorizationResult: 

1042 """Run the complete vectorization with checkpoint support. 

1043  

1044 Args: 

1045 resume_from: Optional checkpoint ID to resume from 

1046  

1047 Returns: 

1048 Vectorization result with statistics 

1049 """ 

1050 await self.start() 

1051 await self.wait_for_completion() 

1052 

1053 return VectorizationResult( 

1054 processed=self._stats["processed"], 

1055 failed=self._stats["failed"], 

1056 checkpoint=self._last_checkpoint, 

1057 ) 

1058 

1059 async def run_batch(self, limit: int | None = None) -> VectorizationResult: 

1060 """Process a limited number of records. 

1061  

1062 Args: 

1063 limit: Maximum number of records to process 

1064  

1065 Returns: 

1066 Vectorization result with statistics 

1067 """ 

1068 # Temporarily modify batch size if limit provided 

1069 original_batch_size = self.batch_size 

1070 if limit: 

1071 self.batch_size = min(self.batch_size, limit) 

1072 

1073 try: 

1074 await self.start() 

1075 

1076 # Wait for limited processing 

1077 while self._stats["processed"] < (limit or float('inf')): 

1078 if self._queue.empty() and self._processing_task.done(): 

1079 break 

1080 await asyncio.sleep(0.1) 

1081 

1082 await self.stop() 

1083 

1084 return VectorizationResult( 

1085 processed=self._stats["processed"], 

1086 failed=self._stats["failed"], 

1087 checkpoint=self._last_checkpoint, 

1088 ) 

1089 finally: 

1090 self.batch_size = original_batch_size 

1091 

1092 @property 

1093 def progress(self) -> VectorizationProgress: 

1094 """Get current progress.""" 

1095 return VectorizationProgress( 

1096 total_records=self._stats.get("total", 0), 

1097 processed_records=self._stats["processed"], 

1098 failed_records=self._stats["failed"], 

1099 queued_records=self._queue.qsize(), 

1100 checkpoint=self._last_checkpoint, 

1101 ) 

1102 

1103 async def get_checkpoint(self) -> str: 

1104 """Get checkpoint ID for resuming.""" 

1105 # Save current progress as checkpoint 

1106 self._last_checkpoint = f"checkpoint_{self._stats['processed']}" 

1107 return self._last_checkpoint 

1108 

1109 

1110@dataclass 

1111class VectorizationResult: 

1112 """Result of a vectorization operation.""" 

1113 processed: int 

1114 failed: int 

1115 checkpoint: str | None = None 

1116 

1117 

1118@dataclass 

1119class VectorizationProgress: 

1120 """Current progress of vectorization.""" 

1121 total_records: int 

1122 processed_records: int 

1123 failed_records: int 

1124 queued_records: int 

1125 checkpoint: str | None = None