Coverage for src/dataknobs_data/migration/migrator.py: 28%

163 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-31 15:06 -0600

1"""Enhanced data migrator with streaming support. 

2""" 

3 

4from __future__ import annotations 

5 

6import concurrent.futures 

7from typing import TYPE_CHECKING 

8 

9from dataknobs_data.query import Query 

10from dataknobs_data.streaming import StreamConfig 

11 

12from .migration import Migration 

13from .progress import MigrationProgress 

14from .transformer import Transformer 

15 

16if TYPE_CHECKING: 

17 from collections.abc import Callable, Iterator 

18 from dataknobs_data.database import AsyncDatabase, SyncDatabase 

19 from dataknobs_data.records import Record 

20 

21 

22class Migrator: 

23 """Data migration orchestrator with streaming support. 

24  

25 Provides memory-efficient migration between databases using streaming, 

26 with support for transformations, progress tracking, and parallel processing. 

27 """ 

28 

29 def migrate( 

30 self, 

31 source: SyncDatabase, 

32 target: SyncDatabase, 

33 transform: Transformer | Migration | None = None, 

34 query: Query | None = None, 

35 batch_size: int = 1000, 

36 on_progress: Callable[[MigrationProgress], None] | None = None, 

37 on_error: Callable[[Exception, Record], bool] | None = None 

38 ) -> MigrationProgress: 

39 """Migrate data between databases with optional transformation. 

40  

41 Args: 

42 source: Source database 

43 target: Target database 

44 transform: Optional transformer or migration to apply 

45 query: Optional query to filter source records 

46 batch_size: Number of records to process per batch 

47 on_progress: Optional callback for progress updates 

48 on_error: Optional error handler (return True to continue) 

49  

50 Returns: 

51 MigrationProgress with final statistics 

52 """ 

53 progress = MigrationProgress().start() 

54 

55 # Get total count for progress tracking 

56 all_records = source.search(query or Query()) 

57 progress.total = len(all_records) 

58 

59 batch = [] 

60 for original_record in all_records: 

61 try: 

62 # Apply transformation if provided 

63 record = original_record 

64 if transform is not None: 

65 if isinstance(transform, Transformer): 

66 original_id = record.id # Preserve ID before transformation 

67 transformed = transform.transform(record) 

68 if transformed is None: 

69 # Record filtered out 

70 progress.record_skip("Filtered by transformer", original_id) 

71 continue 

72 record = transformed 

73 elif isinstance(transform, Migration): 

74 record = transform.apply(record) 

75 

76 batch.append(record) 

77 

78 # Process batch when full 

79 if len(batch) >= batch_size: 

80 self._write_batch(target, batch, progress, on_error) 

81 batch = [] 

82 

83 if on_progress: 

84 on_progress(progress) 

85 

86 except Exception as e: 

87 progress.record_failure(str(e), record.id if hasattr(record, 'id') else None, e) 

88 if on_error: 

89 if not on_error(e, record): 

90 # Handler says stop - re-raise to stop processing immediately 

91 raise 

92 # Handler says continue - keep going 

93 else: 

94 # No handler - stop processing immediately 

95 raise 

96 

97 # Process final batch 

98 if batch: 

99 self._write_batch(target, batch, progress, on_error) 

100 

101 progress.finish() 

102 

103 if on_progress: 

104 on_progress(progress) 

105 

106 return progress 

107 

108 def migrate_stream( 

109 self, 

110 source: SyncDatabase, 

111 target: SyncDatabase, 

112 transform: Transformer | Migration | None = None, 

113 query: Query | None = None, 

114 config: StreamConfig | None = None, 

115 on_progress: Callable[[MigrationProgress], None] | None = None 

116 ) -> MigrationProgress: 

117 """Stream-based migration for memory efficiency. 

118  

119 Never loads full dataset into memory. 

120  

121 Args: 

122 source: Source database with streaming support 

123 target: Target database with streaming support 

124 transform: Optional transformer or migration to apply 

125 query: Optional query to filter source records 

126 config: Streaming configuration 

127 on_progress: Optional callback for progress updates 

128  

129 Returns: 

130 MigrationProgress with final statistics 

131 """ 

