Coverage for src/dataknobs_data/vector/tracker.py: 18%

259 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:23 -0700

1"""Change tracking for automatic vector updates.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import logging 

7from collections import defaultdict, deque 

8from dataclasses import dataclass, field 

9from datetime import datetime 

10from typing import TYPE_CHECKING, Any 

11 

12if TYPE_CHECKING: 

13 from collections.abc import Callable, Coroutine 

14 

15 from ..database import Database 

16 from ..records import Record 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21@dataclass 

22class ChangeEvent: 

23 """Represents a change event for a record field.""" 

24 

25 record_id: str 

26 field_name: str 

27 old_value: Any 

28 new_value: Any 

29 timestamp: datetime = field(default_factory=datetime.utcnow) 

30 event_type: str = "update" # create, update, delete 

31 

32 def __repr__(self) -> str: 

33 """String representation of the event.""" 

34 return ( 

35 f"ChangeEvent(record={self.record_id}, field={self.field_name}, " 

36 f"type={self.event_type}, time={self.timestamp.isoformat()})" 

37 ) 

38 

39 

40@dataclass 

41class UpdateTask: 

42 """Represents a pending vector update task.""" 

43 

44 record_id: str 

45 vector_fields: set[str] 

46 source_fields: dict[str, Any] # source field -> new value 

47 priority: int = 0 

48 created_at: datetime = field(default_factory=datetime.utcnow) 

49 attempts: int = 0 

50 last_error: str | None = None 

51 

52 def __lt__(self, other: UpdateTask) -> bool: 

53 """Enable priority queue sorting (higher priority first).""" 

54 if self.priority != other.priority: 

55 return self.priority > other.priority # Higher priority comes first 

56 return self.created_at > other.created_at # Newer tasks come first for same priority 

57 

58 

59class ChangeTracker: 

60 """Tracks field changes and manages automatic vector updates.""" 

61 

62 def __init__( 

63 self, 

64 database: Database, 

65 tracked_fields: list[str] | None = None, 

66 vector_field: str = "embedding", 

67 max_queue_size: int = 10000, 

68 batch_size: int = 100, 

69 process_interval: float = 5.0, 

70 ): 

71 """Initialize the change tracker with simplified API. 

72  

73 Args: 

74 database: The database to track changes for 

75 tracked_fields: Fields to track for changes (if None, tracks all) 

76 vector_field: Vector field that depends on tracked fields 

77 max_queue_size: Maximum number of pending updates 

78 batch_size: Number of updates to process in a batch 

79 process_interval: Seconds between batch processing 

80 """ 

81 self.database = database 

82 self.tracked_fields = tracked_fields or [] 

83 self.vector_field = vector_field 

84 self.max_queue_size = max_queue_size 

85 self.batch_size = batch_size 

86 self.process_interval = process_interval 

87 

88 # Field dependency mapping: source_field -> [vector_fields] 

89 self._dependencies: dict[str, list[str]] = defaultdict(list) 

90 self._vector_fields: dict[str, dict[str, Any]] = {} 

91 

92 # Set up dependencies for tracked fields 

93 for field_name in self.tracked_fields: 

94 self._dependencies[field_name].append(self.vector_field) 

95 self._vector_fields[self.vector_field] = {"source_fields": self.tracked_fields} 

96 

97 # Update queue and history 

98 self._update_queue: deque[UpdateTask] = deque(maxlen=max_queue_size) 

99 self._pending_updates: dict[str, UpdateTask] = {} # record_id -> task 

100 self._change_history: deque[ChangeEvent] = deque(maxlen=1000) 

101 

102 # Processing state 

103 self._processing_task: asyncio.Task | None = None 

104 self._shutdown_event = asyncio.Event() 

105 self._update_callbacks: list[Callable] = [] 

106 

107 self._initialize_dependencies() 

108 

109 def _initialize_dependencies(self) -> None: 

110 """Initialize field dependency mappings.""" 

111 # Use schema if available 

112 if hasattr(self.database, 'schema') and self.database.schema: 

113 for field_name, field_schema in self.database.schema.fields.items(): 

114 if field_schema.is_vector_field(): 

115 self._vector_fields[field_name] = field_schema.metadata 

116 source_field = field_schema.get_source_field() 

117 if source_field: 

118 self._dependencies[source_field].append(field_name) 

119 

120 def add_update_callback( 

121 self, 

122 callback: Callable[[UpdateTask], None] | Callable[[UpdateTask], Coroutine[Any, Any, None]] 

123 ) -> None: 

124 """Add a callback to be called when updates are processed. 

