batchata
Batchata - Unified Python API for AI Batch requests with cost tracking, Pydantic responses, and parallel execution.
Why AI-batching?
AI providers offer batch APIs that process requests asynchronously at 50% reduced cost compared to real-time APIs. This is ideal for workloads like document processing, data analysis, and content generation where immediate responses aren't required.
Quick Start
Installation
pip install batchata
Basic Usage
from batchata import Batch
# Simple batch processing
batch = Batch(results_dir="./output")
.set_default_params(model="claude-sonnet-4-20250514")
.add_cost_limit(usd=5.0)
# Add jobs
for file in files:
batch.add_job(file=file, prompt="Summarize this document")
# Execute
run = batch.run()
results = run.results()
Structured Output with Pydantic
from batchata import Batch
from pydantic import BaseModel
class DocumentAnalysis(BaseModel):
title: str
summary: str
key_points: list[str]
batch = Batch(results_dir="./results")
.set_default_params(model="claude-sonnet-4-20250514")
batch.add_job(
file="document.pdf",
prompt="Analyze this document",
response_model=DocumentAnalysis,
enable_citations=True # Anthropic only
)
run = batch.run()
for result in run.results()["completed"]:
analysis = result.parsed_response # DocumentAnalysis object
citations = result.citation_mappings # Field -> Citation mapping
Key Features
- 50% Cost Savings: Native batch processing via provider APIs
- Cost Limits: Set
max_cost_usd
limits for batch requests - Time Limits: Control execution time with
.add_time_limit()
- State Persistence: Resume interrupted batches automatically
- Structured Output: Pydantic models with automatic validation
- Citations: Extract and map citations to response fields (Anthropic)
- Multiple Providers: Anthropic Claude and OpenAI GPT models
Supported Providers
Feature | Anthropic | OpenAI |
---|---|---|
Models | All Claude models | All GPT models |
Citations | ✅ | ❌ |
Structured Output | ✅ | ✅ |
File Types | PDF, TXT, DOCX, Images | PDF, Images |
Configuration
Set API keys as environment variables:
export ANTHROPIC_API_KEY="your-key"
export OPENAI_API_KEY="your-key"
Or use a .env
file with python-dotenv.
1"""Batchata - Unified Python API for AI Batch requests with cost tracking, Pydantic responses, and parallel execution. 2 3**Why AI-batching?** 4 5AI providers offer batch APIs that process requests asynchronously at 50% reduced cost compared to real-time APIs. 6This is ideal for workloads like document processing, data analysis, and content generation where immediate 7responses aren't required. 8 9## Quick Start 10 11### Installation 12 13```bash 14pip install batchata 15``` 16 17### Basic Usage 18 19```python 20from batchata import Batch 21 22# Simple batch processing 23batch = Batch(results_dir="./output") 24 .set_default_params(model="claude-sonnet-4-20250514") 25 .add_cost_limit(usd=5.0) 26 27# Add jobs 28for file in files: 29 batch.add_job(file=file, prompt="Summarize this document") 30 31# Execute 32run = batch.run() 33results = run.results() 34``` 35 36### Structured Output with Pydantic 37 38```python 39from batchata import Batch 40from pydantic import BaseModel 41 42class DocumentAnalysis(BaseModel): 43 title: str 44 summary: str 45 key_points: list[str] 46 47batch = Batch(results_dir="./results") 48 .set_default_params(model="claude-sonnet-4-20250514") 49 50batch.add_job( 51 file="document.pdf", 52 prompt="Analyze this document", 53 response_model=DocumentAnalysis, 54 enable_citations=True # Anthropic only 55) 56 57run = batch.run() 58for result in run.results()["completed"]: 59 analysis = result.parsed_response # DocumentAnalysis object 60 citations = result.citation_mappings # Field -> Citation mapping 61``` 62 63## Key Features 64 65- **50% Cost Savings**: Native batch processing via provider APIs 66- **Cost Limits**: Set `max_cost_usd` limits for batch requests 67- **Time Limits**: Control execution time with `.add_time_limit()` 68- **State Persistence**: Resume interrupted batches automatically 69- **Structured Output**: Pydantic models with automatic validation 70- **Citations**: Extract and map citations to response fields (Anthropic) 71- **Multiple Providers**: Anthropic Claude and OpenAI GPT models 72 73## Supported Providers 74 75| Feature | Anthropic | OpenAI | 76|---------|-----------|--------| 77| Models | [All Claude models](https://github.com/agamm/batchata/blob/main/batchata/providers/anthropic/models.py) | [All GPT models](https://github.com/agamm/batchata/blob/main/batchata/providers/openai/models.py) | 78| Citations | ✅ | ❌ | 79| Structured Output | ✅ | ✅ | 80| File Types | PDF, TXT, DOCX, Images | PDF, Images | 81 82## Configuration 83 84Set API keys as environment variables: 85 86```bash 87export ANTHROPIC_API_KEY="your-key" 88export OPENAI_API_KEY="your-key" 89``` 90 91Or use a `.env` file with python-dotenv. 92""" 93 94from .core import Batch, BatchRun, Job, JobResult 95from .exceptions import ( 96 BatchataError, 97 CostLimitExceededError, 98 ProviderError, 99 ProviderNotFoundError, 100 ValidationError, 101) 102from .types import Citation 103 104__version__ = "0.3.0" 105 106__all__ = [ 107 "Batch", 108 "BatchRun", 109 "Job", 110 "JobResult", 111 "Citation", 112 "BatchataError", 113 "CostLimitExceededError", 114 "ProviderError", 115 "ProviderNotFoundError", 116 "ValidationError", 117]
18class Batch: 19 """Builder for batch job configuration. 20 21 Provides a fluent interface for configuring batch jobs with sensible defaults 22 and validation. The batch can be configured with cost limits, default parameters, 23 and progress callbacks. 24 25 Example: 26 ```python 27 batch = Batch("./results", max_parallel_batches=10, items_per_batch=10) 28 .set_state(file="./state.json", reuse_state=True) 29 .set_default_params(model="claude-sonnet-4-20250514", temperature=0.7) 30 .add_cost_limit(usd=15.0) 31 .add_job(messages=[{"role": "user", "content": "Hello"}]) 32 .add_job(file="./path/to/file.pdf", prompt="Generate summary of file") 33 34 run = batch.run() 35 ``` 36 """ 37 38 def __init__(self, results_dir: str, max_parallel_batches: int = 10, items_per_batch: int = 10, raw_files: Optional[bool] = None): 39 """Initialize batch configuration. 40 41 Args: 42 results_dir: Directory to store results 43 max_parallel_batches: Maximum parallel batch requests 44 items_per_batch: Number of jobs per provider batch 45 raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise) 46 """ 47 # Auto-determine raw_files based on results_dir if not explicitly set 48 if raw_files is None: 49 raw_files = bool(results_dir and results_dir.strip()) 50 51 self.config = BatchParams( 52 state_file=None, 53 results_dir=results_dir, 54 max_parallel_batches=max_parallel_batches, 55 items_per_batch=items_per_batch, 56 reuse_state=True, 57 raw_files=raw_files 58 ) 59 self.jobs: List[Job] = [] 60 61 def set_default_params(self, **kwargs) -> 'Batch': 62 """Set default parameters for all jobs. 63 64 These defaults will be applied to all jobs unless overridden 65 by job-specific parameters. 66 67 Args: 68 **kwargs: Default parameters (model, temperature, max_tokens, etc.) 69 70 Returns: 71 Self for chaining 72 73 Example: 74 ```python 75 batch.set_default_params(model="claude-3-sonnet", temperature=0.7) 76 ``` 77 """ 78 # Validate if model is provided 79 if "model" in kwargs: 80 self.config.validate_default_params(kwargs["model"]) 81 82 self.config.default_params.update(kwargs) 83 return self 84 85 def set_state(self, file: Optional[str] = None, reuse_state: bool = True) -> 'Batch': 86 """Set state file configuration. 87 88 Args: 89 file: Path to state file for persistence (default: None) 90 reuse_state: Whether to resume from existing state file (default: True) 91 92 Returns: 93 Self for chaining 94 95 Example: 96 ```python 97 batch.set_state(file="./state.json", reuse_state=True) 98 ``` 99 """ 100 self.config.state_file = file 101 self.config.reuse_state = reuse_state 102 return self 103 104 def add_cost_limit(self, usd: float) -> 'Batch': 105 """Add cost limit for the batch. 106 107 The batch will stop accepting new jobs once the cost limit is reached. 108 Active jobs will be allowed to complete. 109 110 Args: 111 usd: Cost limit in USD 112 113 Returns: 114 Self for chaining 115 116 Example: 117 ```python 118 batch.add_cost_limit(usd=50.0) 119 ``` 120 """ 121 if usd <= 0: 122 raise ValueError("Cost limit must be positive") 123 self.config.cost_limit_usd = usd 124 return self 125 126 def raw_files(self, enabled: bool = True) -> 'Batch': 127 """Enable or disable saving debug files from providers. 128 129 When enabled, debug files (raw API responses, JSONL files) will be saved 130 in a 'raw_files' subdirectory within the results directory. 131 This is useful for debugging, auditing, or accessing provider-specific metadata. 132 133 Args: 134 enabled: Whether to save debug files (default: True) 135 136 Returns: 137 Self for chaining 138 139 Example: 140 ```python 141 batch.raw_files(True) 142 ``` 143 """ 144 self.config.raw_files = enabled 145 return self 146 147 def set_verbosity(self, level: str) -> 'Batch': 148 """Set logging verbosity level. 149 150 Args: 151 level: Verbosity level ("debug", "info", "warn", "error") 152 153 Returns: 154 Self for chaining 155 156 Example: 157 ```python 158 batch.set_verbosity("error") # For production 159 batch.set_verbosity("debug") # For debugging 160 ``` 161 """ 162 valid_levels = {"debug", "info", "warn", "error"} 163 if level.lower() not in valid_levels: 164 raise ValueError(f"Invalid verbosity level: {level}. Must be one of {valid_levels}") 165 self.config.verbosity = level.lower() 166 return self 167 168 def add_time_limit(self, seconds: Optional[float] = None, minutes: Optional[float] = None, hours: Optional[float] = None) -> 'Batch': 169 """Add time limit for the entire batch execution. 170 171 When time limit is reached, all active provider batches are cancelled and 172 remaining unprocessed jobs are marked as failed. The batch execution 173 completes normally without throwing exceptions. 174 175 Args: 176 seconds: Time limit in seconds (optional) 177 minutes: Time limit in minutes (optional) 178 hours: Time limit in hours (optional) 179 180 Returns: 181 Self for chaining 182 183 Raises: 184 ValueError: If no time units specified, or if total time is outside 185 valid range (min: 10 seconds, max: 24 hours) 186 187 Note: 188 - Can combine multiple time units 189 - Time limit is checked every second by a background watchdog thread 190 - Jobs that exceed time limit appear in results()["failed"] with time limit error message 191 - No exceptions are thrown when time limit is reached 192 193 Example: 194 ```python 195 batch.add_time_limit(seconds=30) # 30 seconds 196 batch.add_time_limit(minutes=5) # 5 minutes 197 batch.add_time_limit(hours=2) # 2 hours 198 batch.add_time_limit(hours=1, minutes=30, seconds=15) # 5415 seconds total 199 ``` 200 """ 201 time_limit_seconds = 0.0 202 203 if seconds is not None: 204 time_limit_seconds += seconds 205 if minutes is not None: 206 time_limit_seconds += minutes * 60 207 if hours is not None: 208 time_limit_seconds += hours * 3600 209 210 if time_limit_seconds == 0: 211 raise ValueError("Must specify at least one of seconds, minutes, or hours") 212 213 self.config.time_limit_seconds = time_limit_seconds 214 return self 215 216 def add_job( 217 self, 218 messages: Optional[List[Message]] = None, 219 file: Optional[Union[str, Path]] = None, 220 prompt: Optional[str] = None, 221 model: Optional[str] = None, 222 temperature: Optional[float] = None, 223 max_tokens: Optional[int] = None, 224 response_model: Optional[Type[BaseModel]] = None, 225 enable_citations: bool = False, 226 **kwargs 227 ) -> 'Batch': 228 """Add a job to the batch. 229 230 Either provide messages OR file+prompt, not both. Parameters not provided 231 will use the defaults set via the defaults() method. 232 233 Args: 234 messages: Chat messages for direct input 235 file: File path for file-based input 236 prompt: Prompt to use with file input 237 model: Model to use (overrides default) 238 temperature: Sampling temperature (overrides default) 239 max_tokens: Max tokens to generate (overrides default) 240 response_model: Pydantic model for structured output 241 enable_citations: Whether to extract citations 242 **kwargs: Additional parameters 243 244 Returns: 245 Self for chaining 246 247 Example: 248 ```python 249 batch.add_job( 250 messages=[{"role": "user", "content": "Hello"}], 251 model="gpt-4" 252 ) 253 ``` 254 """ 255 # Generate unique job ID 256 job_id = f"job-{uuid.uuid4().hex[:8]}" 257 258 # Merge with defaults 259 params = self.config.default_params.copy() 260 261 # Update with provided parameters 262 if model is not None: 263 params["model"] = model 264 if temperature is not None: 265 params["temperature"] = temperature 266 if max_tokens is not None: 267 params["max_tokens"] = max_tokens 268 269 # Add other kwargs 270 params.update(kwargs) 271 272 # Ensure model is provided 273 if "model" not in params: 274 raise ValueError("Model must be provided either in defaults or job parameters") 275 276 # Validate parameters 277 provider = get_provider(params["model"]) 278 # Extract params without model to avoid duplicate 279 param_subset = {k: v for k, v in params.items() if k != "model"} 280 provider.validate_params(params["model"], **param_subset) 281 282 # Convert file path if string 283 if isinstance(file, str): 284 file = Path(file) 285 286 # Warn about temporary file paths that may not persist 287 if file: 288 file_str = str(file) 289 if "/tmp/" in file_str or "/var/folders/" in file_str or "temp" in file_str.lower(): 290 logger = logging.getLogger("batchata") 291 logger.debug(f"File path appears to be in a temporary directory: {file}") 292 logger.debug("This may cause issues when resuming from state if temp files are cleaned up") 293 294 # Create job 295 job = Job( 296 id=job_id, 297 messages=messages, 298 file=file, 299 prompt=prompt, 300 response_model=response_model, 301 enable_citations=enable_citations, 302 **params 303 ) 304 305 # Validate citation compatibility 306 if response_model and enable_citations: 307 from ..utils.validation import validate_flat_model 308 validate_flat_model(response_model) 309 310 # Validate job with provider (includes PDF validation for Anthropic) 311 provider.validate_job(job) 312 313 314 self.jobs.append(job) 315 return self 316 317 def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None, progress_interval: float = 1.0, print_status: bool = False) -> 'BatchRun': 318 """Execute the batch. 319 320 Creates a BatchRun instance and executes the jobs synchronously. 321 322 Args: 323 on_progress: Optional progress callback function that receives 324 (stats_dict, elapsed_time_seconds, batch_data) 325 progress_interval: Interval in seconds between progress updates (default: 1.0) 326 print_status: Whether to show rich progress display (default: False) 327 328 Returns: 329 BatchRun instance with completed results 330 331 Raises: 332 ValueError: If no jobs have been added 333 """ 334 if not self.jobs: 335 raise ValueError("No jobs added to batch") 336 337 # Import here to avoid circular dependency 338 from .batch_run import BatchRun 339 340 # Create and start the run 341 run = BatchRun(self.config, self.jobs) 342 343 # Set progress callback - either rich display or custom callback 344 if print_status: 345 return self._run_with_rich_display(run, progress_interval) 346 else: 347 return self._run_with_custom_callback(run, on_progress, progress_interval) 348 349 def _run_with_rich_display(self, run: 'BatchRun', progress_interval: float) -> 'BatchRun': 350 """Execute batch run with rich progress display. 351 352 Args: 353 run: BatchRun instance to execute 354 progress_interval: Interval between progress updates 355 356 Returns: 357 Completed BatchRun instance 358 """ 359 from ..utils.rich_progress import RichBatchProgressDisplay 360 display = RichBatchProgressDisplay() 361 362 def rich_progress_callback(stats, elapsed_time, batch_data): 363 # Start display on first call 364 if not hasattr(rich_progress_callback, '_started'): 365 config_dict = { 366 'results_dir': self.config.results_dir, 367 'state_file': self.config.state_file, 368 'items_per_batch': self.config.items_per_batch, 369 'max_parallel_batches': self.config.max_parallel_batches 370 } 371 display.start(stats, config_dict) 372 rich_progress_callback._started = True 373 374 # Update display 375 display.update(stats, batch_data, elapsed_time) 376 377 run.set_on_progress(rich_progress_callback, interval=progress_interval) 378 379 # Execute with proper cleanup 380 try: 381 run.execute() 382 383 # Show final status with all batches completed 384 stats = run.status() 385 display.update(stats, run.batch_tracking, (datetime.now() - run._start_time).total_seconds()) 386 387 # Small delay to ensure display updates 388 import time 389 time.sleep(0.2) 390 391 except KeyboardInterrupt: 392 # Update batch tracking to show cancelled status for pending/running batches 393 with run._state_lock: 394 for batch_id, batch_info in run.batch_tracking.items(): 395 if batch_info['status'] == 'running': 396 batch_info['status'] = 'cancelled' 397 elif batch_info['status'] == 'pending': 398 batch_info['status'] = 'cancelled' 399 400 # Show final status with cancelled batches 401 stats = run.status() 402 display.update(stats, run.batch_tracking, 0.0) 403 404 # Add a small delay to ensure the display updates 405 import time 406 time.sleep(0.1) 407 408 display.stop() 409 raise 410 finally: 411 if display.live: # Only stop if not already stopped 412 display.stop() 413 414 return run 415 416 def _run_with_custom_callback(self, run: 'BatchRun', on_progress: Optional[Callable[[Dict, float, Dict], None]], progress_interval: float) -> 'BatchRun': 417 """Execute batch run with custom progress callback. 418 419 Args: 420 run: BatchRun instance to execute 421 on_progress: Optional custom progress callback 422 progress_interval: Interval between progress updates 423 424 Returns: 425 Completed BatchRun instance 426 """ 427 # Use custom progress callback if provided 428 if on_progress: 429 run.set_on_progress(on_progress, interval=progress_interval) 430 431 run.execute() 432 return run 433 434 def __len__(self) -> int: 435 """Get the number of jobs in the batch.""" 436 return len(self.jobs) 437 438 def __repr__(self) -> str: 439 """String representation of the batch.""" 440 return ( 441 f"Batch(jobs={len(self.jobs)}, " 442 f"max_parallel_batches={self.config.max_parallel_batches}, " 443 f"cost_limit=${self.config.cost_limit_usd or 'None'})" 444 )
Builder for batch job configuration.
Provides a fluent interface for configuring batch jobs with sensible defaults and validation. The batch can be configured with cost limits, default parameters, and progress callbacks.
Example:
batch = Batch("./results", max_parallel_batches=10, items_per_batch=10)
.set_state(file="./state.json", reuse_state=True)
.set_default_params(model="claude-sonnet-4-20250514", temperature=0.7)
.add_cost_limit(usd=15.0)
.add_job(messages=[{"role": "user", "content": "Hello"}])
.add_job(file="./path/to/file.pdf", prompt="Generate summary of file")
run = batch.run()
38 def __init__(self, results_dir: str, max_parallel_batches: int = 10, items_per_batch: int = 10, raw_files: Optional[bool] = None): 39 """Initialize batch configuration. 40 41 Args: 42 results_dir: Directory to store results 43 max_parallel_batches: Maximum parallel batch requests 44 items_per_batch: Number of jobs per provider batch 45 raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise) 46 """ 47 # Auto-determine raw_files based on results_dir if not explicitly set 48 if raw_files is None: 49 raw_files = bool(results_dir and results_dir.strip()) 50 51 self.config = BatchParams( 52 state_file=None, 53 results_dir=results_dir, 54 max_parallel_batches=max_parallel_batches, 55 items_per_batch=items_per_batch, 56 reuse_state=True, 57 raw_files=raw_files 58 ) 59 self.jobs: List[Job] = []
Initialize batch configuration.
Args: results_dir: Directory to store results max_parallel_batches: Maximum parallel batch requests items_per_batch: Number of jobs per provider batch raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise)
61 def set_default_params(self, **kwargs) -> 'Batch': 62 """Set default parameters for all jobs. 63 64 These defaults will be applied to all jobs unless overridden 65 by job-specific parameters. 66 67 Args: 68 **kwargs: Default parameters (model, temperature, max_tokens, etc.) 69 70 Returns: 71 Self for chaining 72 73 Example: 74 ```python 75 batch.set_default_params(model="claude-3-sonnet", temperature=0.7) 76 ``` 77 """ 78 # Validate if model is provided 79 if "model" in kwargs: 80 self.config.validate_default_params(kwargs["model"]) 81 82 self.config.default_params.update(kwargs) 83 return self
Set default parameters for all jobs.
These defaults will be applied to all jobs unless overridden by job-specific parameters.
Args: **kwargs: Default parameters (model, temperature, max_tokens, etc.)
Returns: Self for chaining
Example:
batch.set_default_params(model="claude-3-sonnet", temperature=0.7)
85 def set_state(self, file: Optional[str] = None, reuse_state: bool = True) -> 'Batch': 86 """Set state file configuration. 87 88 Args: 89 file: Path to state file for persistence (default: None) 90 reuse_state: Whether to resume from existing state file (default: True) 91 92 Returns: 93 Self for chaining 94 95 Example: 96 ```python 97 batch.set_state(file="./state.json", reuse_state=True) 98 ``` 99 """ 100 self.config.state_file = file 101 self.config.reuse_state = reuse_state 102 return self
Set state file configuration.
Args: file: Path to state file for persistence (default: None) reuse_state: Whether to resume from existing state file (default: True)
Returns: Self for chaining
Example:
batch.set_state(file="./state.json", reuse_state=True)
104 def add_cost_limit(self, usd: float) -> 'Batch': 105 """Add cost limit for the batch. 106 107 The batch will stop accepting new jobs once the cost limit is reached. 108 Active jobs will be allowed to complete. 109 110 Args: 111 usd: Cost limit in USD 112 113 Returns: 114 Self for chaining 115 116 Example: 117 ```python 118 batch.add_cost_limit(usd=50.0) 119 ``` 120 """ 121 if usd <= 0: 122 raise ValueError("Cost limit must be positive") 123 self.config.cost_limit_usd = usd 124 return self
Add cost limit for the batch.
The batch will stop accepting new jobs once the cost limit is reached. Active jobs will be allowed to complete.
Args: usd: Cost limit in USD
Returns: Self for chaining
Example:
batch.add_cost_limit(usd=50.0)
126 def raw_files(self, enabled: bool = True) -> 'Batch': 127 """Enable or disable saving debug files from providers. 128 129 When enabled, debug files (raw API responses, JSONL files) will be saved 130 in a 'raw_files' subdirectory within the results directory. 131 This is useful for debugging, auditing, or accessing provider-specific metadata. 132 133 Args: 134 enabled: Whether to save debug files (default: True) 135 136 Returns: 137 Self for chaining 138 139 Example: 140 ```python 141 batch.raw_files(True) 142 ``` 143 """ 144 self.config.raw_files = enabled 145 return self
Enable or disable saving debug files from providers.
When enabled, debug files (raw API responses, JSONL files) will be saved in a 'raw_files' subdirectory within the results directory. This is useful for debugging, auditing, or accessing provider-specific metadata.
Args: enabled: Whether to save debug files (default: True)
Returns: Self for chaining
Example:
batch.raw_files(True)
147 def set_verbosity(self, level: str) -> 'Batch': 148 """Set logging verbosity level. 149 150 Args: 151 level: Verbosity level ("debug", "info", "warn", "error") 152 153 Returns: 154 Self for chaining 155 156 Example: 157 ```python 158 batch.set_verbosity("error") # For production 159 batch.set_verbosity("debug") # For debugging 160 ``` 161 """ 162 valid_levels = {"debug", "info", "warn", "error"} 163 if level.lower() not in valid_levels: 164 raise ValueError(f"Invalid verbosity level: {level}. Must be one of {valid_levels}") 165 self.config.verbosity = level.lower() 166 return self
Set logging verbosity level.
Args: level: Verbosity level ("debug", "info", "warn", "error")
Returns: Self for chaining
Example:
batch.set_verbosity("error") # For production
batch.set_verbosity("debug") # For debugging
168 def add_time_limit(self, seconds: Optional[float] = None, minutes: Optional[float] = None, hours: Optional[float] = None) -> 'Batch': 169 """Add time limit for the entire batch execution. 170 171 When time limit is reached, all active provider batches are cancelled and 172 remaining unprocessed jobs are marked as failed. The batch execution 173 completes normally without throwing exceptions. 174 175 Args: 176 seconds: Time limit in seconds (optional) 177 minutes: Time limit in minutes (optional) 178 hours: Time limit in hours (optional) 179 180 Returns: 181 Self for chaining 182 183 Raises: 184 ValueError: If no time units specified, or if total time is outside 185 valid range (min: 10 seconds, max: 24 hours) 186 187 Note: 188 - Can combine multiple time units 189 - Time limit is checked every second by a background watchdog thread 190 - Jobs that exceed time limit appear in results()["failed"] with time limit error message 191 - No exceptions are thrown when time limit is reached 192 193 Example: 194 ```python 195 batch.add_time_limit(seconds=30) # 30 seconds 196 batch.add_time_limit(minutes=5) # 5 minutes 197 batch.add_time_limit(hours=2) # 2 hours 198 batch.add_time_limit(hours=1, minutes=30, seconds=15) # 5415 seconds total 199 ``` 200 """ 201 time_limit_seconds = 0.0 202 203 if seconds is not None: 204 time_limit_seconds += seconds 205 if minutes is not None: 206 time_limit_seconds += minutes * 60 207 if hours is not None: 208 time_limit_seconds += hours * 3600 209 210 if time_limit_seconds == 0: 211 raise ValueError("Must specify at least one of seconds, minutes, or hours") 212 213 self.config.time_limit_seconds = time_limit_seconds 214 return self
Add time limit for the entire batch execution.
When time limit is reached, all active provider batches are cancelled and remaining unprocessed jobs are marked as failed. The batch execution completes normally without throwing exceptions.
Args: seconds: Time limit in seconds (optional) minutes: Time limit in minutes (optional) hours: Time limit in hours (optional)
Returns: Self for chaining
Raises: ValueError: If no time units specified, or if total time is outside valid range (min: 10 seconds, max: 24 hours)
Note: - Can combine multiple time units - Time limit is checked every second by a background watchdog thread - Jobs that exceed time limit appear in results()["failed"] with time limit error message - No exceptions are thrown when time limit is reached
Example:
batch.add_time_limit(seconds=30) # 30 seconds
batch.add_time_limit(minutes=5) # 5 minutes
batch.add_time_limit(hours=2) # 2 hours
batch.add_time_limit(hours=1, minutes=30, seconds=15) # 5415 seconds total
216 def add_job( 217 self, 218 messages: Optional[List[Message]] = None, 219 file: Optional[Union[str, Path]] = None, 220 prompt: Optional[str] = None, 221 model: Optional[str] = None, 222 temperature: Optional[float] = None, 223 max_tokens: Optional[int] = None, 224 response_model: Optional[Type[BaseModel]] = None, 225 enable_citations: bool = False, 226 **kwargs 227 ) -> 'Batch': 228 """Add a job to the batch. 229 230 Either provide messages OR file+prompt, not both. Parameters not provided 231 will use the defaults set via the defaults() method. 232 233 Args: 234 messages: Chat messages for direct input 235 file: File path for file-based input 236 prompt: Prompt to use with file input 237 model: Model to use (overrides default) 238 temperature: Sampling temperature (overrides default) 239 max_tokens: Max tokens to generate (overrides default) 240 response_model: Pydantic model for structured output 241 enable_citations: Whether to extract citations 242 **kwargs: Additional parameters 243 244 Returns: 245 Self for chaining 246 247 Example: 248 ```python 249 batch.add_job( 250 messages=[{"role": "user", "content": "Hello"}], 251 model="gpt-4" 252 ) 253 ``` 254 """ 255 # Generate unique job ID 256 job_id = f"job-{uuid.uuid4().hex[:8]}" 257 258 # Merge with defaults 259 params = self.config.default_params.copy() 260 261 # Update with provided parameters 262 if model is not None: 263 params["model"] = model 264 if temperature is not None: 265 params["temperature"] = temperature 266 if max_tokens is not None: 267 params["max_tokens"] = max_tokens 268 269 # Add other kwargs 270 params.update(kwargs) 271 272 # Ensure model is provided 273 if "model" not in params: 274 raise ValueError("Model must be provided either in defaults or job parameters") 275 276 # Validate parameters 277 provider = get_provider(params["model"]) 278 # Extract params without model to avoid duplicate 279 param_subset = {k: v for k, v in params.items() if k != "model"} 280 provider.validate_params(params["model"], **param_subset) 281 282 # Convert file path if string 283 if isinstance(file, str): 284 file = Path(file) 285 286 # Warn about temporary file paths that may not persist 287 if file: 288 file_str = str(file) 289 if "/tmp/" in file_str or "/var/folders/" in file_str or "temp" in file_str.lower(): 290 logger = logging.getLogger("batchata") 291 logger.debug(f"File path appears to be in a temporary directory: {file}") 292 logger.debug("This may cause issues when resuming from state if temp files are cleaned up") 293 294 # Create job 295 job = Job( 296 id=job_id, 297 messages=messages, 298 file=file, 299 prompt=prompt, 300 response_model=response_model, 301 enable_citations=enable_citations, 302 **params 303 ) 304 305 # Validate citation compatibility 306 if response_model and enable_citations: 307 from ..utils.validation import validate_flat_model 308 validate_flat_model(response_model) 309 310 # Validate job with provider (includes PDF validation for Anthropic) 311 provider.validate_job(job) 312 313 314 self.jobs.append(job) 315 return self
Add a job to the batch.
Either provide messages OR file+prompt, not both. Parameters not provided will use the defaults set via the defaults() method.
Args: messages: Chat messages for direct input file: File path for file-based input prompt: Prompt to use with file input model: Model to use (overrides default) temperature: Sampling temperature (overrides default) max_tokens: Max tokens to generate (overrides default) response_model: Pydantic model for structured output enable_citations: Whether to extract citations **kwargs: Additional parameters
Returns: Self for chaining
Example:
batch.add_job(
messages=[{"role": "user", "content": "Hello"}],
model="gpt-4"
)
317 def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None, progress_interval: float = 1.0, print_status: bool = False) -> 'BatchRun': 318 """Execute the batch. 319 320 Creates a BatchRun instance and executes the jobs synchronously. 321 322 Args: 323 on_progress: Optional progress callback function that receives 324 (stats_dict, elapsed_time_seconds, batch_data) 325 progress_interval: Interval in seconds between progress updates (default: 1.0) 326 print_status: Whether to show rich progress display (default: False) 327 328 Returns: 329 BatchRun instance with completed results 330 331 Raises: 332 ValueError: If no jobs have been added 333 """ 334 if not self.jobs: 335 raise ValueError("No jobs added to batch") 336 337 # Import here to avoid circular dependency 338 from .batch_run import BatchRun 339 340 # Create and start the run 341 run = BatchRun(self.config, self.jobs) 342 343 # Set progress callback - either rich display or custom callback 344 if print_status: 345 return self._run_with_rich_display(run, progress_interval) 346 else: 347 return self._run_with_custom_callback(run, on_progress, progress_interval)
Execute the batch.
Creates a BatchRun instance and executes the jobs synchronously.
Args: on_progress: Optional progress callback function that receives (stats_dict, elapsed_time_seconds, batch_data) progress_interval: Interval in seconds between progress updates (default: 1.0) print_status: Whether to show rich progress display (default: False)
Returns: BatchRun instance with completed results
Raises: ValueError: If no jobs have been added
25class BatchRun: 26 """Manages the execution of a batch job synchronously. 27 28 Processes jobs in batches based on items_per_batch configuration. 29 Simpler synchronous execution for clear logging and debugging. 30 31 Example: 32 ```python 33 config = BatchParams(...) 34 run = BatchRun(config, jobs) 35 run.execute() 36 results = run.results() 37 ``` 38 """ 39 40 def __init__(self, config: BatchParams, jobs: List[Job]): 41 """Initialize batch run. 42 43 Args: 44 config: Batch configuration 45 jobs: List of jobs to execute 46 """ 47 self.config = config 48 self.jobs = {job.id: job for job in jobs} 49 50 # Set logging level based on config 51 set_log_level(level=config.verbosity.upper()) 52 53 # Initialize components 54 self.cost_tracker = CostTracker(limit_usd=config.cost_limit_usd) 55 56 # Use temp file for state if not provided 57 state_file = config.state_file 58 if not state_file: 59 state_file = create_temp_state_file(config) 60 config.reuse_state = False 61 logger.info(f"Created temporary state file: {state_file}") 62 63 self.state_manager = StateManager(state_file) 64 65 # State tracking 66 self.pending_jobs: List[Job] = [] 67 self.completed_results: Dict[str, JobResult] = {} # job_id -> result 68 self.failed_jobs: Dict[str, str] = {} # job_id -> error 69 self.cancelled_jobs: Dict[str, str] = {} # job_id -> reason 70 71 # Batch tracking 72 self.total_batches = 0 73 self.completed_batches = 0 74 self.current_batch_index = 0 75 self.current_batch_size = 0 76 77 # Execution control 78 self._started = False 79 self._start_time: Optional[datetime] = None 80 self._time_limit_exceeded = False 81 self._progress_callback: Optional[Callable[[Dict, float], None]] = None 82 self._progress_interval: float = 1.0 # Default to 1 second 83 84 # Threading primitives 85 self._state_lock = threading.Lock() 86 self._shutdown_event = threading.Event() 87 88 # Batch tracking for progress display 89 self.batch_tracking: Dict[str, Dict] = {} # batch_id -> batch_info 90 91 # Active batch tracking for cancellation 92 self._active_batches: Dict[str, object] = {} # batch_id -> provider 93 self._active_batches_lock = threading.Lock() 94 95 # Results directory 96 self.results_dir = Path(config.results_dir) 97 98 # If not reusing state, clear the results directory 99 if not config.reuse_state and self.results_dir.exists(): 100 import shutil 101 shutil.rmtree(self.results_dir) 102 103 self.results_dir.mkdir(parents=True, exist_ok=True) 104 105 # Raw files directory (if enabled) 106 self.raw_files_dir = None 107 if config.raw_files: 108 self.raw_files_dir = self.results_dir / "raw_files" 109 self.raw_files_dir.mkdir(parents=True, exist_ok=True) 110 111 # Try to resume from saved state 112 self._resume_from_state() 113 114 115 def _resume_from_state(self): 116 """Resume from saved state if available.""" 117 # Check if we should reuse state 118 if not self.config.reuse_state: 119 # Clear any existing state and start fresh 120 self.state_manager.clear() 121 self.pending_jobs = list(self.jobs.values()) 122 return 123 124 state = self.state_manager.load() 125 if state is None: 126 # No saved state, use jobs passed to constructor 127 self.pending_jobs = list(self.jobs.values()) 128 return 129 130 logger.info("Resuming batch run from saved state") 131 132 # Restore pending jobs 133 self.pending_jobs = [] 134 for job_data in state.pending_jobs: 135 job = Job.from_dict(job_data) 136 # Check if file exists (if job has a file) 137 if job.file and not job.file.exists(): 138 logger.error(f"File not found for job {job.id}: {job.file}") 139 logger.error("This may happen if files were in temporary directories that were cleaned up") 140 self.failed_jobs[job.id] = f"File not found: {job.file}" 141 else: 142 self.pending_jobs.append(job) 143 144 # Restore completed results from file references 145 for result_ref in state.completed_results: 146 job_id = result_ref["job_id"] 147 file_path = result_ref["file_path"] 148 try: 149 with open(file_path, 'r') as f: 150 result_data = json.load(f) 151 result = JobResult.from_dict(result_data) 152 self.completed_results[job_id] = result 153 except Exception as e: 154 logger.error(f"Failed to load result for {job_id} from {file_path}: {e}") 155 # Move to failed jobs if we can't load the result 156 self.failed_jobs[job_id] = f"Failed to load result file: {e}" 157 158 # Restore failed jobs 159 for job_data in state.failed_jobs: 160 self.failed_jobs[job_data["id"]] = job_data.get("error", "Unknown error") 161 162 # Restore cancelled jobs (if they exist in state) 163 for job_data in getattr(state, 'cancelled_jobs', []): 164 self.cancelled_jobs[job_data["id"]] = job_data.get("reason", "Cancelled") 165 166 # Restore cost tracker 167 self.cost_tracker.used_usd = state.total_cost_usd 168 169 logger.info( 170 f"Resumed with {len(self.pending_jobs)} pending, " 171 f"{len(self.completed_results)} completed, " 172 f"{len(self.failed_jobs)} failed, " 173 f"{len(self.cancelled_jobs)} cancelled" 174 ) 175 176 def to_json(self) -> Dict: 177 """Convert current state to JSON-serializable dict.""" 178 return { 179 "created_at": datetime.now().isoformat(), 180 "pending_jobs": [job.to_dict() for job in self.pending_jobs], 181 "completed_results": [ 182 {"job_id": job_id, "file_path": str(self.results_dir / f"{job_id}.json")} 183 for job_id in self.completed_results.keys() 184 ], 185 "failed_jobs": [ 186 { 187 "id": job_id, 188 "error": error, 189 "timestamp": datetime.now().isoformat() 190 } for job_id, error in self.failed_jobs.items() 191 ], 192 "cancelled_jobs": [ 193 { 194 "id": job_id, 195 "reason": reason, 196 "timestamp": datetime.now().isoformat() 197 } for job_id, reason in self.cancelled_jobs.items() 198 ], 199 "total_cost_usd": self.cost_tracker.used_usd, 200 "config": { 201 "state_file": self.config.state_file, 202 "results_dir": self.config.results_dir, 203 "max_parallel_batches": self.config.max_parallel_batches, 204 "items_per_batch": self.config.items_per_batch, 205 "cost_limit_usd": self.config.cost_limit_usd, 206 "default_params": self.config.default_params, 207 "raw_files": self.config.raw_files 208 } 209 } 210 211 def execute(self): 212 """Execute synchronous batch run and wait for completion.""" 213 if self._started: 214 raise RuntimeError("Batch run already started") 215 216 self._started = True 217 self._start_time = datetime.now() 218 219 # Register signal handler for graceful shutdown 220 def signal_handler(signum, frame): 221 logger.warning("Received interrupt signal, shutting down gracefully...") 222 self._shutdown_event.set() 223 224 # Store original handler to restore later 225 original_handler = signal.signal(signal.SIGINT, signal_handler) 226 227 try: 228 logger.info("Starting batch run") 229 230 # Start time limit watchdog if configured 231 self._start_time_limit_watchdog() 232 233 # Call initial progress 234 if self._progress_callback: 235 stats = self.status() 236 batch_data = dict(self.batch_tracking) 237 self._progress_callback(stats, 0.0, batch_data) 238 239 # Process all jobs synchronously 240 self._process_all_jobs() 241 242 logger.info("Batch run completed") 243 finally: 244 # Restore original signal handler 245 signal.signal(signal.SIGINT, original_handler) 246 247 def set_on_progress(self, callback: Callable[[Dict, float, Dict], None], interval: float = 1.0) -> 'BatchRun': 248 """Set progress callback for execution monitoring. 249 250 The callback will be called periodically with progress statistics 251 including completed jobs, total jobs, current cost, etc. 252 253 Args: 254 callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data) 255 - stats_dict: Progress statistics dictionary 256 - elapsed_time_seconds: Time elapsed since batch started (float) 257 - batch_data: Dictionary mapping batch_id to batch information 258 interval: Interval in seconds between progress updates (default: 1.0) 259 260 Returns: 261 Self for chaining 262 263 Example: 264 ```python 265 run.set_on_progress( 266 lambda stats, time, batch_data: print( 267 f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s" 268 ) 269 ) 270 ``` 271 """ 272 self._progress_callback = callback 273 self._progress_interval = interval 274 return self 275 276 def _start_time_limit_watchdog(self): 277 """Start a background thread to check for time limit every second.""" 278 if not self.config.time_limit_seconds: 279 return 280 281 def time_limit_watchdog(): 282 """Check for time limit every second and trigger shutdown if exceeded.""" 283 while not self._shutdown_event.is_set(): 284 if self._check_time_limit(): 285 logger.warning("Batch execution time limit exceeded") 286 with self._state_lock: 287 self._time_limit_exceeded = True 288 self._shutdown_event.set() 289 break 290 time.sleep(1.0) 291 292 # Start watchdog as daemon thread 293 watchdog_thread = threading.Thread(target=time_limit_watchdog, daemon=True) 294 watchdog_thread.start() 295 logger.debug(f"Started time limit watchdog thread (time limit: {self.config.time_limit_seconds}s)") 296 297 def _check_time_limit(self) -> bool: 298 """Check if batch execution has exceeded time limit.""" 299 if not self.config.time_limit_seconds or not self._start_time: 300 return False 301 302 elapsed = (datetime.now() - self._start_time).total_seconds() 303 return elapsed >= self.config.time_limit_seconds 304 305 def _process_all_jobs(self): 306 """Process all jobs with parallel execution.""" 307 # Prepare all batches 308 batches = self._prepare_batches() 309 self.total_batches = len(batches) 310 311 # Process batches in parallel 312 with ThreadPoolExecutor(max_workers=self.config.max_parallel_batches) as executor: 313 futures = [executor.submit(self._execute_batch_wrapped, provider, batch_jobs) 314 for _, provider, batch_jobs in batches] 315 316 try: 317 for future in as_completed(futures): 318 # Stop if shutdown event detected (includes time limit) 319 if self._shutdown_event.is_set(): 320 break 321 future.result() # Re-raise any exceptions 322 except KeyboardInterrupt: 323 self._shutdown_event.set() 324 # Cancel remaining futures 325 for future in futures: 326 future.cancel() 327 raise 328 finally: 329 # Handle time limit or cancellation - mark remaining jobs appropriately 330 with self._state_lock: 331 if self._shutdown_event.is_set(): 332 # If time limit exceeded, cancel all active batches 333 if self._time_limit_exceeded: 334 self._cancel_all_active_batches() 335 336 # Mark any unprocessed jobs based on reason for shutdown 337 for _, _, batch_jobs in batches: 338 for job in batch_jobs: 339 # Skip jobs already processed 340 if (job.id in self.completed_results or 341 job.id in self.failed_jobs or 342 job.id in self.cancelled_jobs): 343 continue 344 345 # Mark based on shutdown reason 346 if self._time_limit_exceeded: 347 self.failed_jobs[job.id] = "Time limit exceeded: batch execution time limit exceeded" 348 else: 349 self.cancelled_jobs[job.id] = "Cancelled by user" 350 351 if job in self.pending_jobs: 352 self.pending_jobs.remove(job) 353 354 # Save state 355 self.state_manager.save(self) 356 357 def _cancel_all_active_batches(self): 358 """Cancel all active batches at the provider level.""" 359 with self._active_batches_lock: 360 active_batch_items = list(self._active_batches.items()) 361 362 logger.info(f"Cancelling {len(active_batch_items)} active batches due to time limit exceeded") 363 364 # Cancel outside the lock to avoid blocking 365 for batch_id, provider in active_batch_items: 366 try: 367 provider.cancel_batch(batch_id) 368 logger.info(f"Cancelled batch {batch_id} due to time limit exceeded") 369 except Exception as e: 370 logger.warning(f"Failed to cancel batch {batch_id}: {e}") 371 372 # Clear the tracking after cancellation attempts 373 with self._active_batches_lock: 374 self._active_batches.clear() 375 376 def _execute_batch_wrapped(self, provider, batch_jobs): 377 """Thread-safe wrapper for _execute_batch.""" 378 try: 379 result = self._execute_batch(provider, batch_jobs) 380 with self._state_lock: 381 self._update_batch_results(result) 382 # Remove jobs from pending_jobs if specified 383 jobs_to_remove = result.get("jobs_to_remove", []) 384 for job in jobs_to_remove: 385 if job in self.pending_jobs: 386 self.pending_jobs.remove(job) 387 except TimeoutError: 388 # Handle time limit exceeded - mark jobs as failed 389 with self._state_lock: 390 for job in batch_jobs: 391 self.failed_jobs[job.id] = "Time limit exceeded: batch execution time limit exceeded" 392 if job in self.pending_jobs: 393 self.pending_jobs.remove(job) 394 self.state_manager.save(self) 395 # Don't re-raise, just return the result 396 return 397 except KeyboardInterrupt: 398 self._shutdown_event.set() 399 # Handle user cancellation 400 with self._state_lock: 401 for job in batch_jobs: 402 self.cancelled_jobs[job.id] = "Cancelled by user" 403 if job in self.pending_jobs: 404 self.pending_jobs.remove(job) 405 self.state_manager.save(self) 406 raise 407 408 def _group_jobs_by_provider(self) -> Dict[str, List[Job]]: 409 """Group jobs by provider.""" 410 jobs_by_provider = {} 411 412 for job in self.pending_jobs[:]: # Copy to avoid modification during iteration 413 try: 414 provider = get_provider(job.model) 415 provider_name = provider.__class__.__name__ 416 417 if provider_name not in jobs_by_provider: 418 jobs_by_provider[provider_name] = [] 419 420 jobs_by_provider[provider_name].append(job) 421 422 except Exception as e: 423 logger.error(f"Failed to get provider for job {job.id}: {e}") 424 with self._state_lock: 425 self.failed_jobs[job.id] = str(e) 426 self.pending_jobs.remove(job) 427 428 return jobs_by_provider 429 430 def _split_into_batches(self, jobs: List[Job]) -> List[List[Job]]: 431 """Split jobs into batches based on items_per_batch.""" 432 batches = [] 433 batch_size = self.config.items_per_batch 434 435 for i in range(0, len(jobs), batch_size): 436 batch = jobs[i:i + batch_size] 437 batches.append(batch) 438 439 return batches 440 441 def _prepare_batches(self) -> List[Tuple[str, object, List[Job]]]: 442 """Prepare all batches as simple list of (provider_name, provider, jobs).""" 443 batches = [] 444 jobs_by_provider = self._group_jobs_by_provider() 445 446 for provider_name, provider_jobs in jobs_by_provider.items(): 447 provider = get_provider(provider_jobs[0].model) 448 job_batches = self._split_into_batches(provider_jobs) 449 450 for batch_jobs in job_batches: 451 batches.append((provider_name, provider, batch_jobs)) 452 453 # Pre-populate batch tracking for pending batches 454 batch_id = f"pending_{len(self.batch_tracking)}" 455 estimated_cost = provider.estimate_cost(batch_jobs) 456 self.batch_tracking[batch_id] = { 457 'start_time': None, 458 'status': 'pending', 459 'total': len(batch_jobs), 460 'completed': 0, 461 'cost': 0.0, 462 'estimated_cost': estimated_cost, 463 'provider': provider_name, 464 'jobs': batch_jobs 465 } 466 467 return batches 468 469 def _poll_batch_status(self, provider, batch_id: str) -> Tuple[str, Optional[Dict]]: 470 """Poll until batch completes.""" 471 status, error_details = provider.get_batch_status(batch_id) 472 logger.info(f"Initial batch status: {status}") 473 poll_count = 0 474 475 # Use provider-specific polling interval 476 provider_polling_interval = provider.get_polling_interval() 477 logger.debug(f"Using {provider_polling_interval}s polling interval for {provider.__class__.__name__}") 478 479 while status not in ["complete", "failed"]: 480 poll_count += 1 481 logger.debug(f"Polling attempt {poll_count}, current status: {status}") 482 483 # Interruptible wait - will wake up immediately if shutdown event is set (includes time limit) 484 if self._shutdown_event.wait(provider_polling_interval): 485 # Check if it's time limit exceeded or user cancellation 486 with self._state_lock: 487 is_time_limit_exceeded = self._time_limit_exceeded 488 489 if is_time_limit_exceeded: 490 logger.info(f"Batch {batch_id} polling interrupted by time limit exceeded") 491 raise TimeoutError("Batch cancelled due to time limit exceeded") 492 else: 493 logger.info(f"Batch {batch_id} polling interrupted by user") 494 raise KeyboardInterrupt("Batch cancelled by user") 495 496 status, error_details = provider.get_batch_status(batch_id) 497 498 if self._progress_callback: 499 with self._state_lock: 500 stats = self.status() 501 elapsed_time = (datetime.now() - self._start_time).total_seconds() 502 batch_data = dict(self.batch_tracking) 503 self._progress_callback(stats, elapsed_time, batch_data) 504 505 elapsed_seconds = poll_count * provider_polling_interval 506 logger.info(f"Batch {batch_id} status: {status} (polling for {elapsed_seconds:.1f}s)") 507 508 return status, error_details 509 510 511 def _update_batch_results(self, batch_result: Dict): 512 """Update state from batch results.""" 513 results = batch_result.get("results", []) 514 failed = batch_result.get("failed", {}) 515 516 # Update completed results 517 for result in results: 518 if result.is_success: 519 self.completed_results[result.job_id] = result 520 self._save_result_to_file(result) 521 logger.info(f"✓ Job {result.job_id} completed successfully") 522 else: 523 error_message = result.error or "Unknown error" 524 self.failed_jobs[result.job_id] = error_message 525 self._save_result_to_file(result) 526 logger.error(f"✗ Job {result.job_id} failed: {result.error}") 527 528 # Remove completed/failed job from pending 529 self.pending_jobs = [job for job in self.pending_jobs if job.id != result.job_id] 530 531 # Update failed jobs 532 for job_id, error in failed.items(): 533 self.failed_jobs[job_id] = error 534 # Remove failed job from pending 535 self.pending_jobs = [job for job in self.pending_jobs if job.id != job_id] 536 logger.error(f"✗ Job {job_id} failed: {error}") 537 538 # Update batch tracking 539 self.completed_batches += 1 540 541 # Save state 542 self.state_manager.save(self) 543 544 def _execute_batch(self, provider, batch_jobs: List[Job]) -> Dict: 545 """Execute one batch, return results dict with jobs/costs/errors.""" 546 if not batch_jobs: 547 return {"results": [], "failed": {}, "cost": 0.0} 548 549 # Reserve cost limit 550 logger.info(f"Estimating cost for batch of {len(batch_jobs)} jobs...") 551 estimated_cost = provider.estimate_cost(batch_jobs) 552 remaining = self.cost_tracker.remaining() 553 remaining_str = f"${remaining:.4f}" if remaining is not None else "unlimited" 554 logger.info(f"Total estimated cost: ${estimated_cost:.4f}, remaining budget: {remaining_str}") 555 556 if not self.cost_tracker.reserve_cost(estimated_cost): 557 logger.warning(f"Cost limit would be exceeded, skipping batch of {len(batch_jobs)} jobs") 558 failed = {} 559 for job in batch_jobs: 560 failed[job.id] = "Cost limit exceeded" 561 return {"results": [], "failed": failed, "cost": 0.0, "jobs_to_remove": list(batch_jobs)} 562 563 batch_id = None 564 job_mapping = None 565 try: 566 # Create batch 567 logger.info(f"Creating batch with {len(batch_jobs)} jobs...") 568 raw_files_path = str(self.raw_files_dir) if self.raw_files_dir else None 569 batch_id, job_mapping = provider.create_batch(batch_jobs, raw_files_path) 570 571 # Track active batch for cancellation 572 with self._active_batches_lock: 573 self._active_batches[batch_id] = provider 574 575 # Track batch creation 576 with self._state_lock: 577 # Remove pending entry if it exists 578 pending_keys = [k for k in self.batch_tracking.keys() if k.startswith('pending_')] 579 for pending_key in pending_keys: 580 if self.batch_tracking[pending_key]['jobs'] == batch_jobs: 581 del self.batch_tracking[pending_key] 582 break 583 584 # Add actual batch tracking 585 self.batch_tracking[batch_id] = { 586 'start_time': datetime.now(), 587 'status': 'running', 588 'total': len(batch_jobs), 589 'completed': 0, 590 'cost': 0.0, 591 'estimated_cost': estimated_cost, 592 'provider': provider.__class__.__name__, 593 'jobs': batch_jobs 594 } 595 596 # Poll for completion 597 logger.info(f"Polling for batch {batch_id} completion...") 598 status, error_details = self._poll_batch_status(provider, batch_id) 599 600 if status == "failed": 601 if error_details: 602 logger.error(f"Batch {batch_id} failed with details: {error_details}") 603 else: 604 logger.error(f"Batch {batch_id} failed") 605 606 # Save error details if configured 607 if self.raw_files_dir and error_details: 608 self._save_batch_error_details(batch_id, error_details) 609 610 # Continue to get individual results - some jobs might have succeeded 611 612 # Get results 613 logger.info(f"Getting results for batch {batch_id}") 614 raw_files_path = str(self.raw_files_dir) if self.raw_files_dir else None 615 results = provider.get_batch_results(batch_id, job_mapping, raw_files_path) 616 617 # Calculate actual cost and adjust reservation 618 actual_cost = sum(r.cost_usd for r in results) 619 self.cost_tracker.adjust_reserved_cost(estimated_cost, actual_cost) 620 621 # Update batch tracking for completion 622 success_count = len([r for r in results if r.is_success]) 623 failed_count = len([r for r in results if not r.is_success]) 624 batch_status = 'complete' if failed_count == 0 else 'failed' 625 626 with self._state_lock: 627 if batch_id in self.batch_tracking: 628 self.batch_tracking[batch_id]['status'] = batch_status 629 self.batch_tracking[batch_id]['completed'] = success_count 630 self.batch_tracking[batch_id]['failed'] = failed_count 631 self.batch_tracking[batch_id]['cost'] = actual_cost 632 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 633 if batch_status == 'failed' and failed_count > 0: 634 # Use the first job's error as the batch error summary 635 first_error = next((r.error for r in results if not r.is_success), 'Some jobs failed') 636 self.batch_tracking[batch_id]['error'] = first_error 637 638 # Remove from active batches tracking 639 with self._active_batches_lock: 640 self._active_batches.pop(batch_id, None) 641 642 status_symbol = "✓" if batch_status == 'complete' else "⚠" 643 logger.info( 644 f"{status_symbol} Batch {batch_id} completed: " 645 f"{success_count} success, " 646 f"{failed_count} failed, " 647 f"cost: ${actual_cost:.6f}" 648 ) 649 650 return {"results": results, "failed": {}, "cost": actual_cost, "jobs_to_remove": list(batch_jobs)} 651 652 except TimeoutError: 653 logger.info(f"Time limit exceeded for batch{f' {batch_id}' if batch_id else ''}") 654 if batch_id: 655 # Update batch tracking for time limit exceeded 656 with self._state_lock: 657 if batch_id in self.batch_tracking: 658 self.batch_tracking[batch_id]['status'] = 'failed' 659 self.batch_tracking[batch_id]['error'] = 'Time limit exceeded: batch execution time limit exceeded' 660 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 661 # NOTE: Don't remove from _active_batches - let centralized cancellation handle it 662 # Release the reservation since batch exceeded time limit 663 self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0) 664 # Re-raise to be handled by wrapper 665 raise 666 667 except KeyboardInterrupt: 668 logger.warning(f"\nCancelling batch{f' {batch_id}' if batch_id else ''}...") 669 if batch_id: 670 # Update batch tracking for cancellation 671 with self._state_lock: 672 if batch_id in self.batch_tracking: 673 self.batch_tracking[batch_id]['status'] = 'cancelled' 674 self.batch_tracking[batch_id]['error'] = 'Cancelled by user' 675 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 676 # Remove from active batches tracking 677 with self._active_batches_lock: 678 self._active_batches.pop(batch_id, None) 679 # Release the reservation since batch was cancelled 680 self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0) 681 # Handle cancellation in the wrapper with proper locking 682 raise 683 684 except Exception as e: 685 logger.error(f"✗ Batch execution failed: {e}") 686 # Update batch tracking for exception 687 if batch_id: 688 with self._state_lock: 689 if batch_id in self.batch_tracking: 690 self.batch_tracking[batch_id]['status'] = 'failed' 691 self.batch_tracking[batch_id]['error'] = str(e) 692 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 693 # Remove from active batches tracking 694 with self._active_batches_lock: 695 self._active_batches.pop(batch_id, None) 696 # Release the reservation since batch failed 697 self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0) 698 failed = {} 699 for job in batch_jobs: 700 failed[job.id] = str(e) 701 return {"results": [], "failed": failed, "cost": 0.0, "jobs_to_remove": list(batch_jobs)} 702 703 704 def _save_result_to_file(self, result: JobResult): 705 """Save individual result to file.""" 706 result_file = self.results_dir / f"{result.job_id}.json" 707 708 try: 709 with open(result_file, 'w') as f: 710 json.dump(result.to_dict(), f, indent=2) 711 except Exception as e: 712 logger.error(f"Failed to save result for {result.job_id}: {e}") 713 714 def _save_batch_error_details(self, batch_id: str, error_details: Dict): 715 """Save batch error details to debug files directory.""" 716 try: 717 error_file = self.raw_files_dir / f"batch_{batch_id}_error.json" 718 with open(error_file, 'w') as f: 719 json.dump({ 720 "batch_id": batch_id, 721 "timestamp": datetime.now().isoformat(), 722 "error_details": error_details 723 }, f, indent=2) 724 logger.info(f"Saved batch error details to {error_file}") 725 except Exception as e: 726 logger.error(f"Failed to save batch error details: {e}") 727 728 @property 729 def is_complete(self) -> bool: 730 """Whether all jobs are complete.""" 731 total_jobs = len(self.jobs) 732 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 733 return len(self.pending_jobs) == 0 and completed_count == total_jobs 734 735 736 def status(self, print_status: bool = False) -> Dict: 737 """Get current execution statistics.""" 738 total_jobs = len(self.jobs) 739 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 740 remaining_count = total_jobs - completed_count 741 742 stats = { 743 "total": total_jobs, 744 "pending": remaining_count, 745 "active": 0, # Always 0 for synchronous execution 746 "completed": len(self.completed_results), 747 "failed": len(self.failed_jobs), 748 "cancelled": len(self.cancelled_jobs), 749 "cost_usd": self.cost_tracker.used_usd, 750 "cost_limit_usd": self.cost_tracker.limit_usd, 751 "is_complete": self.is_complete, 752 "batches_total": self.total_batches, 753 "batches_completed": self.completed_batches, 754 "batches_pending": self.total_batches - self.completed_batches, 755 "current_batch_index": self.current_batch_index, 756 "current_batch_size": self.current_batch_size, 757 "items_per_batch": self.config.items_per_batch 758 } 759 760 if print_status: 761 logger.info("\nBatch Run Status:") 762 logger.info(f" Total jobs: {stats['total']}") 763 logger.info(f" Pending: {stats['pending']}") 764 logger.info(f" Active: {stats['active']}") 765 logger.info(f" Completed: {stats['completed']}") 766 logger.info(f" Failed: {stats['failed']}") 767 logger.info(f" Cancelled: {stats['cancelled']}") 768 logger.info(f" Cost: ${stats['cost_usd']:.6f}") 769 if stats['cost_limit_usd']: 770 logger.info(f" Cost limit: ${stats['cost_limit_usd']:.2f}") 771 logger.info(f" Complete: {stats['is_complete']}") 772 773 return stats 774 775 def results(self) -> Dict[str, List[JobResult]]: 776 """Get all results organized by status. 777 778 Returns: 779 { 780 "completed": [JobResult], 781 "failed": [JobResult], 782 "cancelled": [JobResult] 783 } 784 """ 785 return { 786 "completed": list(self.completed_results.values()), 787 "failed": self._create_failed_results(), 788 "cancelled": self._create_cancelled_results() 789 } 790 791 def get_failed_jobs(self) -> Dict[str, str]: 792 """Get failed jobs with error messages. 793 794 Note: This method is deprecated. Use results()['failed'] instead. 795 """ 796 return dict(self.failed_jobs) 797 798 def _create_failed_results(self) -> List[JobResult]: 799 """Convert failed jobs to JobResult objects.""" 800 failed_results = [] 801 for job_id, error_msg in self.failed_jobs.items(): 802 failed_results.append(JobResult( 803 job_id=job_id, 804 raw_response=None, 805 parsed_response=None, 806 error=error_msg, 807 cost_usd=0.0, 808 input_tokens=0, 809 output_tokens=0 810 )) 811 return failed_results 812 813 def _create_cancelled_results(self) -> List[JobResult]: 814 """Convert cancelled jobs to JobResult objects.""" 815 cancelled_results = [] 816 for job_id, reason in self.cancelled_jobs.items(): 817 cancelled_results.append(JobResult( 818 job_id=job_id, 819 raw_response=None, 820 parsed_response=None, 821 error=reason, 822 cost_usd=0.0, 823 input_tokens=0, 824 output_tokens=0 825 )) 826 return cancelled_results 827 828 def shutdown(self): 829 """Shutdown (no-op for synchronous execution).""" 830 pass
Manages the execution of a batch job synchronously.
Processes jobs in batches based on items_per_batch configuration. Simpler synchronous execution for clear logging and debugging.
Example:
config = BatchParams(...)
run = BatchRun(config, jobs)
run.execute()
results = run.results()
40 def __init__(self, config: BatchParams, jobs: List[Job]): 41 """Initialize batch run. 42 43 Args: 44 config: Batch configuration 45 jobs: List of jobs to execute 46 """ 47 self.config = config 48 self.jobs = {job.id: job for job in jobs} 49 50 # Set logging level based on config 51 set_log_level(level=config.verbosity.upper()) 52 53 # Initialize components 54 self.cost_tracker = CostTracker(limit_usd=config.cost_limit_usd) 55 56 # Use temp file for state if not provided 57 state_file = config.state_file 58 if not state_file: 59 state_file = create_temp_state_file(config) 60 config.reuse_state = False 61 logger.info(f"Created temporary state file: {state_file}") 62 63 self.state_manager = StateManager(state_file) 64 65 # State tracking 66 self.pending_jobs: List[Job] = [] 67 self.completed_results: Dict[str, JobResult] = {} # job_id -> result 68 self.failed_jobs: Dict[str, str] = {} # job_id -> error 69 self.cancelled_jobs: Dict[str, str] = {} # job_id -> reason 70 71 # Batch tracking 72 self.total_batches = 0 73 self.completed_batches = 0 74 self.current_batch_index = 0 75 self.current_batch_size = 0 76 77 # Execution control 78 self._started = False 79 self._start_time: Optional[datetime] = None 80 self._time_limit_exceeded = False 81 self._progress_callback: Optional[Callable[[Dict, float], None]] = None 82 self._progress_interval: float = 1.0 # Default to 1 second 83 84 # Threading primitives 85 self._state_lock = threading.Lock() 86 self._shutdown_event = threading.Event() 87 88 # Batch tracking for progress display 89 self.batch_tracking: Dict[str, Dict] = {} # batch_id -> batch_info 90 91 # Active batch tracking for cancellation 92 self._active_batches: Dict[str, object] = {} # batch_id -> provider 93 self._active_batches_lock = threading.Lock() 94 95 # Results directory 96 self.results_dir = Path(config.results_dir) 97 98 # If not reusing state, clear the results directory 99 if not config.reuse_state and self.results_dir.exists(): 100 import shutil 101 shutil.rmtree(self.results_dir) 102 103 self.results_dir.mkdir(parents=True, exist_ok=True) 104 105 # Raw files directory (if enabled) 106 self.raw_files_dir = None 107 if config.raw_files: 108 self.raw_files_dir = self.results_dir / "raw_files" 109 self.raw_files_dir.mkdir(parents=True, exist_ok=True) 110 111 # Try to resume from saved state 112 self._resume_from_state()
Initialize batch run.
Args: config: Batch configuration jobs: List of jobs to execute
176 def to_json(self) -> Dict: 177 """Convert current state to JSON-serializable dict.""" 178 return { 179 "created_at": datetime.now().isoformat(), 180 "pending_jobs": [job.to_dict() for job in self.pending_jobs], 181 "completed_results": [ 182 {"job_id": job_id, "file_path": str(self.results_dir / f"{job_id}.json")} 183 for job_id in self.completed_results.keys() 184 ], 185 "failed_jobs": [ 186 { 187 "id": job_id, 188 "error": error, 189 "timestamp": datetime.now().isoformat() 190 } for job_id, error in self.failed_jobs.items() 191 ], 192 "cancelled_jobs": [ 193 { 194 "id": job_id, 195 "reason": reason, 196 "timestamp": datetime.now().isoformat() 197 } for job_id, reason in self.cancelled_jobs.items() 198 ], 199 "total_cost_usd": self.cost_tracker.used_usd, 200 "config": { 201 "state_file": self.config.state_file, 202 "results_dir": self.config.results_dir, 203 "max_parallel_batches": self.config.max_parallel_batches, 204 "items_per_batch": self.config.items_per_batch, 205 "cost_limit_usd": self.config.cost_limit_usd, 206 "default_params": self.config.default_params, 207 "raw_files": self.config.raw_files 208 } 209 }
Convert current state to JSON-serializable dict.
211 def execute(self): 212 """Execute synchronous batch run and wait for completion.""" 213 if self._started: 214 raise RuntimeError("Batch run already started") 215 216 self._started = True 217 self._start_time = datetime.now() 218 219 # Register signal handler for graceful shutdown 220 def signal_handler(signum, frame): 221 logger.warning("Received interrupt signal, shutting down gracefully...") 222 self._shutdown_event.set() 223 224 # Store original handler to restore later 225 original_handler = signal.signal(signal.SIGINT, signal_handler) 226 227 try: 228 logger.info("Starting batch run") 229 230 # Start time limit watchdog if configured 231 self._start_time_limit_watchdog() 232 233 # Call initial progress 234 if self._progress_callback: 235 stats = self.status() 236 batch_data = dict(self.batch_tracking) 237 self._progress_callback(stats, 0.0, batch_data) 238 239 # Process all jobs synchronously 240 self._process_all_jobs() 241 242 logger.info("Batch run completed") 243 finally: 244 # Restore original signal handler 245 signal.signal(signal.SIGINT, original_handler)
Execute synchronous batch run and wait for completion.
247 def set_on_progress(self, callback: Callable[[Dict, float, Dict], None], interval: float = 1.0) -> 'BatchRun': 248 """Set progress callback for execution monitoring. 249 250 The callback will be called periodically with progress statistics 251 including completed jobs, total jobs, current cost, etc. 252 253 Args: 254 callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data) 255 - stats_dict: Progress statistics dictionary 256 - elapsed_time_seconds: Time elapsed since batch started (float) 257 - batch_data: Dictionary mapping batch_id to batch information 258 interval: Interval in seconds between progress updates (default: 1.0) 259 260 Returns: 261 Self for chaining 262 263 Example: 264 ```python 265 run.set_on_progress( 266 lambda stats, time, batch_data: print( 267 f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s" 268 ) 269 ) 270 ``` 271 """ 272 self._progress_callback = callback 273 self._progress_interval = interval 274 return self
Set progress callback for execution monitoring.
The callback will be called periodically with progress statistics including completed jobs, total jobs, current cost, etc.
Args: callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data) - stats_dict: Progress statistics dictionary - elapsed_time_seconds: Time elapsed since batch started (float) - batch_data: Dictionary mapping batch_id to batch information interval: Interval in seconds between progress updates (default: 1.0)
Returns: Self for chaining
Example:
run.set_on_progress(
lambda stats, time, batch_data: print(
f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s"
)
)
728 @property 729 def is_complete(self) -> bool: 730 """Whether all jobs are complete.""" 731 total_jobs = len(self.jobs) 732 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 733 return len(self.pending_jobs) == 0 and completed_count == total_jobs
Whether all jobs are complete.
736 def status(self, print_status: bool = False) -> Dict: 737 """Get current execution statistics.""" 738 total_jobs = len(self.jobs) 739 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 740 remaining_count = total_jobs - completed_count 741 742 stats = { 743 "total": total_jobs, 744 "pending": remaining_count, 745 "active": 0, # Always 0 for synchronous execution 746 "completed": len(self.completed_results), 747 "failed": len(self.failed_jobs), 748 "cancelled": len(self.cancelled_jobs), 749 "cost_usd": self.cost_tracker.used_usd, 750 "cost_limit_usd": self.cost_tracker.limit_usd, 751 "is_complete": self.is_complete, 752 "batches_total": self.total_batches, 753 "batches_completed": self.completed_batches, 754 "batches_pending": self.total_batches - self.completed_batches, 755 "current_batch_index": self.current_batch_index, 756 "current_batch_size": self.current_batch_size, 757 "items_per_batch": self.config.items_per_batch 758 } 759 760 if print_status: 761 logger.info("\nBatch Run Status:") 762 logger.info(f" Total jobs: {stats['total']}") 763 logger.info(f" Pending: {stats['pending']}") 764 logger.info(f" Active: {stats['active']}") 765 logger.info(f" Completed: {stats['completed']}") 766 logger.info(f" Failed: {stats['failed']}") 767 logger.info(f" Cancelled: {stats['cancelled']}") 768 logger.info(f" Cost: ${stats['cost_usd']:.6f}") 769 if stats['cost_limit_usd']: 770 logger.info(f" Cost limit: ${stats['cost_limit_usd']:.2f}") 771 logger.info(f" Complete: {stats['is_complete']}") 772 773 return stats
Get current execution statistics.
775 def results(self) -> Dict[str, List[JobResult]]: 776 """Get all results organized by status. 777 778 Returns: 779 { 780 "completed": [JobResult], 781 "failed": [JobResult], 782 "cancelled": [JobResult] 783 } 784 """ 785 return { 786 "completed": list(self.completed_results.values()), 787 "failed": self._create_failed_results(), 788 "cancelled": self._create_cancelled_results() 789 }
Get all results organized by status.
Returns: { "completed": [JobResult], "failed": [JobResult], "cancelled": [JobResult] }
791 def get_failed_jobs(self) -> Dict[str, str]: 792 """Get failed jobs with error messages. 793 794 Note: This method is deprecated. Use results()['failed'] instead. 795 """ 796 return dict(self.failed_jobs)
Get failed jobs with error messages.
Note: This method is deprecated. Use results()['failed'] instead.
12@dataclass 13class Job: 14 """Configuration for a single AI job. 15 16 Either provide messages OR prompt (with optional file), not both. 17 18 Attributes: 19 id: Unique identifier for the job 20 messages: Chat messages for direct message input 21 file: Optional file path for file-based input 22 prompt: Prompt text (can be used alone or with file) 23 model: Model name (e.g., "claude-3-sonnet") 24 temperature: Sampling temperature (0.0-1.0) 25 max_tokens: Maximum tokens to generate 26 response_model: Pydantic model for structured output 27 enable_citations: Whether to extract citations from response 28 """ 29 30 id: str # Unique identifier 31 model: str # Model name (e.g., "claude-3-sonnet") 32 messages: Optional[List[Message]] = None # Chat messages 33 file: Optional[Path] = None # File input 34 prompt: Optional[str] = None # Prompt for file 35 temperature: float = 0.7 36 max_tokens: int = 1000 37 response_model: Optional[Type[BaseModel]] = None # For structured output 38 enable_citations: bool = False 39 40 def __post_init__(self): 41 """Validate job configuration.""" 42 if self.messages and (self.file or self.prompt): 43 raise ValueError("Provide either messages OR file+prompt, not both") 44 45 if self.file and not self.prompt: 46 raise ValueError("File input requires a prompt") 47 48 if not self.messages and not self.prompt: 49 raise ValueError("Must provide either messages or prompt") 50 51 def to_dict(self) -> Dict[str, Any]: 52 """Serialize for state persistence.""" 53 return { 54 "id": self.id, 55 "model": self.model, 56 "messages": self.messages, 57 "file": str(self.file) if self.file else None, 58 "prompt": self.prompt, 59 "temperature": self.temperature, 60 "max_tokens": self.max_tokens, 61 "response_model": self.response_model.__name__ if self.response_model else None, 62 "enable_citations": self.enable_citations 63 } 64 65 @classmethod 66 def from_dict(cls, data: Dict[str, Any]) -> 'Job': 67 """Deserialize from state.""" 68 # Convert file string back to Path if present 69 file_path = None 70 if data.get("file"): 71 file_path = Path(data["file"]) 72 73 # Note: response_model reconstruction would need additional logic 74 # For now, we'll set it to None during deserialization 75 return cls( 76 id=data["id"], 77 model=data["model"], 78 messages=data.get("messages"), 79 file=file_path, 80 prompt=data.get("prompt"), 81 temperature=data.get("temperature", 0.7), 82 max_tokens=data.get("max_tokens", 1000), 83 response_model=None, # Cannot reconstruct from string 84 enable_citations=data.get("enable_citations", False) 85 )
Configuration for a single AI job.
Either provide messages OR prompt (with optional file), not both.
Attributes: id: Unique identifier for the job messages: Chat messages for direct message input file: Optional file path for file-based input prompt: Prompt text (can be used alone or with file) model: Model name (e.g., "claude-3-sonnet") temperature: Sampling temperature (0.0-1.0) max_tokens: Maximum tokens to generate response_model: Pydantic model for structured output enable_citations: Whether to extract citations from response
51 def to_dict(self) -> Dict[str, Any]: 52 """Serialize for state persistence.""" 53 return { 54 "id": self.id, 55 "model": self.model, 56 "messages": self.messages, 57 "file": str(self.file) if self.file else None, 58 "prompt": self.prompt, 59 "temperature": self.temperature, 60 "max_tokens": self.max_tokens, 61 "response_model": self.response_model.__name__ if self.response_model else None, 62 "enable_citations": self.enable_citations 63 }
Serialize for state persistence.
65 @classmethod 66 def from_dict(cls, data: Dict[str, Any]) -> 'Job': 67 """Deserialize from state.""" 68 # Convert file string back to Path if present 69 file_path = None 70 if data.get("file"): 71 file_path = Path(data["file"]) 72 73 # Note: response_model reconstruction would need additional logic 74 # For now, we'll set it to None during deserialization 75 return cls( 76 id=data["id"], 77 model=data["model"], 78 messages=data.get("messages"), 79 file=file_path, 80 prompt=data.get("prompt"), 81 temperature=data.get("temperature", 0.7), 82 max_tokens=data.get("max_tokens", 1000), 83 response_model=None, # Cannot reconstruct from string 84 enable_citations=data.get("enable_citations", False) 85 )
Deserialize from state.
11@dataclass 12class JobResult: 13 """Result from a completed AI job. 14 15 Attributes: 16 job_id: ID of the job this result is for 17 raw_response: Raw text response from the model (None for failed jobs) 18 parsed_response: Structured output (if response_model was used) 19 citations: Extracted citations (if enable_citations was True) 20 citation_mappings: Maps field names to relevant citations (if response_model used) 21 input_tokens: Number of input tokens used 22 output_tokens: Number of output tokens generated 23 cost_usd: Total cost in USD 24 error: Error message if job failed 25 batch_id: ID of the batch this job was part of (for mapping to raw files) 26 """ 27 28 job_id: str 29 raw_response: Optional[str] = None # Raw text response (None for failed jobs) 30 parsed_response: Optional[Union[BaseModel, Dict]] = None # Structured output or error dict 31 citations: Optional[List[Citation]] = None # Extracted citations 32 citation_mappings: Optional[Dict[str, List[Citation]]] = None # Field -> citations mapping 33 input_tokens: int = 0 34 output_tokens: int = 0 35 cost_usd: float = 0.0 36 error: Optional[str] = None # Error message if failed 37 batch_id: Optional[str] = None # Batch ID for mapping to raw files 38 39 @property 40 def is_success(self) -> bool: 41 """Whether the job completed successfully.""" 42 return self.error is None 43 44 @property 45 def total_tokens(self) -> int: 46 """Total tokens used (input + output).""" 47 return self.input_tokens + self.output_tokens 48 49 def to_dict(self) -> Dict[str, Any]: 50 """Serialize for state persistence.""" 51 # Handle parsed_response serialization 52 parsed_response = None 53 if self.parsed_response is not None: 54 if isinstance(self.parsed_response, dict): 55 parsed_response = self.parsed_response 56 elif isinstance(self.parsed_response, BaseModel): 57 parsed_response = self.parsed_response.model_dump() 58 else: 59 parsed_response = str(self.parsed_response) 60 61 # Handle citation_mappings serialization 62 citation_mappings = None 63 if self.citation_mappings: 64 citation_mappings = { 65 field: [asdict(c) for c in citations] 66 for field, citations in self.citation_mappings.items() 67 } 68 69 return { 70 "job_id": self.job_id, 71 "raw_response": self.raw_response, 72 "parsed_response": parsed_response, 73 "citations": [asdict(c) for c in self.citations] if self.citations else None, 74 "citation_mappings": citation_mappings, 75 "input_tokens": self.input_tokens, 76 "output_tokens": self.output_tokens, 77 "cost_usd": self.cost_usd, 78 "error": self.error, 79 "batch_id": self.batch_id 80 } 81 82 @classmethod 83 def from_dict(cls, data: Dict[str, Any]) -> 'JobResult': 84 """Deserialize from state.""" 85 # Reconstruct citations if present 86 citations = None 87 if data.get("citations"): 88 citations = [Citation(**c) for c in data["citations"]] 89 90 # Reconstruct citation_mappings if present 91 citation_mappings = None 92 if data.get("citation_mappings"): 93 citation_mappings = { 94 field: [Citation(**c) for c in citations] 95 for field, citations in data["citation_mappings"].items() 96 } 97 98 return cls( 99 job_id=data["job_id"], 100 raw_response=data["raw_response"], 101 parsed_response=data.get("parsed_response"), 102 citations=citations, 103 citation_mappings=citation_mappings, 104 input_tokens=data.get("input_tokens", 0), 105 output_tokens=data.get("output_tokens", 0), 106 cost_usd=data.get("cost_usd", 0.0), 107 error=data.get("error"), 108 batch_id=data.get("batch_id") 109 )
Result from a completed AI job.
Attributes: job_id: ID of the job this result is for raw_response: Raw text response from the model (None for failed jobs) parsed_response: Structured output (if response_model was used) citations: Extracted citations (if enable_citations was True) citation_mappings: Maps field names to relevant citations (if response_model used) input_tokens: Number of input tokens used output_tokens: Number of output tokens generated cost_usd: Total cost in USD error: Error message if job failed batch_id: ID of the batch this job was part of (for mapping to raw files)
39 @property 40 def is_success(self) -> bool: 41 """Whether the job completed successfully.""" 42 return self.error is None
Whether the job completed successfully.
44 @property 45 def total_tokens(self) -> int: 46 """Total tokens used (input + output).""" 47 return self.input_tokens + self.output_tokens
Total tokens used (input + output).
49 def to_dict(self) -> Dict[str, Any]: 50 """Serialize for state persistence.""" 51 # Handle parsed_response serialization 52 parsed_response = None 53 if self.parsed_response is not None: 54 if isinstance(self.parsed_response, dict): 55 parsed_response = self.parsed_response 56 elif isinstance(self.parsed_response, BaseModel): 57 parsed_response = self.parsed_response.model_dump() 58 else: 59 parsed_response = str(self.parsed_response) 60 61 # Handle citation_mappings serialization 62 citation_mappings = None 63 if self.citation_mappings: 64 citation_mappings = { 65 field: [asdict(c) for c in citations] 66 for field, citations in self.citation_mappings.items() 67 } 68 69 return { 70 "job_id": self.job_id, 71 "raw_response": self.raw_response, 72 "parsed_response": parsed_response, 73 "citations": [asdict(c) for c in self.citations] if self.citations else None, 74 "citation_mappings": citation_mappings, 75 "input_tokens": self.input_tokens, 76 "output_tokens": self.output_tokens, 77 "cost_usd": self.cost_usd, 78 "error": self.error, 79 "batch_id": self.batch_id 80 }
Serialize for state persistence.
82 @classmethod 83 def from_dict(cls, data: Dict[str, Any]) -> 'JobResult': 84 """Deserialize from state.""" 85 # Reconstruct citations if present 86 citations = None 87 if data.get("citations"): 88 citations = [Citation(**c) for c in data["citations"]] 89 90 # Reconstruct citation_mappings if present 91 citation_mappings = None 92 if data.get("citation_mappings"): 93 citation_mappings = { 94 field: [Citation(**c) for c in citations] 95 for field, citations in data["citation_mappings"].items() 96 } 97 98 return cls( 99 job_id=data["job_id"], 100 raw_response=data["raw_response"], 101 parsed_response=data.get("parsed_response"), 102 citations=citations, 103 citation_mappings=citation_mappings, 104 input_tokens=data.get("input_tokens", 0), 105 output_tokens=data.get("output_tokens", 0), 106 cost_usd=data.get("cost_usd", 0.0), 107 error=data.get("error"), 108 batch_id=data.get("batch_id") 109 )
Deserialize from state.
8@dataclass 9class Citation: 10 """Represents a citation extracted from an AI response.""" 11 12 text: str # The cited text 13 source: str # Source identifier (e.g., page number, section) 14 page: Optional[int] = None # Page number if applicable 15 metadata: Optional[Dict[str, Any]] = None # Additional metadata
Represents a citation extracted from an AI response.
Base exception for all Batchata errors.
35class CostLimitExceededError(BatchataError): 36 """Raised when cost limit would be exceeded.""" 37 pass
Raised when cost limit would be exceeded.
Base exception for provider-related errors.
20class ProviderNotFoundError(ProviderError): 21 """Raised when no provider is found for a model.""" 22 pass
Raised when no provider is found for a model.
10class ValidationError(BatchataError): 11 """Raised when job or configuration validation fails.""" 12 pass
Raised when job or configuration validation fails.