Coverage for src/dataknobs_data/streaming.py: 25%

218 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 11:34 -0700

1"""Streaming support for database operations.""" 

2 

3from __future__ import annotations 

4 

5import time 

6from dataclasses import dataclass, field 

7from typing import TYPE_CHECKING, Any 

8 

9 

10if TYPE_CHECKING: 

11 from collections.abc import AsyncIterator, Callable, Iterator 

12 from .query import Query 

13 from .records import Record 

14 

15 

16@dataclass 

17class StreamConfig: 

18 """Configuration for streaming operations.""" 

19 

20 batch_size: int = 1000 

21 prefetch: int = 2 # Number of batches to prefetch 

22 timeout: float | None = None 

23 on_error: Callable[[Exception, Record], bool] | None = None # Return True to continue 

24 

25 def __post_init__(self): 

26 """Validate configuration.""" 

27 if self.batch_size <= 0: 

28 raise ValueError("batch_size must be positive") 

29 if self.prefetch < 0: 

30 raise ValueError("prefetch must be non-negative") 

31 if self.timeout is not None and self.timeout <= 0: 

32 raise ValueError("timeout must be positive if specified") 

33 

34 

35@dataclass 

36class StreamResult: 

37 """Result of streaming operation.""" 

38 

39 total_processed: int = 0 

40 successful: int = 0 

41 failed: int = 0 

42 errors: list[dict[str, Any]] = field(default_factory=list) 

43 duration: float = 0.0 

44 total_batches: int = 0 # Number of batches processed 

45 failed_indices: list[int] = field(default_factory=list) # Indices of failed records 

46 

47 @property 

48 def success_rate(self) -> float: 

49 """Calculate success rate as percentage.""" 

50 if self.total_processed == 0: 

51 return 0.0 

52 return (self.successful / self.total_processed) * 100 

53 

54 def add_error(self, record_id: str | None, error: Exception, index: int | None = None) -> None: 

55 """Add an error to the result. 

56  

57 Args: 

58 record_id: ID of the record that failed 

59 error: The exception that occurred 

60 index: Optional index of the failed record in the original batch 

61 """ 

62 self.errors.append({ 

63 "record_id": record_id, 

64 "error": str(error), 

65 "type": type(error).__name__, 

66 "index": index 

67 }) 

68 if index is not None: 

69 self.failed_indices.append(index) 

70 

71 def merge(self, other: StreamResult) -> None: 

72 """Merge another result into this one.""" 

73 self.total_processed += other.total_processed 

74 self.successful += other.successful 

75 self.failed += other.failed 

76 self.errors.extend(other.errors) 

77 self.duration += other.duration 

78 self.total_batches += other.total_batches 

79 self.failed_indices.extend(other.failed_indices) 

80 

81 def __str__(self) -> str: 

82 """Human-readable representation.""" 

83 return ( 

84 f"StreamResult(processed={self.total_processed}, " 

85 f"successful={self.successful}, failed={self.failed}, " 

86 f"success_rate={self.success_rate:.1f}%, " 

87 f"duration={self.duration:.2f}s)" 

88 ) 

89 

90 

91def process_batch_with_fallback( 

92 batch: list[Record], 

93 batch_create_func: Callable[[list[Record]], list[str]], 

94 single_create_func: Callable[[Record], str], 

95 result: StreamResult, 

96 config: StreamConfig, 

97 on_quit_signal: Callable[[], None] | None = None, 

98 batch_index: int = 0 

99) -> bool: 

100 """Process a batch with graceful fallback to individual record creation. 

101  

102 When a batch operation fails, this function will retry each record individually 

103 to identify which specific records are causing the failure, allowing successful 

104 records to be processed while only failing the problematic ones. 

105  

106 Args: 

107 batch: List of records to process 

108 batch_create_func: Function to create a batch of records 

109 single_create_func: Function to create a single record 

110 result: StreamResult to update with statistics 

111 config: Stream configuration 

112 on_quit_signal: Optional callback when quitting is signaled 

113  

114 Returns: 

115 True to continue processing, False to quit streaming 

116 """ 

117 try: 

118 # Try batch creation first 

119 ids = batch_create_func(batch) 

120 result.successful += len(ids) 

121 result.total_processed += len(batch) 

122 result.total_batches += 1 

123 return True 

124 except Exception: 

125 # Batch failed, try individual records to identify failures 

126 result.total_batches += 1 

127 for i, record in enumerate(batch): 

128 result.total_processed += 1 

129 record_index = batch_index * config.batch_size + i 

130 try: 

131 single_create_func(record) 

132 result.successful += 1 

133 except Exception as record_error: 

134 # This specific record failed 

135 result.failed += 1 

136 # Safely get record ID if available 

137 record_id = record.id if record and hasattr(record, 'id') else None 

138 result.add_error(record_id, record_error, record_index) 

139 

140 if config.on_error: 

141 # Call error handler 

142 if not config.on_error(record_error, record): 

143 # Handler returned False, quit streaming 

144 if on_quit_signal: 

145 on_quit_signal() 

146 return False 

147 else: 