125  

126 Args: 

127 callback: Function to call with update tasks 

128 """ 

129 self._update_callbacks.append(callback) 

130 

131 def track_change( 

132 self, 

133 record_id: str, 

134 field_name: str, 

135 old_value: Any, 

136 new_value: Any, 

137 event_type: str = "update", 

138 ) -> bool: 

139 """Track a field change event. 

140  

141 Args: 

142 record_id: ID of the changed record 

143 field_name: Name of the changed field 

144 old_value: Previous value 

145 new_value: New value 

146 event_type: Type of event (create, update, delete) 

147  

148 Returns: 

149 True if change affects vectors and was queued 

150 """ 

151 # Record the change event 

152 event = ChangeEvent( 

153 record_id=record_id, 

154 field_name=field_name, 

155 old_value=old_value, 

156 new_value=new_value, 

157 event_type=event_type, 

158 ) 

159 self._change_history.append(event) 

160 

161 # Check if this field affects any vectors 

162 affected_vectors = self._dependencies.get(field_name, []) 

163 if not affected_vectors: 

164 return False 

165 

166 # Create or update task 

167 if record_id in self._pending_updates: 

168 task = self._pending_updates[record_id] 

169 task.vector_fields.update(affected_vectors) 

170 task.source_fields[field_name] = new_value 

171 else: 

172 task = UpdateTask( 

173 record_id=record_id, 

174 vector_fields=set(affected_vectors), 

175 source_fields={field_name: new_value}, 

176 ) 

177 self._pending_updates[record_id] = task 

178 

179 # Add to queue if not full 

180 if len(self._update_queue) < self.max_queue_size: 

181 self._update_queue.append(task) 

182 else: 

183 logger.warning(f"Update queue full, dropping task for record {record_id}") 

184 del self._pending_updates[record_id] 

185 return False 

186 

187 return True 

188 

189 async def on_create(self, record: Record) -> None: 

190 """Handle record creation. 

191  

192 Args: 

193 record: The created record 

194 """ 

195 # Skip if record has no ID 

196 if record.id is None: 

197 return 

198 

199 for field_name in record.fields.keys(): 

200 value = record.get_value(field_name) 

201 if field_name in self._dependencies: 

202 self.track_change( 

203 record_id=record.id, 

204 field_name=field_name, 

205 old_value=None, 

206 new_value=value, 

207 event_type="create", 

208 ) 

209 

210 async def on_update( 

211 self, 

212 record_id: str, 

213 old_data: dict[str, Any], 

214 new_data: dict[str, Any], 

215 ) -> None: 

216 """Handle record update. 

217  

218 Args: 

219 record_id: ID of the updated record 

220 old_data: Previous data 

221 new_data: New data 

222 """ 

223 for field_name in self._dependencies: 

224 old_value = old_data.get(field_name) 

225 new_value = new_data.get(field_name) 

226 

227 if old_value != new_value: 

228 self.track_change( 

229 record_id=record_id, 

230 field_name=field_name, 

231 old_value=old_value, 

232 new_value=new_value, 

233 event_type="update", 

234 ) 

235 

236 async def on_delete(self, record_id: str) -> None: 

237 """Handle record deletion. 

238  

239 Args: 

240 record_id: ID of the deleted record 

241 """ 

242 # Remove from pending updates 

243 if record_id in self._pending_updates: 

244 task = self._pending_updates[record_id] 

245 if task in self._update_queue: 

246 self._update_queue.remove(task) 

247 del self._pending_updates[record_id] 

248 

249 def get_pending_updates(self) -> list[UpdateTask]: 

250 """Get list of pending update tasks. 

251  

252 Returns: 

253 List of pending tasks 

254 """ 

255 return list(self._update_queue) 

256 

257 async def start_processing(self) -> None: 

258 """Start background processing of changes.""" 

259 if self._processing_task and not self._processing_task.done(): 

260 return # Already running 

261 

262 # Initialize content hashes for existing vector fields if we have tracked fields 

263 if self.tracked_fields: 

264 await self._initialize_content_hashes() 

265 

266 self._shutdown_event.clear() 

267 self._processing_task = asyncio.create_task(self._process_loop()) 

268 

269 async def start_tracking(self, tracked_fields: list[str] | None = None, vector_field: str | None = None) -> None: 

270 """Legacy method for compatibility - redirects to start_processing.""" 

271 if tracked_fields: 

272 self.tracked_fields = tracked_fields 

273 if vector_field: 

274 self.vector_field = vector_field 

275 

276 # Update dependencies 

277 self._dependencies.clear() 

278 for field_name in self.tracked_fields: 

279 self._dependencies[field_name].append(self.vector_field) 

280 

281 # Initialize content hashes for existing vector fields that don't have them 

282 await self._initialize_content_hashes() 

283 

284 await self.start_processing() 

285 

286 async def get_outdated_records(self) -> list[Record]: 

287 """Get records with outdated vector fields. 