132 config = config or StreamConfig() 

133 progress = MigrationProgress().start() 

134 

135 # Estimate total (if possible) 

136 try: 

137 progress.total = source.count(query) 

138 except Exception: 

139 # Count not available, will track as we go 

140 pass 

141 

142 # Create streaming pipeline 

143 def transform_stream(records: Iterator[Record]) -> Iterator[Record]: 

144 """Apply transformation to streaming records.""" 

145 for record in records: 

146 progress.processed += 1 # Track that we've processed this record 

147 try: 

148 if transform is not None: 

149 if isinstance(transform, Transformer): 

150 original_id = record.id # Preserve ID before transformation 

151 transformed = transform.transform(record) 

152 if transformed: 

153 yield transformed 

154 else: 

155 progress.record_skip("Filtered by transformer", original_id) 

156 elif isinstance(transform, Migration): 

157 yield transform.apply(record) 

158 else: 

159 yield record 

160 except Exception as e: 

161 if config.on_error and config.on_error(e, record): 

162 progress.record_failure(str(e), record.id if hasattr(record, 'id') else None, e) 

163 continue 

164 else: 

165 progress.record_failure(str(e), record.id if hasattr(record, 'id') else None, e) 

166 raise 

167 

168 # Stream from source through transformation to target 

169 source_stream = source.stream_read(query, config) 

170 transformed_stream = transform_stream(source_stream) 

171 

172 # Write stream to target 

173 result = target.stream_write(transformed_stream, config) 

174 

175 # Update progress from result 

176 # Note: processed was already tracked in transform_stream 

177 # Result contains only write successes/failures 

178 progress.succeeded += result.successful 

179 progress.failed += result.failed 

180 progress.errors.extend(result.errors) 

181 

182 progress.finish() 

183 

184 if on_progress: 

185 on_progress(progress) 

186 

187 return progress 

188 

189 def migrate_parallel( 

190 self, 

191 source: SyncDatabase, 

192 target: SyncDatabase, 

193 transform: Transformer | Migration | None = None, 

194 partitions: int = 4, 

195 partition_field: str = "partition_id", 

196 on_progress: Callable[[MigrationProgress], None] | None = None 

197 ) -> MigrationProgress: 

198 """Parallel streaming migration. 

199  

200 Partition data and migrate in parallel streams. 

201  

202 Args: 

203 source: Source database 

204 target: Target database 

205 transform: Optional transformer or migration 

206 partitions: Number of parallel partitions 

207 partition_field: Field to use for partitioning 

208 on_progress: Optional callback for progress updates 

209  

210 Returns: 

211 Combined MigrationProgress 

212 """ 

213 def migrate_partition(partition_id: int) -> MigrationProgress: 

214 """Migrate a single partition.""" 

215 query = Query().filter(partition_field, "=", partition_id) 

216 return self.migrate_stream(source, target, transform, query) 

217 

218 total_progress = MigrationProgress().start() 

219 

220 with concurrent.futures.ThreadPoolExecutor(max_workers=partitions) as executor: 

221 futures = [ 

222 executor.submit(migrate_partition, i) 

223 for i in range(partitions) 

224 ] 

225 

226 for future in concurrent.futures.as_completed(futures): 

227 partition_progress = future.result() 

228 total_progress.merge(partition_progress) 

229 

230 if on_progress: 

231 on_progress(total_progress) 

232 

233 total_progress.finish() 

234 return total_progress 

235 

236 async def migrate_async( 

237 self, 

238 source: AsyncDatabase, 

239 target: AsyncDatabase, 

240 transform: Transformer | Migration | None = None, 

241 query: Query | None = None, 

242 config: StreamConfig | None = None, 

243 on_progress: Callable[[MigrationProgress], None] | None = None 

244 ) -> MigrationProgress: 

245 """Async stream-based migration. 

246  

247 Args: 

248 source: Async source database 

249 target: Async target database 

250 transform: Optional transformer or migration 

251 query: Optional query to filter source records 

252 config: Streaming configuration 

253 on_progress: Optional callback for progress updates 

254  

255 Returns: 

256 MigrationProgress with final statistics 

257 """ 

