Coverage for src/dataknobs_data/vector/sync.py: 18%

276 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-31 07:20 -0600

1"""Synchronization tools for keeping vectors up to date with text changes.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import hashlib 

7import logging 

8from collections import defaultdict 

9from dataclasses import dataclass, field 

10from datetime import datetime 

11from typing import TYPE_CHECKING, Any 

12 

13import numpy as np 

14 

15from ..fields import VectorField 

16from ..records import Record 

17 

18if TYPE_CHECKING: 

19 from collections.abc import Callable, Coroutine 

20 

21 from ..database import Database 

22 

23logger = logging.getLogger(__name__) 

24 

25 

26@dataclass 

27class SyncConfig: 

28 """Configuration for vector synchronization.""" 

29 

30 auto_embed_on_create: bool = True 

31 auto_update_on_text_change: bool = True 

32 batch_size: int = 100 

33 track_model_version: bool = True 

34 embedding_timeout: float = 30.0 

35 max_retries: int = 3 

36 retry_delay: float = 1.0 

37 

38 def validate(self) -> None: 

39 """Validate configuration parameters.""" 

40 if self.batch_size <= 0: 

41 raise ValueError(f"Batch size must be positive, got {self.batch_size}") 

42 if self.embedding_timeout <= 0: 

43 raise ValueError(f"Embedding timeout must be positive, got {self.embedding_timeout}") 

44 if self.max_retries < 0: 

45 raise ValueError(f"Max retries cannot be negative, got {self.max_retries}") 

46 

47 

48@dataclass 

49class SyncStatus: 

50 """Status of a synchronization operation.""" 

51 

52 total_records: int = 0 

53 processed_records: int = 0 

54 updated_records: int = 0 

55 failed_records: int = 0 

56 skipped_records: int = 0 

57 errors: list[dict[str, Any]] = field(default_factory=list) 

58 start_time: datetime | None = None 

59 end_time: datetime | None = None 

60 

61 @property 

62 def success_rate(self) -> float: 

63 """Calculate the success rate of the sync operation.""" 

64 if self.processed_records == 0: 

65 return 0.0 

66 return (self.processed_records - self.failed_records) / self.processed_records 

67 

68 @property 

69 def duration(self) -> float | None: 

70 """Calculate the duration of the sync operation in seconds.""" 

71 if self.start_time and self.end_time: 

72 return (self.end_time - self.start_time).total_seconds() 

73 return None 

74 

75 def to_dict(self) -> dict[str, Any]: 

76 """Convert to dictionary representation.""" 

77 return { 

78 "total_records": self.total_records, 

79 "processed_records": self.processed_records, 

80 "updated_records": self.updated_records, 

81 "failed_records": self.failed_records, 

82 "skipped_records": self.skipped_records, 

83 "success_rate": self.success_rate, 

84 "duration": self.duration, 

85 "errors": self.errors, 

86 "start_time": self.start_time.isoformat() if self.start_time else None, 

87 "end_time": self.end_time.isoformat() if self.end_time else None, 

88 } 

89 

90 

91class VectorTextSynchronizer: 

92 """Synchronizes vector embeddings with their source text fields.""" 

93 

94 def __init__( 

95 self, 

96 database: Database, 

97 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]], 

98 text_fields: list[str] | str | None = None, 

99 vector_field: str = "embedding", 

100 field_separator: str = " ", 

101 auto_sync: bool = True, 

102 batch_size: int = 100, 

103 model_name: str | None = None, 

104 model_version: str | None = None, 

105 config: SyncConfig | None = None, 

106 ): 

107 """Initialize the synchronizer with simplified API. 

108  

109 Args: 

110 database: The database to synchronize 

111 embedding_fn: Function to generate embeddings from text 

112 text_fields: Fields to concatenate for embedding (if None, uses all text fields) 

113 vector_field: Name of the vector field to store embeddings 

114 field_separator: Separator for concatenating text fields 

115 auto_sync: Whether to auto-sync on create/update 

116 batch_size: Batch size for bulk operations 

117 model_name: Name of the embedding model 

