Coverage for src/dataknobs_data/vector/tracker.py: 18%
259 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
1"""Change tracking for automatic vector updates."""
3from __future__ import annotations
5import asyncio
6import logging
7from collections import defaultdict, deque
8from dataclasses import dataclass, field
9from datetime import datetime
10from typing import TYPE_CHECKING, Any
12if TYPE_CHECKING:
13 from collections.abc import Callable, Coroutine
15 from ..database import Database
16 from ..records import Record
18logger = logging.getLogger(__name__)
21@dataclass
22class ChangeEvent:
23 """Represents a change event for a record field."""
25 record_id: str
26 field_name: str
27 old_value: Any
28 new_value: Any
29 timestamp: datetime = field(default_factory=datetime.utcnow)
30 event_type: str = "update" # create, update, delete
32 def __repr__(self) -> str:
33 """String representation of the event."""
34 return (
35 f"ChangeEvent(record={self.record_id}, field={self.field_name}, "
36 f"type={self.event_type}, time={self.timestamp.isoformat()})"
37 )
40@dataclass
41class UpdateTask:
42 """Represents a pending vector update task."""
44 record_id: str
45 vector_fields: set[str]
46 source_fields: dict[str, Any] # source field -> new value
47 priority: int = 0
48 created_at: datetime = field(default_factory=datetime.utcnow)
49 attempts: int = 0
50 last_error: str | None = None
52 def __lt__(self, other: UpdateTask) -> bool:
53 """Enable priority queue sorting (higher priority first)."""
54 if self.priority != other.priority:
55 return self.priority > other.priority # Higher priority comes first
56 return self.created_at > other.created_at # Newer tasks come first for same priority
59class ChangeTracker:
60 """Tracks field changes and manages automatic vector updates."""
62 def __init__(
63 self,
64 database: Database,
65 tracked_fields: list[str] | None = None,
66 vector_field: str = "embedding",
67 max_queue_size: int = 10000,
68 batch_size: int = 100,
69 process_interval: float = 5.0,
70 ):
71 """Initialize the change tracker with simplified API.
73 Args:
74 database: The database to track changes for
75 tracked_fields: Fields to track for changes (if None, tracks all)
76 vector_field: Vector field that depends on tracked fields
77 max_queue_size: Maximum number of pending updates
78 batch_size: Number of updates to process in a batch
79 process_interval: Seconds between batch processing
80 """
81 self.database = database
82 self.tracked_fields = tracked_fields or []
83 self.vector_field = vector_field
84 self.max_queue_size = max_queue_size
85 self.batch_size = batch_size
86 self.process_interval = process_interval
88 # Field dependency mapping: source_field -> [vector_fields]
89 self._dependencies: dict[str, list[str]] = defaultdict(list)
90 self._vector_fields: dict[str, dict[str, Any]] = {}
92 # Set up dependencies for tracked fields
93 for field_name in self.tracked_fields:
94 self._dependencies[field_name].append(self.vector_field)
95 self._vector_fields[self.vector_field] = {"source_fields": self.tracked_fields}
97 # Update queue and history
98 self._update_queue: deque[UpdateTask] = deque(maxlen=max_queue_size)
99 self._pending_updates: dict[str, UpdateTask] = {} # record_id -> task
100 self._change_history: deque[ChangeEvent] = deque(maxlen=1000)
102 # Processing state
103 self._processing_task: asyncio.Task | None = None
104 self._shutdown_event = asyncio.Event()
105 self._update_callbacks: list[Callable] = []
107 self._initialize_dependencies()
109 def _initialize_dependencies(self) -> None:
110 """Initialize field dependency mappings."""
111 # Use schema if available
112 if hasattr(self.database, 'schema') and self.database.schema:
113 for field_name, field_schema in self.database.schema.fields.items():
114 if field_schema.is_vector_field():
115 self._vector_fields[field_name] = field_schema.metadata
116 source_field = field_schema.get_source_field()
117 if source_field:
118 self._dependencies[source_field].append(field_name)
120 def add_update_callback(
121 self,
122 callback: Callable[[UpdateTask], None] | Callable[[UpdateTask], Coroutine[Any, Any, None]]
123 ) -> None:
124 """Add a callback to be called when updates are processed.
126 Args:
127 callback: Function to call with update tasks
128 """
129 self._update_callbacks.append(callback)
131 def track_change(
132 self,
133 record_id: str,
134 field_name: str,
135 old_value: Any,
136 new_value: Any,
137 event_type: str = "update",
138 ) -> bool:
139 """Track a field change event.
141 Args:
142 record_id: ID of the changed record
143 field_name: Name of the changed field
144 old_value: Previous value
145 new_value: New value
146 event_type: Type of event (create, update, delete)
148 Returns:
149 True if change affects vectors and was queued
150 """
151 # Record the change event
152 event = ChangeEvent(
153 record_id=record_id,
154 field_name=field_name,
155 old_value=old_value,
156 new_value=new_value,
157 event_type=event_type,
158 )
159 self._change_history.append(event)
161 # Check if this field affects any vectors
162 affected_vectors = self._dependencies.get(field_name, [])
163 if not affected_vectors:
164 return False
166 # Create or update task
167 if record_id in self._pending_updates:
168 task = self._pending_updates[record_id]
169 task.vector_fields.update(affected_vectors)
170 task.source_fields[field_name] = new_value
171 else:
172 task = UpdateTask(
173 record_id=record_id,
174 vector_fields=set(affected_vectors),
175 source_fields={field_name: new_value},
176 )
177 self._pending_updates[record_id] = task
179 # Add to queue if not full
180 if len(self._update_queue) < self.max_queue_size:
181 self._update_queue.append(task)
182 else:
183 logger.warning(f"Update queue full, dropping task for record {record_id}")
184 del self._pending_updates[record_id]
185 return False
187 return True
189 async def on_create(self, record: Record) -> None:
190 """Handle record creation.
192 Args:
193 record: The created record
194 """
195 # Skip if record has no ID
196 if record.id is None:
197 return
199 for field_name in record.fields.keys():
200 value = record.get_value(field_name)
201 if field_name in self._dependencies:
202 self.track_change(
203 record_id=record.id,
204 field_name=field_name,
205 old_value=None,
206 new_value=value,
207 event_type="create",
208 )
210 async def on_update(
211 self,
212 record_id: str,
213 old_data: dict[str, Any],
214 new_data: dict[str, Any],
215 ) -> None:
216 """Handle record update.
218 Args:
219 record_id: ID of the updated record
220 old_data: Previous data
221 new_data: New data
222 """
223 for field_name in self._dependencies:
224 old_value = old_data.get(field_name)
225 new_value = new_data.get(field_name)
227 if old_value != new_value:
228 self.track_change(
229 record_id=record_id,
230 field_name=field_name,
231 old_value=old_value,
232 new_value=new_value,
233 event_type="update",
234 )
236 async def on_delete(self, record_id: str) -> None:
237 """Handle record deletion.
239 Args:
240 record_id: ID of the deleted record
241 """
242 # Remove from pending updates
243 if record_id in self._pending_updates:
244 task = self._pending_updates[record_id]
245 if task in self._update_queue:
246 self._update_queue.remove(task)
247 del self._pending_updates[record_id]
249 def get_pending_updates(self) -> list[UpdateTask]:
250 """Get list of pending update tasks.
252 Returns:
253 List of pending tasks
254 """
255 return list(self._update_queue)
257 async def start_processing(self) -> None:
258 """Start background processing of changes."""
259 if self._processing_task and not self._processing_task.done():
260 return # Already running
262 # Initialize content hashes for existing vector fields if we have tracked fields
263 if self.tracked_fields:
264 await self._initialize_content_hashes()
266 self._shutdown_event.clear()
267 self._processing_task = asyncio.create_task(self._process_loop())
269 async def start_tracking(self, tracked_fields: list[str] | None = None, vector_field: str | None = None) -> None:
270 """Legacy method for compatibility - redirects to start_processing."""
271 if tracked_fields:
272 self.tracked_fields = tracked_fields
273 if vector_field:
274 self.vector_field = vector_field
276 # Update dependencies
277 self._dependencies.clear()
278 for field_name in self.tracked_fields:
279 self._dependencies[field_name].append(self.vector_field)
281 # Initialize content hashes for existing vector fields that don't have them
282 await self._initialize_content_hashes()
284 await self.start_processing()
286 async def get_outdated_records(self) -> list[Record]:
287 """Get records with outdated vector fields.
289 Returns:
290 List of records that need vector updates
291 """
292 import hashlib
294 from ..query import Query
296 # Get all records
297 all_records = await self.database.search(Query())
298 outdated = []
300 for record in all_records:
301 # Check if vector field exists
302 if self.vector_field not in record.fields:
303 outdated.append(record)
304 continue
306 # Check if any tracked field is newer than vector
307 # by comparing content hashes
308 vector_field = record.fields.get(self.vector_field)
309 if vector_field and hasattr(vector_field, 'metadata'):
310 stored_hash = vector_field.metadata.get('content_hash')
312 # If no content hash is stored, auto-generate it and consider record up-to-date
313 if stored_hash is None:
314 # Calculate and store content hash
315 content_parts = []
316 for field_name in self.tracked_fields:
317 field_value = record.get_value(field_name)
318 if field_value:
319 content_parts.append(str(field_value))
321 if content_parts:
322 current_content = " ".join(content_parts)
323 content_hash = hashlib.md5(current_content.encode()).hexdigest()
325 # Update the vector field metadata
326 vector_field.metadata['content_hash'] = content_hash
328 # Update the record in the database
329 try:
330 await self.database.update(record.id, record)
331 logger.debug(f"Auto-initialized content hash for record {record.id}")
332 except Exception as e:
333 logger.warning(f"Failed to auto-initialize content hash for record {record.id}: {e}")
334 # If we can't update, consider it outdated for safety
335 outdated.append(record)
336 continue
338 # Compute current content hash from tracked fields
339 content_parts = []
340 for field_name in self.tracked_fields:
341 field_value = record.get_value(field_name)
342 if field_value:
343 content_parts.append(str(field_value))
345 if content_parts:
346 current_content = " ".join(content_parts)
347 current_hash = hashlib.md5(current_content.encode()).hexdigest()
349 # If hashes don't match, the record is outdated
350 if stored_hash != current_hash:
351 outdated.append(record)
352 continue
354 return outdated
356 async def mark_updated(self, record_id: str) -> None:
357 """Mark a record as having updated vectors.
359 Args:
360 record_id: ID of the updated record
361 """
362 # Remove from pending updates if present
363 if record_id in self._pending_updates:
364 task = self._pending_updates[record_id]
365 if task in self._update_queue:
366 self._update_queue.remove(task)
367 del self._pending_updates[record_id]
369 def get_change_history(
370 self,
371 record_id: str | None = None,
372 field_name: str | None = None,
373 limit: int = 100,
374 ) -> list[ChangeEvent]:
375 """Get change history with optional filters.
377 Args:
378 record_id: Filter by record ID
379 field_name: Filter by field name
380 limit: Maximum events to return
382 Returns:
383 List of change events
384 """
385 events = list(self._change_history)
387 if record_id:
388 events = [e for e in events if e.record_id == record_id]
390 if field_name:
391 events = [e for e in events if e.field_name == field_name]
393 return events[-limit:]
395 async def process_batch(self) -> int:
396 """Process a batch of pending updates.
398 Returns:
399 Number of tasks processed
400 """
401 processed = 0
402 batch = []
404 # Get batch of tasks
405 while self._update_queue and len(batch) < self.batch_size:
406 task = self._update_queue.popleft()
407 if task.record_id in self._pending_updates:
408 batch.append(task)
409 del self._pending_updates[task.record_id]
411 if not batch:
412 return 0
414 # Process tasks
415 for task in batch:
416 try:
417 # Call update callbacks
418 for callback in self._update_callbacks:
419 if asyncio.iscoroutinefunction(callback):
420 await callback(task)
421 else:
422 await asyncio.to_thread(callback, task)
424 processed += 1
426 except Exception as e:
427 logger.error(f"Failed to process update for record {task.record_id}: {e}")
428 task.attempts += 1
429 task.last_error = str(e)
431 # Retry if under max attempts (3)
432 if task.attempts < 3:
433 task.priority += 1 # Increase priority for retries
434 if len(self._update_queue) < self.max_queue_size:
435 self._update_queue.append(task)
436 self._pending_updates[task.record_id] = task
438 logger.info(f"Processed {processed} vector update tasks")
439 return processed
441 async def _process_loop(self) -> None:
442 """Background processing loop for updates."""
443 logger.info("Started change tracker processing loop")
445 while not self._shutdown_event.is_set():
446 try:
447 # Process batch
448 await self.process_batch()
450 # Wait for interval or shutdown
451 try:
452 await asyncio.wait_for(
453 self._shutdown_event.wait(),
454 timeout=self.process_interval
455 )
456 except asyncio.TimeoutError:
457 continue
459 except Exception as e:
460 logger.error(f"Error in processing loop: {e}")
461 await asyncio.sleep(1)
463 logger.info("Stopped change tracker processing loop")
465 async def stop_processing(self, timeout: float = 10.0) -> None:
466 """Stop background processing.
468 Args:
469 timeout: Maximum time to wait for graceful shutdown
470 """
471 if not self._processing_task:
472 return
474 self._shutdown_event.set()
476 try:
477 await asyncio.wait_for(self._processing_task, timeout=timeout)
478 except asyncio.TimeoutError:
479 logger.warning("Processing task did not stop gracefully, cancelling")
480 self._processing_task.cancel()
481 try:
482 await self._processing_task
483 except asyncio.CancelledError:
484 pass
486 self._processing_task = None
488 def get_stats(self) -> dict[str, Any]:
489 """Get tracker statistics.
491 Returns:
492 Dictionary of statistics
493 """
494 return {
495 "pending_updates": len(self._pending_updates),
496 "queue_size": len(self._update_queue),
497 "max_queue_size": self.max_queue_size,
498 "history_size": len(self._change_history),
499 "dependencies": {
500 field: len(vectors)
501 for field, vectors in self._dependencies.items()
502 },
503 "is_processing": bool(
504 self._processing_task and not self._processing_task.done()
505 ),
506 }
508 async def flush(self) -> int:
509 """Process all pending updates immediately.
511 Returns:
512 Number of tasks processed
513 """
514 total_processed = 0
516 while self._update_queue:
517 processed = await self.process_batch()
518 total_processed += processed
520 if processed == 0:
521 break
523 return total_processed
525 async def _initialize_content_hashes(self) -> None:
526 """Initialize content hashes for existing vector fields that don't have them."""
527 if not self.tracked_fields:
528 return
530 import hashlib
532 from ..query import Query
534 # Get all records
535 all_records = await self.database.search(Query())
537 for record in all_records:
538 # Check if record has vector field but no content hash
539 vector_field = record.fields.get(self.vector_field)
540 if vector_field and hasattr(vector_field, 'metadata'):
541 stored_hash = vector_field.metadata.get('content_hash')
543 if stored_hash is None:
544 # Calculate and store content hash
545 content_parts = []
546 for field_name in self.tracked_fields:
547 field_value = record.get_value(field_name)
548 if field_value:
549 content_parts.append(str(field_value))
551 if content_parts:
552 current_content = " ".join(content_parts)
553 content_hash = hashlib.md5(current_content.encode()).hexdigest()
555 # Update the vector field metadata
556 vector_field.metadata['content_hash'] = content_hash
558 # Update the record in the database
559 try:
560 await self.database.update(record.id, record)
561 logger.debug(f"Initialized content hash for record {record.id}")
562 except Exception as e:
563 logger.warning(f"Failed to initialize content hash for record {record.id}: {e}")