258 config = config or StreamConfig() 

259 progress = MigrationProgress().start() 

260 

261 # Estimate total (if possible) 

262 try: 

263 progress.total = await source.count(query) 

264 except Exception: 

265 pass 

266 

267 # Create async streaming pipeline 

268 async def transform_stream(records): 

269 """Apply transformation to async streaming records.""" 

270 async for record in records: 

271 progress.processed += 1 # Track that we've processed this record 

272 try: 

273 if transform is not None: 

274 if isinstance(transform, Transformer): 

275 original_id = record.id # Preserve ID before transformation 

276 transformed = transform.transform(record) 

277 if transformed: 

278 yield transformed 

279 else: 

280 progress.record_skip("Filtered by transformer", original_id) 

281 elif isinstance(transform, Migration): 

282 yield transform.apply(record) 

283 else: 

284 yield record 

285 except Exception as e: 

286 if config.on_error and config.on_error(e, record): 

287 progress.record_failure(str(e), record.id if hasattr(record, 'id') else None, e) 

288 continue 

289 else: 

290 progress.record_failure(str(e), record.id if hasattr(record, 'id') else None, e) 

291 raise 

292 

293 # Stream from source through transformation to target 

294 source_stream = source.stream_read(query, config) 

295 transformed_stream = transform_stream(source_stream) 

296 

297 # Write stream to target 

298 result = await target.stream_write(transformed_stream, config) 

299 

300 # Update progress from result 

301 # Note: processed was already tracked in transform_stream 

302 # Result contains only write successes/failures 

303 progress.succeeded += result.successful 

304 progress.failed += result.failed 

305 progress.errors.extend(result.errors) 

306 

307 progress.finish() 

308 

309 if on_progress: 

310 on_progress(progress) 

311 

312 return progress 

313 

314 def _write_batch( 

315 self, 

316 target: SyncDatabase, 

317 batch: list[Record], 

318 progress: MigrationProgress, 

319 on_error: Callable[[Exception, Record], bool] | None = None 

320 ) -> None: 

321 """Write a batch of records to target database. 

322  

323 Args: 

324 target: Target database 

325 batch: Batch of records to write 

326 progress: Progress tracker to update 

327 on_error: Optional error handler 

328 """ 

329 for record in batch: 

330 try: 

331 # Ensure record has an ID 

332 if not record.id: 

333 record.generate_id() 

334 

335 target.create(record) 

336 progress.record_success(record.id) 

337 except Exception as e: 

338 progress.record_failure(str(e), record.id, e) 

339 if on_error: 

340 if not on_error(e, record): 

341 # Handler says stop - re-raise to stop processing immediately 

342 raise 

343 # Handler says continue - keep going 

344 else: 

345 # No handler - stop processing immediately 

346 raise 

347 

348 def validate_migration( 

349 self, 

350 source: SyncDatabase, 

351 target: SyncDatabase, 

352 query: Query | None = None, 

353 sample_size: int | None = None 

354 ) -> tuple[bool, list[str]]: 

355 """Validate that migration was successful. 

356  

357 Args: 

358 source: Source database 

359 target: Target database 

360 query: Optional query used for migration 

361 sample_size: Optional number of records to sample for validation 

362  

363 Returns: 

364 Tuple of (is_valid, list_of_issues) 

365 """ 

366 issues = [] 

367 

368 # Get counts 

369 source_records = source.search(query or Query()) 

370 target_records = target.search(Query()) 

371 

372 source_count = len(source_records) 

373 target_count = len(target_records) 

374 

375 if source_count != target_count: 

376 issues.append( 

377 f"Record count mismatch: source={source_count}, target={target_count}" 

378 ) 

379 

380 # Sample validation 

381 if sample_size: 

382 sample = source_records[:sample_size] 

383 else: 

384 sample = source_records 

385 

386 for source_record in sample: 

387 if source_record.id: 

388 target_record = target.read(source_record.id) 

389 if not target_record: 

390 issues.append(f"Record {source_record.id} not found in target") 

391 

392 return len(issues) == 0, issues