118 model_version: Version of the embedding model 

119 config: Advanced configuration object (overrides other params) 

120 """ 

121 self.database = database 

122 self.embedding_fn = embedding_fn 

123 self.embedding_function = embedding_fn # Alias for compatibility 

124 

125 # Handle text_fields 

126 if isinstance(text_fields, str): 

127 text_fields = [text_fields] 

128 self.text_fields = text_fields or [] 

129 

130 self.vector_field = vector_field 

131 self.field_separator = field_separator 

132 self.auto_sync = auto_sync 

133 self.batch_size = batch_size 

134 self.model_name = model_name 

135 self.model_version = model_version 

136 

137 # Use config if provided, otherwise create from params 

138 if config: 

139 self.config = config 

140 else: 

141 self.config = SyncConfig( 

142 auto_embed_on_create=auto_sync, 

143 auto_update_on_text_change=auto_sync, 

144 batch_size=batch_size, 

145 ) 

146 self.config.validate() 

147 

148 # Track vector fields and their source fields 

149 self._vector_fields: dict[str, dict[str, Any]] = {} 

150 self._source_fields: dict[str, list[str]] = defaultdict(list) 

151 self._initialize_field_mappings() 

152 

153 def _initialize_field_mappings(self) -> None: 

154 """Initialize mappings between vector fields and source fields.""" 

155 # Use schema if available 

156 for field_name, field_schema in self.database.schema.fields.items(): 

157 if field_schema.is_vector_field(): 

158 self._vector_fields[field_name] = { 

159 "dimensions": field_schema.get_dimensions() or 384, 

160 "source_field": field_schema.get_source_field(), 

161 } 

162 source = field_schema.get_source_field() 

163 if source: 

164 self._source_fields[source].append(field_name) 

165 

166 def _compute_content_hash(self, content: str) -> str: 

167 """Compute a hash of the content for change detection.""" 

168 return hashlib.md5(content.encode()).hexdigest() 

169 

170 def _has_current_vector(self, record: Record, vector_field: str) -> bool: 

171 """Check if a record has a current vector for the given field. 

172  

173 Args: 

174 record: The record to check 

175 vector_field: Name of the vector field 

176  

177 Returns: 

178 True if the vector is current, False otherwise 

179 """ 

180 # Check if field exists 

181 field_obj = record.fields.get(vector_field) 

182 if not field_obj: 

183 return False 

184 

185 # Get the vector value 

186 vector_value = None 

187 if isinstance(field_obj, VectorField): 

188 vector_value = field_obj.value 

189 if vector_value is None: 

190 return False 

191 

192 # For VectorField, check model version if tracking is enabled 

193 if self.config.track_model_version and self.model_version: 

194 stored_version = field_obj.model_version 

195 if stored_version != self.model_version: 

196 return False 

197 else: 

198 # Plain value (list or array) 

199 vector_value = field_obj.value 

200 if vector_value is None: 

201 return False 

202 if not isinstance(vector_value, (list, np.ndarray)): 

203 return False 

204 

205 # For plain values, check metadata and content hash separately 

206 if self.config.track_model_version and self.model_version: 

207 metadata_field = f"{vector_field}_metadata" 

208 metadata = record.get_value(metadata_field) 

209 if not metadata or not isinstance(metadata, dict): 

210 return False 

211 stored_version = metadata.get("model_version") 

212 if stored_version != self.model_version: 

213 return False 

214 

215 # Check content hash if source field exists 

216 field_info = self._vector_fields.get(vector_field) 

217 if field_info and field_info.get("source_field"): 

218 source_content = record.get_value(field_info["source_field"], "") 

219 if source_content: 

220 # For VectorField objects, we don't check content hash 

221 # as they're considered immutable once created 

222 if isinstance(field_obj, VectorField): 

223 # VectorField with matching version is considered current 

224 return True 

225 

226 # For plain values, check the content hash field 

227 hash_field = f"{vector_field}_content_hash" 

228 stored_hash = record.get_value(hash_field) 

229 current_hash = self._compute_content_hash(str(source_content)) 

230 if stored_hash != current_hash: 

231 return False 

232 

233 return True 

234 

235 def _needs_update(self, record: Record, vector_field: str) -> bool: 

236 """Check if a vector field needs to be updated. 

