Coverage for src/dataknobs_data/migration_v2/migrator.py: 39%
158 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-15 12:29 -0500
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-15 12:29 -0500
1"""
2Enhanced data migrator with streaming support.
3"""
5import asyncio
6import concurrent.futures
7from typing import Callable, Iterator, Optional, Union, List
9from dataknobs_data.database import Database as AsyncDatabase, SyncDatabase as Database
10from dataknobs_data.query import Query
11from dataknobs_data.records import Record
12from dataknobs_data.streaming import StreamConfig, StreamResult
14from .transformer import Transformer
15from .migration import Migration
16from .progress import MigrationProgress
19class Migrator:
20 """
21 Data migration orchestrator with streaming support.
23 Provides memory-efficient migration between databases using streaming,
24 with support for transformations, progress tracking, and parallel processing.
25 """
27 def migrate(
28 self,
29 source: Database,
30 target: Database,
31 transform: Optional[Union[Transformer, Migration]] = None,
32 query: Optional[Query] = None,
33 batch_size: int = 1000,
34 on_progress: Optional[Callable[[MigrationProgress], None]] = None,
35 on_error: Optional[Callable[[Exception, Record], bool]] = None
36 ) -> MigrationProgress:
37 """
38 Migrate data between databases with optional transformation.
40 Args:
41 source: Source database
42 target: Target database
43 transform: Optional transformer or migration to apply
44 query: Optional query to filter source records
45 batch_size: Number of records to process per batch
46 on_progress: Optional callback for progress updates
47 on_error: Optional error handler (return True to continue)
49 Returns:
50 MigrationProgress with final statistics
51 """
52 progress = MigrationProgress().start()
54 # Get total count for progress tracking
55 all_records = source.search(query or Query())
56 progress.total = len(all_records)
58 batch = []
59 for record in all_records:
60 try:
61 # Apply transformation if provided
62 if transform:
63 if isinstance(transform, Transformer):
64 record = transform.transform(record)
65 if record is None:
66 # Record filtered out
67 progress.record_skip("Filtered by transformer", record.id)
68 continue
69 elif isinstance(transform, Migration):
70 record = transform.apply(record)
72 batch.append(record)
74 # Process batch when full
75 if len(batch) >= batch_size:
76 self._write_batch(target, batch, progress, on_error)
77 batch = []
79 if on_progress:
80 on_progress(progress)
82 except Exception as e:
83 if on_error and on_error(e, record):
84 # Continue processing
85 progress.record_failure(str(e), record.id, e)
86 else:
87 # Stop processing
88 progress.record_failure(str(e), record.id, e)
89 break
91 # Process final batch
92 if batch:
93 self._write_batch(target, batch, progress, on_error)
95 progress.finish()
97 if on_progress:
98 on_progress(progress)
100 return progress
102 def migrate_stream(
103 self,
104 source: Database,
105 target: Database,
106 transform: Optional[Union[Transformer, Migration]] = None,
107 query: Optional[Query] = None,
108 config: Optional[StreamConfig] = None,
109 on_progress: Optional[Callable[[MigrationProgress], None]] = None
110 ) -> MigrationProgress:
111 """
112 Stream-based migration for memory efficiency.
114 Never loads full dataset into memory.
116 Args:
117 source: Source database with streaming support
118 target: Target database with streaming support
119 transform: Optional transformer or migration to apply
120 query: Optional query to filter source records
121 config: Streaming configuration
122 on_progress: Optional callback for progress updates
124 Returns:
125 MigrationProgress with final statistics
126 """
127 config = config or StreamConfig()
128 progress = MigrationProgress().start()
130 # Estimate total (if possible)
131 try:
132 progress.total = source.count(query)
133 except:
134 # Count not available, will track as we go
135 pass
137 # Create streaming pipeline
138 def transform_stream(records: Iterator[Record]) -> Iterator[Record]:
139 """Apply transformation to streaming records."""
140 for record in records:
141 try:
142 if transform:
143 if isinstance(transform, Transformer):
144 transformed = transform.transform(record)
145 if transformed:
146 yield transformed
147 else:
148 progress.record_skip("Filtered by transformer", record.id)
149 elif isinstance(transform, Migration):
150 yield transform.apply(record)
151 else:
152 yield record
153 except Exception as e:
154 if config.on_error and config.on_error(e, record):
155 progress.record_failure(str(e), record.id, e)
156 continue
157 else:
158 progress.record_failure(str(e), record.id, e)
159 raise
161 # Stream from source through transformation to target
162 source_stream = source.stream_read(query, config)
163 transformed_stream = transform_stream(source_stream)
165 # Write stream to target
166 result = target.stream_write(transformed_stream, config)
168 # Update progress from result
169 progress.processed = result.total_processed
170 progress.succeeded = result.successful
171 progress.failed = result.failed
172 progress.errors.extend(result.errors)
174 progress.finish()
176 if on_progress:
177 on_progress(progress)
179 return progress
181 def migrate_parallel(
182 self,
183 source: Database,
184 target: Database,
185 transform: Optional[Union[Transformer, Migration]] = None,
186 partitions: int = 4,
187 partition_field: str = "partition_id",
188 on_progress: Optional[Callable[[MigrationProgress], None]] = None
189 ) -> MigrationProgress:
190 """
191 Parallel streaming migration.
193 Partition data and migrate in parallel streams.
195 Args:
196 source: Source database
197 target: Target database
198 transform: Optional transformer or migration
199 partitions: Number of parallel partitions
200 partition_field: Field to use for partitioning
201 on_progress: Optional callback for progress updates
203 Returns:
204 Combined MigrationProgress
205 """
206 def migrate_partition(partition_id: int) -> MigrationProgress:
207 """Migrate a single partition."""
208 query = Query().filter(partition_field, "=", partition_id)
209 return self.migrate_stream(source, target, transform, query)
211 total_progress = MigrationProgress().start()
213 with concurrent.futures.ThreadPoolExecutor(max_workers=partitions) as executor:
214 futures = [
215 executor.submit(migrate_partition, i)
216 for i in range(partitions)
217 ]
219 for future in concurrent.futures.as_completed(futures):
220 partition_progress = future.result()
221 total_progress.merge(partition_progress)
223 if on_progress:
224 on_progress(total_progress)
226 total_progress.finish()
227 return total_progress
229 async def migrate_async(
230 self,
231 source: AsyncDatabase,
232 target: AsyncDatabase,
233 transform: Optional[Union[Transformer, Migration]] = None,
234 query: Optional[Query] = None,
235 config: Optional[StreamConfig] = None,
236 on_progress: Optional[Callable[[MigrationProgress], None]] = None
237 ) -> MigrationProgress:
238 """
239 Async stream-based migration.
241 Args:
242 source: Async source database
243 target: Async target database
244 transform: Optional transformer or migration
245 query: Optional query to filter source records
246 config: Streaming configuration
247 on_progress: Optional callback for progress updates
249 Returns:
250 MigrationProgress with final statistics
251 """
252 config = config or StreamConfig()
253 progress = MigrationProgress().start()
255 # Estimate total (if possible)
256 try:
257 progress.total = await source.count(query)
258 except:
259 pass
261 # Create async streaming pipeline
262 async def transform_stream(records):
263 """Apply transformation to async streaming records."""
264 async for record in records:
265 try:
266 if transform:
267 if isinstance(transform, Transformer):
268 transformed = transform.transform(record)
269 if transformed:
270 yield transformed
271 else:
272 progress.record_skip("Filtered by transformer", record.id)
273 elif isinstance(transform, Migration):
274 yield transform.apply(record)
275 else:
276 yield record
277 except Exception as e:
278 if config.on_error and config.on_error(e, record):
279 progress.record_failure(str(e), record.id, e)
280 continue
281 else:
282 progress.record_failure(str(e), record.id, e)
283 raise
285 # Stream from source through transformation to target
286 source_stream = source.stream_read(query, config)
287 transformed_stream = transform_stream(source_stream)
289 # Write stream to target
290 result = await target.stream_write(transformed_stream, config)
292 # Update progress from result
293 progress.processed = result.total_processed
294 progress.succeeded = result.successful
295 progress.failed = result.failed
296 progress.errors.extend(result.errors)
298 progress.finish()
300 if on_progress:
301 on_progress(progress)
303 return progress
305 def _write_batch(
306 self,
307 target: Database,
308 batch: List[Record],
309 progress: MigrationProgress,
310 on_error: Optional[Callable[[Exception, Record], bool]] = None
311 ) -> None:
312 """
313 Write a batch of records to target database.
315 Args:
316 target: Target database
317 batch: Batch of records to write
318 progress: Progress tracker to update
319 on_error: Optional error handler
320 """
321 for record in batch:
322 try:
323 # Ensure record has an ID
324 if not record.id:
325 record.generate_id()
327 target.create(record)
328 progress.record_success(record.id)
329 except Exception as e:
330 if on_error and on_error(e, record):
331 progress.record_failure(str(e), record.id, e)
332 else:
333 progress.record_failure(str(e), record.id, e)
334 raise
336 def validate_migration(
337 self,
338 source: Database,
339 target: Database,
340 query: Optional[Query] = None,
341 sample_size: Optional[int] = None
342 ) -> tuple[bool, List[str]]:
343 """
344 Validate that migration was successful.
346 Args:
347 source: Source database
348 target: Target database
349 query: Optional query used for migration
350 sample_size: Optional number of records to sample for validation
352 Returns:
353 Tuple of (is_valid, list_of_issues)
354 """
355 issues = []
357 # Get counts
358 source_records = source.search(query or Query())
359 target_records = target.search(Query())
361 source_count = len(source_records)
362 target_count = len(target_records)
364 if source_count != target_count:
365 issues.append(
366 f"Record count mismatch: source={source_count}, target={target_count}"
367 )
369 # Sample validation
370 if sample_size:
371 sample = source_records[:sample_size]
372 else:
373 sample = source_records
375 for source_record in sample:
376 if source_record.id:
377 target_record = target.read(source_record.id)
378 if not target_record:
379 issues.append(f"Record {source_record.id} not found in target")
381 return len(issues) == 0, issues