288  

289 Returns: 

290 List of records that need vector updates 

291 """ 

292 import hashlib 

293 

294 from ..query import Query 

295 

296 # Get all records 

297 all_records = await self.database.search(Query()) 

298 outdated = [] 

299 

300 for record in all_records: 

301 # Check if vector field exists 

302 if self.vector_field not in record.fields: 

303 outdated.append(record) 

304 continue 

305 

306 # Check if any tracked field is newer than vector 

307 # by comparing content hashes 

308 vector_field = record.fields.get(self.vector_field) 

309 if vector_field and hasattr(vector_field, 'metadata'): 

310 stored_hash = vector_field.metadata.get('content_hash') 

311 

312 # If no content hash is stored, auto-generate it and consider record up-to-date 

313 if stored_hash is None: 

314 # Calculate and store content hash 

315 content_parts = [] 

316 for field_name in self.tracked_fields: 

317 field_value = record.get_value(field_name) 

318 if field_value: 

319 content_parts.append(str(field_value)) 

320 

321 if content_parts: 

322 current_content = " ".join(content_parts) 

323 content_hash = hashlib.md5(current_content.encode()).hexdigest() 

324 

325 # Update the vector field metadata 

326 vector_field.metadata['content_hash'] = content_hash 

327 

328 # Update the record in the database 

329 try: 

330 await self.database.update(record.id, record) 

331 logger.debug(f"Auto-initialized content hash for record {record.id}") 

332 except Exception as e: 

333 logger.warning(f"Failed to auto-initialize content hash for record {record.id}: {e}") 

334 # If we can't update, consider it outdated for safety 

335 outdated.append(record) 

336 continue 

337 

338 # Compute current content hash from tracked fields 

339 content_parts = [] 

340 for field_name in self.tracked_fields: 

341 field_value = record.get_value(field_name) 

342 if field_value: 

343 content_parts.append(str(field_value)) 

344 

345 if content_parts: 

346 current_content = " ".join(content_parts) 

347 current_hash = hashlib.md5(current_content.encode()).hexdigest() 

348 

349 # If hashes don't match, the record is outdated 

350 if stored_hash != current_hash: 

351 outdated.append(record) 

352 continue 

353 

354 return outdated 

355 

356 async def mark_updated(self, record_id: str) -> None: 

357 """Mark a record as having updated vectors. 

358  

359 Args: 

360 record_id: ID of the updated record 

361 """ 

362 # Remove from pending updates if present 

363 if record_id in self._pending_updates: 

364 task = self._pending_updates[record_id] 

365 if task in self._update_queue: 

366 self._update_queue.remove(task) 

367 del self._pending_updates[record_id] 

368 

369 def get_change_history( 

370 self, 

371 record_id: str | None = None, 

372 field_name: str | None = None, 

373 limit: int = 100, 

374 ) -> list[ChangeEvent]: 

375 """Get change history with optional filters. 

376  

377 Args: 

378 record_id: Filter by record ID 

379 field_name: Filter by field name 

380 limit: Maximum events to return 

381  

382 Returns: 

383 List of change events 

384 """ 

385 events = list(self._change_history) 

386 

387 if record_id: 

388 events = [e for e in events if e.record_id == record_id] 

389 

390 if field_name: 

391 events = [e for e in events if e.field_name == field_name] 

392 

393 return events[-limit:] 

394 

395 async def process_batch(self) -> int: 

396 """Process a batch of pending updates. 

397  

398 Returns: 

399 Number of tasks processed 