148 # No error handler, quit on first error 

149 if on_quit_signal: 

150 on_quit_signal() 

151 return False 

152 

153 return True 

154 

155 

156async def async_process_batch_with_fallback( 

157 batch: list[Record], 

158 batch_create_func: Callable, # Async callable 

159 single_create_func: Callable, # Async callable 

160 result: StreamResult, 

161 config: StreamConfig, 

162 on_quit_signal: Callable[[], None] | None = None, 

163 batch_index: int = 0 

164) -> bool: 

165 """Async version of process_batch_with_fallback. 

166  

167 When a batch operation fails, this function will retry each record individually 

168 to identify which specific records are causing the failure, allowing successful 

169 records to be processed while only failing the problematic ones. 

170  

171 Args: 

172 batch: List of records to process 

173 batch_create_func: Async function to create a batch of records 

174 single_create_func: Async function to create a single record 

175 result: StreamResult to update with statistics 

176 config: Stream configuration 

177 on_quit_signal: Optional callback when quitting is signaled 

178  

179 Returns: 

180 True to continue processing, False to quit streaming 

181 """ 

182 try: 

183 # Try batch creation first 

184 ids = await batch_create_func(batch) 

185 result.successful += len(ids) 

186 result.total_processed += len(batch) 

187 result.total_batches += 1 

188 return True 

189 except Exception: 

190 # Batch failed, try individual records to identify failures 

191 result.total_batches += 1 

192 for i, record in enumerate(batch): 

193 result.total_processed += 1 

194 record_index = batch_index * config.batch_size + i 

195 try: 

196 await single_create_func(record) 

197 result.successful += 1 

198 except Exception as record_error: 

199 # This specific record failed 

200 result.failed += 1 

201 # Safely get record ID if available 

202 record_id = record.id if record and hasattr(record, 'id') else None 

203 result.add_error(record_id, record_error, record_index) 

204 

205 if config.on_error: 

206 # Call error handler 

207 if not config.on_error(record_error, record): 

208 # Handler returned False, quit streaming 

209 if on_quit_signal: 

210 on_quit_signal() 

211 return False 

212 else: 

213 # No error handler, quit on first error 

214 if on_quit_signal: 

215 on_quit_signal() 

216 return False 

217 

218 return True 

219 

220 

221class StreamProcessor: 

222 """Base class for stream processing utilities.""" 

223 

224 @staticmethod 

225 def batch_iterator( 

226 iterator: Iterator[Record], 

227 batch_size: int 

228 ) -> Iterator[list[Record]]: 

229 """Convert a record iterator into batches.""" 

230 batch = [] 

231 for record in iterator: 

232 batch.append(record) 

233 if len(batch) >= batch_size: 

234 yield batch 

235 batch = [] 

236 if batch: 

237 yield batch 

238 

239 @staticmethod 

240 def list_to_iterator(records: list[Record]) -> Iterator[Record]: 

241 """Convert a list of records to an iterator. 

242  

243 Args: 

244 records: List of records 

245  

246 Yields: 

247 Individual records from the list 

248 """ 

249 for record in records: 

250 yield record 

251 

252 @staticmethod 

253 async def list_to_async_iterator(records: list[Record]) -> AsyncIterator[Record]: 

254 """Convert a list of records to an async iterator. 

255  

256 This adapter allows synchronous lists to be used with async streaming APIs. 

257  

258 Args: 

259 records: List of records 

260  

261 Yields: 

262 Individual records from the list 

263 """ 

264 for record in records: 

265 yield record 

266 

267 @staticmethod 

268 async def iterator_to_async_iterator(iterator: Iterator[Record]) -> AsyncIterator[Record]: 

269 """Convert a synchronous iterator to an async iterator. 

270  

271 This adapter allows synchronous iterators to be used with async streaming APIs. 

272  

273 Args: 

274 iterator: Synchronous iterator of records 

275  

276 Yields: 

277 Individual records from the iterator 

278 """ 

279 for record in iterator: 

280 yield record 

281 

282 @staticmethod 

283 async def async_batch_iterator( 

284 iterator: AsyncIterator[Record], 

285 batch_size: int 

286 ) -> AsyncIterator[list[Record]]: 

287 """Convert an async record iterator into batches.""" 

288 batch = [] 

289 async for record in iterator: 

290 batch.append(record) 

291 if len(batch) >= batch_size: 

292 yield batch 

293 batch = [] 

294 if batch: 

295 yield batch 

296 

297 @staticmethod 

298 def filter_stream( 

299 iterator: Iterator[Record], 

300 predicate: Callable[[Record], bool] 

301 ) -> Iterator[Record]: 

302 """Filter records in a stream.""" 

303 for record in iterator: 

304 if predicate(record): 

305 yield record 

306 

307 @staticmethod 

308 async def async_filter_stream( 

309 iterator: AsyncIterator[Record], 

310 predicate: Callable[[Record], bool] 

311 ) -> AsyncIterator[Record]: 

312 """Filter records in an async stream.""" 

313 async for record in iterator: 

314 if predicate(record): 

315 yield record 

316 

317 @staticmethod 

