Coverage for src / dataknobs_data / pandas / batch_ops.py: 0%

194 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 15:45 -0700

1"""Batch operations for DataKnobs-Pandas integration.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import logging 

7from dataclasses import dataclass 

8from typing import Any, cast, TYPE_CHECKING 

9 

10import pandas as pd 

11 

12from .converter import ConversionOptions, DataFrameConverter 

13 

14if TYPE_CHECKING: 

15 from collections.abc import Callable, Generator 

16 from dataknobs_data.database import AsyncDatabase, SyncDatabase 

17 from dataknobs_data.query import Query 

18 

19 

20logger = logging.getLogger(__name__) 

21 

22 

23@dataclass 

24class BatchConfig: 

25 """Configuration for batch operations.""" 

26 chunk_size: int = 1000 

27 parallel: bool = False 

28 max_workers: int = 4 

29 progress_callback: Callable[[int, int], None] | None = None 

30 error_handling: str = "raise" # "raise", "skip", "log" 

31 memory_efficient: bool = True 

32 

33 def __post_init__(self): 

34 """Validate configuration parameters.""" 

35 if self.chunk_size <= 0: 

36 raise ValueError("chunk_size must be greater than 0") 

37 

38 if self.error_handling not in ("raise", "skip", "log"): 

39 raise ValueError("error_handling must be one of: 'raise', 'skip', 'log'") 

40 

41 

42class ChunkedProcessor: 

43 """Process DataFrames in chunks for memory efficiency.""" 

44 

45 def __init__(self, chunk_size: int = 1000): 

46 """Initialize chunked processor. 

47  

48 Args: 

49 chunk_size: Size of each chunk 

50 """ 

51 self.chunk_size = chunk_size 

52 

53 def process_dataframe( 

54 self, 

55 df: pd.DataFrame, 

56 processor: Callable[[pd.DataFrame], Any], 

57 combine: Callable[[list[Any]], Any] | None = None 

58 ) -> Any: 

59 """Process DataFrame in chunks. 

60  

61 Args: 

62 df: DataFrame to process 

63 processor: Function to process each chunk 

64 combine: Function to combine results 

65  

66 Returns: 

67 Combined results or list of chunk results 

68 """ 

69 results = [] 

70 

71 for chunk in self.iter_chunks(df): 

72 result = processor(chunk) 

73 results.append(result) 

74 

75 if combine: 

76 return combine(results) 

77 return results 

78 

79 def iter_chunks(self, df: pd.DataFrame) -> Generator[pd.DataFrame, None, None]: 

80 """Iterate over DataFrame in chunks. 

81  

82 Args: 

83 df: DataFrame to chunk 

84  

85 Yields: 

86 DataFrame chunks 

87 """ 

88 for start_idx in range(0, len(df), self.chunk_size): 

89 end_idx = min(start_idx + self.chunk_size, len(df)) 

90 yield df.iloc[start_idx:end_idx] 

91 

92 def read_csv_chunked( 

93 self, 

94 filepath: str, 

95 processor: Callable[[pd.DataFrame], Any], 

96 **read_kwargs 

97 ) -> list[Any]: 

98 """Read CSV file in chunks and process. 

99  

100 Args: 

101 filepath: Path to CSV file 

102 processor: Function to process each chunk 

103 **read_kwargs: Additional arguments for pd.read_csv 

104  

105 Returns: 

106 List of processed results 

107 """ 

108 results = [] 

109 

110 for chunk in pd.read_csv(filepath, chunksize=self.chunk_size, **read_kwargs): 

111 result = processor(chunk) 

112 results.append(result) 

113 

114 return results 

115 

116 

117class BatchOperations: 

118 """Batch operations for DataKnobs databases using DataFrames.""" 

119 

120 def __init__( 

121 self, 

122 database: AsyncDatabase | SyncDatabase, 

123 converter: DataFrameConverter | None = None 

124 ): 

125 """Initialize batch operations. 

126  

127 Args: 

128 database: Target database 

129 converter: DataFrame converter 