237  

238 Args: 

239 record: The record to check 

240 vector_field: Name of the vector field 

241  

242 Returns: 

243 True if the vector needs updating, False otherwise 

244 """ 

245 return not self._has_current_vector(record, vector_field) 

246 

247 async def _embed_text(self, text: str) -> np.ndarray | None: 

248 """Generate embedding for text with error handling. 

249  

250 Args: 

251 text: Text to embed 

252  

253 Returns: 

254 Embedding vector or None if failed 

255 """ 

256 if not text: 

257 return None 

258 

259 for attempt in range(self.config.max_retries): 

260 try: 

261 if asyncio.iscoroutinefunction(self.embedding_fn): 

262 result = await asyncio.wait_for( 

263 self.embedding_fn(text), 

264 timeout=self.config.embedding_timeout 

265 ) 

266 else: 

267 result = await asyncio.to_thread(self.embedding_fn, text) 

268 

269 if isinstance(result, np.ndarray): 

270 return result 

271 elif isinstance(result, list): 

272 return np.array(result) 

273 else: 

274 logger.error(f"Embedding function returned unexpected type: {type(result)}") 

275 return None 

276 

277 except asyncio.TimeoutError: 

278 logger.warning(f"Embedding timeout on attempt {attempt + 1}") 

279 if attempt < self.config.max_retries - 1: 

280 await asyncio.sleep(self.config.retry_delay) 

281 except Exception as e: 

282 logger.error(f"Embedding error on attempt {attempt + 1}: {e}") 

283 if attempt < self.config.max_retries - 1: 

284 await asyncio.sleep(self.config.retry_delay) 

285 

286 return None 

287 

288 async def sync_record( 

289 self, 

290 record_or_id: Record | str, 

291 force: bool = False 

292 ) -> tuple[bool, list[str]]: 

293 """Synchronize vectors for a single record. 

294  

295 Args: 

296 record_or_id: The record or record ID to synchronize 

297 force: Force update even if vectors appear current 

298  

299 Returns: 

300 Tuple of (success, list of updated fields) 

301 """ 

302 # Get record if ID provided 

303 if isinstance(record_or_id, str): 

304 record = await self.database.read(record_or_id) 

305 if not record: 

306 return False, [] 

307 record_id = record_or_id 

308 else: 

309 record = record_or_id 

310 record_id = record.id 

311 

312 updated_fields = [] 

313 failed_fields = [] 

314 

315 # If text_fields are specified, use them for the default vector field 

316 if self.text_fields: 

317 text_parts = [] 

318 for field in self.text_fields: 

319 value = record.get_value(field) 

320 if value: 

321 text_parts.append(str(value)) 

322 

323 if text_parts: 

324 text = self.field_separator.join(text_parts) 

325 embedding = await self._embed_text(text) 

326 if embedding is not None: 

327 from ..fields import VectorField 

328 # Compute content hash for change tracking 

329 content_hash = self._compute_content_hash(text) 

330 vector_field_obj = VectorField( 

331 value=embedding, 

332 name=self.vector_field, 

333 source_field=self.text_fields[0] if len(self.text_fields) == 1 else None, 

334 model_name=self.model_name, 

335 model_version=self.model_version, 

336 metadata={"content_hash": content_hash} 

337 ) 

338 record.fields[self.vector_field] = vector_field_obj 

339 updated_fields.append(self.vector_field) 

340 else: 

341 # Embedding generation failed 

342 failed_fields.append(self.vector_field) 

343 

344 # Also process vector fields defined in schema with source fields 

345 for vector_field_name, field_info in self._vector_fields.items(): 

346 source_field = field_info.get("source_field") 

347 if source_field and (force or self._needs_update(record, vector_field_name)): 

348 source_value = record.get_value(source_field) 

349 if source_value: 

350 source_text = str(source_value) 

351 embedding = await self._embed_text(source_text) 

352 if embedding is not None: 

353 from ..fields import VectorField 

354 # Compute content hash for change tracking 

355 content_hash = self._compute_content_hash(source_text) 

356 vector_field_obj = VectorField( 

357 value=embedding, 

358 name=vector_field_name, 

359 source_field=source_field, 

360 model_name=self.model_name, 

361 model_version=self.model_version, 

362 metadata={"content_hash": content_hash} 

363 ) 

364 record.fields[vector_field_name] = vector_field_obj 

365 updated_fields.append(vector_field_name) 

366 else: 

367 # Embedding generation failed 

368 failed_fields.append(vector_field_name) 

369 

370 # Save to database if any fields were updated 

371 if updated_fields: 

372 # Use storage_id if available, otherwise fall back to record.id 

373 update_id = record.storage_id if record.has_storage_id() else record_id 

374 await self.database.update(update_id, record) 

375 

376 # Return success=False if there were failures and no successes 

377 success = len(failed_fields) == 0 or len(updated_fields) > 0 

378 return success, updated_fields 

379 

380 async def sync_all( 

381 self, 

382 batch_size: int | None = None, 

383 force: bool = False, 

384 progress_callback: Callable[[int, int], None] | None = None, 

385 ) -> dict[str, Any]: 

386 """Synchronize all records in the database. 

