batchata
Batchata - Unified Python API for AI Batch requests with cost tracking, Pydantic responses, and parallel execution.
Why AI-batching?
AI providers offer batch APIs that process requests asynchronously at 50% reduced cost compared to real-time APIs. This is ideal for workloads like document processing, data analysis, and content generation where immediate responses aren't required.
Quick Start
Installation
pip install batchata
Basic Usage
from batchata import Batch
# Simple batch processing
batch = Batch(results_dir="./output")
.set_default_params(model="claude-sonnet-4-20250514")
.add_cost_limit(usd=5.0)
# Add jobs
for file in files:
batch.add_job(file=file, prompt="Summarize this document")
# Execute
run = batch.run()
results = run.results()
Structured Output with Pydantic
from batchata import Batch
from pydantic import BaseModel
class DocumentAnalysis(BaseModel):
title: str
summary: str
key_points: list[str]
batch = Batch(results_dir="./results")
.set_default_params(model="claude-sonnet-4-20250514")
batch.add_job(
file="document.pdf",
prompt="Analyze this document",
response_model=DocumentAnalysis,
enable_citations=True # Anthropic only
)
run = batch.run()
for result in run.results()["completed"]:
analysis = result.parsed_response # DocumentAnalysis object
citations = result.citation_mappings # Field -> Citation mapping
Key Features
- 50% Cost Savings: Native batch processing via provider APIs
- Cost Limits: Set
max_cost_usdlimits for batch requests - Time Limits: Control execution time with
.add_time_limit() - State Persistence: Resume interrupted batches automatically
- Structured Output: Pydantic models with automatic validation
- Citations: Extract and map citations to response fields (Anthropic)
- Multiple Providers: Anthropic Claude and OpenAI GPT models
Supported Providers
| Feature | Anthropic | OpenAI |
|---|---|---|
| Models | All Claude models | All GPT models |
| Citations | ✅ | ❌ |
| Structured Output | ✅ | ✅ |
| File Types | PDF, TXT, DOCX, Images | PDF, Images |
Configuration
Set API keys as environment variables:
export ANTHROPIC_API_KEY="your-key"
export OPENAI_API_KEY="your-key"
Or use a .env file with python-dotenv.
1"""Batchata - Unified Python API for AI Batch requests with cost tracking, Pydantic responses, and parallel execution. 2 3**Why AI-batching?** 4 5AI providers offer batch APIs that process requests asynchronously at 50% reduced cost compared to real-time APIs. 6This is ideal for workloads like document processing, data analysis, and content generation where immediate 7responses aren't required. 8 9## Quick Start 10 11### Installation 12 13```bash 14pip install batchata 15``` 16 17### Basic Usage 18 19```python 20from batchata import Batch 21 22# Simple batch processing 23batch = Batch(results_dir="./output") 24 .set_default_params(model="claude-sonnet-4-20250514") 25 .add_cost_limit(usd=5.0) 26 27# Add jobs 28for file in files: 29 batch.add_job(file=file, prompt="Summarize this document") 30 31# Execute 32run = batch.run() 33results = run.results() 34``` 35 36### Structured Output with Pydantic 37 38```python 39from batchata import Batch 40from pydantic import BaseModel 41 42class DocumentAnalysis(BaseModel): 43 title: str 44 summary: str 45 key_points: list[str] 46 47batch = Batch(results_dir="./results") 48 .set_default_params(model="claude-sonnet-4-20250514") 49 50batch.add_job( 51 file="document.pdf", 52 prompt="Analyze this document", 53 response_model=DocumentAnalysis, 54 enable_citations=True # Anthropic only 55) 56 57run = batch.run() 58for result in run.results()["completed"]: 59 analysis = result.parsed_response # DocumentAnalysis object 60 citations = result.citation_mappings # Field -> Citation mapping 61``` 62 63## Key Features 64 65- **50% Cost Savings**: Native batch processing via provider APIs 66- **Cost Limits**: Set `max_cost_usd` limits for batch requests 67- **Time Limits**: Control execution time with `.add_time_limit()` 68- **State Persistence**: Resume interrupted batches automatically 69- **Structured Output**: Pydantic models with automatic validation 70- **Citations**: Extract and map citations to response fields (Anthropic) 71- **Multiple Providers**: Anthropic Claude and OpenAI GPT models 72 73## Supported Providers 74 75| Feature | Anthropic | OpenAI | 76|---------|-----------|--------| 77| Models | [All Claude models](https://github.com/agamm/batchata/blob/main/batchata/providers/anthropic/models.py) | [All GPT models](https://github.com/agamm/batchata/blob/main/batchata/providers/openai/models.py) | 78| Citations | ✅ | ❌ | 79| Structured Output | ✅ | ✅ | 80| File Types | PDF, TXT, DOCX, Images | PDF, Images | 81 82## Configuration 83 84Set API keys as environment variables: 85 86```bash 87export ANTHROPIC_API_KEY="your-key" 88export OPENAI_API_KEY="your-key" 89``` 90 91Or use a `.env` file with python-dotenv. 92""" 93 94from .core import Batch, BatchRun, Job, JobResult 95from .exceptions import ( 96 BatchataError, 97 CostLimitExceededError, 98 ProviderError, 99 ProviderNotFoundError, 100 ValidationError, 101) 102from .types import Citation 103 104__version__ = "0.3.0" 105 106__all__ = [ 107 "Batch", 108 "BatchRun", 109 "Job", 110 "JobResult", 111 "Citation", 112 "BatchataError", 113 "CostLimitExceededError", 114 "ProviderError", 115 "ProviderNotFoundError", 116 "ValidationError", 117]
17class Batch: 18 """Builder for batch job configuration. 19 20 Provides a fluent interface for configuring batch jobs with sensible defaults 21 and validation. The batch can be configured with cost limits, default parameters, 22 and progress callbacks. 23 24 Example: 25 ```python 26 batch = Batch("./results", max_parallel_batches=10, items_per_batch=10) 27 .set_state(file="./state.json", reuse_state=True) 28 .set_default_params(model="claude-sonnet-4-20250514", temperature=0.7) 29 .add_cost_limit(usd=15.0) 30 .add_job(messages=[{"role": "user", "content": "Hello"}]) 31 .add_job(file="./path/to/file.pdf", prompt="Generate summary of file") 32 33 run = batch.run() 34 ``` 35 """ 36 37 def __init__(self, results_dir: str, max_parallel_batches: int = 10, items_per_batch: int = 10, raw_files: Optional[bool] = None): 38 """Initialize batch configuration. 39 40 Args: 41 results_dir: Directory to store results 42 max_parallel_batches: Maximum parallel batch requests 43 items_per_batch: Number of jobs per provider batch 44 raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise) 45 """ 46 # Auto-determine raw_files based on results_dir if not explicitly set 47 if raw_files is None: 48 raw_files = bool(results_dir and results_dir.strip()) 49 50 self.config = BatchParams( 51 state_file=None, 52 results_dir=results_dir, 53 max_parallel_batches=max_parallel_batches, 54 items_per_batch=items_per_batch, 55 reuse_state=True, 56 raw_files=raw_files 57 ) 58 self.jobs: List[Job] = [] 59 60 def set_default_params(self, **kwargs) -> 'Batch': 61 """Set default parameters for all jobs. 62 63 These defaults will be applied to all jobs unless overridden 64 by job-specific parameters. 65 66 Args: 67 **kwargs: Default parameters (model, temperature, max_tokens, etc.) 68 69 Returns: 70 Self for chaining 71 72 Example: 73 ```python 74 batch.set_default_params(model="claude-3-sonnet", temperature=0.7) 75 ``` 76 """ 77 # Validate if model is provided 78 if "model" in kwargs: 79 self.config.validate_default_params(kwargs["model"]) 80 81 self.config.default_params.update(kwargs) 82 return self 83 84 def set_state(self, file: Optional[str] = None, reuse_state: bool = True) -> 'Batch': 85 """Set state file configuration. 86 87 Args: 88 file: Path to state file for persistence (default: None) 89 reuse_state: Whether to resume from existing state file (default: True) 90 91 Returns: 92 Self for chaining 93 94 Example: 95 ```python 96 batch.set_state(file="./state.json", reuse_state=True) 97 ``` 98 """ 99 self.config.state_file = file 100 self.config.reuse_state = reuse_state 101 return self 102 103 def add_cost_limit(self, usd: float) -> 'Batch': 104 """Add cost limit for the batch. 105 106 The batch will stop accepting new jobs once the cost limit is reached. 107 Active jobs will be allowed to complete. 108 109 Args: 110 usd: Cost limit in USD 111 112 Returns: 113 Self for chaining 114 115 Example: 116 ```python 117 batch.add_cost_limit(usd=50.0) 118 ``` 119 """ 120 if usd <= 0: 121 raise ValueError("Cost limit must be positive") 122 self.config.cost_limit_usd = usd 123 return self 124 125 def raw_files(self, enabled: bool = True) -> 'Batch': 126 """Enable or disable saving debug files from providers. 127 128 When enabled, debug files (raw API responses, JSONL files) will be saved 129 in a 'raw_files' subdirectory within the results directory. 130 This is useful for debugging, auditing, or accessing provider-specific metadata. 131 132 Args: 133 enabled: Whether to save debug files (default: True) 134 135 Returns: 136 Self for chaining 137 138 Example: 139 ```python 140 batch.raw_files(True) 141 ``` 142 """ 143 self.config.raw_files = enabled 144 return self 145 146 def set_verbosity(self, level: str) -> 'Batch': 147 """Set logging verbosity level. 148 149 Args: 150 level: Verbosity level ("debug", "info", "warn", "error") 151 152 Returns: 153 Self for chaining 154 155 Example: 156 ```python 157 batch.set_verbosity("error") # For production 158 batch.set_verbosity("debug") # For debugging 159 ``` 160 """ 161 valid_levels = {"debug", "info", "warn", "error"} 162 if level.lower() not in valid_levels: 163 raise ValueError(f"Invalid verbosity level: {level}. Must be one of {valid_levels}") 164 self.config.verbosity = level.lower() 165 return self 166 167 def add_time_limit(self, seconds: Optional[float] = None, minutes: Optional[float] = None, hours: Optional[float] = None) -> 'Batch': 168 """Add time limit for the entire batch execution. 169 170 When time limit is reached, all active provider batches are cancelled and 171 remaining unprocessed jobs are marked as failed. The batch execution 172 completes normally without throwing exceptions. 173 174 Args: 175 seconds: Time limit in seconds (optional) 176 minutes: Time limit in minutes (optional) 177 hours: Time limit in hours (optional) 178 179 Returns: 180 Self for chaining 181 182 Raises: 183 ValueError: If no time units specified, or if total time is outside 184 valid range (min: 10 seconds, max: 24 hours) 185 186 Note: 187 - Can combine multiple time units 188 - Time limit is checked every second by a background watchdog thread 189 - Jobs that exceed time limit appear in results()["failed"] with time limit error message 190 - No exceptions are thrown when time limit is reached 191 192 Example: 193 ```python 194 batch.add_time_limit(seconds=30) # 30 seconds 195 batch.add_time_limit(minutes=5) # 5 minutes 196 batch.add_time_limit(hours=2) # 2 hours 197 batch.add_time_limit(hours=1, minutes=30, seconds=15) # 5415 seconds total 198 ``` 199 """ 200 time_limit_seconds = 0.0 201 202 if seconds is not None: 203 time_limit_seconds += seconds 204 if minutes is not None: 205 time_limit_seconds += minutes * 60 206 if hours is not None: 207 time_limit_seconds += hours * 3600 208 209 if time_limit_seconds == 0: 210 raise ValueError("Must specify at least one of seconds, minutes, or hours") 211 212 self.config.time_limit_seconds = time_limit_seconds 213 return self 214 215 def add_job( 216 self, 217 messages: Optional[List[Message]] = None, 218 file: Optional[Union[str, Path]] = None, 219 prompt: Optional[str] = None, 220 model: Optional[str] = None, 221 temperature: Optional[float] = None, 222 max_tokens: Optional[int] = None, 223 response_model: Optional[Type[BaseModel]] = None, 224 enable_citations: bool = False, 225 **kwargs 226 ) -> 'Batch': 227 """Add a job to the batch. 228 229 Either provide messages OR file+prompt, not both. Parameters not provided 230 will use the defaults set via the defaults() method. 231 232 Args: 233 messages: Chat messages for direct input 234 file: File path for file-based input 235 prompt: Prompt to use with file input 236 model: Model to use (overrides default) 237 temperature: Sampling temperature (overrides default) 238 max_tokens: Max tokens to generate (overrides default) 239 response_model: Pydantic model for structured output 240 enable_citations: Whether to extract citations 241 **kwargs: Additional parameters 242 243 Returns: 244 Self for chaining 245 246 Example: 247 ```python 248 batch.add_job( 249 messages=[{"role": "user", "content": "Hello"}], 250 model="gpt-4" 251 ) 252 ``` 253 """ 254 # Generate unique job ID 255 job_id = f"job-{uuid.uuid4().hex[:8]}" 256 257 # Merge with defaults 258 params = self.config.default_params.copy() 259 260 # Update with provided parameters 261 if model is not None: 262 params["model"] = model 263 if temperature is not None: 264 params["temperature"] = temperature 265 if max_tokens is not None: 266 params["max_tokens"] = max_tokens 267 268 # Add other kwargs 269 params.update(kwargs) 270 271 # Ensure model is provided 272 if "model" not in params: 273 raise ValueError("Model must be provided either in defaults or job parameters") 274 275 # Validate parameters 276 provider = get_provider(params["model"]) 277 # Extract params without model to avoid duplicate 278 param_subset = {k: v for k, v in params.items() if k != "model"} 279 provider.validate_params(params["model"], **param_subset) 280 281 # Convert file path if string 282 if isinstance(file, str): 283 file = Path(file) 284 285 # Create job 286 job = Job( 287 id=job_id, 288 messages=messages, 289 file=file, 290 prompt=prompt, 291 response_model=response_model, 292 enable_citations=enable_citations, 293 **params 294 ) 295 296 # Validate citation compatibility 297 if response_model and enable_citations: 298 from ..utils.validation import validate_flat_model 299 validate_flat_model(response_model) 300 301 # Validate job with provider (includes PDF validation for Anthropic) 302 provider.validate_job(job) 303 304 305 self.jobs.append(job) 306 return self 307 308 def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None, progress_interval: float = 1.0, print_status: bool = False) -> 'BatchRun': 309 """Execute the batch. 310 311 Creates a BatchRun instance and executes the jobs synchronously. 312 313 Args: 314 on_progress: Optional progress callback function that receives 315 (stats_dict, elapsed_time_seconds, batch_data) 316 progress_interval: Interval in seconds between progress updates (default: 1.0) 317 print_status: Whether to show rich progress display (default: False) 318 319 Returns: 320 BatchRun instance with completed results 321 322 Raises: 323 ValueError: If no jobs have been added 324 """ 325 if not self.jobs: 326 raise ValueError("No jobs added to batch") 327 328 # Import here to avoid circular dependency 329 from .batch_run import BatchRun 330 331 # Create and start the run 332 run = BatchRun(self.config, self.jobs) 333 334 # Set progress callback - either rich display or custom callback 335 if print_status: 336 return self._run_with_rich_display(run, progress_interval) 337 else: 338 return self._run_with_custom_callback(run, on_progress, progress_interval) 339 340 def _run_with_rich_display(self, run: 'BatchRun', progress_interval: float) -> 'BatchRun': 341 """Execute batch run with rich progress display. 342 343 Args: 344 run: BatchRun instance to execute 345 progress_interval: Interval between progress updates 346 347 Returns: 348 Completed BatchRun instance 349 """ 350 from ..utils.rich_progress import RichBatchProgressDisplay 351 display = RichBatchProgressDisplay() 352 353 def rich_progress_callback(stats, elapsed_time, batch_data): 354 # Start display on first call 355 if not hasattr(rich_progress_callback, '_started'): 356 config_dict = { 357 'results_dir': self.config.results_dir, 358 'state_file': self.config.state_file, 359 'items_per_batch': self.config.items_per_batch, 360 'max_parallel_batches': self.config.max_parallel_batches 361 } 362 display.start(stats, config_dict) 363 rich_progress_callback._started = True 364 365 # Update display 366 display.update(stats, batch_data, elapsed_time) 367 368 run.set_on_progress(rich_progress_callback, interval=progress_interval) 369 370 # Execute with proper cleanup 371 try: 372 run.execute() 373 374 # Show final status with all batches completed 375 stats = run.status() 376 display.update(stats, run.batch_tracking, (datetime.now() - run._start_time).total_seconds()) 377 378 # Small delay to ensure display updates 379 import time 380 time.sleep(0.2) 381 382 except KeyboardInterrupt: 383 # Update batch tracking to show cancelled status for pending/running batches 384 with run._state_lock: 385 for batch_id, batch_info in run.batch_tracking.items(): 386 if batch_info['status'] == 'running': 387 batch_info['status'] = 'cancelled' 388 elif batch_info['status'] == 'pending': 389 batch_info['status'] = 'cancelled' 390 391 # Show final status with cancelled batches 392 stats = run.status() 393 display.update(stats, run.batch_tracking, 0.0) 394 395 # Add a small delay to ensure the display updates 396 import time 397 time.sleep(0.1) 398 399 display.stop() 400 raise 401 finally: 402 if display.live: # Only stop if not already stopped 403 display.stop() 404 405 return run 406 407 def _run_with_custom_callback(self, run: 'BatchRun', on_progress: Optional[Callable[[Dict, float, Dict], None]], progress_interval: float) -> 'BatchRun': 408 """Execute batch run with custom progress callback. 409 410 Args: 411 run: BatchRun instance to execute 412 on_progress: Optional custom progress callback 413 progress_interval: Interval between progress updates 414 415 Returns: 416 Completed BatchRun instance 417 """ 418 # Use custom progress callback if provided 419 if on_progress: 420 run.set_on_progress(on_progress, interval=progress_interval) 421 422 run.execute() 423 return run 424 425 def __len__(self) -> int: 426 """Get the number of jobs in the batch.""" 427 return len(self.jobs) 428 429 def __repr__(self) -> str: 430 """String representation of the batch.""" 431 return ( 432 f"Batch(jobs={len(self.jobs)}, " 433 f"max_parallel_batches={self.config.max_parallel_batches}, " 434 f"cost_limit=${self.config.cost_limit_usd or 'None'})" 435 )
Builder for batch job configuration.
Provides a fluent interface for configuring batch jobs with sensible defaults and validation. The batch can be configured with cost limits, default parameters, and progress callbacks.
Example:
batch = Batch("./results", max_parallel_batches=10, items_per_batch=10)
.set_state(file="./state.json", reuse_state=True)
.set_default_params(model="claude-sonnet-4-20250514", temperature=0.7)
.add_cost_limit(usd=15.0)
.add_job(messages=[{"role": "user", "content": "Hello"}])
.add_job(file="./path/to/file.pdf", prompt="Generate summary of file")
run = batch.run()
37 def __init__(self, results_dir: str, max_parallel_batches: int = 10, items_per_batch: int = 10, raw_files: Optional[bool] = None): 38 """Initialize batch configuration. 39 40 Args: 41 results_dir: Directory to store results 42 max_parallel_batches: Maximum parallel batch requests 43 items_per_batch: Number of jobs per provider batch 44 raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise) 45 """ 46 # Auto-determine raw_files based on results_dir if not explicitly set 47 if raw_files is None: 48 raw_files = bool(results_dir and results_dir.strip()) 49 50 self.config = BatchParams( 51 state_file=None, 52 results_dir=results_dir, 53 max_parallel_batches=max_parallel_batches, 54 items_per_batch=items_per_batch, 55 reuse_state=True, 56 raw_files=raw_files 57 ) 58 self.jobs: List[Job] = []
Initialize batch configuration.
Args: results_dir: Directory to store results max_parallel_batches: Maximum parallel batch requests items_per_batch: Number of jobs per provider batch raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise)
60 def set_default_params(self, **kwargs) -> 'Batch': 61 """Set default parameters for all jobs. 62 63 These defaults will be applied to all jobs unless overridden 64 by job-specific parameters. 65 66 Args: 67 **kwargs: Default parameters (model, temperature, max_tokens, etc.) 68 69 Returns: 70 Self for chaining 71 72 Example: 73 ```python 74 batch.set_default_params(model="claude-3-sonnet", temperature=0.7) 75 ``` 76 """ 77 # Validate if model is provided 78 if "model" in kwargs: 79 self.config.validate_default_params(kwargs["model"]) 80 81 self.config.default_params.update(kwargs) 82 return self
Set default parameters for all jobs.
These defaults will be applied to all jobs unless overridden by job-specific parameters.
Args: **kwargs: Default parameters (model, temperature, max_tokens, etc.)
Returns: Self for chaining
Example:
batch.set_default_params(model="claude-3-sonnet", temperature=0.7)
84 def set_state(self, file: Optional[str] = None, reuse_state: bool = True) -> 'Batch': 85 """Set state file configuration. 86 87 Args: 88 file: Path to state file for persistence (default: None) 89 reuse_state: Whether to resume from existing state file (default: True) 90 91 Returns: 92 Self for chaining 93 94 Example: 95 ```python 96 batch.set_state(file="./state.json", reuse_state=True) 97 ``` 98 """ 99 self.config.state_file = file 100 self.config.reuse_state = reuse_state 101 return self
Set state file configuration.
Args: file: Path to state file for persistence (default: None) reuse_state: Whether to resume from existing state file (default: True)
Returns: Self for chaining
Example:
batch.set_state(file="./state.json", reuse_state=True)
103 def add_cost_limit(self, usd: float) -> 'Batch': 104 """Add cost limit for the batch. 105 106 The batch will stop accepting new jobs once the cost limit is reached. 107 Active jobs will be allowed to complete. 108 109 Args: 110 usd: Cost limit in USD 111 112 Returns: 113 Self for chaining 114 115 Example: 116 ```python 117 batch.add_cost_limit(usd=50.0) 118 ``` 119 """ 120 if usd <= 0: 121 raise ValueError("Cost limit must be positive") 122 self.config.cost_limit_usd = usd 123 return self
Add cost limit for the batch.
The batch will stop accepting new jobs once the cost limit is reached. Active jobs will be allowed to complete.
Args: usd: Cost limit in USD
Returns: Self for chaining
Example:
batch.add_cost_limit(usd=50.0)
125 def raw_files(self, enabled: bool = True) -> 'Batch': 126 """Enable or disable saving debug files from providers. 127 128 When enabled, debug files (raw API responses, JSONL files) will be saved 129 in a 'raw_files' subdirectory within the results directory. 130 This is useful for debugging, auditing, or accessing provider-specific metadata. 131 132 Args: 133 enabled: Whether to save debug files (default: True) 134 135 Returns: 136 Self for chaining 137 138 Example: 139 ```python 140 batch.raw_files(True) 141 ``` 142 """ 143 self.config.raw_files = enabled 144 return self
Enable or disable saving debug files from providers.
When enabled, debug files (raw API responses, JSONL files) will be saved in a 'raw_files' subdirectory within the results directory. This is useful for debugging, auditing, or accessing provider-specific metadata.
Args: enabled: Whether to save debug files (default: True)
Returns: Self for chaining
Example:
batch.raw_files(True)
146 def set_verbosity(self, level: str) -> 'Batch': 147 """Set logging verbosity level. 148 149 Args: 150 level: Verbosity level ("debug", "info", "warn", "error") 151 152 Returns: 153 Self for chaining 154 155 Example: 156 ```python 157 batch.set_verbosity("error") # For production 158 batch.set_verbosity("debug") # For debugging 159 ``` 160 """ 161 valid_levels = {"debug", "info", "warn", "error"} 162 if level.lower() not in valid_levels: 163 raise ValueError(f"Invalid verbosity level: {level}. Must be one of {valid_levels}") 164 self.config.verbosity = level.lower() 165 return self
Set logging verbosity level.
Args: level: Verbosity level ("debug", "info", "warn", "error")
Returns: Self for chaining
Example:
batch.set_verbosity("error") # For production
batch.set_verbosity("debug") # For debugging
167 def add_time_limit(self, seconds: Optional[float] = None, minutes: Optional[float] = None, hours: Optional[float] = None) -> 'Batch': 168 """Add time limit for the entire batch execution. 169 170 When time limit is reached, all active provider batches are cancelled and 171 remaining unprocessed jobs are marked as failed. The batch execution 172 completes normally without throwing exceptions. 173 174 Args: 175 seconds: Time limit in seconds (optional) 176 minutes: Time limit in minutes (optional) 177 hours: Time limit in hours (optional) 178 179 Returns: 180 Self for chaining 181 182 Raises: 183 ValueError: If no time units specified, or if total time is outside 184 valid range (min: 10 seconds, max: 24 hours) 185 186 Note: 187 - Can combine multiple time units 188 - Time limit is checked every second by a background watchdog thread 189 - Jobs that exceed time limit appear in results()["failed"] with time limit error message 190 - No exceptions are thrown when time limit is reached 191 192 Example: 193 ```python 194 batch.add_time_limit(seconds=30) # 30 seconds 195 batch.add_time_limit(minutes=5) # 5 minutes 196 batch.add_time_limit(hours=2) # 2 hours 197 batch.add_time_limit(hours=1, minutes=30, seconds=15) # 5415 seconds total 198 ``` 199 """ 200 time_limit_seconds = 0.0 201 202 if seconds is not None: 203 time_limit_seconds += seconds 204 if minutes is not None: 205 time_limit_seconds += minutes * 60 206 if hours is not None: 207 time_limit_seconds += hours * 3600 208 209 if time_limit_seconds == 0: 210 raise ValueError("Must specify at least one of seconds, minutes, or hours") 211 212 self.config.time_limit_seconds = time_limit_seconds 213 return self
Add time limit for the entire batch execution.
When time limit is reached, all active provider batches are cancelled and remaining unprocessed jobs are marked as failed. The batch execution completes normally without throwing exceptions.
Args: seconds: Time limit in seconds (optional) minutes: Time limit in minutes (optional) hours: Time limit in hours (optional)
Returns: Self for chaining
Raises: ValueError: If no time units specified, or if total time is outside valid range (min: 10 seconds, max: 24 hours)
Note: - Can combine multiple time units - Time limit is checked every second by a background watchdog thread - Jobs that exceed time limit appear in results()["failed"] with time limit error message - No exceptions are thrown when time limit is reached
Example:
batch.add_time_limit(seconds=30) # 30 seconds
batch.add_time_limit(minutes=5) # 5 minutes
batch.add_time_limit(hours=2) # 2 hours
batch.add_time_limit(hours=1, minutes=30, seconds=15) # 5415 seconds total
215 def add_job( 216 self, 217 messages: Optional[List[Message]] = None, 218 file: Optional[Union[str, Path]] = None, 219 prompt: Optional[str] = None, 220 model: Optional[str] = None, 221 temperature: Optional[float] = None, 222 max_tokens: Optional[int] = None, 223 response_model: Optional[Type[BaseModel]] = None, 224 enable_citations: bool = False, 225 **kwargs 226 ) -> 'Batch': 227 """Add a job to the batch. 228 229 Either provide messages OR file+prompt, not both. Parameters not provided 230 will use the defaults set via the defaults() method. 231 232 Args: 233 messages: Chat messages for direct input 234 file: File path for file-based input 235 prompt: Prompt to use with file input 236 model: Model to use (overrides default) 237 temperature: Sampling temperature (overrides default) 238 max_tokens: Max tokens to generate (overrides default) 239 response_model: Pydantic model for structured output 240 enable_citations: Whether to extract citations 241 **kwargs: Additional parameters 242 243 Returns: 244 Self for chaining 245 246 Example: 247 ```python 248 batch.add_job( 249 messages=[{"role": "user", "content": "Hello"}], 250 model="gpt-4" 251 ) 252 ``` 253 """ 254 # Generate unique job ID 255 job_id = f"job-{uuid.uuid4().hex[:8]}" 256 257 # Merge with defaults 258 params = self.config.default_params.copy() 259 260 # Update with provided parameters 261 if model is not None: 262 params["model"] = model 263 if temperature is not None: 264 params["temperature"] = temperature 265 if max_tokens is not None: 266 params["max_tokens"] = max_tokens 267 268 # Add other kwargs 269 params.update(kwargs) 270 271 # Ensure model is provided 272 if "model" not in params: 273 raise ValueError("Model must be provided either in defaults or job parameters") 274 275 # Validate parameters 276 provider = get_provider(params["model"]) 277 # Extract params without model to avoid duplicate 278 param_subset = {k: v for k, v in params.items() if k != "model"} 279 provider.validate_params(params["model"], **param_subset) 280 281 # Convert file path if string 282 if isinstance(file, str): 283 file = Path(file) 284 285 # Create job 286 job = Job( 287 id=job_id, 288 messages=messages, 289 file=file, 290 prompt=prompt, 291 response_model=response_model, 292 enable_citations=enable_citations, 293 **params 294 ) 295 296 # Validate citation compatibility 297 if response_model and enable_citations: 298 from ..utils.validation import validate_flat_model 299 validate_flat_model(response_model) 300 301 # Validate job with provider (includes PDF validation for Anthropic) 302 provider.validate_job(job) 303 304 305 self.jobs.append(job) 306 return self
Add a job to the batch.
Either provide messages OR file+prompt, not both. Parameters not provided will use the defaults set via the defaults() method.
Args: messages: Chat messages for direct input file: File path for file-based input prompt: Prompt to use with file input model: Model to use (overrides default) temperature: Sampling temperature (overrides default) max_tokens: Max tokens to generate (overrides default) response_model: Pydantic model for structured output enable_citations: Whether to extract citations **kwargs: Additional parameters
Returns: Self for chaining
Example:
batch.add_job(
messages=[{"role": "user", "content": "Hello"}],
model="gpt-4"
)
308 def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None, progress_interval: float = 1.0, print_status: bool = False) -> 'BatchRun': 309 """Execute the batch. 310 311 Creates a BatchRun instance and executes the jobs synchronously. 312 313 Args: 314 on_progress: Optional progress callback function that receives 315 (stats_dict, elapsed_time_seconds, batch_data) 316 progress_interval: Interval in seconds between progress updates (default: 1.0) 317 print_status: Whether to show rich progress display (default: False) 318 319 Returns: 320 BatchRun instance with completed results 321 322 Raises: 323 ValueError: If no jobs have been added 324 """ 325 if not self.jobs: 326 raise ValueError("No jobs added to batch") 327 328 # Import here to avoid circular dependency 329 from .batch_run import BatchRun 330 331 # Create and start the run 332 run = BatchRun(self.config, self.jobs) 333 334 # Set progress callback - either rich display or custom callback 335 if print_status: 336 return self._run_with_rich_display(run, progress_interval) 337 else: 338 return self._run_with_custom_callback(run, on_progress, progress_interval)
Execute the batch.
Creates a BatchRun instance and executes the jobs synchronously.
Args: on_progress: Optional progress callback function that receives (stats_dict, elapsed_time_seconds, batch_data) progress_interval: Interval in seconds between progress updates (default: 1.0) print_status: Whether to show rich progress display (default: False)
Returns: BatchRun instance with completed results
Raises: ValueError: If no jobs have been added
25class BatchRun: 26 """Manages the execution of a batch job synchronously. 27 28 Processes jobs in batches based on items_per_batch configuration. 29 Simpler synchronous execution for clear logging and debugging. 30 31 Example: 32 ```python 33 config = BatchParams(...) 34 run = BatchRun(config, jobs) 35 run.execute() 36 results = run.results() 37 ``` 38 """ 39 40 def __init__(self, config: BatchParams, jobs: List[Job]): 41 """Initialize batch run. 42 43 Args: 44 config: Batch configuration 45 jobs: List of jobs to execute 46 """ 47 self.config = config 48 self.jobs = {job.id: job for job in jobs} 49 50 # Set logging level based on config 51 set_log_level(level=config.verbosity.upper()) 52 53 # Initialize components 54 self.cost_tracker = CostTracker(limit_usd=config.cost_limit_usd) 55 56 # Use temp file for state if not provided 57 state_file = config.state_file 58 if not state_file: 59 state_file = create_temp_state_file(config) 60 config.reuse_state = False 61 logger.info(f"Created temporary state file: {state_file}") 62 63 self.state_manager = StateManager(state_file) 64 65 # State tracking 66 self.pending_jobs: List[Job] = [] 67 self.completed_results: Dict[str, JobResult] = {} # job_id -> result 68 self.failed_jobs: Dict[str, str] = {} # job_id -> error 69 self.cancelled_jobs: Dict[str, str] = {} # job_id -> reason 70 71 # Batch tracking 72 self.total_batches = 0 73 self.completed_batches = 0 74 self.current_batch_index = 0 75 self.current_batch_size = 0 76 77 # Execution control 78 self._started = False 79 self._start_time: Optional[datetime] = None 80 self._time_limit_exceeded = False 81 self._progress_callback: Optional[Callable[[Dict, float], None]] = None 82 self._progress_interval: float = 1.0 # Default to 1 second 83 84 # Threading primitives 85 self._state_lock = threading.Lock() 86 self._shutdown_event = threading.Event() 87 88 # Batch tracking for progress display 89 self.batch_tracking: Dict[str, Dict] = {} # batch_id -> batch_info 90 91 # Active batch tracking for cancellation 92 self._active_batches: Dict[str, object] = {} # batch_id -> provider 93 self._active_batches_lock = threading.Lock() 94 95 # Results directory 96 self.results_dir = Path(config.results_dir) 97 98 # If not reusing state, clear the results directory 99 if not config.reuse_state and self.results_dir.exists(): 100 import shutil 101 shutil.rmtree(self.results_dir) 102 103 self.results_dir.mkdir(parents=True, exist_ok=True) 104 105 # Raw files directory (if enabled) 106 self.raw_files_dir = None 107 if config.raw_files: 108 self.raw_files_dir = self.results_dir / "raw_files" 109 self.raw_files_dir.mkdir(parents=True, exist_ok=True) 110 111 # Try to resume from saved state 112 self._resume_from_state() 113 114 115 def _resume_from_state(self): 116 """Resume from saved state if available.""" 117 # Check if we should reuse state 118 if not self.config.reuse_state: 119 # Clear any existing state and start fresh 120 self.state_manager.clear() 121 self.pending_jobs = list(self.jobs.values()) 122 return 123 124 state = self.state_manager.load() 125 if state is None: 126 # No saved state, use jobs passed to constructor 127 self.pending_jobs = list(self.jobs.values()) 128 return 129 130 logger.info("Resuming batch run from saved state") 131 132 # Restore pending jobs 133 self.pending_jobs = [] 134 for job_data in state.pending_jobs: 135 job = Job.from_dict(job_data) 136 self.pending_jobs.append(job) 137 138 # Restore completed results 139 for result_data in state.completed_results: 140 result = JobResult.from_dict(result_data) 141 self.completed_results[result.job_id] = result 142 143 # Restore failed jobs 144 for job_data in state.failed_jobs: 145 self.failed_jobs[job_data["id"]] = job_data.get("error", "Unknown error") 146 147 # Restore cancelled jobs (if they exist in state) 148 for job_data in getattr(state, 'cancelled_jobs', []): 149 self.cancelled_jobs[job_data["id"]] = job_data.get("reason", "Cancelled") 150 151 # Restore cost tracker 152 self.cost_tracker.used_usd = state.total_cost_usd 153 154 logger.info( 155 f"Resumed with {len(self.pending_jobs)} pending, " 156 f"{len(self.completed_results)} completed, " 157 f"{len(self.failed_jobs)} failed, " 158 f"{len(self.cancelled_jobs)} cancelled" 159 ) 160 161 def to_json(self) -> Dict: 162 """Convert current state to JSON-serializable dict.""" 163 return { 164 "created_at": datetime.now().isoformat(), 165 "pending_jobs": [job.to_dict() for job in self.pending_jobs], 166 "completed_results": [result.to_dict() for result in self.completed_results.values()], 167 "failed_jobs": [ 168 { 169 "id": job_id, 170 "error": error, 171 "timestamp": datetime.now().isoformat() 172 } for job_id, error in self.failed_jobs.items() 173 ], 174 "cancelled_jobs": [ 175 { 176 "id": job_id, 177 "reason": reason, 178 "timestamp": datetime.now().isoformat() 179 } for job_id, reason in self.cancelled_jobs.items() 180 ], 181 "total_cost_usd": self.cost_tracker.used_usd, 182 "config": { 183 "state_file": self.config.state_file, 184 "results_dir": self.config.results_dir, 185 "max_parallel_batches": self.config.max_parallel_batches, 186 "items_per_batch": self.config.items_per_batch, 187 "cost_limit_usd": self.config.cost_limit_usd, 188 "default_params": self.config.default_params, 189 "raw_files": self.config.raw_files 190 } 191 } 192 193 def execute(self): 194 """Execute synchronous batch run and wait for completion.""" 195 if self._started: 196 raise RuntimeError("Batch run already started") 197 198 self._started = True 199 self._start_time = datetime.now() 200 201 # Register signal handler for graceful shutdown 202 def signal_handler(signum, frame): 203 logger.warning("Received interrupt signal, shutting down gracefully...") 204 self._shutdown_event.set() 205 206 # Store original handler to restore later 207 original_handler = signal.signal(signal.SIGINT, signal_handler) 208 209 try: 210 logger.info("Starting batch run") 211 212 # Start time limit watchdog if configured 213 self._start_time_limit_watchdog() 214 215 # Call initial progress 216 if self._progress_callback: 217 stats = self.status() 218 batch_data = dict(self.batch_tracking) 219 self._progress_callback(stats, 0.0, batch_data) 220 221 # Process all jobs synchronously 222 self._process_all_jobs() 223 224 logger.info("Batch run completed") 225 finally: 226 # Restore original signal handler 227 signal.signal(signal.SIGINT, original_handler) 228 229 def set_on_progress(self, callback: Callable[[Dict, float, Dict], None], interval: float = 1.0) -> 'BatchRun': 230 """Set progress callback for execution monitoring. 231 232 The callback will be called periodically with progress statistics 233 including completed jobs, total jobs, current cost, etc. 234 235 Args: 236 callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data) 237 - stats_dict: Progress statistics dictionary 238 - elapsed_time_seconds: Time elapsed since batch started (float) 239 - batch_data: Dictionary mapping batch_id to batch information 240 interval: Interval in seconds between progress updates (default: 1.0) 241 242 Returns: 243 Self for chaining 244 245 Example: 246 ```python 247 run.set_on_progress( 248 lambda stats, time, batch_data: print( 249 f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s" 250 ) 251 ) 252 ``` 253 """ 254 self._progress_callback = callback 255 self._progress_interval = interval 256 return self 257 258 def _start_time_limit_watchdog(self): 259 """Start a background thread to check for time limit every second.""" 260 if not self.config.time_limit_seconds: 261 return 262 263 def time_limit_watchdog(): 264 """Check for time limit every second and trigger shutdown if exceeded.""" 265 while not self._shutdown_event.is_set(): 266 if self._check_time_limit(): 267 logger.warning("Batch execution time limit exceeded") 268 with self._state_lock: 269 self._time_limit_exceeded = True 270 self._shutdown_event.set() 271 break 272 time.sleep(1.0) 273 274 # Start watchdog as daemon thread 275 watchdog_thread = threading.Thread(target=time_limit_watchdog, daemon=True) 276 watchdog_thread.start() 277 logger.debug(f"Started time limit watchdog thread (time limit: {self.config.time_limit_seconds}s)") 278 279 def _check_time_limit(self) -> bool: 280 """Check if batch execution has exceeded time limit.""" 281 if not self.config.time_limit_seconds or not self._start_time: 282 return False 283 284 elapsed = (datetime.now() - self._start_time).total_seconds() 285 return elapsed >= self.config.time_limit_seconds 286 287 def _process_all_jobs(self): 288 """Process all jobs with parallel execution.""" 289 # Prepare all batches 290 batches = self._prepare_batches() 291 self.total_batches = len(batches) 292 293 # Process batches in parallel 294 with ThreadPoolExecutor(max_workers=self.config.max_parallel_batches) as executor: 295 futures = [executor.submit(self._execute_batch_wrapped, provider, batch_jobs) 296 for _, provider, batch_jobs in batches] 297 298 try: 299 for future in as_completed(futures): 300 # Stop if shutdown event detected (includes time limit) 301 if self._shutdown_event.is_set(): 302 break 303 future.result() # Re-raise any exceptions 304 except KeyboardInterrupt: 305 self._shutdown_event.set() 306 # Cancel remaining futures 307 for future in futures: 308 future.cancel() 309 raise 310 finally: 311 # Handle time limit or cancellation - mark remaining jobs appropriately 312 with self._state_lock: 313 if self._shutdown_event.is_set(): 314 # If time limit exceeded, cancel all active batches 315 if self._time_limit_exceeded: 316 self._cancel_all_active_batches() 317 318 # Mark any unprocessed jobs based on reason for shutdown 319 for _, _, batch_jobs in batches: 320 for job in batch_jobs: 321 # Skip jobs already processed 322 if (job.id in self.completed_results or 323 job.id in self.failed_jobs or 324 job.id in self.cancelled_jobs): 325 continue 326 327 # Mark based on shutdown reason 328 if self._time_limit_exceeded: 329 self.failed_jobs[job.id] = "Time limit exceeded: batch execution time limit exceeded" 330 else: 331 self.cancelled_jobs[job.id] = "Cancelled by user" 332 333 if job in self.pending_jobs: 334 self.pending_jobs.remove(job) 335 336 # Save state 337 self.state_manager.save(self) 338 339 def _cancel_all_active_batches(self): 340 """Cancel all active batches at the provider level.""" 341 with self._active_batches_lock: 342 active_batch_items = list(self._active_batches.items()) 343 344 logger.info(f"Cancelling {len(active_batch_items)} active batches due to time limit exceeded") 345 346 # Cancel outside the lock to avoid blocking 347 for batch_id, provider in active_batch_items: 348 try: 349 provider.cancel_batch(batch_id) 350 logger.info(f"Cancelled batch {batch_id} due to time limit exceeded") 351 except Exception as e: 352 logger.warning(f"Failed to cancel batch {batch_id}: {e}") 353 354 # Clear the tracking after cancellation attempts 355 with self._active_batches_lock: 356 self._active_batches.clear() 357 358 def _execute_batch_wrapped(self, provider, batch_jobs): 359 """Thread-safe wrapper for _execute_batch.""" 360 try: 361 result = self._execute_batch(provider, batch_jobs) 362 with self._state_lock: 363 self._update_batch_results(result) 364 # Remove jobs from pending_jobs if specified 365 jobs_to_remove = result.get("jobs_to_remove", []) 366 for job in jobs_to_remove: 367 if job in self.pending_jobs: 368 self.pending_jobs.remove(job) 369 except TimeoutError: 370 # Handle time limit exceeded - mark jobs as failed 371 with self._state_lock: 372 for job in batch_jobs: 373 self.failed_jobs[job.id] = "Time limit exceeded: batch execution time limit exceeded" 374 if job in self.pending_jobs: 375 self.pending_jobs.remove(job) 376 self.state_manager.save(self) 377 # Don't re-raise, just return the result 378 return 379 except KeyboardInterrupt: 380 self._shutdown_event.set() 381 # Handle user cancellation 382 with self._state_lock: 383 for job in batch_jobs: 384 self.cancelled_jobs[job.id] = "Cancelled by user" 385 if job in self.pending_jobs: 386 self.pending_jobs.remove(job) 387 self.state_manager.save(self) 388 raise 389 390 def _group_jobs_by_provider(self) -> Dict[str, List[Job]]: 391 """Group jobs by provider.""" 392 jobs_by_provider = {} 393 394 for job in self.pending_jobs[:]: # Copy to avoid modification during iteration 395 try: 396 provider = get_provider(job.model) 397 provider_name = provider.__class__.__name__ 398 399 if provider_name not in jobs_by_provider: 400 jobs_by_provider[provider_name] = [] 401 402 jobs_by_provider[provider_name].append(job) 403 404 except Exception as e: 405 logger.error(f"Failed to get provider for job {job.id}: {e}") 406 with self._state_lock: 407 self.failed_jobs[job.id] = str(e) 408 self.pending_jobs.remove(job) 409 410 return jobs_by_provider 411 412 def _split_into_batches(self, jobs: List[Job]) -> List[List[Job]]: 413 """Split jobs into batches based on items_per_batch.""" 414 batches = [] 415 batch_size = self.config.items_per_batch 416 417 for i in range(0, len(jobs), batch_size): 418 batch = jobs[i:i + batch_size] 419 batches.append(batch) 420 421 return batches 422 423 def _prepare_batches(self) -> List[Tuple[str, object, List[Job]]]: 424 """Prepare all batches as simple list of (provider_name, provider, jobs).""" 425 batches = [] 426 jobs_by_provider = self._group_jobs_by_provider() 427 428 for provider_name, provider_jobs in jobs_by_provider.items(): 429 provider = get_provider(provider_jobs[0].model) 430 job_batches = self._split_into_batches(provider_jobs) 431 432 for batch_jobs in job_batches: 433 batches.append((provider_name, provider, batch_jobs)) 434 435 # Pre-populate batch tracking for pending batches 436 batch_id = f"pending_{len(self.batch_tracking)}" 437 estimated_cost = provider.estimate_cost(batch_jobs) 438 self.batch_tracking[batch_id] = { 439 'start_time': None, 440 'status': 'pending', 441 'total': len(batch_jobs), 442 'completed': 0, 443 'cost': 0.0, 444 'estimated_cost': estimated_cost, 445 'provider': provider_name, 446 'jobs': batch_jobs 447 } 448 449 return batches 450 451 def _poll_batch_status(self, provider, batch_id: str) -> Tuple[str, Optional[Dict]]: 452 """Poll until batch completes.""" 453 status, error_details = provider.get_batch_status(batch_id) 454 logger.info(f"Initial batch status: {status}") 455 poll_count = 0 456 457 # Use provider-specific polling interval 458 provider_polling_interval = provider.get_polling_interval() 459 logger.debug(f"Using {provider_polling_interval}s polling interval for {provider.__class__.__name__}") 460 461 while status not in ["complete", "failed"]: 462 poll_count += 1 463 logger.debug(f"Polling attempt {poll_count}, current status: {status}") 464 465 # Interruptible wait - will wake up immediately if shutdown event is set (includes time limit) 466 if self._shutdown_event.wait(provider_polling_interval): 467 # Check if it's time limit exceeded or user cancellation 468 with self._state_lock: 469 is_time_limit_exceeded = self._time_limit_exceeded 470 471 if is_time_limit_exceeded: 472 logger.info(f"Batch {batch_id} polling interrupted by time limit exceeded") 473 raise TimeoutError("Batch cancelled due to time limit exceeded") 474 else: 475 logger.info(f"Batch {batch_id} polling interrupted by user") 476 raise KeyboardInterrupt("Batch cancelled by user") 477 478 status, error_details = provider.get_batch_status(batch_id) 479 480 if self._progress_callback: 481 with self._state_lock: 482 stats = self.status() 483 elapsed_time = (datetime.now() - self._start_time).total_seconds() 484 batch_data = dict(self.batch_tracking) 485 self._progress_callback(stats, elapsed_time, batch_data) 486 487 elapsed_seconds = poll_count * provider_polling_interval 488 logger.info(f"Batch {batch_id} status: {status} (polling for {elapsed_seconds:.1f}s)") 489 490 return status, error_details 491 492 493 def _update_batch_results(self, batch_result: Dict): 494 """Update state from batch results.""" 495 results = batch_result.get("results", []) 496 failed = batch_result.get("failed", {}) 497 498 # Update completed results 499 for result in results: 500 if result.is_success: 501 self.completed_results[result.job_id] = result 502 self._save_result_to_file(result) 503 logger.info(f"✓ Job {result.job_id} completed successfully") 504 else: 505 error_message = result.error or "Unknown error" 506 self.failed_jobs[result.job_id] = error_message 507 self._save_result_to_file(result) 508 logger.error(f"✗ Job {result.job_id} failed: {result.error}") 509 510 # Update failed jobs 511 for job_id, error in failed.items(): 512 self.failed_jobs[job_id] = error 513 logger.error(f"✗ Job {job_id} failed: {error}") 514 515 # Update batch tracking 516 self.completed_batches += 1 517 518 # Save state 519 self.state_manager.save(self) 520 521 def _execute_batch(self, provider, batch_jobs: List[Job]) -> Dict: 522 """Execute one batch, return results dict with jobs/costs/errors.""" 523 if not batch_jobs: 524 return {"results": [], "failed": {}, "cost": 0.0} 525 526 # Reserve cost limit 527 logger.info(f"Estimating cost for batch of {len(batch_jobs)} jobs...") 528 estimated_cost = provider.estimate_cost(batch_jobs) 529 remaining = self.cost_tracker.remaining() 530 remaining_str = f"${remaining:.4f}" if remaining is not None else "unlimited" 531 logger.info(f"Total estimated cost: ${estimated_cost:.4f}, remaining budget: {remaining_str}") 532 533 if not self.cost_tracker.reserve_cost(estimated_cost): 534 logger.warning(f"Cost limit would be exceeded, skipping batch of {len(batch_jobs)} jobs") 535 failed = {} 536 for job in batch_jobs: 537 failed[job.id] = "Cost limit exceeded" 538 return {"results": [], "failed": failed, "cost": 0.0, "jobs_to_remove": list(batch_jobs)} 539 540 batch_id = None 541 job_mapping = None 542 try: 543 # Create batch 544 logger.info(f"Creating batch with {len(batch_jobs)} jobs...") 545 raw_files_path = str(self.raw_files_dir) if self.raw_files_dir else None 546 batch_id, job_mapping = provider.create_batch(batch_jobs, raw_files_path) 547 548 # Track active batch for cancellation 549 with self._active_batches_lock: 550 self._active_batches[batch_id] = provider 551 552 # Track batch creation 553 with self._state_lock: 554 # Remove pending entry if it exists 555 pending_keys = [k for k in self.batch_tracking.keys() if k.startswith('pending_')] 556 for pending_key in pending_keys: 557 if self.batch_tracking[pending_key]['jobs'] == batch_jobs: 558 del self.batch_tracking[pending_key] 559 break 560 561 # Add actual batch tracking 562 self.batch_tracking[batch_id] = { 563 'start_time': datetime.now(), 564 'status': 'running', 565 'total': len(batch_jobs), 566 'completed': 0, 567 'cost': 0.0, 568 'estimated_cost': estimated_cost, 569 'provider': provider.__class__.__name__, 570 'jobs': batch_jobs 571 } 572 573 # Poll for completion 574 logger.info(f"Polling for batch {batch_id} completion...") 575 status, error_details = self._poll_batch_status(provider, batch_id) 576 577 if status == "failed": 578 if error_details: 579 logger.error(f"Batch {batch_id} failed with details: {error_details}") 580 else: 581 logger.error(f"Batch {batch_id} failed") 582 583 # Save error details if configured 584 if self.raw_files_dir and error_details: 585 self._save_batch_error_details(batch_id, error_details) 586 587 # Continue to get individual results - some jobs might have succeeded 588 589 # Get results 590 logger.info(f"Getting results for batch {batch_id}") 591 raw_files_path = str(self.raw_files_dir) if self.raw_files_dir else None 592 results = provider.get_batch_results(batch_id, job_mapping, raw_files_path) 593 594 # Calculate actual cost and adjust reservation 595 actual_cost = sum(r.cost_usd for r in results) 596 self.cost_tracker.adjust_reserved_cost(estimated_cost, actual_cost) 597 598 # Update batch tracking for completion 599 success_count = len([r for r in results if r.is_success]) 600 failed_count = len([r for r in results if not r.is_success]) 601 batch_status = 'complete' if failed_count == 0 else 'failed' 602 603 with self._state_lock: 604 if batch_id in self.batch_tracking: 605 self.batch_tracking[batch_id]['status'] = batch_status 606 self.batch_tracking[batch_id]['completed'] = success_count 607 self.batch_tracking[batch_id]['failed'] = failed_count 608 self.batch_tracking[batch_id]['cost'] = actual_cost 609 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 610 if batch_status == 'failed' and failed_count > 0: 611 # Use the first job's error as the batch error summary 612 first_error = next((r.error for r in results if not r.is_success), 'Some jobs failed') 613 self.batch_tracking[batch_id]['error'] = first_error 614 615 # Remove from active batches tracking 616 with self._active_batches_lock: 617 self._active_batches.pop(batch_id, None) 618 619 status_symbol = "✓" if batch_status == 'complete' else "⚠" 620 logger.info( 621 f"{status_symbol} Batch {batch_id} completed: " 622 f"{success_count} success, " 623 f"{failed_count} failed, " 624 f"cost: ${actual_cost:.6f}" 625 ) 626 627 return {"results": results, "failed": {}, "cost": actual_cost, "jobs_to_remove": list(batch_jobs)} 628 629 except TimeoutError: 630 logger.info(f"Time limit exceeded for batch{f' {batch_id}' if batch_id else ''}") 631 if batch_id: 632 # Update batch tracking for time limit exceeded 633 with self._state_lock: 634 if batch_id in self.batch_tracking: 635 self.batch_tracking[batch_id]['status'] = 'failed' 636 self.batch_tracking[batch_id]['error'] = 'Time limit exceeded: batch execution time limit exceeded' 637 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 638 # NOTE: Don't remove from _active_batches - let centralized cancellation handle it 639 # Release the reservation since batch exceeded time limit 640 self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0) 641 # Re-raise to be handled by wrapper 642 raise 643 644 except KeyboardInterrupt: 645 logger.warning(f"\nCancelling batch{f' {batch_id}' if batch_id else ''}...") 646 if batch_id: 647 # Update batch tracking for cancellation 648 with self._state_lock: 649 if batch_id in self.batch_tracking: 650 self.batch_tracking[batch_id]['status'] = 'cancelled' 651 self.batch_tracking[batch_id]['error'] = 'Cancelled by user' 652 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 653 # Remove from active batches tracking 654 with self._active_batches_lock: 655 self._active_batches.pop(batch_id, None) 656 # Release the reservation since batch was cancelled 657 self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0) 658 # Handle cancellation in the wrapper with proper locking 659 raise 660 661 except Exception as e: 662 logger.error(f"✗ Batch execution failed: {e}") 663 # Update batch tracking for exception 664 if batch_id: 665 with self._state_lock: 666 if batch_id in self.batch_tracking: 667 self.batch_tracking[batch_id]['status'] = 'failed' 668 self.batch_tracking[batch_id]['error'] = str(e) 669 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 670 # Remove from active batches tracking 671 with self._active_batches_lock: 672 self._active_batches.pop(batch_id, None) 673 # Release the reservation since batch failed 674 self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0) 675 failed = {} 676 for job in batch_jobs: 677 failed[job.id] = str(e) 678 return {"results": [], "failed": failed, "cost": 0.0, "jobs_to_remove": list(batch_jobs)} 679 680 681 def _save_result_to_file(self, result: JobResult): 682 """Save individual result to file.""" 683 result_file = self.results_dir / f"{result.job_id}.json" 684 685 try: 686 with open(result_file, 'w') as f: 687 json.dump(result.to_dict(), f, indent=2) 688 except Exception as e: 689 logger.error(f"Failed to save result for {result.job_id}: {e}") 690 691 def _save_batch_error_details(self, batch_id: str, error_details: Dict): 692 """Save batch error details to debug files directory.""" 693 try: 694 error_file = self.raw_files_dir / f"batch_{batch_id}_error.json" 695 with open(error_file, 'w') as f: 696 json.dump({ 697 "batch_id": batch_id, 698 "timestamp": datetime.now().isoformat(), 699 "error_details": error_details 700 }, f, indent=2) 701 logger.info(f"Saved batch error details to {error_file}") 702 except Exception as e: 703 logger.error(f"Failed to save batch error details: {e}") 704 705 @property 706 def is_complete(self) -> bool: 707 """Whether all jobs are complete.""" 708 total_jobs = len(self.jobs) 709 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 710 return len(self.pending_jobs) == 0 and completed_count == total_jobs 711 712 713 def status(self, print_status: bool = False) -> Dict: 714 """Get current execution statistics.""" 715 total_jobs = len(self.jobs) 716 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 717 remaining_count = total_jobs - completed_count 718 719 stats = { 720 "total": total_jobs, 721 "pending": remaining_count, 722 "active": 0, # Always 0 for synchronous execution 723 "completed": len(self.completed_results), 724 "failed": len(self.failed_jobs), 725 "cancelled": len(self.cancelled_jobs), 726 "cost_usd": self.cost_tracker.used_usd, 727 "cost_limit_usd": self.cost_tracker.limit_usd, 728 "is_complete": self.is_complete, 729 "batches_total": self.total_batches, 730 "batches_completed": self.completed_batches, 731 "batches_pending": self.total_batches - self.completed_batches, 732 "current_batch_index": self.current_batch_index, 733 "current_batch_size": self.current_batch_size, 734 "items_per_batch": self.config.items_per_batch 735 } 736 737 if print_status: 738 logger.info("\nBatch Run Status:") 739 logger.info(f" Total jobs: {stats['total']}") 740 logger.info(f" Pending: {stats['pending']}") 741 logger.info(f" Active: {stats['active']}") 742 logger.info(f" Completed: {stats['completed']}") 743 logger.info(f" Failed: {stats['failed']}") 744 logger.info(f" Cancelled: {stats['cancelled']}") 745 logger.info(f" Cost: ${stats['cost_usd']:.6f}") 746 if stats['cost_limit_usd']: 747 logger.info(f" Cost limit: ${stats['cost_limit_usd']:.2f}") 748 logger.info(f" Complete: {stats['is_complete']}") 749 750 return stats 751 752 def results(self) -> Dict[str, List[JobResult]]: 753 """Get all results organized by status. 754 755 Returns: 756 { 757 "completed": [JobResult], 758 "failed": [JobResult], 759 "cancelled": [JobResult] 760 } 761 """ 762 return { 763 "completed": list(self.completed_results.values()), 764 "failed": self._create_failed_results(), 765 "cancelled": self._create_cancelled_results() 766 } 767 768 def get_failed_jobs(self) -> Dict[str, str]: 769 """Get failed jobs with error messages. 770 771 Note: This method is deprecated. Use results()['failed'] instead. 772 """ 773 return dict(self.failed_jobs) 774 775 def _create_failed_results(self) -> List[JobResult]: 776 """Convert failed jobs to JobResult objects.""" 777 failed_results = [] 778 for job_id, error_msg in self.failed_jobs.items(): 779 failed_results.append(JobResult( 780 job_id=job_id, 781 raw_response=None, 782 parsed_response=None, 783 error=error_msg, 784 cost_usd=0.0, 785 input_tokens=0, 786 output_tokens=0 787 )) 788 return failed_results 789 790 def _create_cancelled_results(self) -> List[JobResult]: 791 """Convert cancelled jobs to JobResult objects.""" 792 cancelled_results = [] 793 for job_id, reason in self.cancelled_jobs.items(): 794 cancelled_results.append(JobResult( 795 job_id=job_id, 796 raw_response=None, 797 parsed_response=None, 798 error=reason, 799 cost_usd=0.0, 800 input_tokens=0, 801 output_tokens=0 802 )) 803 return cancelled_results 804 805 def shutdown(self): 806 """Shutdown (no-op for synchronous execution).""" 807 pass
Manages the execution of a batch job synchronously.
Processes jobs in batches based on items_per_batch configuration. Simpler synchronous execution for clear logging and debugging.
Example:
config = BatchParams(...)
run = BatchRun(config, jobs)
run.execute()
results = run.results()
40 def __init__(self, config: BatchParams, jobs: List[Job]): 41 """Initialize batch run. 42 43 Args: 44 config: Batch configuration 45 jobs: List of jobs to execute 46 """ 47 self.config = config 48 self.jobs = {job.id: job for job in jobs} 49 50 # Set logging level based on config 51 set_log_level(level=config.verbosity.upper()) 52 53 # Initialize components 54 self.cost_tracker = CostTracker(limit_usd=config.cost_limit_usd) 55 56 # Use temp file for state if not provided 57 state_file = config.state_file 58 if not state_file: 59 state_file = create_temp_state_file(config) 60 config.reuse_state = False 61 logger.info(f"Created temporary state file: {state_file}") 62 63 self.state_manager = StateManager(state_file) 64 65 # State tracking 66 self.pending_jobs: List[Job] = [] 67 self.completed_results: Dict[str, JobResult] = {} # job_id -> result 68 self.failed_jobs: Dict[str, str] = {} # job_id -> error 69 self.cancelled_jobs: Dict[str, str] = {} # job_id -> reason 70 71 # Batch tracking 72 self.total_batches = 0 73 self.completed_batches = 0 74 self.current_batch_index = 0 75 self.current_batch_size = 0 76 77 # Execution control 78 self._started = False 79 self._start_time: Optional[datetime] = None 80 self._time_limit_exceeded = False 81 self._progress_callback: Optional[Callable[[Dict, float], None]] = None 82 self._progress_interval: float = 1.0 # Default to 1 second 83 84 # Threading primitives 85 self._state_lock = threading.Lock() 86 self._shutdown_event = threading.Event() 87 88 # Batch tracking for progress display 89 self.batch_tracking: Dict[str, Dict] = {} # batch_id -> batch_info 90 91 # Active batch tracking for cancellation 92 self._active_batches: Dict[str, object] = {} # batch_id -> provider 93 self._active_batches_lock = threading.Lock() 94 95 # Results directory 96 self.results_dir = Path(config.results_dir) 97 98 # If not reusing state, clear the results directory 99 if not config.reuse_state and self.results_dir.exists(): 100 import shutil 101 shutil.rmtree(self.results_dir) 102 103 self.results_dir.mkdir(parents=True, exist_ok=True) 104 105 # Raw files directory (if enabled) 106 self.raw_files_dir = None 107 if config.raw_files: 108 self.raw_files_dir = self.results_dir / "raw_files" 109 self.raw_files_dir.mkdir(parents=True, exist_ok=True) 110 111 # Try to resume from saved state 112 self._resume_from_state()
Initialize batch run.
Args: config: Batch configuration jobs: List of jobs to execute
161 def to_json(self) -> Dict: 162 """Convert current state to JSON-serializable dict.""" 163 return { 164 "created_at": datetime.now().isoformat(), 165 "pending_jobs": [job.to_dict() for job in self.pending_jobs], 166 "completed_results": [result.to_dict() for result in self.completed_results.values()], 167 "failed_jobs": [ 168 { 169 "id": job_id, 170 "error": error, 171 "timestamp": datetime.now().isoformat() 172 } for job_id, error in self.failed_jobs.items() 173 ], 174 "cancelled_jobs": [ 175 { 176 "id": job_id, 177 "reason": reason, 178 "timestamp": datetime.now().isoformat() 179 } for job_id, reason in self.cancelled_jobs.items() 180 ], 181 "total_cost_usd": self.cost_tracker.used_usd, 182 "config": { 183 "state_file": self.config.state_file, 184 "results_dir": self.config.results_dir, 185 "max_parallel_batches": self.config.max_parallel_batches, 186 "items_per_batch": self.config.items_per_batch, 187 "cost_limit_usd": self.config.cost_limit_usd, 188 "default_params": self.config.default_params, 189 "raw_files": self.config.raw_files 190 } 191 }
Convert current state to JSON-serializable dict.
193 def execute(self): 194 """Execute synchronous batch run and wait for completion.""" 195 if self._started: 196 raise RuntimeError("Batch run already started") 197 198 self._started = True 199 self._start_time = datetime.now() 200 201 # Register signal handler for graceful shutdown 202 def signal_handler(signum, frame): 203 logger.warning("Received interrupt signal, shutting down gracefully...") 204 self._shutdown_event.set() 205 206 # Store original handler to restore later 207 original_handler = signal.signal(signal.SIGINT, signal_handler) 208 209 try: 210 logger.info("Starting batch run") 211 212 # Start time limit watchdog if configured 213 self._start_time_limit_watchdog() 214 215 # Call initial progress 216 if self._progress_callback: 217 stats = self.status() 218 batch_data = dict(self.batch_tracking) 219 self._progress_callback(stats, 0.0, batch_data) 220 221 # Process all jobs synchronously 222 self._process_all_jobs() 223 224 logger.info("Batch run completed") 225 finally: 226 # Restore original signal handler 227 signal.signal(signal.SIGINT, original_handler)
Execute synchronous batch run and wait for completion.
229 def set_on_progress(self, callback: Callable[[Dict, float, Dict], None], interval: float = 1.0) -> 'BatchRun': 230 """Set progress callback for execution monitoring. 231 232 The callback will be called periodically with progress statistics 233 including completed jobs, total jobs, current cost, etc. 234 235 Args: 236 callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data) 237 - stats_dict: Progress statistics dictionary 238 - elapsed_time_seconds: Time elapsed since batch started (float) 239 - batch_data: Dictionary mapping batch_id to batch information 240 interval: Interval in seconds between progress updates (default: 1.0) 241 242 Returns: 243 Self for chaining 244 245 Example: 246 ```python 247 run.set_on_progress( 248 lambda stats, time, batch_data: print( 249 f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s" 250 ) 251 ) 252 ``` 253 """ 254 self._progress_callback = callback 255 self._progress_interval = interval 256 return self
Set progress callback for execution monitoring.
The callback will be called periodically with progress statistics including completed jobs, total jobs, current cost, etc.
Args: callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data) - stats_dict: Progress statistics dictionary - elapsed_time_seconds: Time elapsed since batch started (float) - batch_data: Dictionary mapping batch_id to batch information interval: Interval in seconds between progress updates (default: 1.0)
Returns: Self for chaining
Example:
run.set_on_progress(
lambda stats, time, batch_data: print(
f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s"
)
)
705 @property 706 def is_complete(self) -> bool: 707 """Whether all jobs are complete.""" 708 total_jobs = len(self.jobs) 709 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 710 return len(self.pending_jobs) == 0 and completed_count == total_jobs
Whether all jobs are complete.
713 def status(self, print_status: bool = False) -> Dict: 714 """Get current execution statistics.""" 715 total_jobs = len(self.jobs) 716 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 717 remaining_count = total_jobs - completed_count 718 719 stats = { 720 "total": total_jobs, 721 "pending": remaining_count, 722 "active": 0, # Always 0 for synchronous execution 723 "completed": len(self.completed_results), 724 "failed": len(self.failed_jobs), 725 "cancelled": len(self.cancelled_jobs), 726 "cost_usd": self.cost_tracker.used_usd, 727 "cost_limit_usd": self.cost_tracker.limit_usd, 728 "is_complete": self.is_complete, 729 "batches_total": self.total_batches, 730 "batches_completed": self.completed_batches, 731 "batches_pending": self.total_batches - self.completed_batches, 732 "current_batch_index": self.current_batch_index, 733 "current_batch_size": self.current_batch_size, 734 "items_per_batch": self.config.items_per_batch 735 } 736 737 if print_status: 738 logger.info("\nBatch Run Status:") 739 logger.info(f" Total jobs: {stats['total']}") 740 logger.info(f" Pending: {stats['pending']}") 741 logger.info(f" Active: {stats['active']}") 742 logger.info(f" Completed: {stats['completed']}") 743 logger.info(f" Failed: {stats['failed']}") 744 logger.info(f" Cancelled: {stats['cancelled']}") 745 logger.info(f" Cost: ${stats['cost_usd']:.6f}") 746 if stats['cost_limit_usd']: 747 logger.info(f" Cost limit: ${stats['cost_limit_usd']:.2f}") 748 logger.info(f" Complete: {stats['is_complete']}") 749 750 return stats
Get current execution statistics.
752 def results(self) -> Dict[str, List[JobResult]]: 753 """Get all results organized by status. 754 755 Returns: 756 { 757 "completed": [JobResult], 758 "failed": [JobResult], 759 "cancelled": [JobResult] 760 } 761 """ 762 return { 763 "completed": list(self.completed_results.values()), 764 "failed": self._create_failed_results(), 765 "cancelled": self._create_cancelled_results() 766 }
Get all results organized by status.
Returns: { "completed": [JobResult], "failed": [JobResult], "cancelled": [JobResult] }
768 def get_failed_jobs(self) -> Dict[str, str]: 769 """Get failed jobs with error messages. 770 771 Note: This method is deprecated. Use results()['failed'] instead. 772 """ 773 return dict(self.failed_jobs)
Get failed jobs with error messages.
Note: This method is deprecated. Use results()['failed'] instead.
12@dataclass 13class Job: 14 """Configuration for a single AI job. 15 16 Either provide messages OR file+prompt, not both. 17 18 Attributes: 19 id: Unique identifier for the job 20 messages: Chat messages for direct message input 21 file: File path for file-based input 22 prompt: Prompt to use with file input 23 model: Model name (e.g., "claude-3-sonnet") 24 temperature: Sampling temperature (0.0-1.0) 25 max_tokens: Maximum tokens to generate 26 response_model: Pydantic model for structured output 27 enable_citations: Whether to extract citations from response 28 """ 29 30 id: str # Unique identifier 31 model: str # Model name (e.g., "claude-3-sonnet") 32 messages: Optional[List[Message]] = None # Chat messages 33 file: Optional[Path] = None # File input 34 prompt: Optional[str] = None # Prompt for file 35 temperature: float = 0.7 36 max_tokens: int = 1000 37 response_model: Optional[Type[BaseModel]] = None # For structured output 38 enable_citations: bool = False 39 40 def __post_init__(self): 41 """Validate job configuration.""" 42 if self.messages and (self.file or self.prompt): 43 raise ValueError("Provide either messages OR file+prompt, not both") 44 45 if self.file and not self.prompt: 46 raise ValueError("File input requires a prompt") 47 48 if not self.messages and not (self.file and self.prompt): 49 raise ValueError("Must provide either messages or file+prompt") 50 51 def to_dict(self) -> Dict[str, Any]: 52 """Serialize for state persistence.""" 53 return { 54 "id": self.id, 55 "model": self.model, 56 "messages": self.messages, 57 "file": str(self.file) if self.file else None, 58 "prompt": self.prompt, 59 "temperature": self.temperature, 60 "max_tokens": self.max_tokens, 61 "response_model": self.response_model.__name__ if self.response_model else None, 62 "enable_citations": self.enable_citations 63 } 64 65 @classmethod 66 def from_dict(cls, data: Dict[str, Any]) -> 'Job': 67 """Deserialize from state.""" 68 # Convert file string back to Path if present 69 file_path = None 70 if data.get("file"): 71 file_path = Path(data["file"]) 72 73 # Note: response_model reconstruction would need additional logic 74 # For now, we'll set it to None during deserialization 75 return cls( 76 id=data["id"], 77 model=data["model"], 78 messages=data.get("messages"), 79 file=file_path, 80 prompt=data.get("prompt"), 81 temperature=data.get("temperature", 0.7), 82 max_tokens=data.get("max_tokens", 1000), 83 response_model=None, # Cannot reconstruct from string 84 enable_citations=data.get("enable_citations", False) 85 )
Configuration for a single AI job.
Either provide messages OR file+prompt, not both.
Attributes: id: Unique identifier for the job messages: Chat messages for direct message input file: File path for file-based input prompt: Prompt to use with file input model: Model name (e.g., "claude-3-sonnet") temperature: Sampling temperature (0.0-1.0) max_tokens: Maximum tokens to generate response_model: Pydantic model for structured output enable_citations: Whether to extract citations from response
51 def to_dict(self) -> Dict[str, Any]: 52 """Serialize for state persistence.""" 53 return { 54 "id": self.id, 55 "model": self.model, 56 "messages": self.messages, 57 "file": str(self.file) if self.file else None, 58 "prompt": self.prompt, 59 "temperature": self.temperature, 60 "max_tokens": self.max_tokens, 61 "response_model": self.response_model.__name__ if self.response_model else None, 62 "enable_citations": self.enable_citations 63 }
Serialize for state persistence.
65 @classmethod 66 def from_dict(cls, data: Dict[str, Any]) -> 'Job': 67 """Deserialize from state.""" 68 # Convert file string back to Path if present 69 file_path = None 70 if data.get("file"): 71 file_path = Path(data["file"]) 72 73 # Note: response_model reconstruction would need additional logic 74 # For now, we'll set it to None during deserialization 75 return cls( 76 id=data["id"], 77 model=data["model"], 78 messages=data.get("messages"), 79 file=file_path, 80 prompt=data.get("prompt"), 81 temperature=data.get("temperature", 0.7), 82 max_tokens=data.get("max_tokens", 1000), 83 response_model=None, # Cannot reconstruct from string 84 enable_citations=data.get("enable_citations", False) 85 )
Deserialize from state.
11@dataclass 12class JobResult: 13 """Result from a completed AI job. 14 15 Attributes: 16 job_id: ID of the job this result is for 17 raw_response: Raw text response from the model (None for failed jobs) 18 parsed_response: Structured output (if response_model was used) 19 citations: Extracted citations (if enable_citations was True) 20 citation_mappings: Maps field names to relevant citations (if response_model used) 21 input_tokens: Number of input tokens used 22 output_tokens: Number of output tokens generated 23 cost_usd: Total cost in USD 24 error: Error message if job failed 25 batch_id: ID of the batch this job was part of (for mapping to raw files) 26 """ 27 28 job_id: str 29 raw_response: Optional[str] = None # Raw text response (None for failed jobs) 30 parsed_response: Optional[Union[BaseModel, Dict]] = None # Structured output or error dict 31 citations: Optional[List[Citation]] = None # Extracted citations 32 citation_mappings: Optional[Dict[str, List[Citation]]] = None # Field -> citations mapping 33 input_tokens: int = 0 34 output_tokens: int = 0 35 cost_usd: float = 0.0 36 error: Optional[str] = None # Error message if failed 37 batch_id: Optional[str] = None # Batch ID for mapping to raw files 38 39 @property 40 def is_success(self) -> bool: 41 """Whether the job completed successfully.""" 42 return self.error is None 43 44 @property 45 def total_tokens(self) -> int: 46 """Total tokens used (input + output).""" 47 return self.input_tokens + self.output_tokens 48 49 def to_dict(self) -> Dict[str, Any]: 50 """Serialize for state persistence.""" 51 # Handle parsed_response serialization 52 parsed_response = None 53 if self.parsed_response is not None: 54 if isinstance(self.parsed_response, dict): 55 parsed_response = self.parsed_response 56 elif isinstance(self.parsed_response, BaseModel): 57 parsed_response = self.parsed_response.model_dump() 58 else: 59 parsed_response = str(self.parsed_response) 60 61 # Handle citation_mappings serialization 62 citation_mappings = None 63 if self.citation_mappings: 64 citation_mappings = { 65 field: [asdict(c) for c in citations] 66 for field, citations in self.citation_mappings.items() 67 } 68 69 return { 70 "job_id": self.job_id, 71 "raw_response": self.raw_response, 72 "parsed_response": parsed_response, 73 "citations": [asdict(c) for c in self.citations] if self.citations else None, 74 "citation_mappings": citation_mappings, 75 "input_tokens": self.input_tokens, 76 "output_tokens": self.output_tokens, 77 "cost_usd": self.cost_usd, 78 "error": self.error, 79 "batch_id": self.batch_id 80 } 81 82 @classmethod 83 def from_dict(cls, data: Dict[str, Any]) -> 'JobResult': 84 """Deserialize from state.""" 85 # Reconstruct citations if present 86 citations = None 87 if data.get("citations"): 88 citations = [Citation(**c) for c in data["citations"]] 89 90 # Reconstruct citation_mappings if present 91 citation_mappings = None 92 if data.get("citation_mappings"): 93 citation_mappings = { 94 field: [Citation(**c) for c in citations] 95 for field, citations in data["citation_mappings"].items() 96 } 97 98 return cls( 99 job_id=data["job_id"], 100 raw_response=data["raw_response"], 101 parsed_response=data.get("parsed_response"), 102 citations=citations, 103 citation_mappings=citation_mappings, 104 input_tokens=data.get("input_tokens", 0), 105 output_tokens=data.get("output_tokens", 0), 106 cost_usd=data.get("cost_usd", 0.0), 107 error=data.get("error"), 108 batch_id=data.get("batch_id") 109 )
Result from a completed AI job.
Attributes: job_id: ID of the job this result is for raw_response: Raw text response from the model (None for failed jobs) parsed_response: Structured output (if response_model was used) citations: Extracted citations (if enable_citations was True) citation_mappings: Maps field names to relevant citations (if response_model used) input_tokens: Number of input tokens used output_tokens: Number of output tokens generated cost_usd: Total cost in USD error: Error message if job failed batch_id: ID of the batch this job was part of (for mapping to raw files)
39 @property 40 def is_success(self) -> bool: 41 """Whether the job completed successfully.""" 42 return self.error is None
Whether the job completed successfully.
44 @property 45 def total_tokens(self) -> int: 46 """Total tokens used (input + output).""" 47 return self.input_tokens + self.output_tokens
Total tokens used (input + output).
49 def to_dict(self) -> Dict[str, Any]: 50 """Serialize for state persistence.""" 51 # Handle parsed_response serialization 52 parsed_response = None 53 if self.parsed_response is not None: 54 if isinstance(self.parsed_response, dict): 55 parsed_response = self.parsed_response 56 elif isinstance(self.parsed_response, BaseModel): 57 parsed_response = self.parsed_response.model_dump() 58 else: 59 parsed_response = str(self.parsed_response) 60 61 # Handle citation_mappings serialization 62 citation_mappings = None 63 if self.citation_mappings: 64 citation_mappings = { 65 field: [asdict(c) for c in citations] 66 for field, citations in self.citation_mappings.items() 67 } 68 69 return { 70 "job_id": self.job_id, 71 "raw_response": self.raw_response, 72 "parsed_response": parsed_response, 73 "citations": [asdict(c) for c in self.citations] if self.citations else None, 74 "citation_mappings": citation_mappings, 75 "input_tokens": self.input_tokens, 76 "output_tokens": self.output_tokens, 77 "cost_usd": self.cost_usd, 78 "error": self.error, 79 "batch_id": self.batch_id 80 }
Serialize for state persistence.
82 @classmethod 83 def from_dict(cls, data: Dict[str, Any]) -> 'JobResult': 84 """Deserialize from state.""" 85 # Reconstruct citations if present 86 citations = None 87 if data.get("citations"): 88 citations = [Citation(**c) for c in data["citations"]] 89 90 # Reconstruct citation_mappings if present 91 citation_mappings = None 92 if data.get("citation_mappings"): 93 citation_mappings = { 94 field: [Citation(**c) for c in citations] 95 for field, citations in data["citation_mappings"].items() 96 } 97 98 return cls( 99 job_id=data["job_id"], 100 raw_response=data["raw_response"], 101 parsed_response=data.get("parsed_response"), 102 citations=citations, 103 citation_mappings=citation_mappings, 104 input_tokens=data.get("input_tokens", 0), 105 output_tokens=data.get("output_tokens", 0), 106 cost_usd=data.get("cost_usd", 0.0), 107 error=data.get("error"), 108 batch_id=data.get("batch_id") 109 )
Deserialize from state.
8@dataclass 9class Citation: 10 """Represents a citation extracted from an AI response.""" 11 12 text: str # The cited text 13 source: str # Source identifier (e.g., page number, section) 14 page: Optional[int] = None # Page number if applicable 15 metadata: Optional[Dict[str, Any]] = None # Additional metadata
Represents a citation extracted from an AI response.
Base exception for all Batchata errors.
35class CostLimitExceededError(BatchataError): 36 """Raised when cost limit would be exceeded.""" 37 pass
Raised when cost limit would be exceeded.
Base exception for provider-related errors.
20class ProviderNotFoundError(ProviderError): 21 """Raised when no provider is found for a model.""" 22 pass
Raised when no provider is found for a model.
10class ValidationError(BatchataError): 11 """Raised when job or configuration validation fails.""" 12 pass
Raised when job or configuration validation fails.