318 def transform_stream( 

319 iterator: Iterator[Record], 

320 transform: Callable[[Record], Record | None] 

321 ) -> Iterator[Record]: 

322 """Transform records in a stream, filtering out None results.""" 

323 for record in iterator: 

324 result = transform(record) 

325 if result is not None: 

326 yield result 

327 

328 @staticmethod 

329 async def async_transform_stream( 

330 iterator: AsyncIterator[Record], 

331 transform: Callable[[Record], Record | None] 

332 ) -> AsyncIterator[Record]: 

333 """Transform records in an async stream, filtering out None results.""" 

334 async for record in iterator: 

335 result = transform(record) 

336 if result is not None: 

337 yield result 

338 

339 

340class StreamingMixin: 

341 """Mixin class providing common streaming functionality for sync databases.""" 

342 

343 def _default_stream_read( 

344 self, 

345 query: Query | None = None, 

346 config: StreamConfig | None = None 

347 ) -> Iterator[Record]: 

348 """Default implementation of stream_read using search method. 

349  

350 This provides a simple streaming wrapper around the search method 

351 that most backends can use without modification. 

352 """ 

353 config = config or StreamConfig() 

354 

355 # Use search to get all matching records 

356 if query: 

357 records = self.search(query) 

358 else: 

359 # If no query, get all records 

360 from .query import Query 

361 records = self.search(Query()) 

362 

363 # Yield records in batches for consistency 

364 for i in range(0, len(records), config.batch_size): 

365 batch = records[i:i + config.batch_size] 

366 for record in batch: 

367 yield record 

368 

369 def _default_stream_write( 

370 self, 

371 records: Iterator[Record], 

372 config: StreamConfig | None = None 

373 ) -> StreamResult: 

374 """Default implementation of stream_write using create_batch method. 

375  

376 This provides batch writing functionality with graceful fallback 

377 to individual record creation when batches fail. 

378 """ 

379 config = config or StreamConfig() 

380 result = StreamResult() 

381 start_time = time.time() 

382 quitting = False 

383 batch_index = 0 

384 

385 batch = [] 

386 for record in records: 

387 batch.append(record) 

388 

389 if len(batch) >= config.batch_size: 

390 # Write batch with graceful fallback 

391 continue_processing = process_batch_with_fallback( 

392 batch, 

393 self.create_batch, 

394 self.create, 

395 result, 

396 config, 

397 batch_index=batch_index 

398 ) 

399 

400 if not continue_processing: 

401 quitting = True 

402 break 

403 

404 batch = [] 

405 batch_index += 1 

406 

407 # Write remaining batch 

408 if batch and not quitting: 

409 process_batch_with_fallback( 

410 batch, 

411 self.create_batch, 

412 self.create, 

413 result, 

414 config, 

415 batch_index=batch_index 

416 ) 

417 

418 result.duration = time.time() - start_time 

419 return result 

420 

421 

422class AsyncStreamingMixin: 

423 """Mixin class providing common streaming functionality for async databases.""" 

424 

425 async def _default_stream_read( 

426 self, 

427 query: Query | None = None, 

428 config: StreamConfig | None = None 

429 ) -> AsyncIterator[Record]: 

430 """Default implementation of async stream_read using search method. 

431  

432 This provides a simple streaming wrapper around the search method 

433 that most backends can use without modification. 

434 """ 

435 config = config or StreamConfig() 

436 

437 # Use search to get all matching records 

438 if query: 

439 records = await self.search(query) 

440 else: 

441 # If no query, get all records 

442 from .query import Query 

443 records = await self.search(Query()) 

444 

445 # Yield records in batches for consistency 

446 for i in range(0, len(records), config.batch_size): 

447 batch = records[i:i + config.batch_size] 

448 for record in batch: 

449 yield record 

450 

451 async def _default_stream_write( 

452 self, 

453 records: AsyncIterator[Record], 

454 config: StreamConfig | None = None 

455 ) -> StreamResult: 

456 """Default implementation of async stream_write using create_batch method. 

457  

458 This provides batch writing functionality with graceful fallback 

459 to individual record creation when batches fail. 

460 """ 

461 config = config or StreamConfig() 

462 result = StreamResult() 

463 start_time = time.time() 

464 quitting = False 

465 batch_index = 0 

466 

467 batch = [] 

468 async for record in records: 

469 batch.append(record) 

470 

471 if len(batch) >= config.batch_size: 

472 # Write batch with graceful fallback 

473 continue_processing = await async_process_batch_with_fallback( 

474 batch, 

475 self.create_batch, 

476 self.create, 

477 result, 

478 config, 

479 batch_index=batch_index 

480 ) 

481 

482 if not continue_processing: 

483 quitting = True 

484 break 

485 

486 batch = [] 

487 batch_index += 1 

488 

489 # Write remaining batch 

490 if batch and not quitting: 

491 await async_process_batch_with_fallback( 

492 batch, 

493 self.create_batch, 

494 self.create, 

495 result, 

496 config, 

497 batch_index=batch_index 

498 ) 

499 

500 result.duration = time.time() - start_time 

501 return result