387  

388 Args: 

389 batch_size: Batch size for processing (uses self.batch_size if None) 

390 force: Force update even if vectors appear current 

391 progress_callback: Callback for progress updates (done, total) 

392  

393 Returns: 

394 Dictionary with sync results 

395 """ 

396 from ..query import Query 

397 

398 batch_size = batch_size or self.batch_size 

399 

400 # Get all records 

401 all_records = await self.database.search(Query()) 

402 total = len(all_records) 

403 

404 processed = 0 

405 updated = 0 

406 failed = 0 

407 

408 # Process in batches 

409 for i in range(0, total, batch_size): 

410 batch = all_records[i:i + batch_size] 

411 

412 for record in batch: 

413 success, fields = await self.sync_record(record, force=force) 

414 

415 processed += 1 

416 if success and fields: 

417 updated += 1 

418 elif not success: 

419 failed += 1 

420 

421 if progress_callback: 

422 progress_callback(processed, total) 

423 

424 return { 

425 "processed": processed, 

426 "updated": updated, 

427 "failed": failed, 

428 "total": total, 

429 } 

430 

431 async def bulk_sync( 

432 self, 

433 records: list[Record] | None = None, 

434 force: bool = False, 

435 progress_callback: Callable[[SyncStatus], None] | None = None, 

436 ) -> SyncStatus: 

437 """Synchronize vectors for multiple records in batches. 

438  

439 Args: 

440 records: Records to sync (None for all records in database) 

441 force: Force update even if vectors appear current 

442 progress_callback: Callback for progress updates 

443  

444 Returns: 

445 Synchronization status 

446 """ 

447 status = SyncStatus(start_time=datetime.utcnow()) 

448 

449 try: 

450 # Get records if not provided 

451 if records is None: 

452 records = await self.database.all() 

453 

454 status.total_records = len(records) 

455 

456 # Process in batches 

457 for i in range(0, len(records), self.config.batch_size): 

458 batch = records[i:i + self.config.batch_size] 

459 

460 for record in batch: 

461 try: 

462 success, updated_fields = await self.sync_record(record, force) 

463 status.processed_records += 1 

464 

465 if updated_fields: 

466 # sync_record already updates the database 

467 status.updated_records += 1 

468 elif success: 

469 status.skipped_records += 1 

470 else: 

471 status.failed_records += 1 

472 

473 except Exception as e: 

474 status.failed_records += 1 

475 status.errors.append({ 

476 "record_id": record.id, 

477 "error": str(e), 

478 }) 

479 logger.error(f"Failed to sync record {record.id}: {e}") 

480 

481 # Call progress callback 

482 if progress_callback: 

483 progress_callback(status) 

484 

485 finally: 

486 status.end_time = datetime.utcnow() 

487 

488 logger.info( 

489 f"Sync completed: {status.updated_records} updated, " 

490 f"{status.skipped_records} skipped, {status.failed_records} failed" 

491 ) 

492 

493 return status 

494 

495 async def sync_on_update( 

496 self, 

497 record_id: str, 

498 old_data: dict[str, Any], 

499 new_data: dict[str, Any], 

500 ) -> bool: 

501 """Handle record updates and sync vectors if needed. 