130 """ 

131 self.database = database 

132 self.converter = converter or DataFrameConverter() 

133 self.is_async = hasattr(database, 'create') and asyncio.iscoroutinefunction(database.create) 

134 

135 def bulk_insert_dataframe( 

136 self, 

137 df: pd.DataFrame, 

138 config: BatchConfig | None = None, 

139 conversion_options: ConversionOptions | None = None 

140 ) -> dict[str, Any]: 

141 """Bulk insert DataFrame rows into database. 

142  

143 Args: 

144 df: DataFrame to insert 

145 config: Batch configuration 

146 conversion_options: Options for DataFrame conversion 

147  

148 Returns: 

149 Insert statistics 

150 """ 

151 config = config or BatchConfig() 

152 conversion_options = conversion_options or ConversionOptions() 

153 # These are now guaranteed to be non-None 

154 assert config is not None 

155 assert conversion_options is not None 

156 

157 stats: dict[str, Any] = { 

158 "total_rows": len(df), 

159 "inserted": 0, 

160 "failed": 0, 

161 "errors": [] 

162 } 

163 

164 # Process in chunks if memory efficient mode 

165 if config.memory_efficient and len(df) > config.chunk_size: 

166 processor = ChunkedProcessor(config.chunk_size) 

167 # Create local references that are guaranteed non-None 

168 final_config = config 

169 final_conversion_options = conversion_options 

170 

171 def process_chunk(chunk_df: pd.DataFrame) -> dict[str, int]: 

172 return self._insert_chunk(chunk_df, final_config, final_conversion_options) 

173 

174 chunk_results = processor.process_dataframe(df, process_chunk) 

175 

176 # Aggregate results 

177 for result in chunk_results: 

178 stats["inserted"] += result["inserted"] 

179 stats["failed"] += result["failed"] 

180 if "errors" in result: 

181 stats["errors"].extend(result["errors"]) 

182 else: 

183 # Process entire DataFrame at once 

184 stats = self._insert_chunk(df, config, conversion_options) 

185 

186 return stats 

187 

188 def query_as_dataframe( 

189 self, 

190 query: Query, 

191 conversion_options: ConversionOptions | None = None 

192 ) -> pd.DataFrame: 

193 """Execute query and return results as DataFrame. 

194  

195 Args: 

196 query: Query to execute 

197 conversion_options: Options for conversion 

198  

199 Returns: 

200 Query results as DataFrame 

201 """ 

202 conversion_options = conversion_options or ConversionOptions() 

203 

204 # Execute query 

205 if self.is_async: 

206 records = asyncio.run(cast("AsyncDatabase", self.database).search(query)) 

207 else: 

208 records = cast("SyncDatabase", self.database).search(query) 

209 

210 # Convert to DataFrame 

211 return self.converter.records_to_dataframe(records, conversion_options) 

212 

213 def update_from_dataframe( 

214 self, 

215 df: pd.DataFrame, 

216 id_column: str | None, 

217 config: BatchConfig | None = None, 

218 conversion_options: ConversionOptions | None = None 

219 ) -> dict[str, Any]: 

220 """Update records from DataFrame using ID column. 

221  

222 Args: 

223 df: DataFrame with updates 

224 id_column: Column containing record IDs 

225 config: Batch configuration 

226 conversion_options: Conversion options 

227  

228 Returns: 

229 Update statistics 

