batchata
Batchata - Unified Python API for AI Batch requests with cost tracking, Pydantic responses, and parallel execution.
Why AI-batching?
AI providers offer batch APIs that process requests asynchronously at 50% reduced cost compared to real-time APIs. This is ideal for workloads like document processing, data analysis, and content generation where immediate responses aren't required.
Quick Start
Installation
pip install batchata
Basic Usage
from batchata import Batch
# Simple batch processing
batch = Batch(results_dir="./output")
.set_default_params(model="claude-sonnet-4-20250514")
.add_cost_limit(usd=5.0)
# Add jobs
for file in files:
batch.add_job(file=file, prompt="Summarize this document")
# Execute
run = batch.run()
results = run.results()
Structured Output with Pydantic
from batchata import Batch
from pydantic import BaseModel
class DocumentAnalysis(BaseModel):
title: str
summary: str
key_points: list[str]
batch = Batch(results_dir="./results")
.set_default_params(model="claude-sonnet-4-20250514")
batch.add_job(
file="document.pdf",
prompt="Analyze this document",
response_model=DocumentAnalysis,
enable_citations=True # Anthropic only
)
run = batch.run()
for result in run.results()["completed"]:
analysis = result.parsed_response # DocumentAnalysis object
citations = result.citation_mappings # Field -> Citation mapping
Key Features
- 50% Cost Savings: Native batch processing via provider APIs
- Cost Limits: Set
max_cost_usdlimits for batch requests - Time Limits: Control execution time with
.add_time_limit() - State Persistence: Resume interrupted batches automatically
- Structured Output: Pydantic models with automatic validation
- Citations: Extract and map citations to response fields (Anthropic)
- Multiple Providers: Anthropic Claude and OpenAI GPT models
Supported Providers
| Feature | Anthropic | OpenAI |
|---|---|---|
| Models | All Claude models | All GPT models |
| Citations | ✅ | ❌ |
| Structured Output | ✅ | ✅ |
| File Types | PDF, TXT, DOCX, Images | PDF, Images |
Configuration
Set API keys as environment variables:
export ANTHROPIC_API_KEY="your-key"
export OPENAI_API_KEY="your-key"
Or use a .env file with python-dotenv.
1"""Batchata - Unified Python API for AI Batch requests with cost tracking, Pydantic responses, and parallel execution. 2 3**Why AI-batching?** 4 5AI providers offer batch APIs that process requests asynchronously at 50% reduced cost compared to real-time APIs. 6This is ideal for workloads like document processing, data analysis, and content generation where immediate 7responses aren't required. 8 9## Quick Start 10 11### Installation 12 13```bash 14pip install batchata 15``` 16 17### Basic Usage 18 19```python 20from batchata import Batch 21 22# Simple batch processing 23batch = Batch(results_dir="./output") 24 .set_default_params(model="claude-sonnet-4-20250514") 25 .add_cost_limit(usd=5.0) 26 27# Add jobs 28for file in files: 29 batch.add_job(file=file, prompt="Summarize this document") 30 31# Execute 32run = batch.run() 33results = run.results() 34``` 35 36### Structured Output with Pydantic 37 38```python 39from batchata import Batch 40from pydantic import BaseModel 41 42class DocumentAnalysis(BaseModel): 43 title: str 44 summary: str 45 key_points: list[str] 46 47batch = Batch(results_dir="./results") 48 .set_default_params(model="claude-sonnet-4-20250514") 49 50batch.add_job( 51 file="document.pdf", 52 prompt="Analyze this document", 53 response_model=DocumentAnalysis, 54 enable_citations=True # Anthropic only 55) 56 57run = batch.run() 58for result in run.results()["completed"]: 59 analysis = result.parsed_response # DocumentAnalysis object 60 citations = result.citation_mappings # Field -> Citation mapping 61``` 62 63## Key Features 64 65- **50% Cost Savings**: Native batch processing via provider APIs 66- **Cost Limits**: Set `max_cost_usd` limits for batch requests 67- **Time Limits**: Control execution time with `.add_time_limit()` 68- **State Persistence**: Resume interrupted batches automatically 69- **Structured Output**: Pydantic models with automatic validation 70- **Citations**: Extract and map citations to response fields (Anthropic) 71- **Multiple Providers**: Anthropic Claude and OpenAI GPT models 72 73## Supported Providers 74 75| Feature | Anthropic | OpenAI | 76|---------|-----------|--------| 77| Models | [All Claude models](https://github.com/agamm/batchata/blob/main/batchata/providers/anthropic/models.py) | [All GPT models](https://github.com/agamm/batchata/blob/main/batchata/providers/openai/models.py) | 78| Citations | ✅ | ❌ | 79| Structured Output | ✅ | ✅ | 80| File Types | PDF, TXT, DOCX, Images | PDF, Images | 81 82## Configuration 83 84Set API keys as environment variables: 85 86```bash 87export ANTHROPIC_API_KEY="your-key" 88export OPENAI_API_KEY="your-key" 89``` 90 91Or use a `.env` file with python-dotenv. 92""" 93 94from .core import Batch, BatchRun, Job, JobResult 95from .exceptions import ( 96 BatchataError, 97 CostLimitExceededError, 98 ProviderError, 99 ProviderNotFoundError, 100 ValidationError, 101) 102from .types import Citation 103 104__version__ = "0.3.0" 105 106__all__ = [ 107 "Batch", 108 "BatchRun", 109 "Job", 110 "JobResult", 111 "Citation", 112 "BatchataError", 113 "CostLimitExceededError", 114 "ProviderError", 115 "ProviderNotFoundError", 116 "ValidationError", 117]
19class Batch: 20 """Builder for batch job configuration. 21 22 Provides a fluent interface for configuring batch jobs with sensible defaults 23 and validation. The batch can be configured with cost limits, default parameters, 24 and progress callbacks. 25 26 Example: 27 ```python 28 batch = Batch("./results", max_parallel_batches=10, items_per_batch=10) 29 .set_state(file="./state.json", reuse_state=True) 30 .set_default_params(model="claude-sonnet-4-20250514", temperature=0.7) 31 .add_cost_limit(usd=15.0) 32 .add_job(messages=[{"role": "user", "content": "Hello"}]) 33 .add_job(file="./path/to/file.pdf", prompt="Generate summary of file") 34 35 run = batch.run() 36 ``` 37 """ 38 39 def __init__(self, results_dir: str, max_parallel_batches: int = 10, items_per_batch: int = 10, raw_files: Optional[bool] = None): 40 """Initialize batch configuration. 41 42 Args: 43 results_dir: Directory to store results 44 max_parallel_batches: Maximum parallel batch requests 45 items_per_batch: Number of jobs per provider batch 46 raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise) 47 """ 48 # Auto-determine raw_files based on results_dir if not explicitly set 49 if raw_files is None: 50 raw_files = bool(results_dir and results_dir.strip()) 51 52 self.config = BatchParams( 53 state_file=None, 54 results_dir=results_dir, 55 max_parallel_batches=max_parallel_batches, 56 items_per_batch=items_per_batch, 57 reuse_state=True, 58 raw_files=raw_files 59 ) 60 self.jobs: List[Job] = [] 61 62 def set_default_params(self, **kwargs) -> 'Batch': 63 """Set default parameters for all jobs. 64 65 These defaults will be applied to all jobs unless overridden 66 by job-specific parameters. 67 68 Args: 69 **kwargs: Default parameters (model, temperature, max_tokens, etc.) 70 71 Returns: 72 Self for chaining 73 74 Example: 75 ```python 76 batch.set_default_params(model="claude-3-sonnet", temperature=0.7) 77 ``` 78 """ 79 # Validate if model is provided 80 if "model" in kwargs: 81 self.config.validate_default_params(kwargs["model"]) 82 83 self.config.default_params.update(kwargs) 84 return self 85 86 def set_state(self, file: Optional[str] = None, reuse_state: bool = True) -> 'Batch': 87 """Set state file configuration. 88 89 Args: 90 file: Path to state file for persistence (default: None) 91 reuse_state: Whether to resume from existing state file (default: True) 92 93 Returns: 94 Self for chaining 95 96 Example: 97 ```python 98 batch.set_state(file="./state.json", reuse_state=True) 99 ``` 100 """ 101 self.config.state_file = file 102 self.config.reuse_state = reuse_state 103 return self 104 105 def add_cost_limit(self, usd: float) -> 'Batch': 106 """Add cost limit for the batch. 107 108 The batch will stop accepting new jobs once the cost limit is reached. 109 Active jobs will be allowed to complete. 110 111 Args: 112 usd: Cost limit in USD 113 114 Returns: 115 Self for chaining 116 117 Example: 118 ```python 119 batch.add_cost_limit(usd=50.0) 120 ``` 121 """ 122 if usd <= 0: 123 raise ValueError("Cost limit must be positive") 124 self.config.cost_limit_usd = usd 125 return self 126 127 def raw_files(self, enabled: bool = True) -> 'Batch': 128 """Enable or disable saving debug files from providers. 129 130 When enabled, debug files (raw API responses, JSONL files) will be saved 131 in a 'raw_files' subdirectory within the results directory. 132 This is useful for debugging, auditing, or accessing provider-specific metadata. 133 134 Args: 135 enabled: Whether to save debug files (default: True) 136 137 Returns: 138 Self for chaining 139 140 Example: 141 ```python 142 batch.raw_files(True) 143 ``` 144 """ 145 self.config.raw_files = enabled 146 return self 147 148 def set_verbosity(self, level: str) -> 'Batch': 149 """Set logging verbosity level. 150 151 Args: 152 level: Verbosity level ("debug", "info", "warn", "error") 153 154 Returns: 155 Self for chaining 156 157 Example: 158 ```python 159 batch.set_verbosity("error") # For production 160 batch.set_verbosity("debug") # For debugging 161 ``` 162 """ 163 valid_levels = {"debug", "info", "warn", "error"} 164 if level.lower() not in valid_levels: 165 raise ValueError(f"Invalid verbosity level: {level}. Must be one of {valid_levels}") 166 self.config.verbosity = level.lower() 167 return self 168 169 def add_time_limit(self, seconds: Optional[float] = None, minutes: Optional[float] = None, hours: Optional[float] = None) -> 'Batch': 170 """Add time limit for the entire batch execution. 171 172 When time limit is reached, all active provider batches are cancelled and 173 remaining unprocessed jobs are marked as failed. The batch execution 174 completes normally without throwing exceptions. 175 176 Args: 177 seconds: Time limit in seconds (optional) 178 minutes: Time limit in minutes (optional) 179 hours: Time limit in hours (optional) 180 181 Returns: 182 Self for chaining 183 184 Raises: 185 ValueError: If no time units specified, or if total time is outside 186 valid range (min: 10 seconds, max: 24 hours) 187 188 Note: 189 - Can combine multiple time units 190 - Time limit is checked every second by a background watchdog thread 191 - Jobs that exceed time limit appear in results()["failed"] with time limit error message 192 - No exceptions are thrown when time limit is reached 193 194 Example: 195 ```python 196 batch.add_time_limit(seconds=30) # 30 seconds 197 batch.add_time_limit(minutes=5) # 5 minutes 198 batch.add_time_limit(hours=2) # 2 hours 199 batch.add_time_limit(hours=1, minutes=30, seconds=15) # 5415 seconds total 200 ``` 201 """ 202 time_limit_seconds = 0.0 203 204 if seconds is not None: 205 time_limit_seconds += seconds 206 if minutes is not None: 207 time_limit_seconds += minutes * 60 208 if hours is not None: 209 time_limit_seconds += hours * 3600 210 211 if time_limit_seconds == 0: 212 raise ValueError("Must specify at least one of seconds, minutes, or hours") 213 214 self.config.time_limit_seconds = time_limit_seconds 215 return self 216 217 def add_job( 218 self, 219 messages: Optional[List[Message]] = None, 220 file: Optional[Union[str, Path]] = None, 221 prompt: Optional[str] = None, 222 model: Optional[str] = None, 223 temperature: Optional[float] = None, 224 max_tokens: Optional[int] = None, 225 response_model: Optional[Type[BaseModel]] = None, 226 enable_citations: bool = False, 227 **kwargs 228 ) -> 'Batch': 229 """Add a job to the batch. 230 231 Either provide messages OR file+prompt, not both. Parameters not provided 232 will use the defaults set via the defaults() method. 233 234 Args: 235 messages: Chat messages for direct input 236 file: File path for file-based input 237 prompt: Prompt to use with file input 238 model: Model to use (overrides default) 239 temperature: Sampling temperature (overrides default) 240 max_tokens: Max tokens to generate (overrides default) 241 response_model: Pydantic model for structured output 242 enable_citations: Whether to extract citations 243 **kwargs: Additional parameters 244 245 Returns: 246 Self for chaining 247 248 Example: 249 ```python 250 batch.add_job( 251 messages=[{"role": "user", "content": "Hello"}], 252 model="gpt-4" 253 ) 254 ``` 255 """ 256 # Generate unique job ID 257 job_id = f"job-{uuid.uuid4().hex[:8]}" 258 259 # Merge with defaults 260 params = self.config.default_params.copy() 261 262 # Update with provided parameters 263 if model is not None: 264 params["model"] = model 265 if temperature is not None: 266 params["temperature"] = temperature 267 if max_tokens is not None: 268 params["max_tokens"] = max_tokens 269 270 # Add other kwargs 271 params.update(kwargs) 272 273 # Ensure model is provided 274 if "model" not in params: 275 raise ValueError("Model must be provided either in defaults or job parameters") 276 277 # Validate parameters 278 provider = get_provider(params["model"]) 279 # Extract params without model to avoid duplicate 280 param_subset = {k: v for k, v in params.items() if k != "model"} 281 provider.validate_params(params["model"], **param_subset) 282 283 # Convert file path if string 284 if isinstance(file, str): 285 file = Path(file) 286 287 # Warn about temporary file paths that may not persist 288 if file: 289 file_str = str(file) 290 if "/tmp/" in file_str or "/var/folders/" in file_str or "temp" in file_str.lower(): 291 logger = logging.getLogger("batchata") 292 logger.debug(f"File path appears to be in a temporary directory: {file}") 293 logger.debug("This may cause issues when resuming from state if temp files are cleaned up") 294 295 # Create job 296 job = Job( 297 id=job_id, 298 messages=messages, 299 file=file, 300 prompt=prompt, 301 response_model=response_model, 302 enable_citations=enable_citations, 303 **params 304 ) 305 306 # Validate citation compatibility 307 if response_model and enable_citations: 308 from ..utils.validation import validate_flat_model 309 validate_flat_model(response_model) 310 311 # Validate job with provider (includes PDF validation for Anthropic) 312 provider.validate_job(job) 313 314 315 self.jobs.append(job) 316 return self 317 318 def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None, progress_interval: float = 1.0, print_status: bool = False, dry_run: bool = False) -> 'BatchRun': 319 """Execute the batch. 320 321 Creates a BatchRun instance and executes the jobs synchronously. 322 323 Args: 324 on_progress: Optional progress callback function that receives 325 (stats_dict, elapsed_time_seconds, batch_data) 326 progress_interval: Interval in seconds between progress updates (default: 1.0) 327 print_status: Whether to show rich progress display (default: False) 328 dry_run: If True, only show cost estimation without executing (default: False) 329 330 Returns: 331 BatchRun instance with completed results 332 333 Raises: 334 ValueError: If no jobs have been added 335 """ 336 if not self.jobs: 337 raise ValueError("No jobs added to batch") 338 339 # Import here to avoid circular dependency 340 from .batch_run import BatchRun 341 342 # Create and start the run 343 run = BatchRun(self.config, self.jobs) 344 345 # Handle dry run mode 346 if dry_run: 347 return run.dry_run() 348 349 # Set progress callback - either rich display or custom callback 350 if print_status: 351 return self._run_with_rich_display(run, progress_interval) 352 else: 353 return self._run_with_custom_callback(run, on_progress, progress_interval) 354 355 def _run_with_rich_display(self, run: 'BatchRun', progress_interval: float) -> 'BatchRun': 356 """Execute batch run with rich progress display. 357 358 Args: 359 run: BatchRun instance to execute 360 progress_interval: Interval between progress updates 361 362 Returns: 363 Completed BatchRun instance 364 """ 365 from ..utils.rich_progress import RichBatchProgressDisplay 366 display = RichBatchProgressDisplay() 367 368 def rich_progress_callback(stats, elapsed_time, batch_data): 369 # Start display on first call 370 if not hasattr(rich_progress_callback, '_started'): 371 config_dict = { 372 'results_dir': self.config.results_dir, 373 'state_file': self.config.state_file, 374 'items_per_batch': self.config.items_per_batch, 375 'max_parallel_batches': self.config.max_parallel_batches 376 } 377 display.start(stats, config_dict) 378 rich_progress_callback._started = True 379 380 # Update display 381 display.update(stats, batch_data, elapsed_time) 382 383 run.set_on_progress(rich_progress_callback, interval=progress_interval) 384 385 # Execute with proper cleanup 386 try: 387 run.execute() 388 389 # Show final status with all batches completed 390 stats = run.status() 391 display.update(stats, run.batch_tracking, (datetime.now() - run._start_time).total_seconds()) 392 393 # Small delay to ensure display updates 394 import time 395 time.sleep(0.2) 396 397 except KeyboardInterrupt: 398 # Update batch tracking to show cancelled status for pending/running batches 399 with run._state_lock: 400 for batch_id, batch_info in run.batch_tracking.items(): 401 if batch_info['status'] == 'running': 402 batch_info['status'] = 'cancelled' 403 elif batch_info['status'] == 'pending': 404 batch_info['status'] = 'cancelled' 405 406 # Show final status with cancelled batches 407 stats = run.status() 408 display.update(stats, run.batch_tracking, 0.0) 409 410 # Add a small delay to ensure the display updates 411 import time 412 time.sleep(0.1) 413 414 display.stop() 415 raise 416 finally: 417 if display.live: # Only stop if not already stopped 418 display.stop() 419 420 return run 421 422 def _run_with_custom_callback(self, run: 'BatchRun', on_progress: Optional[Callable[[Dict, float, Dict], None]], progress_interval: float) -> 'BatchRun': 423 """Execute batch run with custom progress callback. 424 425 Args: 426 run: BatchRun instance to execute 427 on_progress: Optional custom progress callback 428 progress_interval: Interval between progress updates 429 430 Returns: 431 Completed BatchRun instance 432 """ 433 # Use custom progress callback if provided 434 if on_progress: 435 run.set_on_progress(on_progress, interval=progress_interval) 436 437 run.execute() 438 return run 439 440 def __len__(self) -> int: 441 """Get the number of jobs in the batch.""" 442 return len(self.jobs) 443 444 def __repr__(self) -> str: 445 """String representation of the batch.""" 446 return ( 447 f"Batch(jobs={len(self.jobs)}, " 448 f"max_parallel_batches={self.config.max_parallel_batches}, " 449 f"cost_limit=${self.config.cost_limit_usd or 'None'})" 450 )
Builder for batch job configuration.
Provides a fluent interface for configuring batch jobs with sensible defaults and validation. The batch can be configured with cost limits, default parameters, and progress callbacks.
Example:
batch = Batch("./results", max_parallel_batches=10, items_per_batch=10)
.set_state(file="./state.json", reuse_state=True)
.set_default_params(model="claude-sonnet-4-20250514", temperature=0.7)
.add_cost_limit(usd=15.0)
.add_job(messages=[{"role": "user", "content": "Hello"}])
.add_job(file="./path/to/file.pdf", prompt="Generate summary of file")
run = batch.run()
39 def __init__(self, results_dir: str, max_parallel_batches: int = 10, items_per_batch: int = 10, raw_files: Optional[bool] = None): 40 """Initialize batch configuration. 41 42 Args: 43 results_dir: Directory to store results 44 max_parallel_batches: Maximum parallel batch requests 45 items_per_batch: Number of jobs per provider batch 46 raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise) 47 """ 48 # Auto-determine raw_files based on results_dir if not explicitly set 49 if raw_files is None: 50 raw_files = bool(results_dir and results_dir.strip()) 51 52 self.config = BatchParams( 53 state_file=None, 54 results_dir=results_dir, 55 max_parallel_batches=max_parallel_batches, 56 items_per_batch=items_per_batch, 57 reuse_state=True, 58 raw_files=raw_files 59 ) 60 self.jobs: List[Job] = []
Initialize batch configuration.
Args: results_dir: Directory to store results max_parallel_batches: Maximum parallel batch requests items_per_batch: Number of jobs per provider batch raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise)
62 def set_default_params(self, **kwargs) -> 'Batch': 63 """Set default parameters for all jobs. 64 65 These defaults will be applied to all jobs unless overridden 66 by job-specific parameters. 67 68 Args: 69 **kwargs: Default parameters (model, temperature, max_tokens, etc.) 70 71 Returns: 72 Self for chaining 73 74 Example: 75 ```python 76 batch.set_default_params(model="claude-3-sonnet", temperature=0.7) 77 ``` 78 """ 79 # Validate if model is provided 80 if "model" in kwargs: 81 self.config.validate_default_params(kwargs["model"]) 82 83 self.config.default_params.update(kwargs) 84 return self
Set default parameters for all jobs.
These defaults will be applied to all jobs unless overridden by job-specific parameters.
Args: **kwargs: Default parameters (model, temperature, max_tokens, etc.)
Returns: Self for chaining
Example:
batch.set_default_params(model="claude-3-sonnet", temperature=0.7)
86 def set_state(self, file: Optional[str] = None, reuse_state: bool = True) -> 'Batch': 87 """Set state file configuration. 88 89 Args: 90 file: Path to state file for persistence (default: None) 91 reuse_state: Whether to resume from existing state file (default: True) 92 93 Returns: 94 Self for chaining 95 96 Example: 97 ```python 98 batch.set_state(file="./state.json", reuse_state=True) 99 ``` 100 """ 101 self.config.state_file = file 102 self.config.reuse_state = reuse_state 103 return self
Set state file configuration.
Args: file: Path to state file for persistence (default: None) reuse_state: Whether to resume from existing state file (default: True)
Returns: Self for chaining
Example:
batch.set_state(file="./state.json", reuse_state=True)
105 def add_cost_limit(self, usd: float) -> 'Batch': 106 """Add cost limit for the batch. 107 108 The batch will stop accepting new jobs once the cost limit is reached. 109 Active jobs will be allowed to complete. 110 111 Args: 112 usd: Cost limit in USD 113 114 Returns: 115 Self for chaining 116 117 Example: 118 ```python 119 batch.add_cost_limit(usd=50.0) 120 ``` 121 """ 122 if usd <= 0: 123 raise ValueError("Cost limit must be positive") 124 self.config.cost_limit_usd = usd 125 return self
Add cost limit for the batch.
The batch will stop accepting new jobs once the cost limit is reached. Active jobs will be allowed to complete.
Args: usd: Cost limit in USD
Returns: Self for chaining
Example:
batch.add_cost_limit(usd=50.0)
127 def raw_files(self, enabled: bool = True) -> 'Batch': 128 """Enable or disable saving debug files from providers. 129 130 When enabled, debug files (raw API responses, JSONL files) will be saved 131 in a 'raw_files' subdirectory within the results directory. 132 This is useful for debugging, auditing, or accessing provider-specific metadata. 133 134 Args: 135 enabled: Whether to save debug files (default: True) 136 137 Returns: 138 Self for chaining 139 140 Example: 141 ```python 142 batch.raw_files(True) 143 ``` 144 """ 145 self.config.raw_files = enabled 146 return self
Enable or disable saving debug files from providers.
When enabled, debug files (raw API responses, JSONL files) will be saved in a 'raw_files' subdirectory within the results directory. This is useful for debugging, auditing, or accessing provider-specific metadata.
Args: enabled: Whether to save debug files (default: True)
Returns: Self for chaining
Example:
batch.raw_files(True)
148 def set_verbosity(self, level: str) -> 'Batch': 149 """Set logging verbosity level. 150 151 Args: 152 level: Verbosity level ("debug", "info", "warn", "error") 153 154 Returns: 155 Self for chaining 156 157 Example: 158 ```python 159 batch.set_verbosity("error") # For production 160 batch.set_verbosity("debug") # For debugging 161 ``` 162 """ 163 valid_levels = {"debug", "info", "warn", "error"} 164 if level.lower() not in valid_levels: 165 raise ValueError(f"Invalid verbosity level: {level}. Must be one of {valid_levels}") 166 self.config.verbosity = level.lower() 167 return self
Set logging verbosity level.
Args: level: Verbosity level ("debug", "info", "warn", "error")
Returns: Self for chaining
Example:
batch.set_verbosity("error") # For production
batch.set_verbosity("debug") # For debugging
169 def add_time_limit(self, seconds: Optional[float] = None, minutes: Optional[float] = None, hours: Optional[float] = None) -> 'Batch': 170 """Add time limit for the entire batch execution. 171 172 When time limit is reached, all active provider batches are cancelled and 173 remaining unprocessed jobs are marked as failed. The batch execution 174 completes normally without throwing exceptions. 175 176 Args: 177 seconds: Time limit in seconds (optional) 178 minutes: Time limit in minutes (optional) 179 hours: Time limit in hours (optional) 180 181 Returns: 182 Self for chaining 183 184 Raises: 185 ValueError: If no time units specified, or if total time is outside 186 valid range (min: 10 seconds, max: 24 hours) 187 188 Note: 189 - Can combine multiple time units 190 - Time limit is checked every second by a background watchdog thread 191 - Jobs that exceed time limit appear in results()["failed"] with time limit error message 192 - No exceptions are thrown when time limit is reached 193 194 Example: 195 ```python 196 batch.add_time_limit(seconds=30) # 30 seconds 197 batch.add_time_limit(minutes=5) # 5 minutes 198 batch.add_time_limit(hours=2) # 2 hours 199 batch.add_time_limit(hours=1, minutes=30, seconds=15) # 5415 seconds total 200 ``` 201 """ 202 time_limit_seconds = 0.0 203 204 if seconds is not None: 205 time_limit_seconds += seconds 206 if minutes is not None: 207 time_limit_seconds += minutes * 60 208 if hours is not None: 209 time_limit_seconds += hours * 3600 210 211 if time_limit_seconds == 0: 212 raise ValueError("Must specify at least one of seconds, minutes, or hours") 213 214 self.config.time_limit_seconds = time_limit_seconds 215 return self
Add time limit for the entire batch execution.
When time limit is reached, all active provider batches are cancelled and remaining unprocessed jobs are marked as failed. The batch execution completes normally without throwing exceptions.
Args: seconds: Time limit in seconds (optional) minutes: Time limit in minutes (optional) hours: Time limit in hours (optional)
Returns: Self for chaining
Raises: ValueError: If no time units specified, or if total time is outside valid range (min: 10 seconds, max: 24 hours)
Note: - Can combine multiple time units - Time limit is checked every second by a background watchdog thread - Jobs that exceed time limit appear in results()["failed"] with time limit error message - No exceptions are thrown when time limit is reached
Example:
batch.add_time_limit(seconds=30) # 30 seconds
batch.add_time_limit(minutes=5) # 5 minutes
batch.add_time_limit(hours=2) # 2 hours
batch.add_time_limit(hours=1, minutes=30, seconds=15) # 5415 seconds total
217 def add_job( 218 self, 219 messages: Optional[List[Message]] = None, 220 file: Optional[Union[str, Path]] = None, 221 prompt: Optional[str] = None, 222 model: Optional[str] = None, 223 temperature: Optional[float] = None, 224 max_tokens: Optional[int] = None, 225 response_model: Optional[Type[BaseModel]] = None, 226 enable_citations: bool = False, 227 **kwargs 228 ) -> 'Batch': 229 """Add a job to the batch. 230 231 Either provide messages OR file+prompt, not both. Parameters not provided 232 will use the defaults set via the defaults() method. 233 234 Args: 235 messages: Chat messages for direct input 236 file: File path for file-based input 237 prompt: Prompt to use with file input 238 model: Model to use (overrides default) 239 temperature: Sampling temperature (overrides default) 240 max_tokens: Max tokens to generate (overrides default) 241 response_model: Pydantic model for structured output 242 enable_citations: Whether to extract citations 243 **kwargs: Additional parameters 244 245 Returns: 246 Self for chaining 247 248 Example: 249 ```python 250 batch.add_job( 251 messages=[{"role": "user", "content": "Hello"}], 252 model="gpt-4" 253 ) 254 ``` 255 """ 256 # Generate unique job ID 257 job_id = f"job-{uuid.uuid4().hex[:8]}" 258 259 # Merge with defaults 260 params = self.config.default_params.copy() 261 262 # Update with provided parameters 263 if model is not None: 264 params["model"] = model 265 if temperature is not None: 266 params["temperature"] = temperature 267 if max_tokens is not None: 268 params["max_tokens"] = max_tokens 269 270 # Add other kwargs 271 params.update(kwargs) 272 273 # Ensure model is provided 274 if "model" not in params: 275 raise ValueError("Model must be provided either in defaults or job parameters") 276 277 # Validate parameters 278 provider = get_provider(params["model"]) 279 # Extract params without model to avoid duplicate 280 param_subset = {k: v for k, v in params.items() if k != "model"} 281 provider.validate_params(params["model"], **param_subset) 282 283 # Convert file path if string 284 if isinstance(file, str): 285 file = Path(file) 286 287 # Warn about temporary file paths that may not persist 288 if file: 289 file_str = str(file) 290 if "/tmp/" in file_str or "/var/folders/" in file_str or "temp" in file_str.lower(): 291 logger = logging.getLogger("batchata") 292 logger.debug(f"File path appears to be in a temporary directory: {file}") 293 logger.debug("This may cause issues when resuming from state if temp files are cleaned up") 294 295 # Create job 296 job = Job( 297 id=job_id, 298 messages=messages, 299 file=file, 300 prompt=prompt, 301 response_model=response_model, 302 enable_citations=enable_citations, 303 **params 304 ) 305 306 # Validate citation compatibility 307 if response_model and enable_citations: 308 from ..utils.validation import validate_flat_model 309 validate_flat_model(response_model) 310 311 # Validate job with provider (includes PDF validation for Anthropic) 312 provider.validate_job(job) 313 314 315 self.jobs.append(job) 316 return self
Add a job to the batch.
Either provide messages OR file+prompt, not both. Parameters not provided will use the defaults set via the defaults() method.
Args: messages: Chat messages for direct input file: File path for file-based input prompt: Prompt to use with file input model: Model to use (overrides default) temperature: Sampling temperature (overrides default) max_tokens: Max tokens to generate (overrides default) response_model: Pydantic model for structured output enable_citations: Whether to extract citations **kwargs: Additional parameters
Returns: Self for chaining
Example:
batch.add_job(
messages=[{"role": "user", "content": "Hello"}],
model="gpt-4"
)
318 def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None, progress_interval: float = 1.0, print_status: bool = False, dry_run: bool = False) -> 'BatchRun': 319 """Execute the batch. 320 321 Creates a BatchRun instance and executes the jobs synchronously. 322 323 Args: 324 on_progress: Optional progress callback function that receives 325 (stats_dict, elapsed_time_seconds, batch_data) 326 progress_interval: Interval in seconds between progress updates (default: 1.0) 327 print_status: Whether to show rich progress display (default: False) 328 dry_run: If True, only show cost estimation without executing (default: False) 329 330 Returns: 331 BatchRun instance with completed results 332 333 Raises: 334 ValueError: If no jobs have been added 335 """ 336 if not self.jobs: 337 raise ValueError("No jobs added to batch") 338 339 # Import here to avoid circular dependency 340 from .batch_run import BatchRun 341 342 # Create and start the run 343 run = BatchRun(self.config, self.jobs) 344 345 # Handle dry run mode 346 if dry_run: 347 return run.dry_run() 348 349 # Set progress callback - either rich display or custom callback 350 if print_status: 351 return self._run_with_rich_display(run, progress_interval) 352 else: 353 return self._run_with_custom_callback(run, on_progress, progress_interval)
Execute the batch.
Creates a BatchRun instance and executes the jobs synchronously.
Args: on_progress: Optional progress callback function that receives (stats_dict, elapsed_time_seconds, batch_data) progress_interval: Interval in seconds between progress updates (default: 1.0) print_status: Whether to show rich progress display (default: False) dry_run: If True, only show cost estimation without executing (default: False)
Returns: BatchRun instance with completed results
Raises: ValueError: If no jobs have been added
25class BatchRun: 26 """Manages the execution of a batch job synchronously. 27 28 Processes jobs in batches based on items_per_batch configuration. 29 Simpler synchronous execution for clear logging and debugging. 30 31 Example: 32 ```python 33 config = BatchParams(...) 34 run = BatchRun(config, jobs) 35 run.execute() 36 results = run.results() 37 ``` 38 """ 39 40 def __init__(self, config: BatchParams, jobs: List[Job]): 41 """Initialize batch run. 42 43 Args: 44 config: Batch configuration 45 jobs: List of jobs to execute 46 """ 47 self.config = config 48 self.jobs = {job.id: job for job in jobs} 49 50 # Set logging level based on config 51 set_log_level(level=config.verbosity.upper()) 52 53 # Initialize components 54 self.cost_tracker = CostTracker(limit_usd=config.cost_limit_usd) 55 56 # Use temp file for state if not provided 57 state_file = config.state_file 58 if not state_file: 59 state_file = create_temp_state_file(config) 60 config.reuse_state = False 61 logger.info(f"Created temporary state file: {state_file}") 62 63 self.state_manager = StateManager(state_file) 64 65 # State tracking 66 self.pending_jobs: List[Job] = [] 67 self.completed_results: Dict[str, JobResult] = {} # job_id -> result 68 self.failed_jobs: Dict[str, str] = {} # job_id -> error 69 self.cancelled_jobs: Dict[str, str] = {} # job_id -> reason 70 71 # Batch tracking 72 self.total_batches = 0 73 self.completed_batches = 0 74 self.current_batch_index = 0 75 self.current_batch_size = 0 76 77 # Execution control 78 self._started = False 79 self._start_time: Optional[datetime] = None 80 self._time_limit_exceeded = False 81 self._progress_callback: Optional[Callable[[Dict, float], None]] = None 82 self._progress_interval: float = 1.0 # Default to 1 second 83 84 # Threading primitives 85 self._state_lock = threading.Lock() 86 self._shutdown_event = threading.Event() 87 88 # Batch tracking for progress display 89 self.batch_tracking: Dict[str, Dict] = {} # batch_id -> batch_info 90 91 # Active batch tracking for cancellation 92 self._active_batches: Dict[str, object] = {} # batch_id -> provider 93 self._active_batches_lock = threading.Lock() 94 95 # Results directory 96 self.results_dir = Path(config.results_dir) 97 98 # If not reusing state, clear the results directory 99 if not config.reuse_state and self.results_dir.exists(): 100 import shutil 101 shutil.rmtree(self.results_dir) 102 103 self.results_dir.mkdir(parents=True, exist_ok=True) 104 105 # Raw files directory (if enabled) 106 self.raw_files_dir = None 107 if config.raw_files: 108 self.raw_files_dir = self.results_dir / "raw_files" 109 self.raw_files_dir.mkdir(parents=True, exist_ok=True) 110 111 # Try to resume from saved state 112 self._resume_from_state() 113 114 115 def _resume_from_state(self): 116 """Resume from saved state if available.""" 117 # Check if we should reuse state 118 if not self.config.reuse_state: 119 # Clear any existing state and start fresh 120 self.state_manager.clear() 121 self.pending_jobs = list(self.jobs.values()) 122 return 123 124 state = self.state_manager.load() 125 if state is None: 126 # No saved state, use jobs passed to constructor 127 self.pending_jobs = list(self.jobs.values()) 128 return 129 130 logger.info("Resuming batch run from saved state") 131 132 # Restore pending jobs 133 self.pending_jobs = [] 134 for job_data in state.pending_jobs: 135 job = Job.from_dict(job_data) 136 # Check if file exists (if job has a file) 137 if job.file and not job.file.exists(): 138 logger.error(f"File not found for job {job.id}: {job.file}") 139 logger.error("This may happen if files were in temporary directories that were cleaned up") 140 self.failed_jobs[job.id] = f"File not found: {job.file}" 141 else: 142 self.pending_jobs.append(job) 143 144 # Restore completed results from file references 145 for result_ref in state.completed_results: 146 job_id = result_ref["job_id"] 147 file_path = result_ref["file_path"] 148 try: 149 with open(file_path, 'r') as f: 150 result_data = json.load(f) 151 result = JobResult.from_dict(result_data) 152 self.completed_results[job_id] = result 153 except Exception as e: 154 logger.error(f"Failed to load result for {job_id} from {file_path}: {e}") 155 # Move to failed jobs if we can't load the result 156 self.failed_jobs[job_id] = f"Failed to load result file: {e}" 157 158 # Restore failed jobs 159 for job_data in state.failed_jobs: 160 self.failed_jobs[job_data["id"]] = job_data.get("error", "Unknown error") 161 162 # Restore cancelled jobs (if they exist in state) 163 for job_data in getattr(state, 'cancelled_jobs', []): 164 self.cancelled_jobs[job_data["id"]] = job_data.get("reason", "Cancelled") 165 166 # Restore cost tracker 167 self.cost_tracker.used_usd = state.total_cost_usd 168 169 logger.info( 170 f"Resumed with {len(self.pending_jobs)} pending, " 171 f"{len(self.completed_results)} completed, " 172 f"{len(self.failed_jobs)} failed, " 173 f"{len(self.cancelled_jobs)} cancelled" 174 ) 175 176 def to_json(self) -> Dict: 177 """Convert current state to JSON-serializable dict.""" 178 return { 179 "created_at": datetime.now().isoformat(), 180 "pending_jobs": [job.to_dict() for job in self.pending_jobs], 181 "completed_results": [ 182 {"job_id": job_id, "file_path": str(self.results_dir / f"{job_id}.json")} 183 for job_id in self.completed_results.keys() 184 ], 185 "failed_jobs": [ 186 { 187 "id": job_id, 188 "error": error, 189 "timestamp": datetime.now().isoformat() 190 } for job_id, error in self.failed_jobs.items() 191 ], 192 "cancelled_jobs": [ 193 { 194 "id": job_id, 195 "reason": reason, 196 "timestamp": datetime.now().isoformat() 197 } for job_id, reason in self.cancelled_jobs.items() 198 ], 199 "total_cost_usd": self.cost_tracker.used_usd, 200 "config": { 201 "state_file": self.config.state_file, 202 "results_dir": self.config.results_dir, 203 "max_parallel_batches": self.config.max_parallel_batches, 204 "items_per_batch": self.config.items_per_batch, 205 "cost_limit_usd": self.config.cost_limit_usd, 206 "default_params": self.config.default_params, 207 "raw_files": self.config.raw_files 208 } 209 } 210 211 def execute(self): 212 """Execute synchronous batch run and wait for completion.""" 213 if self._started: 214 raise RuntimeError("Batch run already started") 215 216 self._started = True 217 self._start_time = datetime.now() 218 219 # Register signal handler for graceful shutdown 220 def signal_handler(signum, frame): 221 logger.warning("Received interrupt signal, shutting down gracefully...") 222 self._shutdown_event.set() 223 224 # Store original handler to restore later 225 original_handler = signal.signal(signal.SIGINT, signal_handler) 226 227 try: 228 logger.info("Starting batch run") 229 230 # Start time limit watchdog if configured 231 self._start_time_limit_watchdog() 232 233 # Call initial progress 234 if self._progress_callback: 235 stats = self.status() 236 batch_data = dict(self.batch_tracking) 237 self._progress_callback(stats, 0.0, batch_data) 238 239 # Process all jobs synchronously 240 self._process_all_jobs() 241 242 logger.info("Batch run completed") 243 finally: 244 # Restore original signal handler 245 signal.signal(signal.SIGINT, original_handler) 246 247 def set_on_progress(self, callback: Callable[[Dict, float, Dict], None], interval: float = 1.0) -> 'BatchRun': 248 """Set progress callback for execution monitoring. 249 250 The callback will be called periodically with progress statistics 251 including completed jobs, total jobs, current cost, etc. 252 253 Args: 254 callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data) 255 - stats_dict: Progress statistics dictionary 256 - elapsed_time_seconds: Time elapsed since batch started (float) 257 - batch_data: Dictionary mapping batch_id to batch information 258 interval: Interval in seconds between progress updates (default: 1.0) 259 260 Returns: 261 Self for chaining 262 263 Example: 264 ```python 265 run.set_on_progress( 266 lambda stats, time, batch_data: print( 267 f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s" 268 ) 269 ) 270 ``` 271 """ 272 self._progress_callback = callback 273 self._progress_interval = interval 274 return self 275 276 def _start_time_limit_watchdog(self): 277 """Start a background thread to check for time limit every second.""" 278 if not self.config.time_limit_seconds: 279 return 280 281 def time_limit_watchdog(): 282 """Check for time limit every second and trigger shutdown if exceeded.""" 283 while not self._shutdown_event.is_set(): 284 if self._check_time_limit(): 285 logger.warning("Batch execution time limit exceeded") 286 with self._state_lock: 287 self._time_limit_exceeded = True 288 self._shutdown_event.set() 289 break 290 time.sleep(1.0) 291 292 # Start watchdog as daemon thread 293 watchdog_thread = threading.Thread(target=time_limit_watchdog, daemon=True) 294 watchdog_thread.start() 295 logger.debug(f"Started time limit watchdog thread (time limit: {self.config.time_limit_seconds}s)") 296 297 def _check_time_limit(self) -> bool: 298 """Check if batch execution has exceeded time limit.""" 299 if not self.config.time_limit_seconds or not self._start_time: 300 return False 301 302 elapsed = (datetime.now() - self._start_time).total_seconds() 303 return elapsed >= self.config.time_limit_seconds 304 305 def _process_all_jobs(self): 306 """Process all jobs with parallel execution.""" 307 # Prepare all batches 308 batches = self._prepare_batches() 309 self.total_batches = len(batches) 310 311 # Process batches in parallel 312 with ThreadPoolExecutor(max_workers=self.config.max_parallel_batches) as executor: 313 futures = [executor.submit(self._execute_batch_wrapped, provider, batch_jobs) 314 for _, provider, batch_jobs in batches] 315 316 try: 317 for future in as_completed(futures): 318 # Stop if shutdown event detected (includes time limit) 319 if self._shutdown_event.is_set(): 320 break 321 future.result() # Re-raise any exceptions 322 except KeyboardInterrupt: 323 self._shutdown_event.set() 324 # Cancel remaining futures 325 for future in futures: 326 future.cancel() 327 raise 328 finally: 329 # Handle time limit or cancellation - mark remaining jobs appropriately 330 with self._state_lock: 331 if self._shutdown_event.is_set(): 332 # If time limit exceeded, cancel all active batches 333 if self._time_limit_exceeded: 334 self._cancel_all_active_batches() 335 336 # Mark any unprocessed jobs based on reason for shutdown 337 for _, _, batch_jobs in batches: 338 for job in batch_jobs: 339 # Skip jobs already processed 340 if (job.id in self.completed_results or 341 job.id in self.failed_jobs or 342 job.id in self.cancelled_jobs): 343 continue 344 345 # Mark based on shutdown reason 346 if self._time_limit_exceeded: 347 self.failed_jobs[job.id] = "Time limit exceeded: batch execution time limit exceeded" 348 else: 349 self.cancelled_jobs[job.id] = "Cancelled by user" 350 351 if job in self.pending_jobs: 352 self.pending_jobs.remove(job) 353 354 # Save state 355 self.state_manager.save(self) 356 357 def _cancel_all_active_batches(self): 358 """Cancel all active batches at the provider level.""" 359 with self._active_batches_lock: 360 active_batch_items = list(self._active_batches.items()) 361 362 logger.info(f"Cancelling {len(active_batch_items)} active batches due to time limit exceeded") 363 364 # Cancel outside the lock to avoid blocking 365 for batch_id, provider in active_batch_items: 366 try: 367 provider.cancel_batch(batch_id) 368 logger.info(f"Cancelled batch {batch_id} due to time limit exceeded") 369 except Exception as e: 370 logger.warning(f"Failed to cancel batch {batch_id}: {e}") 371 372 # Clear the tracking after cancellation attempts 373 with self._active_batches_lock: 374 self._active_batches.clear() 375 376 def _execute_batch_wrapped(self, provider, batch_jobs): 377 """Thread-safe wrapper for _execute_batch.""" 378 try: 379 result = self._execute_batch(provider, batch_jobs) 380 with self._state_lock: 381 self._update_batch_results(result) 382 # Remove jobs from pending_jobs if specified 383 jobs_to_remove = result.get("jobs_to_remove", []) 384 for job in jobs_to_remove: 385 if job in self.pending_jobs: 386 self.pending_jobs.remove(job) 387 except TimeoutError: 388 # Handle time limit exceeded - mark jobs as failed 389 with self._state_lock: 390 for job in batch_jobs: 391 self.failed_jobs[job.id] = "Time limit exceeded: batch execution time limit exceeded" 392 if job in self.pending_jobs: 393 self.pending_jobs.remove(job) 394 self.state_manager.save(self) 395 # Don't re-raise, just return the result 396 return 397 except KeyboardInterrupt: 398 self._shutdown_event.set() 399 # Handle user cancellation 400 with self._state_lock: 401 for job in batch_jobs: 402 self.cancelled_jobs[job.id] = "Cancelled by user" 403 if job in self.pending_jobs: 404 self.pending_jobs.remove(job) 405 self.state_manager.save(self) 406 raise 407 408 def _group_jobs_by_provider(self) -> Dict[str, List[Job]]: 409 """Group jobs by provider.""" 410 jobs_by_provider = {} 411 412 for job in self.pending_jobs[:]: # Copy to avoid modification during iteration 413 try: 414 provider = get_provider(job.model) 415 provider_name = provider.__class__.__name__ 416 417 if provider_name not in jobs_by_provider: 418 jobs_by_provider[provider_name] = [] 419 420 jobs_by_provider[provider_name].append(job) 421 422 except Exception as e: 423 logger.error(f"Failed to get provider for job {job.id}: {e}") 424 with self._state_lock: 425 self.failed_jobs[job.id] = str(e) 426 self.pending_jobs.remove(job) 427 428 return jobs_by_provider 429 430 def _split_into_batches(self, jobs: List[Job]) -> List[List[Job]]: 431 """Split jobs into batches based on items_per_batch.""" 432 batches = [] 433 batch_size = self.config.items_per_batch 434 435 for i in range(0, len(jobs), batch_size): 436 batch = jobs[i:i + batch_size] 437 batches.append(batch) 438 439 return batches 440 441 def _prepare_batches(self) -> List[Tuple[str, object, List[Job]]]: 442 """Prepare all batches as simple list of (provider_name, provider, jobs).""" 443 batches = [] 444 jobs_by_provider = self._group_jobs_by_provider() 445 446 for provider_name, provider_jobs in jobs_by_provider.items(): 447 provider = get_provider(provider_jobs[0].model) 448 job_batches = self._split_into_batches(provider_jobs) 449 450 for batch_jobs in job_batches: 451 batches.append((provider_name, provider, batch_jobs)) 452 453 # Pre-populate batch tracking for pending batches 454 batch_id = f"pending_{len(self.batch_tracking)}" 455 estimated_cost = provider.estimate_cost(batch_jobs) 456 self.batch_tracking[batch_id] = { 457 'start_time': None, 458 'status': 'pending', 459 'total': len(batch_jobs), 460 'completed': 0, 461 'cost': 0.0, 462 'estimated_cost': estimated_cost, 463 'provider': provider_name, 464 'jobs': batch_jobs 465 } 466 467 return batches 468 469 def _poll_batch_status(self, provider, batch_id: str) -> Tuple[str, Optional[Dict]]: 470 """Poll until batch completes.""" 471 status, error_details = provider.get_batch_status(batch_id) 472 logger.info(f"Initial batch status: {status}") 473 poll_count = 0 474 475 # Use provider-specific polling interval 476 provider_polling_interval = provider.get_polling_interval() 477 logger.debug(f"Using {provider_polling_interval}s polling interval for {provider.__class__.__name__}") 478 479 while status not in ["complete", "failed"]: 480 poll_count += 1 481 logger.debug(f"Polling attempt {poll_count}, current status: {status}") 482 483 # Interruptible wait - will wake up immediately if shutdown event is set (includes time limit) 484 if self._shutdown_event.wait(provider_polling_interval): 485 # Check if it's time limit exceeded or user cancellation 486 with self._state_lock: 487 is_time_limit_exceeded = self._time_limit_exceeded 488 489 if is_time_limit_exceeded: 490 logger.info(f"Batch {batch_id} polling interrupted by time limit exceeded") 491 raise TimeoutError("Batch cancelled due to time limit exceeded") 492 else: 493 logger.info(f"Batch {batch_id} polling interrupted by user") 494 raise KeyboardInterrupt("Batch cancelled by user") 495 496 status, error_details = provider.get_batch_status(batch_id) 497 498 if self._progress_callback: 499 with self._state_lock: 500 stats = self.status() 501 elapsed_time = (datetime.now() - self._start_time).total_seconds() 502 batch_data = dict(self.batch_tracking) 503 self._progress_callback(stats, elapsed_time, batch_data) 504 505 elapsed_seconds = poll_count * provider_polling_interval 506 logger.info(f"Batch {batch_id} status: {status} (polling for {elapsed_seconds:.1f}s)") 507 508 return status, error_details 509 510 511 def _update_batch_results(self, batch_result: Dict): 512 """Update state from batch results.""" 513 results = batch_result.get("results", []) 514 failed = batch_result.get("failed", {}) 515 516 # Update completed results 517 for result in results: 518 if result.is_success: 519 self.completed_results[result.job_id] = result 520 self._save_result_to_file(result) 521 logger.info(f"✓ Job {result.job_id} completed successfully") 522 else: 523 error_message = result.error or "Unknown error" 524 self.failed_jobs[result.job_id] = error_message 525 self._save_result_to_file(result) 526 logger.error(f"✗ Job {result.job_id} failed: {result.error}") 527 528 # Remove completed/failed job from pending 529 self.pending_jobs = [job for job in self.pending_jobs if job.id != result.job_id] 530 531 # Update failed jobs 532 for job_id, error in failed.items(): 533 self.failed_jobs[job_id] = error 534 # Remove failed job from pending 535 self.pending_jobs = [job for job in self.pending_jobs if job.id != job_id] 536 logger.error(f"✗ Job {job_id} failed: {error}") 537 538 # Update batch tracking 539 self.completed_batches += 1 540 541 # Save state 542 self.state_manager.save(self) 543 544 def _execute_batch(self, provider, batch_jobs: List[Job]) -> Dict: 545 """Execute one batch, return results dict with jobs/costs/errors.""" 546 if not batch_jobs: 547 return {"results": [], "failed": {}, "cost": 0.0} 548 549 # Reserve cost limit 550 logger.info(f"Estimating cost for batch of {len(batch_jobs)} jobs...") 551 estimated_cost = provider.estimate_cost(batch_jobs) 552 remaining = self.cost_tracker.remaining() 553 remaining_str = f"${remaining:.4f}" if remaining is not None else "unlimited" 554 logger.info(f"Total estimated cost: ${estimated_cost:.4f}, remaining budget: {remaining_str}") 555 556 if not self.cost_tracker.reserve_cost(estimated_cost): 557 logger.warning(f"Cost limit would be exceeded, skipping batch of {len(batch_jobs)} jobs") 558 failed = {} 559 for job in batch_jobs: 560 failed[job.id] = "Cost limit exceeded" 561 return {"results": [], "failed": failed, "cost": 0.0, "jobs_to_remove": list(batch_jobs)} 562 563 batch_id = None 564 job_mapping = None 565 try: 566 # Create batch 567 logger.info(f"Creating batch with {len(batch_jobs)} jobs...") 568 raw_files_path = str(self.raw_files_dir) if self.raw_files_dir else None 569 batch_id, job_mapping = provider.create_batch(batch_jobs, raw_files_path) 570 571 # Track active batch for cancellation 572 with self._active_batches_lock: 573 self._active_batches[batch_id] = provider 574 575 # Track batch creation 576 with self._state_lock: 577 # Remove pending entry if it exists 578 pending_keys = [k for k in self.batch_tracking.keys() if k.startswith('pending_')] 579 for pending_key in pending_keys: 580 if self.batch_tracking[pending_key]['jobs'] == batch_jobs: 581 del self.batch_tracking[pending_key] 582 break 583 584 # Add actual batch tracking 585 self.batch_tracking[batch_id] = { 586 'start_time': datetime.now(), 587 'status': 'running', 588 'total': len(batch_jobs), 589 'completed': 0, 590 'cost': 0.0, 591 'estimated_cost': estimated_cost, 592 'provider': provider.__class__.__name__, 593 'jobs': batch_jobs 594 } 595 596 # Poll for completion 597 logger.info(f"Polling for batch {batch_id} completion...") 598 status, error_details = self._poll_batch_status(provider, batch_id) 599 600 if status == "failed": 601 if error_details: 602 logger.error(f"Batch {batch_id} failed with details: {error_details}") 603 else: 604 logger.error(f"Batch {batch_id} failed") 605 606 # Save error details if configured 607 if self.raw_files_dir and error_details: 608 self._save_batch_error_details(batch_id, error_details) 609 610 # Continue to get individual results - some jobs might have succeeded 611 612 # Get results 613 logger.info(f"Getting results for batch {batch_id}") 614 raw_files_path = str(self.raw_files_dir) if self.raw_files_dir else None 615 results = provider.get_batch_results(batch_id, job_mapping, raw_files_path) 616 617 # Calculate actual cost and adjust reservation 618 actual_cost = sum(r.cost_usd for r in results) 619 self.cost_tracker.adjust_reserved_cost(estimated_cost, actual_cost) 620 621 # Update batch tracking for completion 622 success_count = len([r for r in results if r.is_success]) 623 failed_count = len([r for r in results if not r.is_success]) 624 batch_status = 'complete' if failed_count == 0 else 'failed' 625 626 with self._state_lock: 627 if batch_id in self.batch_tracking: 628 self.batch_tracking[batch_id]['status'] = batch_status 629 self.batch_tracking[batch_id]['completed'] = success_count 630 self.batch_tracking[batch_id]['failed'] = failed_count 631 self.batch_tracking[batch_id]['cost'] = actual_cost 632 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 633 if batch_status == 'failed' and failed_count > 0: 634 # Use the first job's error as the batch error summary 635 first_error = next((r.error for r in results if not r.is_success), 'Some jobs failed') 636 self.batch_tracking[batch_id]['error'] = first_error 637 638 # Remove from active batches tracking 639 with self._active_batches_lock: 640 self._active_batches.pop(batch_id, None) 641 642 status_symbol = "✓" if batch_status == 'complete' else "⚠" 643 logger.info( 644 f"{status_symbol} Batch {batch_id} completed: " 645 f"{success_count} success, " 646 f"{failed_count} failed, " 647 f"cost: ${actual_cost:.6f}" 648 ) 649 650 return {"results": results, "failed": {}, "cost": actual_cost, "jobs_to_remove": list(batch_jobs)} 651 652 except TimeoutError: 653 logger.info(f"Time limit exceeded for batch{f' {batch_id}' if batch_id else ''}") 654 if batch_id: 655 # Update batch tracking for time limit exceeded 656 with self._state_lock: 657 if batch_id in self.batch_tracking: 658 self.batch_tracking[batch_id]['status'] = 'failed' 659 self.batch_tracking[batch_id]['error'] = 'Time limit exceeded: batch execution time limit exceeded' 660 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 661 # NOTE: Don't remove from _active_batches - let centralized cancellation handle it 662 # Release the reservation since batch exceeded time limit 663 self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0) 664 # Re-raise to be handled by wrapper 665 raise 666 667 except KeyboardInterrupt: 668 logger.warning(f"\nCancelling batch{f' {batch_id}' if batch_id else ''}...") 669 if batch_id: 670 # Update batch tracking for cancellation 671 with self._state_lock: 672 if batch_id in self.batch_tracking: 673 self.batch_tracking[batch_id]['status'] = 'cancelled' 674 self.batch_tracking[batch_id]['error'] = 'Cancelled by user' 675 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 676 # Remove from active batches tracking 677 with self._active_batches_lock: 678 self._active_batches.pop(batch_id, None) 679 # Release the reservation since batch was cancelled 680 self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0) 681 # Handle cancellation in the wrapper with proper locking 682 raise 683 684 except Exception as e: 685 logger.error(f"✗ Batch execution failed: {e}") 686 # Update batch tracking for exception 687 if batch_id: 688 with self._state_lock: 689 if batch_id in self.batch_tracking: 690 self.batch_tracking[batch_id]['status'] = 'failed' 691 self.batch_tracking[batch_id]['error'] = str(e) 692 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 693 # Remove from active batches tracking 694 with self._active_batches_lock: 695 self._active_batches.pop(batch_id, None) 696 # Release the reservation since batch failed 697 self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0) 698 failed = {} 699 for job in batch_jobs: 700 failed[job.id] = str(e) 701 return {"results": [], "failed": failed, "cost": 0.0, "jobs_to_remove": list(batch_jobs)} 702 703 704 def _save_result_to_file(self, result: JobResult): 705 """Save individual result to file.""" 706 result_file = self.results_dir / f"{result.job_id}.json" 707 708 try: 709 with open(result_file, 'w') as f: 710 json.dump(result.to_dict(), f, indent=2) 711 except Exception as e: 712 logger.error(f"Failed to save result for {result.job_id}: {e}") 713 714 def _save_batch_error_details(self, batch_id: str, error_details: Dict): 715 """Save batch error details to debug files directory.""" 716 try: 717 error_file = self.raw_files_dir / f"batch_{batch_id}_error.json" 718 with open(error_file, 'w') as f: 719 json.dump({ 720 "batch_id": batch_id, 721 "timestamp": datetime.now().isoformat(), 722 "error_details": error_details 723 }, f, indent=2) 724 logger.info(f"Saved batch error details to {error_file}") 725 except Exception as e: 726 logger.error(f"Failed to save batch error details: {e}") 727 728 @property 729 def is_complete(self) -> bool: 730 """Whether all jobs are complete.""" 731 total_jobs = len(self.jobs) 732 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 733 return len(self.pending_jobs) == 0 and completed_count == total_jobs 734 735 736 def status(self, print_status: bool = False) -> Dict: 737 """Get current execution statistics.""" 738 total_jobs = len(self.jobs) 739 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 740 remaining_count = total_jobs - completed_count 741 742 stats = { 743 "total": total_jobs, 744 "pending": remaining_count, 745 "active": 0, # Always 0 for synchronous execution 746 "completed": len(self.completed_results), 747 "failed": len(self.failed_jobs), 748 "cancelled": len(self.cancelled_jobs), 749 "cost_usd": self.cost_tracker.used_usd, 750 "cost_limit_usd": self.cost_tracker.limit_usd, 751 "is_complete": self.is_complete, 752 "batches_total": self.total_batches, 753 "batches_completed": self.completed_batches, 754 "batches_pending": self.total_batches - self.completed_batches, 755 "current_batch_index": self.current_batch_index, 756 "current_batch_size": self.current_batch_size, 757 "items_per_batch": self.config.items_per_batch 758 } 759 760 if print_status: 761 logger.info("\nBatch Run Status:") 762 logger.info(f" Total jobs: {stats['total']}") 763 logger.info(f" Pending: {stats['pending']}") 764 logger.info(f" Active: {stats['active']}") 765 logger.info(f" Completed: {stats['completed']}") 766 logger.info(f" Failed: {stats['failed']}") 767 logger.info(f" Cancelled: {stats['cancelled']}") 768 logger.info(f" Cost: ${stats['cost_usd']:.6f}") 769 if stats['cost_limit_usd']: 770 logger.info(f" Cost limit: ${stats['cost_limit_usd']:.2f}") 771 logger.info(f" Complete: {stats['is_complete']}") 772 773 return stats 774 775 def results(self) -> Dict[str, List[JobResult]]: 776 """Get all results organized by status. 777 778 Returns: 779 { 780 "completed": [JobResult], 781 "failed": [JobResult], 782 "cancelled": [JobResult] 783 } 784 """ 785 return { 786 "completed": list(self.completed_results.values()), 787 "failed": self._create_failed_results(), 788 "cancelled": self._create_cancelled_results() 789 } 790 791 def get_failed_jobs(self) -> Dict[str, str]: 792 """Get failed jobs with error messages. 793 794 Note: This method is deprecated. Use results()['failed'] instead. 795 """ 796 return dict(self.failed_jobs) 797 798 def _create_failed_results(self) -> List[JobResult]: 799 """Convert failed jobs to JobResult objects.""" 800 failed_results = [] 801 for job_id, error_msg in self.failed_jobs.items(): 802 failed_results.append(JobResult( 803 job_id=job_id, 804 raw_response=None, 805 parsed_response=None, 806 error=error_msg, 807 cost_usd=0.0, 808 input_tokens=0, 809 output_tokens=0 810 )) 811 return failed_results 812 813 def _create_cancelled_results(self) -> List[JobResult]: 814 """Convert cancelled jobs to JobResult objects.""" 815 cancelled_results = [] 816 for job_id, reason in self.cancelled_jobs.items(): 817 cancelled_results.append(JobResult( 818 job_id=job_id, 819 raw_response=None, 820 parsed_response=None, 821 error=reason, 822 cost_usd=0.0, 823 input_tokens=0, 824 output_tokens=0 825 )) 826 return cancelled_results 827 828 def shutdown(self): 829 """Shutdown (no-op for synchronous execution).""" 830 pass 831 832 def dry_run(self) -> 'BatchRun': 833 """Perform a dry run - show cost estimation and job details without executing. 834 835 Returns: 836 Self for chaining (doesn't actually execute jobs) 837 """ 838 logger.info("=== DRY RUN MODE ===") 839 logger.info("This will show cost estimates without executing jobs") 840 841 # Load existing state if reuse_state=True 842 if self.config.reuse_state: 843 self.state_manager.load_state(self) 844 845 # Filter out completed jobs from previous runs 846 self.pending_jobs = [job for job in self.jobs.values() if job.id not in self.completed_results] 847 848 if not self.pending_jobs: 849 logger.info("No pending jobs to analyze (all jobs already completed)") 850 return self 851 852 logger.info(f"Analyzing {len(self.pending_jobs)} pending jobs...") 853 854 # Group jobs by provider and analyze costs 855 provider_groups = self._group_jobs_by_provider() 856 total_estimated_cost = 0.0 857 858 logger.info(f"\nJob breakdown:") 859 for provider_name, jobs in provider_groups.items(): 860 provider = get_provider(jobs[0].model) 861 logger.info(f"\n{provider_name} ({len(jobs)} jobs):") 862 863 job_batches = [jobs[i:i + self.config.items_per_batch] 864 for i in range(0, len(jobs), self.config.items_per_batch)] 865 866 for batch_idx, batch_jobs in enumerate(job_batches, 1): 867 estimated_cost = provider.estimate_cost(batch_jobs) 868 total_estimated_cost += estimated_cost 869 870 logger.info(f" Batch {batch_idx}: {len(batch_jobs)} jobs, estimated cost: ${estimated_cost:.4f}") 871 for job in batch_jobs: 872 if job.file: 873 logger.info(f" - {job.id}: {job.file.name} (citations: {job.enable_citations})") 874 else: 875 logger.info(f" - {job.id}: direct messages (citations: {job.enable_citations})") 876 877 # Show cost summary 878 logger.info(f"\n=== COST SUMMARY ===") 879 logger.info(f"Total estimated cost: ${total_estimated_cost:.4f}") 880 881 if self.config.cost_limit_usd: 882 logger.info(f"Cost limit: ${self.config.cost_limit_usd:.2f}") 883 if total_estimated_cost > self.config.cost_limit_usd: 884 excess = total_estimated_cost - self.config.cost_limit_usd 885 logger.warning(f"⚠️ Estimated cost exceeds limit by ${excess:.4f}") 886 else: 887 remaining = self.config.cost_limit_usd - total_estimated_cost 888 logger.info(f"✅ Within cost limit (${remaining:.4f} remaining)") 889 else: 890 logger.info("No cost limit set") 891 892 # Show execution plan 893 logger.info(f"\n=== EXECUTION PLAN ===") 894 total_batches = sum( 895 len(jobs) // self.config.items_per_batch + (1 if len(jobs) % self.config.items_per_batch else 0) 896 for jobs in provider_groups.values() 897 ) 898 logger.info(f"Total batches to process: {total_batches}") 899 logger.info(f"Max parallel batches: {self.config.max_parallel_batches}") 900 logger.info(f"Items per batch: {self.config.items_per_batch}") 901 logger.info(f"Results directory: {self.config.results_dir}") 902 903 logger.info("\n=== DRY RUN COMPLETE ===") 904 logger.info("To execute for real, call run() without dry_run=True") 905 906 return self
Manages the execution of a batch job synchronously.
Processes jobs in batches based on items_per_batch configuration. Simpler synchronous execution for clear logging and debugging.
Example:
config = BatchParams(...)
run = BatchRun(config, jobs)
run.execute()
results = run.results()
40 def __init__(self, config: BatchParams, jobs: List[Job]): 41 """Initialize batch run. 42 43 Args: 44 config: Batch configuration 45 jobs: List of jobs to execute 46 """ 47 self.config = config 48 self.jobs = {job.id: job for job in jobs} 49 50 # Set logging level based on config 51 set_log_level(level=config.verbosity.upper()) 52 53 # Initialize components 54 self.cost_tracker = CostTracker(limit_usd=config.cost_limit_usd) 55 56 # Use temp file for state if not provided 57 state_file = config.state_file 58 if not state_file: 59 state_file = create_temp_state_file(config) 60 config.reuse_state = False 61 logger.info(f"Created temporary state file: {state_file}") 62 63 self.state_manager = StateManager(state_file) 64 65 # State tracking 66 self.pending_jobs: List[Job] = [] 67 self.completed_results: Dict[str, JobResult] = {} # job_id -> result 68 self.failed_jobs: Dict[str, str] = {} # job_id -> error 69 self.cancelled_jobs: Dict[str, str] = {} # job_id -> reason 70 71 # Batch tracking 72 self.total_batches = 0 73 self.completed_batches = 0 74 self.current_batch_index = 0 75 self.current_batch_size = 0 76 77 # Execution control 78 self._started = False 79 self._start_time: Optional[datetime] = None 80 self._time_limit_exceeded = False 81 self._progress_callback: Optional[Callable[[Dict, float], None]] = None 82 self._progress_interval: float = 1.0 # Default to 1 second 83 84 # Threading primitives 85 self._state_lock = threading.Lock() 86 self._shutdown_event = threading.Event() 87 88 # Batch tracking for progress display 89 self.batch_tracking: Dict[str, Dict] = {} # batch_id -> batch_info 90 91 # Active batch tracking for cancellation 92 self._active_batches: Dict[str, object] = {} # batch_id -> provider 93 self._active_batches_lock = threading.Lock() 94 95 # Results directory 96 self.results_dir = Path(config.results_dir) 97 98 # If not reusing state, clear the results directory 99 if not config.reuse_state and self.results_dir.exists(): 100 import shutil 101 shutil.rmtree(self.results_dir) 102 103 self.results_dir.mkdir(parents=True, exist_ok=True) 104 105 # Raw files directory (if enabled) 106 self.raw_files_dir = None 107 if config.raw_files: 108 self.raw_files_dir = self.results_dir / "raw_files" 109 self.raw_files_dir.mkdir(parents=True, exist_ok=True) 110 111 # Try to resume from saved state 112 self._resume_from_state()
Initialize batch run.
Args: config: Batch configuration jobs: List of jobs to execute
176 def to_json(self) -> Dict: 177 """Convert current state to JSON-serializable dict.""" 178 return { 179 "created_at": datetime.now().isoformat(), 180 "pending_jobs": [job.to_dict() for job in self.pending_jobs], 181 "completed_results": [ 182 {"job_id": job_id, "file_path": str(self.results_dir / f"{job_id}.json")} 183 for job_id in self.completed_results.keys() 184 ], 185 "failed_jobs": [ 186 { 187 "id": job_id, 188 "error": error, 189 "timestamp": datetime.now().isoformat() 190 } for job_id, error in self.failed_jobs.items() 191 ], 192 "cancelled_jobs": [ 193 { 194 "id": job_id, 195 "reason": reason, 196 "timestamp": datetime.now().isoformat() 197 } for job_id, reason in self.cancelled_jobs.items() 198 ], 199 "total_cost_usd": self.cost_tracker.used_usd, 200 "config": { 201 "state_file": self.config.state_file, 202 "results_dir": self.config.results_dir, 203 "max_parallel_batches": self.config.max_parallel_batches, 204 "items_per_batch": self.config.items_per_batch, 205 "cost_limit_usd": self.config.cost_limit_usd, 206 "default_params": self.config.default_params, 207 "raw_files": self.config.raw_files 208 } 209 }
Convert current state to JSON-serializable dict.
211 def execute(self): 212 """Execute synchronous batch run and wait for completion.""" 213 if self._started: 214 raise RuntimeError("Batch run already started") 215 216 self._started = True 217 self._start_time = datetime.now() 218 219 # Register signal handler for graceful shutdown 220 def signal_handler(signum, frame): 221 logger.warning("Received interrupt signal, shutting down gracefully...") 222 self._shutdown_event.set() 223 224 # Store original handler to restore later 225 original_handler = signal.signal(signal.SIGINT, signal_handler) 226 227 try: 228 logger.info("Starting batch run") 229 230 # Start time limit watchdog if configured 231 self._start_time_limit_watchdog() 232 233 # Call initial progress 234 if self._progress_callback: 235 stats = self.status() 236 batch_data = dict(self.batch_tracking) 237 self._progress_callback(stats, 0.0, batch_data) 238 239 # Process all jobs synchronously 240 self._process_all_jobs() 241 242 logger.info("Batch run completed") 243 finally: 244 # Restore original signal handler 245 signal.signal(signal.SIGINT, original_handler)
Execute synchronous batch run and wait for completion.
247 def set_on_progress(self, callback: Callable[[Dict, float, Dict], None], interval: float = 1.0) -> 'BatchRun': 248 """Set progress callback for execution monitoring. 249 250 The callback will be called periodically with progress statistics 251 including completed jobs, total jobs, current cost, etc. 252 253 Args: 254 callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data) 255 - stats_dict: Progress statistics dictionary 256 - elapsed_time_seconds: Time elapsed since batch started (float) 257 - batch_data: Dictionary mapping batch_id to batch information 258 interval: Interval in seconds between progress updates (default: 1.0) 259 260 Returns: 261 Self for chaining 262 263 Example: 264 ```python 265 run.set_on_progress( 266 lambda stats, time, batch_data: print( 267 f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s" 268 ) 269 ) 270 ``` 271 """ 272 self._progress_callback = callback 273 self._progress_interval = interval 274 return self
Set progress callback for execution monitoring.
The callback will be called periodically with progress statistics including completed jobs, total jobs, current cost, etc.
Args: callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data) - stats_dict: Progress statistics dictionary - elapsed_time_seconds: Time elapsed since batch started (float) - batch_data: Dictionary mapping batch_id to batch information interval: Interval in seconds between progress updates (default: 1.0)
Returns: Self for chaining
Example:
run.set_on_progress(
lambda stats, time, batch_data: print(
f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s"
)
)
728 @property 729 def is_complete(self) -> bool: 730 """Whether all jobs are complete.""" 731 total_jobs = len(self.jobs) 732 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 733 return len(self.pending_jobs) == 0 and completed_count == total_jobs
Whether all jobs are complete.
736 def status(self, print_status: bool = False) -> Dict: 737 """Get current execution statistics.""" 738 total_jobs = len(self.jobs) 739 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 740 remaining_count = total_jobs - completed_count 741 742 stats = { 743 "total": total_jobs, 744 "pending": remaining_count, 745 "active": 0, # Always 0 for synchronous execution 746 "completed": len(self.completed_results), 747 "failed": len(self.failed_jobs), 748 "cancelled": len(self.cancelled_jobs), 749 "cost_usd": self.cost_tracker.used_usd, 750 "cost_limit_usd": self.cost_tracker.limit_usd, 751 "is_complete": self.is_complete, 752 "batches_total": self.total_batches, 753 "batches_completed": self.completed_batches, 754 "batches_pending": self.total_batches - self.completed_batches, 755 "current_batch_index": self.current_batch_index, 756 "current_batch_size": self.current_batch_size, 757 "items_per_batch": self.config.items_per_batch 758 } 759 760 if print_status: 761 logger.info("\nBatch Run Status:") 762 logger.info(f" Total jobs: {stats['total']}") 763 logger.info(f" Pending: {stats['pending']}") 764 logger.info(f" Active: {stats['active']}") 765 logger.info(f" Completed: {stats['completed']}") 766 logger.info(f" Failed: {stats['failed']}") 767 logger.info(f" Cancelled: {stats['cancelled']}") 768 logger.info(f" Cost: ${stats['cost_usd']:.6f}") 769 if stats['cost_limit_usd']: 770 logger.info(f" Cost limit: ${stats['cost_limit_usd']:.2f}") 771 logger.info(f" Complete: {stats['is_complete']}") 772 773 return stats
Get current execution statistics.
775 def results(self) -> Dict[str, List[JobResult]]: 776 """Get all results organized by status. 777 778 Returns: 779 { 780 "completed": [JobResult], 781 "failed": [JobResult], 782 "cancelled": [JobResult] 783 } 784 """ 785 return { 786 "completed": list(self.completed_results.values()), 787 "failed": self._create_failed_results(), 788 "cancelled": self._create_cancelled_results() 789 }
Get all results organized by status.
Returns: { "completed": [JobResult], "failed": [JobResult], "cancelled": [JobResult] }
791 def get_failed_jobs(self) -> Dict[str, str]: 792 """Get failed jobs with error messages. 793 794 Note: This method is deprecated. Use results()['failed'] instead. 795 """ 796 return dict(self.failed_jobs)
Get failed jobs with error messages.
Note: This method is deprecated. Use results()['failed'] instead.
832 def dry_run(self) -> 'BatchRun': 833 """Perform a dry run - show cost estimation and job details without executing. 834 835 Returns: 836 Self for chaining (doesn't actually execute jobs) 837 """ 838 logger.info("=== DRY RUN MODE ===") 839 logger.info("This will show cost estimates without executing jobs") 840 841 # Load existing state if reuse_state=True 842 if self.config.reuse_state: 843 self.state_manager.load_state(self) 844 845 # Filter out completed jobs from previous runs 846 self.pending_jobs = [job for job in self.jobs.values() if job.id not in self.completed_results] 847 848 if not self.pending_jobs: 849 logger.info("No pending jobs to analyze (all jobs already completed)") 850 return self 851 852 logger.info(f"Analyzing {len(self.pending_jobs)} pending jobs...") 853 854 # Group jobs by provider and analyze costs 855 provider_groups = self._group_jobs_by_provider() 856 total_estimated_cost = 0.0 857 858 logger.info(f"\nJob breakdown:") 859 for provider_name, jobs in provider_groups.items(): 860 provider = get_provider(jobs[0].model) 861 logger.info(f"\n{provider_name} ({len(jobs)} jobs):") 862 863 job_batches = [jobs[i:i + self.config.items_per_batch] 864 for i in range(0, len(jobs), self.config.items_per_batch)] 865 866 for batch_idx, batch_jobs in enumerate(job_batches, 1): 867 estimated_cost = provider.estimate_cost(batch_jobs) 868 total_estimated_cost += estimated_cost 869 870 logger.info(f" Batch {batch_idx}: {len(batch_jobs)} jobs, estimated cost: ${estimated_cost:.4f}") 871 for job in batch_jobs: 872 if job.file: 873 logger.info(f" - {job.id}: {job.file.name} (citations: {job.enable_citations})") 874 else: 875 logger.info(f" - {job.id}: direct messages (citations: {job.enable_citations})") 876 877 # Show cost summary 878 logger.info(f"\n=== COST SUMMARY ===") 879 logger.info(f"Total estimated cost: ${total_estimated_cost:.4f}") 880 881 if self.config.cost_limit_usd: 882 logger.info(f"Cost limit: ${self.config.cost_limit_usd:.2f}") 883 if total_estimated_cost > self.config.cost_limit_usd: 884 excess = total_estimated_cost - self.config.cost_limit_usd 885 logger.warning(f"⚠️ Estimated cost exceeds limit by ${excess:.4f}") 886 else: 887 remaining = self.config.cost_limit_usd - total_estimated_cost 888 logger.info(f"✅ Within cost limit (${remaining:.4f} remaining)") 889 else: 890 logger.info("No cost limit set") 891 892 # Show execution plan 893 logger.info(f"\n=== EXECUTION PLAN ===") 894 total_batches = sum( 895 len(jobs) // self.config.items_per_batch + (1 if len(jobs) % self.config.items_per_batch else 0) 896 for jobs in provider_groups.values() 897 ) 898 logger.info(f"Total batches to process: {total_batches}") 899 logger.info(f"Max parallel batches: {self.config.max_parallel_batches}") 900 logger.info(f"Items per batch: {self.config.items_per_batch}") 901 logger.info(f"Results directory: {self.config.results_dir}") 902 903 logger.info("\n=== DRY RUN COMPLETE ===") 904 logger.info("To execute for real, call run() without dry_run=True") 905 906 return self
Perform a dry run - show cost estimation and job details without executing.
Returns: Self for chaining (doesn't actually execute jobs)
12@dataclass 13class Job: 14 """Configuration for a single AI job. 15 16 Either provide messages OR prompt (with optional file), not both. 17 18 Attributes: 19 id: Unique identifier for the job 20 messages: Chat messages for direct message input 21 file: Optional file path for file-based input 22 prompt: Prompt text (can be used alone or with file) 23 model: Model name (e.g., "claude-3-sonnet") 24 temperature: Sampling temperature (0.0-1.0) 25 max_tokens: Maximum tokens to generate 26 response_model: Pydantic model for structured output 27 enable_citations: Whether to extract citations from response 28 """ 29 30 id: str # Unique identifier 31 model: str # Model name (e.g., "claude-3-sonnet") 32 messages: Optional[List[Message]] = None # Chat messages 33 file: Optional[Path] = None # File input 34 prompt: Optional[str] = None # Prompt for file 35 temperature: float = 0.7 36 max_tokens: int = 1000 37 response_model: Optional[Type[BaseModel]] = None # For structured output 38 enable_citations: bool = False 39 40 def __post_init__(self): 41 """Validate job configuration.""" 42 if self.messages and (self.file or self.prompt): 43 raise ValueError("Provide either messages OR file+prompt, not both") 44 45 if self.file and not self.prompt: 46 raise ValueError("File input requires a prompt") 47 48 if not self.messages and not self.prompt: 49 raise ValueError("Must provide either messages or prompt") 50 51 def to_dict(self) -> Dict[str, Any]: 52 """Serialize for state persistence.""" 53 return { 54 "id": self.id, 55 "model": self.model, 56 "messages": self.messages, 57 "file": str(self.file) if self.file else None, 58 "prompt": self.prompt, 59 "temperature": self.temperature, 60 "max_tokens": self.max_tokens, 61 "response_model": self.response_model.__name__ if self.response_model else None, 62 "enable_citations": self.enable_citations 63 } 64 65 @classmethod 66 def from_dict(cls, data: Dict[str, Any]) -> 'Job': 67 """Deserialize from state.""" 68 # Convert file string back to Path if present 69 file_path = None 70 if data.get("file"): 71 file_path = Path(data["file"]) 72 73 # Note: response_model reconstruction would need additional logic 74 # For now, we'll set it to None during deserialization 75 return cls( 76 id=data["id"], 77 model=data["model"], 78 messages=data.get("messages"), 79 file=file_path, 80 prompt=data.get("prompt"), 81 temperature=data.get("temperature", 0.7), 82 max_tokens=data.get("max_tokens", 1000), 83 response_model=None, # Cannot reconstruct from string 84 enable_citations=data.get("enable_citations", False) 85 )
Configuration for a single AI job.
Either provide messages OR prompt (with optional file), not both.
Attributes: id: Unique identifier for the job messages: Chat messages for direct message input file: Optional file path for file-based input prompt: Prompt text (can be used alone or with file) model: Model name (e.g., "claude-3-sonnet") temperature: Sampling temperature (0.0-1.0) max_tokens: Maximum tokens to generate response_model: Pydantic model for structured output enable_citations: Whether to extract citations from response
51 def to_dict(self) -> Dict[str, Any]: 52 """Serialize for state persistence.""" 53 return { 54 "id": self.id, 55 "model": self.model, 56 "messages": self.messages, 57 "file": str(self.file) if self.file else None, 58 "prompt": self.prompt, 59 "temperature": self.temperature, 60 "max_tokens": self.max_tokens, 61 "response_model": self.response_model.__name__ if self.response_model else None, 62 "enable_citations": self.enable_citations 63 }
Serialize for state persistence.
65 @classmethod 66 def from_dict(cls, data: Dict[str, Any]) -> 'Job': 67 """Deserialize from state.""" 68 # Convert file string back to Path if present 69 file_path = None 70 if data.get("file"): 71 file_path = Path(data["file"]) 72 73 # Note: response_model reconstruction would need additional logic 74 # For now, we'll set it to None during deserialization 75 return cls( 76 id=data["id"], 77 model=data["model"], 78 messages=data.get("messages"), 79 file=file_path, 80 prompt=data.get("prompt"), 81 temperature=data.get("temperature", 0.7), 82 max_tokens=data.get("max_tokens", 1000), 83 response_model=None, # Cannot reconstruct from string 84 enable_citations=data.get("enable_citations", False) 85 )
Deserialize from state.
11@dataclass 12class JobResult: 13 """Result from a completed AI job. 14 15 Attributes: 16 job_id: ID of the job this result is for 17 raw_response: Raw text response from the model (None for failed jobs) 18 parsed_response: Structured output (if response_model was used) 19 citations: Extracted citations (if enable_citations was True) 20 citation_mappings: Maps field names to relevant citations (if response_model used) 21 input_tokens: Number of input tokens used 22 output_tokens: Number of output tokens generated 23 cost_usd: Total cost in USD 24 error: Error message if job failed 25 batch_id: ID of the batch this job was part of (for mapping to raw files) 26 """ 27 28 job_id: str 29 raw_response: Optional[str] = None # Raw text response (None for failed jobs) 30 parsed_response: Optional[Union[BaseModel, Dict]] = None # Structured output or error dict 31 citations: Optional[List[Citation]] = None # Extracted citations 32 citation_mappings: Optional[Dict[str, List[Citation]]] = None # Field -> citations mapping 33 input_tokens: int = 0 34 output_tokens: int = 0 35 cost_usd: float = 0.0 36 error: Optional[str] = None # Error message if failed 37 batch_id: Optional[str] = None # Batch ID for mapping to raw files 38 39 @property 40 def is_success(self) -> bool: 41 """Whether the job completed successfully.""" 42 return self.error is None 43 44 @property 45 def total_tokens(self) -> int: 46 """Total tokens used (input + output).""" 47 return self.input_tokens + self.output_tokens 48 49 def to_dict(self) -> Dict[str, Any]: 50 """Serialize for state persistence.""" 51 # Handle parsed_response serialization 52 parsed_response = None 53 if self.parsed_response is not None: 54 if isinstance(self.parsed_response, dict): 55 parsed_response = self.parsed_response 56 elif isinstance(self.parsed_response, BaseModel): 57 parsed_response = self.parsed_response.model_dump() 58 else: 59 parsed_response = str(self.parsed_response) 60 61 # Handle citation_mappings serialization 62 citation_mappings = None 63 if self.citation_mappings: 64 citation_mappings = { 65 field: [asdict(c) for c in citations] 66 for field, citations in self.citation_mappings.items() 67 } 68 69 return { 70 "job_id": self.job_id, 71 "raw_response": self.raw_response, 72 "parsed_response": parsed_response, 73 "citations": [asdict(c) for c in self.citations] if self.citations else None, 74 "citation_mappings": citation_mappings, 75 "input_tokens": self.input_tokens, 76 "output_tokens": self.output_tokens, 77 "cost_usd": self.cost_usd, 78 "error": self.error, 79 "batch_id": self.batch_id 80 } 81 82 @classmethod 83 def from_dict(cls, data: Dict[str, Any]) -> 'JobResult': 84 """Deserialize from state.""" 85 # Reconstruct citations if present 86 citations = None 87 if data.get("citations"): 88 citations = [Citation(**c) for c in data["citations"]] 89 90 # Reconstruct citation_mappings if present 91 citation_mappings = None 92 if data.get("citation_mappings"): 93 citation_mappings = { 94 field: [Citation(**c) for c in citations] 95 for field, citations in data["citation_mappings"].items() 96 } 97 98 return cls( 99 job_id=data["job_id"], 100 raw_response=data["raw_response"], 101 parsed_response=data.get("parsed_response"), 102 citations=citations, 103 citation_mappings=citation_mappings, 104 input_tokens=data.get("input_tokens", 0), 105 output_tokens=data.get("output_tokens", 0), 106 cost_usd=data.get("cost_usd", 0.0), 107 error=data.get("error"), 108 batch_id=data.get("batch_id") 109 )
Result from a completed AI job.
Attributes: job_id: ID of the job this result is for raw_response: Raw text response from the model (None for failed jobs) parsed_response: Structured output (if response_model was used) citations: Extracted citations (if enable_citations was True) citation_mappings: Maps field names to relevant citations (if response_model used) input_tokens: Number of input tokens used output_tokens: Number of output tokens generated cost_usd: Total cost in USD error: Error message if job failed batch_id: ID of the batch this job was part of (for mapping to raw files)
39 @property 40 def is_success(self) -> bool: 41 """Whether the job completed successfully.""" 42 return self.error is None
Whether the job completed successfully.
44 @property 45 def total_tokens(self) -> int: 46 """Total tokens used (input + output).""" 47 return self.input_tokens + self.output_tokens
Total tokens used (input + output).
49 def to_dict(self) -> Dict[str, Any]: 50 """Serialize for state persistence.""" 51 # Handle parsed_response serialization 52 parsed_response = None 53 if self.parsed_response is not None: 54 if isinstance(self.parsed_response, dict): 55 parsed_response = self.parsed_response 56 elif isinstance(self.parsed_response, BaseModel): 57 parsed_response = self.parsed_response.model_dump() 58 else: 59 parsed_response = str(self.parsed_response) 60 61 # Handle citation_mappings serialization 62 citation_mappings = None 63 if self.citation_mappings: 64 citation_mappings = { 65 field: [asdict(c) for c in citations] 66 for field, citations in self.citation_mappings.items() 67 } 68 69 return { 70 "job_id": self.job_id, 71 "raw_response": self.raw_response, 72 "parsed_response": parsed_response, 73 "citations": [asdict(c) for c in self.citations] if self.citations else None, 74 "citation_mappings": citation_mappings, 75 "input_tokens": self.input_tokens, 76 "output_tokens": self.output_tokens, 77 "cost_usd": self.cost_usd, 78 "error": self.error, 79 "batch_id": self.batch_id 80 }
Serialize for state persistence.
82 @classmethod 83 def from_dict(cls, data: Dict[str, Any]) -> 'JobResult': 84 """Deserialize from state.""" 85 # Reconstruct citations if present 86 citations = None 87 if data.get("citations"): 88 citations = [Citation(**c) for c in data["citations"]] 89 90 # Reconstruct citation_mappings if present 91 citation_mappings = None 92 if data.get("citation_mappings"): 93 citation_mappings = { 94 field: [Citation(**c) for c in citations] 95 for field, citations in data["citation_mappings"].items() 96 } 97 98 return cls( 99 job_id=data["job_id"], 100 raw_response=data["raw_response"], 101 parsed_response=data.get("parsed_response"), 102 citations=citations, 103 citation_mappings=citation_mappings, 104 input_tokens=data.get("input_tokens", 0), 105 output_tokens=data.get("output_tokens", 0), 106 cost_usd=data.get("cost_usd", 0.0), 107 error=data.get("error"), 108 batch_id=data.get("batch_id") 109 )
Deserialize from state.
8@dataclass 9class Citation: 10 """Represents a citation extracted from an AI response.""" 11 12 text: str # The cited text 13 source: str # Source identifier (e.g., page number, section) 14 page: Optional[int] = None # Page number if applicable 15 metadata: Optional[Dict[str, Any]] = None # Additional metadata
Represents a citation extracted from an AI response.
Base exception for all Batchata errors.
35class CostLimitExceededError(BatchataError): 36 """Raised when cost limit would be exceeded.""" 37 pass
Raised when cost limit would be exceeded.
Base exception for provider-related errors.
20class ProviderNotFoundError(ProviderError): 21 """Raised when no provider is found for a model.""" 22 pass
Raised when no provider is found for a model.
10class ValidationError(BatchataError): 11 """Raised when job or configuration validation fails.""" 12 pass
Raised when job or configuration validation fails.