400 """ 

401 processed = 0 

402 batch = [] 

403 

404 # Get batch of tasks 

405 while self._update_queue and len(batch) < self.batch_size: 

406 task = self._update_queue.popleft() 

407 if task.record_id in self._pending_updates: 

408 batch.append(task) 

409 del self._pending_updates[task.record_id] 

410 

411 if not batch: 

412 return 0 

413 

414 # Process tasks 

415 for task in batch: 

416 try: 

417 # Call update callbacks 

418 for callback in self._update_callbacks: 

419 if asyncio.iscoroutinefunction(callback): 

420 await callback(task) 

421 else: 

422 await asyncio.to_thread(callback, task) 

423 

424 processed += 1 

425 

426 except Exception as e: 

427 logger.error(f"Failed to process update for record {task.record_id}: {e}") 

428 task.attempts += 1 

429 task.last_error = str(e) 

430 

431 # Retry if under max attempts (3) 

432 if task.attempts < 3: 

433 task.priority += 1 # Increase priority for retries 

434 if len(self._update_queue) < self.max_queue_size: 

435 self._update_queue.append(task) 

436 self._pending_updates[task.record_id] = task 

437 

438 logger.info(f"Processed {processed} vector update tasks") 

439 return processed 

440 

441 async def _process_loop(self) -> None: 

442 """Background processing loop for updates.""" 

443 logger.info("Started change tracker processing loop") 

444 

445 while not self._shutdown_event.is_set(): 

446 try: 

447 # Process batch 

448 await self.process_batch() 

449 

450 # Wait for interval or shutdown 

451 try: 

452 await asyncio.wait_for( 

453 self._shutdown_event.wait(), 

454 timeout=self.process_interval 

455 ) 

456 except asyncio.TimeoutError: 

457 continue 

458 

459 except Exception as e: 

460 logger.error(f"Error in processing loop: {e}") 

461 await asyncio.sleep(1) 

462 

463 logger.info("Stopped change tracker processing loop") 

464 

465 async def stop_processing(self, timeout: float = 10.0) -> None: 

466 """Stop background processing. 

467  

468 Args: 

469 timeout: Maximum time to wait for graceful shutdown 

470 """ 

471 if not self._processing_task: 

472 return 

473 

474 self._shutdown_event.set() 

475 

476 try: 

477 await asyncio.wait_for(self._processing_task, timeout=timeout) 

478 except asyncio.TimeoutError: 

479 logger.warning("Processing task did not stop gracefully, cancelling") 

480 self._processing_task.cancel() 

481 try: 

482 await self._processing_task 

483 except asyncio.CancelledError: 

484 pass 

485 

486 self._processing_task = None 

487 

488 def get_stats(self) -> dict[str, Any]: 

489 """Get tracker statistics. 

490  

491 Returns: 

492 Dictionary of statistics 

493 """ 

494 return { 

495 "pending_updates": len(self._pending_updates), 

496 "queue_size": len(self._update_queue), 

497 "max_queue_size": self.max_queue_size, 

498 "history_size": len(self._change_history), 

499 "dependencies": { 

500 field: len(vectors) 

501 for field, vectors in self._dependencies.items() 

502 }, 

503 "is_processing": bool( 

504 self._processing_task and not self._processing_task.done() 

505 ), 

506 } 

507 

508 async def flush(self) -> int: 

509 """Process all pending updates immediately. 

510  

511 Returns: 

512 Number of tasks processed 

513 """ 

514 total_processed = 0 

515 

516 while self._update_queue: 

517 processed = await self.process_batch() 

518 total_processed += processed 

519 

520 if processed == 0: 

521 break 

522 

523 return total_processed 

524 

525 async def _initialize_content_hashes(self) -> None: 

526 """Initialize content hashes for existing vector fields that don't have them.""" 

527 if not self.tracked_fields: 

528 return 

529 

530 import hashlib 

531 

532 from ..query import Query 

533 

534 # Get all records 

535 all_records = await self.database.search(Query()) 

536 

537 for record in all_records: 

538 # Check if record has vector field but no content hash 

539 vector_field = record.fields.get(self.vector_field) 

540 if vector_field and hasattr(vector_field, 'metadata'): 

541 stored_hash = vector_field.metadata.get('content_hash') 

542 

543 if stored_hash is None: 

544 # Calculate and store content hash 

545 content_parts = [] 

546 for field_name in self.tracked_fields: 

547 field_value = record.get_value(field_name) 

548 if field_value: 

549 content_parts.append(str(field_value)) 

550 

551 if content_parts: 

552 current_content = " ".join(content_parts) 

553 content_hash = hashlib.md5(current_content.encode()).hexdigest() 

554 

555 # Update the vector field metadata 

556 vector_field.metadata['content_hash'] = content_hash 

557 

558 # Update the record in the database 

559 try: 

560 await self.database.update(record.id, record) 

561 logger.debug(f"Initialized content hash for record {record.id}") 

562 except Exception as e: 

563 logger.warning(f"Failed to initialize content hash for record {record.id}: {e}")