230 """ 

231 config = config or BatchConfig() 

232 conversion_options = conversion_options or ConversionOptions() 

233 

234 stats: dict[str, Any] = { 

235 "total_rows": len(df), 

236 "updated": 0, 

237 "failed": 0, 

238 "not_found": 0, 

239 "errors": [] 

240 } 

241 

242 # Convert DataFrame to records 

243 records = self.converter.dataframe_to_records(df, conversion_options) 

244 

245 # Prepare updates as (id, record) tuples 

246 updates = [] 

247 if id_column is None: 

248 # Use index as ID source 

249 for idx, record in zip(df.index, records, strict=True): 

250 record_id = str(idx) 

251 updates.append((record_id, record)) 

252 else: 

253 # Ensure ID column exists 

254 if id_column not in df.columns: 

255 raise ValueError(f"ID column '{id_column}' not found in DataFrame") 

256 # Use specified column as ID source 

257 for i, record in enumerate(records): 

258 record_id = str(df.iloc[i][id_column]) 

259 updates.append((record_id, record)) 

260 

261 # Process updates in chunks 

262 for i in range(0, len(updates), config.chunk_size): 

263 chunk = updates[i:i + config.chunk_size] 

264 

265 try: 

266 # Use batch update for better performance 

267 if self.is_async: 

268 results = asyncio.run(cast("AsyncDatabase", self.database).update_batch(chunk)) 

269 else: 

270 results = cast("SyncDatabase", self.database).update_batch(chunk) 

271 

272 # Count successes and failures 

273 for success in results: 

274 if success: 

275 stats["updated"] += 1 

276 else: 

277 stats["not_found"] += 1 

278 

279 except Exception: 

280 # If batch fails, try individual updates 

281 if config.error_handling == "raise": 

282 raise 

283 

284 for record_id, record in chunk: 

285 try: 

286 if self.is_async: 

287 success = asyncio.run(cast("AsyncDatabase", self.database).update(record_id, record)) 

288 else: 

289 success = cast("SyncDatabase", self.database).update(record_id, record) 

290 

291 if success: 

292 stats["updated"] += 1 

293 else: 

294 stats["not_found"] += 1 

295 

296 except Exception as e: 

297 stats["failed"] += 1 

298 if config.error_handling == "log": 

299 logger.error(f"Failed to update record {record_id}: {e}") 

300 stats["errors"].append(str(e)) 

301 # else "skip" 

302 

303 # Progress callback 

304 if config.progress_callback: 

305 processed = stats["updated"] + stats["failed"] + stats["not_found"] 

306 config.progress_callback(processed, len(updates)) 

307 

308 return stats 

309 

310 def aggregate( 

311 self, 

312 query: Query, 

313 aggregations: dict[str, str | Callable], 

314 group_by: list[str] | None = None 

315 ) -> pd.DataFrame: 

316 """Perform aggregations on query results. 

317  

318 Args: 

319 query: Query to execute 

320 aggregations: Dictionary of column: aggregation function 

321 group_by: Columns to group by 

322  

323 Returns: 

324 Aggregated DataFrame 

325 """ 

326 # Get data as DataFrame 

327 df = self.query_as_dataframe(query) 

328 

329 if df.empty: 

330 return pd.DataFrame() 

331 

332 # Perform aggregations 

333 if group_by: 

334 grouped = df.groupby(group_by) 

335 return grouped.agg(aggregations) 

336 else: 

337 # Single row with aggregations 

338 result = {} 

339 for col, agg_func in aggregations.items(): 

340 if col in df.columns: 

341 if isinstance(agg_func, str): 

342 result[f"{col}_{agg_func}"] = df[col].agg(agg_func) 

343 else: 

344 result[f"{col}_agg"] = agg_func(df[col]) 

345 return pd.DataFrame([result]) 

346 

347 def transform_and_save( 

348 self, 

349 query: Query, 

350 transformer: Callable[[pd.DataFrame], pd.DataFrame], 

351 config: BatchConfig | None = None 

352 ) -> dict[str, Any]: 

353 """Query, transform with pandas, and save back. 

354  

355 Args: 

356 query: Query to get records 

357 transformer: Function to transform DataFrame 

358 config: Batch configuration 

359  

360 Returns: 

361 Operation statistics 

362 """ 

363 config = config or BatchConfig() 

364 

365 # Get data 

366 df = self.query_as_dataframe(query) 

367 

368 if df.empty: 

369 return {"total_rows": 0, "transformed": 0} 

370 

371 # Apply transformation 

372 transformed_df = transformer(df) 

373 

374 # Save back if index preserved (has record IDs) 

375 if df.index.name == "record_id" and transformed_df.index.name == "record_id": 

376 return self.update_from_dataframe( 

377 transformed_df, 

378 id_column=None, # Use index 

379 config=config 

380 ) 

381 else: 

382 # Insert as new records 

383 return self.bulk_insert_dataframe(transformed_df, config) 

384 

385 def _insert_chunk( 

386 self, 

387 df: pd.DataFrame, 

388 config: BatchConfig, 

389 conversion_options: ConversionOptions 

390 ) -> dict[str, Any]: 

391 """Insert a chunk of DataFrame rows. 