502  

503 Args: 

504 record_id: ID of the updated record 

505 old_data: Previous data 

506 new_data: New data 

507  

508 Returns: 

509 True if sync was performed, False otherwise 

510 """ 

511 if not self.config.auto_update_on_text_change: 

512 return False 

513 

514 # Check if any source fields changed 

515 fields_to_update = set() 

516 for source_field, vector_fields in self._source_fields.items(): 

517 old_value = old_data.get(source_field) 

518 new_value = new_data.get(source_field) 

519 

520 if old_value != new_value: 

521 fields_to_update.update(vector_fields) 

522 

523 if not fields_to_update: 

524 return False 

525 

526 # Create record and sync 

527 record = Record(id=record_id, data=new_data) 

528 success, updated_fields = await self.sync_record(record, force=True) 

529 

530 if updated_fields: 

531 # Update only the vector fields 

532 update_data = { 

533 field: record.get_value(field) 

534 for field in updated_fields 

535 if record.get_value(field) is not None 

536 } 

537 

538 # Include metadata fields 

539 for field in updated_fields: 

540 metadata_field = f"{field}_metadata" 

541 metadata_value = record.get_value(metadata_field) 

542 if metadata_value is not None: 

543 update_data[metadata_field] = metadata_value 

544 

545 hash_field = f"{field}_content_hash" 

546 hash_value = record.get_value(hash_field) 

547 if hash_value is not None: 

548 update_data[hash_field] = hash_value 

549 

550 # Get the existing record and update it 

551 existing_record = await self.database.read(record_id) 

552 if existing_record: 

553 for key, value in update_data.items(): 

554 existing_record.set_value(key, value) 

555 await self.database.update(record_id, existing_record) 

556 return True 

557 

558 return False 

559 

560 async def sync_on_create(self, record: Record) -> bool: 

561 """Handle record creation and sync vectors if needed. 

562  

563 Args: 

564 record: The newly created record 

565  

566 Returns: 

567 True if sync was performed, False otherwise 

568 """ 

569 if not self.config.auto_embed_on_create: 

570 return False 

571 

572 success, updated_fields = await self.sync_record(record) 

573 

574 if updated_fields: 

575 # Update the record with vector data 

576 await self.database.update(record.id, record) 

577 return True 

578 

579 return False 

580 

581 @classmethod 

582 def from_config( 

583 cls, 

584 database: Database, 

585 embedding_fn: Callable[[str], np.ndarray] | Callable[[str], Coroutine[Any, Any, np.ndarray]], 

586 config: SyncConfig, 

587 text_fields: list[str] | None = None, 

588 vector_field: str = "embedding", 

589 model_name: str | None = None, 

590 model_version: str | None = None, 

591 ) -> VectorTextSynchronizer: 

592 """Create synchronizer from a config object for advanced use cases. 

593  

594 Args: 

595 database: The database to synchronize 

596 embedding_fn: Function to generate embeddings from text 

597 config: Synchronization configuration 

598 text_fields: Text field names (optional) 

599 vector_field: Name of the vector field 

600 model_name: Name of the embedding model 

601 model_version: Version of the embedding model 

602  

603 Returns: 

604 Configured VectorTextSynchronizer instance 

605 """ 

606 return cls( 

607 database=database, 

608 embedding_fn=embedding_fn, 

609 text_fields=text_fields, 

610 vector_field=vector_field, 

611 auto_sync=config.auto_embed_on_create, 

612 batch_size=config.batch_size, 

613 model_name=model_name, 

614 model_version=model_version, 

615 config=config, 

616 )