392  

393 Args: 

394 df: DataFrame chunk 

395 config: Batch configuration 

396 conversion_options: Conversion options 

397  

398 Returns: 

399 Insert statistics for chunk 

400 """ 

401 stats: dict[str, Any] = { 

402 "total_rows": len(df), 

403 "inserted": 0, 

404 "failed": 0, 

405 "errors": [] 

406 } 

407 

408 # Convert to records 

409 records = self.converter.dataframe_to_records(df, conversion_options) 

410 

411 # Use batch creation for better performance with graceful fallback 

412 if hasattr(self.database, 'create_batch'): 

413 try: 

414 if self.is_async: 

415 ids = asyncio.run(cast("AsyncDatabase", self.database).create_batch(records)) 

416 else: 

417 ids = cast("SyncDatabase", self.database).create_batch(records) 

418 stats["inserted"] = len(ids) 

419 

420 # Progress callback for successful batch 

421 if config.progress_callback: 

422 config.progress_callback(len(records), len(records)) 

423 

424 except Exception: 

425 # Batch failed, try individual records to identify failures 

426 for i, record in enumerate(records): 

427 try: 

428 if self.is_async: 

429 asyncio.run(cast("AsyncDatabase", self.database).create(record)) 

430 else: 

431 cast("SyncDatabase", self.database).create(record) 

432 stats["inserted"] += 1 

433 

434 except Exception as record_error: 

435 stats["failed"] += 1 

436 

437 # Handle error based on config 

438 if config.error_handling == "raise": 

439 raise 

440 elif config.error_handling == "log": 

441 logger.error(f"Failed to insert row {i}: {record_error}") 

442 stats["errors"].append(str(record_error)) 

443 # else "skip" - just continue 

444 

445 # Progress callback for each record 

446 if config.progress_callback: 

447 config.progress_callback(i + 1, len(records)) 

448 else: 

449 # Fallback to individual inserts if create_batch not available 

450 for i, record in enumerate(records): 

451 try: 

452 if self.is_async: 

453 asyncio.run(cast("AsyncDatabase", self.database).create(record)) 

454 else: 

455 cast("SyncDatabase", self.database).create(record) 

456 stats["inserted"] += 1 

457 

458 except Exception as e: 

459 stats["failed"] += 1 

460 if config.error_handling == "raise": 

461 raise 

462 elif config.error_handling == "log": 

463 logger.error(f"Failed to insert row {i}: {e}") 

464 stats["errors"].append(str(e)) 

465 # else "skip" 

466 

467 # Progress callback 

468 if config.progress_callback: 

469 config.progress_callback(i + 1, len(records)) 

470 

471 return stats 

472 

473 def export_to_csv( 

474 self, 

475 query: Query, 

476 filepath: str, 

477 conversion_options: ConversionOptions | None = None, 

478 **to_csv_kwargs 

479 ) -> None: 

480 """Export query results to CSV file. 

481  

482 Args: 

483 query: Query to execute 

484 filepath: Output file path 

485 conversion_options: Conversion options 

486 **to_csv_kwargs: Additional arguments for DataFrame.to_csv 

487 """ 

488 df = self.query_as_dataframe(query, conversion_options) 

489 df.to_csv(filepath, **to_csv_kwargs) 

490 

491 def export_to_parquet( 

492 self, 

493 query: Query, 

494 filepath: str, 

495 conversion_options: ConversionOptions | None = None, 

496 **to_parquet_kwargs 

497 ) -> None: 

498 """Export query results to Parquet file. 

499  

500 Args: 

501 query: Query to execute 

502 filepath: Output file path 

503 conversion_options: Conversion options 

504 **to_parquet_kwargs: Additional arguments for DataFrame.to_parquet 

505 """ 

506 df = self.query_as_dataframe(query, conversion_options) 

507 df.to_parquet(filepath, **to_parquet_kwargs) 

508 

509 

510# Import asyncio only if needed 

511try: 

512 import asyncio 

513except ImportError: 

514 asyncio = None