Metadata-Version: 2.4
Name: chemcalculations
Version: 0.1.1
Summary: PyTorch neural network emulator for FastChem gas-phase chemical equilibrium in planetary and exoplanet atmospheres.
Author: Yashnil Mohanty
License-Expression: MIT
Project-URL: Homepage, https://github.com/yashnil/chemCalculations
Project-URL: Repository, https://github.com/yashnil/chemCalculations
Project-URL: Documentation, https://github.com/yashnil/chemCalculations#readme
Project-URL: PyPI, https://pypi.org/project/chemcalculations/
Keywords: fastchem,astronomy,exoplanets,atmospheres,pytorch,machine-learning,emulator
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Astronomy
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0
Requires-Dist: numpy>=1.20
Requires-Dist: pandas>=1.3
Requires-Dist: scikit-learn>=1.0
Requires-Dist: matplotlib>=3.5
Requires-Dist: scipy>=1.7
Requires-Dist: tqdm>=4.60
Provides-Extra: fastchem
Requires-Dist: pyfastchem>=3.0; extra == "fastchem"
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"
Provides-Extra: release
Requires-Dist: build>=1.0; extra == "release"
Requires-Dist: twine>=4.0; extra == "release"
Dynamic: license-file

# FastChem Neural Network Emulator

A high-performance machine learning surrogate for FastChem gas-phase equilibrium in **planetary and exoplanet** atmospheres.

[![Python](https://img.shields.io/badge/python-3.9%2B-blue)](https://www.python.org/)
[![PyTorch](https://img.shields.io/badge/PyTorch-2.0%2B-red)](https://pytorch.org/)
[![PyPI](https://img.shields.io/pypi/v/chemcalculations.svg?label=PyPI)](https://pypi.org/project/chemcalculations/)
[![Status](https://img.shields.io/badge/status-study--complete-brightgreen)](https://github.com/yashnil/chemCalculations)
[![Log R²](https://img.shields.io/badge/Log_R²-0.9999-brightgreen)](https://github.com/yashnil/chemCalculations)

**Status: Study Complete** (6 baseline + 1 improved + 1 MLP comparison)  
**Best Model Performance (x4800_improved)**: Log R² ≈ 0.9999, Log MAE ≈ 0.00391 dex, **MFAE ≈ 0.0105** (winsorized mean fractional error; see [Performance Metrics](#performance-metrics)), **~1,500× faster** on GPU / ~250× on CPU (measured)  
**Best Model**: **x4800_improved** (4.8M samples, AdamW + train-only normalization, FlowMap architecture)  
**Study Design**: 800K increments (800, 1600, 2400, 3200, 4000, 4800) — asymptotic behavior confirmed

**Improvement vs baseline (x4800_optimal_retrained)**: ~51% lower test loss, ~71% lower Log MAE, ~41% lower MFAE (same FlowMap architecture, improved training recipe).

---

## At a glance (first visit)

| Topic | Where to look |
|--------|----------------|
| **What this is** | Neural network emulator: predicts **gas-phase number densities** (33 species, cm⁻³) from **T, P, and 5 elemental abundances (dex)**—trained to match [FastChem](https://github.com/exoclime/FastChem). |
| **Best reported model** | **`x4800_improved`** — metrics in [`plots/comparison_metrics.csv`](plots/comparison_metrics.csv), figures described in [`plots/README.md`](plots/README.md). |
| **Install** | **From PyPI:** [`pip install chemcalculations`](https://pypi.org/project/chemcalculations/) — [project page](https://pypi.org/project/chemcalculations/). **From source:** `pip install -e ".[fastchem]"` or `pip install -e .`. See [Quick Start](#quick-start). |
| **Trained checkpoints** | Checkpoints, `best.pt`, and `best_model.py` are written under **`results/runs/<run_name>/`** when you train. That folder is **gitignored**—a plain `git clone` does **not** include weights. Train locally ([Quick Start](#quick-start)) or supply your own run directory. |
| **Inference pattern** | Add `results/runs/runs_autoencoder_x4800_improved` to `sys.path`, then `from best_model import load_model, ...` — see [Quick Start](#quick-start). |

**Navigation:** [Problem statement](#problem-statement) · [Performance metrics](#performance-metrics) · [Project structure](#project-structure) · [Quick Start](#quick-start) · [Citation](#citation)

---

## Table of Contents

1. [At a glance (first visit)](#at-a-glance-first-visit)
2. [Problem Statement](#problem-statement)
3. [Project Goals](#project-goals)
4. [Solution Overview](#solution-overview)
5. [Performance Metrics](#performance-metrics)
6. [Project Structure](#project-structure)
7. [Quick Start](#quick-start)
8. [Model Architecture](#model-architecture)
9. [Methods](#methods)
10. [Diagnostics](#diagnostics)
11. [Usage Examples](#usage-examples)
12. [Technical Details](#technical-details)
13. [Citation](#citation)
14. [Contact](#contact)

---

## Problem Statement

[FastChem](https://github.com/exoclime/FastChem) computes **gas-phase chemical equilibrium** (species abundances vs. temperature, pressure, and elemental composition). It is accurate but **expensive per evaluation** (order **~1–2 ms** with engine reuse in this repo’s benchmarks; see `scripts/fast_inference.py`). That cost adds up for retrievals, large grids, or many forward calls.

**This repository** trains a **PyTorch emulator** that maps **(T, P, five elemental abundances in dex)** to **33 species number densities (cm⁻³)** to match FastChem on the training distribution. Scope here is **planetary / exoplanet atmosphere** use cases reflected in the dataset and validation (not a general stellar-atmosphere product).

---

## Project Goals

1. **Speed** — Large throughput gains vs FastChem on CPU/GPU (see benchmarks; numbers are hardware-dependent).
2. **Accuracy** — Strong log-space metrics on the held-out setup (Log R², Log MAE, MFAE; [Performance Metrics](#performance-metrics)).
3. **Coverage** — Train for the project’s physical range (e.g. ~750–3000 K, 10⁻¹⁰–10⁵ bar).
4. **Reproducibility** — Packaged training, diagnostics, and comparison tables (`plots/comparison_metrics.csv`).

---

## Solution Overview

A **FlowMapAutoencoder** (encoder → latent dynamics → decoder) replaces the iterative solver at inference time. Training uses log-ratio loss and config-driven normalization; see [Model Architecture](#model-architecture) and `configs/`. The best reported run (**x4800_improved**) reaches Log R² ≈ 0.9999 with large speedups vs FastChem in this repo’s GPU/CPU timings. Outputs use a **fixed 33-species ordering**; low-temperature samples are filtered in training data where noted in the methods sections below.

---

## Performance Metrics

### Best Model Performance

**🏆 Best Overall Performance (x4800_improved):**
- **Model**: x4800_improved (FlowMapAutoencoder with AdamW + train-only normalization)
- **Test Loss**: 7.27×10⁻³ (normalized log_ratio space)
- **Validation Loss**: 7.65×10⁻³
- **Log MAE**: 0.00391 dex (mean absolute error in log₁₀ space on test set)
- **Log R²**: 0.9999 (99.99% variance explained in log space)
- **MFAE**: **0.01047** — mean fractional absolute error over parity/scatter points (both true & pred > 10⁻¹⁰ cm⁻³); each pair’s \|pred−true\|/true is capped at **2.0** before averaging (winsorized mean). See `src/chemcalculations/mfae_metrics.py` and `plots/comparison_metrics.csv`.
- **Dataset Size**: 4,800,000 samples
- **Architecture**: latent_dim=192, width=512, layers=3, SiLU activation
- **Species**: 33 species (32 + e-) with static ordering (99.68% coverage)
- **Training**: AdamW optimizer, train-only normalization (no eval-time normalization)

**x4800 comparison (same data scale, test-set metrics)** — from `plots/comparison_metrics.csv` (regenerate via `python scripts/update_comparison_baseline_vs_improved.py`):

| Run | Test loss | Log MAE (dex) | Log R² | MFAE |
|-----|-----------|---------------|--------|------|
| x4800_optimal_retrained (baseline FlowMap) | 1.47×10⁻² | 0.0137 | 0.9996 | 0.0179 |
| **x4800_improved** (best) | **7.27×10⁻³** | **0.00391** | **0.9999** | **0.0105** |
| x4800_mlp (6×1024 MLP baseline) | 0.142 | 0.085 | 0.9992 | 0.326 |

**Performance vs Dataset Size** (800K increments, consistent architecture):

| Dataset Size | Test Loss | Val Loss | Log MAE (dex) | Log R² |
|--------------|-----------|----------|---------------|--------|
| 800K | 2.69×10⁻² | 2.56×10⁻² | 0.0248 | 0.9994 |
| 1600K | 2.46×10⁻² | 2.39×10⁻² | 0.0229 | 0.9995 |
| 2400K | 1.71×10⁻² | 1.68×10⁻² | 0.0157 | 0.9998 |
| 3200K | 1.52×10⁻² | 1.66×10⁻² | 0.0139 | 0.9999 |
| 4000K | 1.60×10⁻² | 1.60×10⁻² | 0.0149 | 0.9998 |
| 4800K (baseline) | 1.47×10⁻² | 1.60×10⁻² | 0.0137 | 0.9996 |
| **4800K (improved)** | **7.27×10⁻³** | **7.65×10⁻³** | **0.0039** | **0.9999** |

**Study Design**: 800K increments allow clearer visualization of the asymptotic learning curve, where performance gains diminish with increasing data (approaching the architecture's capacity limit).

**Key Observations**:
- Log MAE decreases from 0.0248 (800K) to 0.0137 (4800K baseline) — **45% improvement**
- **x4800_improved** achieves ~51% lower test loss and ~71% lower Log MAE than x4800_optimal_retrained (same FlowMap architecture)
- Log R² ≥ 0.9994 across all sizes, peaking at 0.9999 (x3200, x4800_improved)
- **Clear asymptotic plateau**: the jump from 800K→2400K yields a 37% Log MAE reduction, while 2400K→4800K yields only 13% more
- x3200, x4000, and x4800 baseline cluster tightly (Log MAE 0.0137–0.0149), confirming the architecture's capacity limit
- Diminishing returns beyond ~3200K suggest the optimal cost-performance tradeoff lies around 3200K–4000K

**Training Configuration (Baseline)**:
- Dataset: up to 4800K samples (750–3000 K, T > 750K filter)
- Split: 85% train / 10% val / 5% test
- Architecture: FlowMapAutoencoder, 192-dim latent, 512-width layers (3 layers each)
- Activation: SiLU (Sigmoid Linear Unit)
- Loss: Log-ratio loss (L = |log₁₀(ŷ/y)|, computed in normalized space)
- Scheduler: ReduceLROnPlateau (factor=0.5, patience=10, min_lr=1e-6)
- Training time: ~30 minutes (200 epochs)
- Dropout: 0.0 (no overfitting observed)
- **Static Species Ordering**: Fixed species list ordered by mean abundance

**Training Configuration (x4800_improved)**:
- Same architecture as baseline; **AdamW** optimizer (vs Adam) in `train_autoencoder_improved.py`
- **Train-only normalization**: Target statistics computed from training set only (no eval-time normalization)
- Config: `configs/x4800_improved.json`

**Key Improvements Over Previous Architecture**:
- **Static ordering**: Consistent architecture across all 6 dataset sizes, enabling fair comparison
- **Log-ratio loss**: Direct log-space error minimization across 30 orders of magnitude
- **Optimal hyperparameters**: latent_dim=192, width=512, layers=3 determined via systematic studies
- **AdamW + train-only norm**: ~71% Log MAE reduction in x4800_improved vs x4800_optimal_retrained

### Independent Validation (Out-of-Distribution)

To verify generalization beyond the training distribution, the best model (x4800) was validated against FastChem on **320 completely independent conditions** across 5 physically motivated scenarios:

| Scenario | Conditions | Log MAE (dex) | Log R² | Max Error (dex) |
|----------|-----------|---------------|--------|-----------------|
| Hot Jupiter T-P profile | 60 | 0.274 | 0.9985 | 1.76 |
| Cool dwarf T-P profile | 50 | 0.063 | 0.9999 | 0.65 |
| Systematic T-P grid (solar) | 120 | 0.097 | 0.9996 | 2.08 |
| C/O ratio sweep (0.1–2.0) | 40 | 0.161 | 0.9992 | 2.99 |
| Metallicity sweep (0.01–100× solar) | 50 | 0.064 | 0.9998 | 2.24 |
| **All combined** | **320** | **0.127** | **0.9994** | **2.99** |

**Key findings**:
- Overall Log R² = 0.9994 on independent data — the model generalizes well
- Performance is best on conditions within the training parameter space (cool dwarf, metallicity)
- The hot Jupiter profile shows higher error (0.27 dex) due to low-pressure conditions (10⁻⁶ bar) that are sparse in training data
- Maximum errors occur at extreme compositions (C/O > 1.5, [M/H] > +1.5)

**Speed on independent validation** (320 conditions, inference-only timing, model loaded once):

| Backend | Total Time | ms/eval | Speedup vs FastChem |
|---------|------------|---------|---------------------|
| FastChem (fresh engine/row) | 1.80 s | ~5.6 ms | 1× |
| ML CPU | 0.008 s | ~0.026 ms | **213×** |
| ML MPS GPU | 0.012 s | ~0.038 ms | **149×** |

*At small batch sizes (40–120), CPU can outperform GPU due to kernel launch overhead. For large batches (≥1K), use GPU — see speed benchmark table below for 10K-batch numbers (~1,500×).*

**Plots**: See `plots/independent_validation/` for parity plots, atmospheric profiles, and sweep comparisons. Run `python scripts/independent_validation.py` to regenerate.

### Hyperparameter Optimization Studies

We conducted three systematic hyperparameter studies to identify optimal model configuration:

#### Test #1: Latent Dimension Study
**Objective**: Find optimal latent space dimensionality

**Tested values**: 64, 96, 128, 160, 192, 256, 320, 384, 448, 512  
**Results**: 
- **Best**: latent_dim=192 (test_loss=0.000339) at 50 epochs
- Performance degrades for both smaller and larger dimensions
- Clear minimum at 192, optimal for 33-species output space (static ordering)
- Confirmed optimal for full 200-epoch training with static ordering

**Plot**: `plots/latent_dim_study.png`

#### Test #2: Layer Width Study  
**Objective**: Find optimal layer width and depth

**Tested configurations**: 
- Widths: 256, 512, 768, 1024
- Layers: 3, 4
- Using latent_dim=192 from Test #1

**Results**:
- **Best overall**: width=512, layers=3 (test_loss=0.000339)
- **Best 4-layer**: width=768, layers=4 (test_loss=0.000348)
- Wider layers (1024) don't improve performance
- 3 layers perform better overall

**Plot**: `plots/layer_width_study.png`

#### Test #3: Dataset Size Study with Optimal Hyperparameters
**Objective**: Evaluate optimal hyperparameters across different dataset sizes

**Configuration**: latent_dim=192, width=512, layers=3, static ordering, log_ratio loss  
**Sizes**: 800K, 1600K, 2400K, 3200K, 4000K, 4800K (800K increments)

**Key Findings**:
- Performance improves with dataset size; gains diminish sharply after ~2400K
- Log R² ≥ 0.9994 across all sizes, peaking at 0.9999 (x3200)
- Log MAE decreases from 0.0248 dex (800K) to 0.0137 dex (4800K) — **45% improvement**
- x3200, x4000, x4800 form a tight cluster (Log MAE 0.0137–0.0149), confirming asymptotic plateau
- Architecture capacity limit reached; further data provides diminishing returns
- **Best baseline-only run (asymptote study)**: x4800_optimal_retrained (4800K samples, Log MAE ≈ 0.0137). **Production best overall**: **x4800_improved** (Log MAE ≈ 0.00391, MFAE ≈ 0.0105).

**Plots**: 
- `plots/performance_vs_size.png` - Main performance metrics
- `plots/performance_vs_size_comprehensive.png` - With trend analysis
- `plots/asymptote_analysis.png` - Asymptote behavior analysis
- `plots/baseline_vs_improved_bar.png` - x4800 baseline vs improved bar chart
- `plots/baseline_vs_improved_performance.png` - Performance curves with x4800_improved overlay

**Full results**: See `plots/comparison_metrics.csv`

#### Test #4: Static Species Ordering Study
**Objective**: Compare static vs dynamic species selection

**Tested configurations**:
- Dynamic top-20 (baseline): 21 species selected per dataset
- Static 24 species: 25 species (24 + e-), 98.77% coverage
- Static 32 species: 33 species (32 + e-), 99.68% coverage ← **Best**
- Static 36 species: 37 species (36 + e-), 99.86% coverage

**Results**:
- **Static 32**: Best performance (test_loss=0.000150, log_mae=0.0224)
- **11.7% improvement** in Log MAE over dynamic baseline
- Static ordering provides better consistency and reproducibility
- 32 species optimal balance between coverage and model complexity

**Full results**: See `plots/comparison_metrics.csv`

### Comparison: FastChem vs ML Emulator

| Aspect | FastChem | ML Emulator (x4800_improved) | Advantage |
|--------|----------|-----------------------------|-----------|
| **Accuracy** | Exact (ground truth) | Log R² = 0.9999, Test Loss = 7.27×10⁻³ | Excellent match (99.99% variance) |
| **Speed** | 1.3 ms/eval (engine reuse) | 0.0009 ms/eval (GPU batch) | **~1,500× faster (GPU)** |
| **Scalability** | Linear | Parallel batching | GPU-accelerable |
| **Deployment** | C++ binary | Python/PyTorch | Easy integration |
| **Use case** | Ground truth | Production inference | Complementary |

**Speed benchmark details** (measured on Apple M1 Max):

| Batch Size | CPU ms/sample | CPU Speedup | MPS GPU ms/sample | GPU Speedup | GPU/CPU |
|---|---|---|---|---|---|
| 1 | 0.202 | 6× | 1.473 | 1× | 0.1× |
| 10 | 0.062 | 21× | 0.177 | 7× | 0.4× |
| 100 | 0.026 | 49× | 0.021 | 60× | 1.2× |
| 1,000 | 0.009 | 149× | 0.003 | 476× | 3.2× |
| **10,000** | **0.005** | **251×** | **0.0009** | **1,505×** | **6.0×** |
| 100,000 | 0.006 | 231× | 0.001 | 1,299× | 5.6× |

\* *FastChem baseline: 1.3 ms/sample (engine reuse). Measured using `scripts/fast_inference.py` with 5 repeats (median). MPS = Apple Metal GPU backend. GPU overhead dominates at small batches; use CPU for batch < 100.*

**ONNX Runtime** (CPU): For small batch sizes (1–100), ONNX Runtime provides an additional ~2.4× speedup over PyTorch CPU by eliminating Python dispatch overhead. Export via `python scripts/export_onnx.py`. The ONNX model is 7.7 MB with negligible numerical differences (max diff ~10⁻⁵).

**Bottom line**: On Apple Silicon GPU, the ML emulator is **~1,500× faster** than FastChem at batch sizes >= 10K. Even on CPU alone, it is **~250× faster**. This turns days-long retrieval calculations into minutes.

---

## Project Structure

```
chemCalculations/
│
├── pyproject.toml                  # Package metadata & dependencies (pip / PyPI)
├── LICENSE                         # MIT
├── src/
│   └── chemcalculations/           # Installable Python package (import chemcalculations)
│       ├── __init__.py
│       ├── _paths.py               # project_root()
│       ├── autoencoder_model.py    # FlowMapAutoencoder + SimpleMLP
│       ├── mfae_metrics.py         # MFAE for comparison_metrics / diagnostics
│       ├── train_autoencoder.py
│       ├── train_autoencoder_improved.py
│       ├── diagnostics.py
│       ├── make_comparison_metrics.py
│       ├── plot_full_suite.py
│       ├── plot_comprehensive_analysis.py
│       ├── plot_baseline_vs_improved.py
│       ├── plot_training_analysis.py
│       ├── plot_latent_dim_results.py
│       └── plot_layer_width_results.py
├── tests/                          # pytest smoke tests
│
├── results/runs/                   # Trained models (NOT in git — see At a glance)
│   ├── runs_autoencoder_x800_optimal_retrained/
│   ├── runs_autoencoder_x1600_optimal_retrained/
│   ├── runs_autoencoder_x2400_optimal_retrained/
│   ├── runs_autoencoder_x3200_optimal_retrained/
│   ├── runs_autoencoder_x4000_optimal_retrained/
│   ├── runs_autoencoder_x4800_optimal_retrained/
│   ├── runs_autoencoder_x4800_improved/   # Target dir for best model training
│   └── runs_autoencoder_x4800_mlp/         # Optional MLP baseline
│
├── plots/                          # Visualization outputs
│   ├── comparison_metrics.csv     # Performance metrics for all models
│   ├── baseline_vs_improved_bar.png    # x4800 baseline vs improved bar chart
│   ├── baseline_vs_improved_performance.png  # Performance curves with improved overlay
│   ├── independent_validation/    # Independent validation results
│   └── *.png                      # Study plots and visualizations
│
├── data/                           # Datasets and thermodynamic data
│   ├── datasets/                  # Training datasets (excluded from Git)
│   └── fastchem_data/             # FastChem logK and element abundance files
│
├── configs/                        # Training configurations (JSON)
│   ├── x4800_optimal_retrained.json   # Baseline config
│   ├── x4800_improved.json            # Best model config (AdamW)
│   └── mlp_config.json                # Optional SimpleMLP architecture
│
├── scripts/                        # Utility scripts
│   ├── data_generation/           # FastChem job generation and merging
│   ├── update_comparison_baseline_vs_improved.py  # Metrics CSV + regenerate key plots
│   ├── plot_compare_x4800_three.py # Baseline vs improved vs MLP comparison figures
│   ├── plot_fastchem_style.py     # Mixing-ratio vs T / T–P heatmaps
│   ├── independent_validation.py  # Independent validation vs FastChem
│   ├── benchmark_fastchem_speed.py # FastChem vs ML speed comparison
│   ├── fast_inference.py          # Optimized inference (CPU vs MPS GPU)
│   ├── export_onnx.py             # ONNX export and benchmark
│   └── setup_fastchem_env.sh      # FastChem environment setup
│
├── requirements.txt                # Legacy pin list; canonical deps in pyproject.toml
└── README.md                       # This file
```

**Why the clone can feel large:** Git tracks **PNG figures** under `plots/` (many are ~0.1–0.7 MB each) plus small **thermodynamic tables** in `data/fastchem_data/`. Training CSVs, checkpoints, and `results/` are **not** tracked (see `.gitignore`). The file `plots/worst100_samples.csv` is listed in `.gitignore` so it is not re-added after diagnostics runs.

---

## Quick Start

### Prerequisites

From the repository root, install the package in editable mode (recommended for development):

```bash
# Core emulator + training (PyTorch, numpy, pandas, sklearn, matplotlib, scipy, tqdm)
pip install -e .

# Optional: include pyfastchem for data generation and FastChem-backed validation scripts
pip install -e ".[fastchem]"

# Dev: tests
pip install -e ".[dev]"
```

You can still use `pip install -r requirements.txt` as a loose mirror of core versions; **`pyproject.toml`** is the canonical dependency list for packaging.

Core stack: **torch**, **numpy**, **pandas**, **scikit-learn**, **matplotlib**, **scipy**, **tqdm**. **`pyfastchem`** is only needed for generating training data and scripts that call FastChem; forward inference on an existing `best_model.py` / checkpoint does not require it.

After `pip install -e .`, console entry points are available: **`chemcalc-train`**, **`chemcalc-train-improved`**, **`chemcalc-diagnostics`** (same as `python -m chemcalculations...`).

### Running tests

```bash
pip install -e ".[dev]"
pytest
```

`tests/conftest.py` sets `KMP_DUPLICATE_LIB_OK` and thread limits before PyTorch loads, which avoids a rare **Fatal Python error: Aborted** on macOS when OpenMP/BLAS runtimes clash. If you still see an abort, try `export KMP_DUPLICATE_LIB_OK=TRUE` in your shell before `pytest`.

### Publishing to PyPI

Maintainers: install tools with `pip install -e ".[release]"`, bump `version` in `pyproject.toml`, run `python -m build`, then `twine upload dist/*` (after [PyPI](https://pypi.org/) account + API token). Full checklist: **[`RELEASING.md`](RELEASING.md)**. Published package: **[chemcalculations on PyPI](https://pypi.org/project/chemcalculations/)**.

### Training a Model

```bash
# Train best model (x4800_improved) — AdamW, train-only normalization
python -m chemcalculations.train_autoencoder_improved \
    --config configs/x4800_improved.json \
    --run-dir results/runs/runs_autoencoder_x4800_improved

# Or train baseline with original script
python -m chemcalculations.train_autoencoder \
    --config configs/x4800_optimal_retrained.json \
    --loss-type log_ratio \
    --run-dir results/runs/runs_autoencoder_x4800_optimal_retrained

# Generate diagnostic plots for a trained model
CSV_PATH=data/datasets/all_gas_fastchem_x4800.csv \
BEST_MODULE=results/runs/runs_autoencoder_x4800_improved/best_model.py \
OUT_DIR=results/runs/runs_autoencoder_x4800_improved/diagnostics \
python -m chemcalculations.diagnostics
```

### Using the Trained Model

Requires a completed training run so that `results/runs/runs_autoencoder_x4800_improved/best_model.py` (and `best.pt`) exist.

```python
import sys
sys.path.append('results/runs/runs_autoencoder_x4800_improved')  # Best model (after training)

from best_model import load_model, normalize_inputs, forward_autoencoder, denormalize_targets, TARGET_COLS
import pandas as pd
import torch

# Load model
model = load_model(device='cpu')
model.eval()

# Prepare input (T, P, elemental abundances in dex scale)
df_input = pd.DataFrame({
    'T_K': [1500.0],           # Temperature in Kelvin
    'P_bar': [0.1],            # Pressure in bar
    'abund_H_dex': [12.0],     # H abundance (reference = 12.0)
    'abund_O_dex': [8.69],     # O abundance (solar)
    'abund_C_dex': [8.43],     # C abundance (solar)
    'abund_N_dex': [7.83],     # N abundance (solar)
    'abund_S_dex': [7.12],     # S abundance (solar)
})

# Normalize, predict, denormalize
X = normalize_inputs(df_input)
with torch.no_grad():
    y_scaled = forward_autoencoder(model, X).cpu().numpy()
y_linear = denormalize_targets(y_scaled)

# Results: 33 species number densities (cm⁻³)
results = pd.DataFrame(y_linear, columns=TARGET_COLS)
print("Top-5 most abundant species:")
print(results.iloc[0].sort_values(ascending=False).head(5))
```

### Batch size and 1D profiles

- **Training:** `batch_size` comes from your JSON config (typical range **256–4096**, limited by GPU memory).
- **Inference:** There is **no required batch size**. The network uses the first dimension as **N** independent conditions; `normalize_inputs` / `forward_autoencoder` accept **any N ≥ 1**. A **full 1D profile** (N altitude layers with one `(T, P, abundances…)` row each) is **one forward pass** with shape **(N, 7)** after normalization—**dynamic N** is the normal way to amortize work. The only limit is **available memory**; for huge N, split into chunks.
- **Benchmarks:** `scripts/fast_inference.py` uses batch sizes **1 … 100,000** to report throughput (per-sample time depends on device and N).

---

## Model Architecture

### FlowMapAutoencoder (Production)

**Framework**: PyTorch 2.0+  
**Type**: FlowMapAutoencoder (Encoder-Dynamics-Decoder architecture)  
**Total Parameters**: ~2.01M  
**Model Size**: ~8 MB (compressed weights)

### Architecture Overview

The FlowMapAutoencoder is a specialized architecture designed for learning mappings between high-dimensional state spaces. It consists of three main components:

```
┌─────────────────────────────────────────────────────────────┐
│                    FLOWMAPAUTOENCODER                       │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  Input: [T, P, abundances] (7D)                             │
│         ↓                                                   │
│  ┌──────────────────────────────────────┐                  │
│  │ ENCODER                               │                  │
│  │ Input: [state(33) + global(7)] = 40  │                  │
│  │ → Dense(512) → SiLU                  │                  │
│  │ → Dense(512) → SiLU                  │                  │
│  │ → Dense(512) → SiLU                  │                  │
│  │ Output: latent(192)                  │                  │
│  └──────────────────────────────────────┘                  │
│         ↓                                                   │
│  ┌──────────────────────────────────────┐                  │
│  │ DYNAMICS                              │                  │
│  │ Input: [latent(192) + dt(1) + g(7)]  │                  │
│  │ → Dense(512) → SiLU                  │                  │
│  │ → Dense(512) → SiLU                  │                  │
│  │ → Dense(512) → SiLU                  │                  │
│  │ → Output: latent_delta(192)         │                  │
│  │ → Residual: latent + latent_delta    │                  │
│  └──────────────────────────────────────┘                  │
│         ↓                                                   │
│  ┌──────────────────────────────────────┐                  │
│  │ DECODER                               │                  │
│  │ Input: latent(192)                    │                  │
│  │ → Dense(512) → SiLU                  │                  │
│  │ → Dense(512) → SiLU                  │                  │
│  │ → Dense(512) → SiLU                  │                  │
│  │ Output: state(33)                    │                  │
│  └──────────────────────────────────────┘                  │
│         ↓                                                   │
│  Output: [33 species number densities, cm⁻³]               │
└─────────────────────────────────────────────────────────────┘
```

### Component Details

#### 1. Encoder (`Encoder` class)
**Purpose**: Maps input state + global context to latent representation

- **Input**: Concatenated [state(33) + global(7)] = 40 dimensions
  - State: Initial species abundances (zeros in our case)
  - Global: T, P, elemental abundances (normalized)
- **Architecture**: 3-layer MLP with 512 hidden units per layer
- **Output**: Latent vector of dimension 192
- **Activation**: SiLU (Sigmoid Linear Unit) - smooth, self-gated, prevents dead neurons
- **Rationale**: Compresses high-dimensional input to compact latent space while preserving essential information

#### 2. Dynamics (`LatentDynamics` class)
**Purpose**: Evolves latent state over time step dt

- **Input**: [latent(192) + dt(1) + global(7)] = 200 dimensions
  - dt: Time step (normalized to 1.0 for equilibrium)
  - Global: Same T, P, abundances as encoder
- **Architecture**: 3-layer MLP with 512 hidden units per layer
- **Output**: Latent delta (change in latent state)
- **Residual Connection**: `z_new = z_old + delta` (enables identity mapping, better gradient flow)
- **Initialization**: Output layer initialized near zero (small initial deltas)
- **Rationale**: Learns how to evolve from initial state to equilibrium in latent space

#### 3. Decoder (`Decoder` class)
**Purpose**: Maps latent representation back to species abundances

- **Input**: Latent vector (192 dimensions)
- **Architecture**: 3-layer MLP with 512 hidden units per layer
- **Output**: 33 species abundances (normalized)
- **Activation**: SiLU throughout
- **Rationale**: Reconstructs high-dimensional output from compact latent representation

### Key Architectural Decisions

#### Latent Dimension: 192
- **Rationale**: Optimal balance between compression and information preservation
- **Determined by**: Systematic hyperparameter study (Test #1)
- **Finding**: Smaller dimensions lose information; larger dimensions overfit
- **Result**: 192 provides best test loss for 33-species output space

#### Layer Width: 512
- **Rationale**: Sufficient capacity without overfitting
- **Determined by**: Layer width study (Test #2)
- **Finding**: 512 optimal; 768/1024 don't improve performance
- **Result**: Consistent 512-width layers across all components

#### Depth: 3 Layers
- **Rationale**: Deep enough to learn complex mappings, shallow enough to train efficiently
- **Determined by**: Layer width study comparing 3 vs 4 layers
- **Finding**: 3 layers perform better than 4 layers
- **Result**: All components use 3 hidden layers

#### Activation: SiLU (Swish)
- **Formula**: `SiLU(x) = x · sigmoid(x)`
- **Advantages**:
  - Smooth, differentiable everywhere
  - Self-gated (can output negative values)
  - Prevents "dying ReLU" problem
  - Better gradient flow than ReLU
- **Rationale**: Empirically performs better than ReLU, GELU, Tanh for this task

#### Residual Connections
- **Location**: Dynamics module only
- **Implementation**: `z_out = z_in + network(z_in, dt, g)`
- **Benefits**:
  - Enables identity mapping (if delta ≈ 0)
  - Better gradient flow through deep networks
  - Easier optimization
- **Initialization**: Output layer bias initialized to zero, weights scaled by 0.1

#### No Dropout
- **Rationale**: Model doesn't overfit with sufficient data (800K–4800K samples)
- **Evidence**: Validation loss tracks training loss closely
- **Result**: Dropout = 0.0 for all layers

### Forward Pass

```python
# Pseudocode for forward pass
def forward(y0, dt, g):
    # y0: initial state (zeros for equilibrium)
    # dt: time step (1.0 for equilibrium)
    # g: global context [T, P, abundances]
    
    # Encode: state + context → latent
    z = encoder(y0, g)  # [B, 192]
    
    # Dynamics: evolve latent state
    z_evolved = dynamics(z, dt, g)  # [B, 192]
    
    # Decode: latent → species abundances
    y_pred = decoder(z_evolved)  # [B, 33]
    
    return y_pred
```

### Why FlowMapAutoencoder?

1. **Handles High-Dimensional Outputs**: 33 species simultaneously (vs. separate models per species)
2. **Learns Correlations**: Captures relationships between species abundances
3. **Efficient**: Single forward pass predicts all species
4. **Interpretable Latent Space**: 192-dim latent captures essential chemistry
5. **Proven Architecture**: Used successfully in physics-informed ML
6. **Scalable**: Architecture works across dataset sizes (800K → 4800K)

### SimpleMLP (Optional Alternative)

An alternative architecture is available via `train_autoencoder_improved.py` with `configs/mlp_config.json`:

- **Architecture**: 6 fully connected layers of 1024 units each, with dropout (0.05)
- **Input**: Global features (T, P, abundances) → direct mapping to 33 species
- **Use case**: Simpler baseline for comparison; x4800_improved uses FlowMap (not MLP)

### Input Features (7 total)

| Feature | Description | Normalization | Typical Range |
|---------|-------------|---------------|---------------|
| T_K | Temperature (K) | T / 4000 | [0.17, 0.75] |
| P_bar | Pressure (bar) | log₁₀(P) / 10 | [-1.0, 0.5] |
| abund_H_dex | Hydrogen abundance | (dex - 12) / 10 = 0 | 0.0 |
| abund_O_dex | Oxygen abundance | (dex - 12) / 10 | [-0.9, 0.9] |
| abund_C_dex | Carbon abundance | (dex - 12) / 10 | [-0.9, 0.9] |
| abund_N_dex | Nitrogen abundance | (dex - 12) / 10 | [-0.9, 0.9] |
| abund_S_dex | Sulfur abundance | (dex - 12) / 10 | [-0.9, 0.9] |

**Dex scale**: `abund_X_dex = 12 + log₁₀(N_X / N_H)` (standard astrophysical notation)
- Solar values: H=12.0, O≈8.69, C≈8.43, N≈7.83, S≈7.12

### Output Species (33 total)

**Static species list**: Fixed 33-species list (32 + e⁻) ordered by mean abundance.

**Species list** (ordered by abundance): e⁻, N₂, O₂, C₅, H₂, S₂, C₁S₂, S₇, H₂S₁, S₈, C₄N₂, C₁O₁, O₂S₁, C₃H₁, C₂H₂, H₂O₁, S₃, S₄, C₁O₂, C₁H₄, S₆, O₃S₁, H₃N₁, O₁S₂, C₂H₄, S₅, C₁S₁, C₁O₁S₁, C₁H₁N₁_1, H₂O₄S₁, C₃O₂, N₁O₁, C₂N₂

**Coverage**: 99.68% of total mass abundance  
**Ordering**: Determined once from comprehensive dataset analysis, fixed for all training runs  
**Benefits**: Consistent architecture, reproducible outputs, better performance than dynamic selection

### Output Units (FastChem Default)

**Both FastChem and this emulator output species number densities in cm⁻³** (particles per cubic centimeter). This matches the FastChem default (see [FastChem documentation](https://newstrangeworlds.github.io/FastChem/sections/input_output.html)).

- **Training data**: Generated via `pyfastchem` → `output_data.number_densities` → number densities in cm⁻³
- **Model output**: Same units; `denormalize_targets()` returns values in cm⁻³
- **Typical range**: ~10⁻³⁰ (trace species) to ~10²³ (dominant species at high P, T) depending on conditions
- **Plots**: All parity, scatter, and error plots use consistent axis labels with units (cm⁻³)

*Note: FastChem can optionally output mixing ratios via config; this project uses the default number-density output.*

### Training Configuration

**Baseline** (`train_autoencoder.py`):
```python
Optimizer:      Adam (lr=5×10⁻⁴, weight_decay=1×10⁻⁵)
Scheduler:      ReduceLROnPlateau (factor=0.5, patience=10, min_lr=1×10⁻⁶)
Loss:           Log-ratio loss (L = |log₁₀(ŷ/y)|, computed in normalized space)
Batch size:     512
Epochs:         200
Gradient clip:  5.0
Data split:     85% train / 10% val / 5% test
Dataset size:   800K–4800K samples (800K increments, asymptote study)
Species:        Static ordering (33 species from configs/static_species_list_32.json)
```

**Best model** (`train_autoencoder_improved.py`, x4800_improved):
```python
Optimizer:      AdamW (lr=5×10⁻⁴, weight_decay=1×10⁻⁵)
Normalization: Train-only (target stats from training set only)
Architecture:  FlowMap (same as baseline) or optional SimpleMLP (6×1024, dropout 0.05)
```

---

## Usage Examples

### Basic Inference

```python
import sys
sys.path.append('results/runs/runs_autoencoder_x4800_improved')  # Best model

from best_model import load_model, normalize_inputs, forward_autoencoder, denormalize_targets, TARGET_COLS
import pandas as pd
import numpy as np
import torch

# Load trained model
model = load_model(device='cpu')
model.eval()

# Hot Jupiter atmosphere: T=1500K, P=0.1 bar, solar composition
df = pd.DataFrame({
    'T_K': [1500.0],
    'P_bar': [0.1],
    'abund_H_dex': [12.0],
    'abund_O_dex': [8.69],
    'abund_C_dex': [8.43],
    'abund_N_dex': [7.83],
    'abund_S_dex': [7.12],
})

# Predict
X = normalize_inputs(df)
with torch.no_grad():
    y_scaled = forward_autoencoder(model, X).cpu().numpy()
y_linear = denormalize_targets(y_scaled)

# Results: 33 species number densities
results = pd.DataFrame(y_linear, columns=TARGET_COLS)
print("Top-5 most abundant species:")
print(results.iloc[0].sort_values(ascending=False).head(5))
```

### Batch Processing (High Throughput)

```python
# Create T-P grid for atmospheric model
T_grid = np.linspace(1000, 2500, 100)
P_grid = np.logspace(-2, 2, 100)
T_mesh, P_mesh = np.meshgrid(T_grid, P_grid)

# Solar composition across entire grid
df_grid = pd.DataFrame({
    'T_K': T_mesh.ravel(),
    'P_bar': P_mesh.ravel(),
    'abund_H_dex': 12.0,
    'abund_O_dex': 8.69,
    'abund_C_dex': 8.43,
    'abund_N_dex': 7.83,
    'abund_S_dex': 7.12,
})

# Predict for 10,000 conditions
# CPU: ~0.05 seconds | MPS GPU: ~0.009 seconds
X = normalize_inputs(df_grid)
with torch.no_grad():
    y_scaled = forward_autoencoder(model, X).cpu().numpy()
y_linear = denormalize_targets(y_scaled)

# Shape: (10000, 33) — abundances for 33 species at 10k T-P points
```

**GPU Acceleration** (Apple Silicon — ~6× faster than CPU at batch 10K):

```python
# Load model on MPS GPU (Apple Silicon)
model = load_model(device='mps')
model.eval()

X = normalize_inputs(df_grid).to('mps')
with torch.no_grad():
    y_scaled = forward_autoencoder(model, X).cpu().numpy()
y_linear = denormalize_targets(y_scaled)
```

### Integration with Atmospheric Models

```python
def chemistry_step(T, P, composition):
    """
    Drop-in replacement for FastChem in atmospheric models.
    ~1,500× faster on GPU (1.3 ms → 0.0009 ms per evaluation, measured).
    """
    df = pd.DataFrame({
        'T_K': [T],
        'P_bar': [P],
        'abund_H_dex': [composition['H']],
        'abund_O_dex': [composition['O']],
        'abund_C_dex': [composition['C']],
        'abund_N_dex': [composition['N']],
        'abund_S_dex': [composition['S']],
    })
    
    X = normalize_inputs(df)
    with torch.no_grad():
        y_scaled = forward_autoencoder(model, X).cpu().numpy()
    y_linear = denormalize_targets(y_scaled)
    
    return dict(zip(TARGET_COLS, y_linear[0]))

# Use in your atmospheric model
for layer in atmosphere:
    chem = chemistry_step(layer.T, layer.P, layer.composition)
```

---

## Citation

If you use this emulator in your research, please cite:

```bibtex
@software{fastchem_ml_emulator_2025,
  author = {Mohanty, Yashnil and Malsky, Isaac and Zhang, Xi},
  title = {FastChem Neural Network Emulator: A PyTorch Surrogate Model 
           for Chemical Equilibrium in Planetary Atmospheres},
  year = {2025},
  version = {1.0},
  url = {https://github.com/yashnil/chemCalculations},
  note = {~1,500× speed-up over FastChem on GPU (measured) with Log R² = 0.9999}
}
```

### Acknowledging FastChem

This emulator is trained on data generated by [FastChem](https://github.com/exoclime/FastChem). Please also cite:

```bibtex
@article{Stock2018,
  author = {Stock, Joachim W. and Kitzmann, Daniel and Patzer, A. Beate C.},
  title = {FastChem: A computer program for efficient complex chemical 
           equilibrium calculations in the neutral/ionized gas phase with 
           applications to stellar and planetary atmospheres},
  journal = {Monthly Notices of the Royal Astronomical Society},
  year = {2018},
  volume = {479},
  number = {1},
  pages = {865--874},
  doi = {10.1093/mnras/sty1531}
}
```

---

## Contact

### Project Maintainers

**Yashnil Mohanty** (Lead Developer)  
📧 Email: yashnilmohanty@gmail.com  
🏛 Affiliation: Westmont High School  
🔬 Research: Computer Science and Exoplanet Atmospheres

**Xi Zhang** (Lead Mentor)  
📧 Email: xiz@ucsc.edu  
🏛 Affiliation: University of California, Santa Cruz  
🔬 Research: Earth and Planetary Sciences

**Isaac Malsky** (PyTorch Implementation)  
🏛 Affiliation: Jet Propulsion Laboratory  
🔬 Research: Exoplanet Atmospheres and Simulation of Physical Processes

### Getting Help

- **📖 Documentation**: See this README for detailed usage
- **🐛 Issues**: Report bugs via GitHub Issues
- **💬 Questions**: Email yashnilmohanty@gmail.com
- **🤝 Collaborations**: Open to integration with atmospheric modeling codes

### Contributing

We welcome contributions! Areas of interest:

1. **Extended chemistry**: More elements (Fe, Ti, Mg, Si, Na, Ca, etc.) - expandable to 30+ inputs
2. **Condensed phases**: Cloud and haze formation predictions
3. **Uncertainty quantification**: Bayesian neural networks, ensembles
4. **Speed optimization**: TensorRT, ONNX export, quantization
5. **Integration**: Wrappers for petitRADTRANS, BART, PICASO, Exo-Transmit
6. **Validation**: Testing against JWST/HST retrievals

**To contribute**: Fork repository, create feature branch, add tests, submit pull request

---

## Methods

### Data Generation Pipeline

#### Overview
We generate training datasets by running FastChem on millions of randomly sampled atmospheric conditions. The pipeline consists of three main steps:

1. **Prepare Job Shards**: Sample (T, P, abundances) conditions and split into parallelizable shards
2. **Run FastChem**: Execute FastChem for each shard using Python bindings
3. **Merge Results**: Combine shard outputs into single CSV matching reference schema

#### Sampling Strategy: Empirical Resampling with Jitter

**Method**: Empirical resampling from existing dataset with controlled jitter

**Process**:
1. **Base Dataset**: Start with reference dataset (e.g., x800K)
2. **Resample**: Randomly sample conditions from base dataset
3. **Apply Jitter**: Add controlled noise to increase diversity:
   - **Temperature**: ±50K jitter (Gaussian)
   - **Pressure**: ±0.1 dex jitter in log₁₀(P) space
   - **Abundances**: ±0.05 dex jitter for each element
4. **Validate**: Ensure jittered values remain in valid ranges

**Advantages**:
- Preserves distribution of realistic conditions
- Increases diversity through jitter
- Faster than uniform sampling (no wasted samples)
- Maintains coverage of parameter space

**Dataset Sizes**: 800K, 1600K, 2400K, 3200K, 4000K, 4800K (800K increments)

#### FastChem Execution

**Implementation**: Python bindings (`pyfastchem`)

**Process**:
1. Load FastChem with thermodynamic data (logK tables)
2. For each condition:
   - Set elemental abundances
   - Set temperature and pressure
   - Call `calcDensities()` to solve equilibrium
   - Extract species number densities
3. Handle failures: Mark NaN for failed calculations (filtered later)

**Performance**: ~1.3 ms per evaluation with engine reuse, ~6 ms cold start (measured)

**Validation**:
- Check for convergence flags
- Filter NaN/Inf values during merge
- Validate row counts match conditions

#### Data Quality Filters

**Low-Temperature Filter**: T > 750K
- **Rationale**: FastChem produces numerical artifacts (vertical striping) at low temperatures
- **Impact**: Removes ~5-10% of samples but eliminates prediction artifacts
- **Result**: Clean predictions across full temperature range

**NaN/Inf Filtering**:
- FastChem failures marked as NaN
- Dropped during merge step
- Reported in merge logs

**Mass Conservation**:
- Validated against FastChem outputs
- Total mass should be conserved (within numerical precision)

**Coverage**: 99.68% of total mass with 33 species (static ordering)

### Training Methodology

#### Data Splitting
- **Train**: 85% (e.g., 4.08M samples for x4800K)
- **Validation**: 10% (e.g., 480K samples)
- **Test**: 5% (e.g., 240K samples)
- **Split Method**: Random split with fixed seed (42) for reproducibility

#### Loss Function: Log-Ratio Loss

**Formula**: `L = |log₁₀(ŷ/y)|` computed in normalized space

**Rationale**:
- Species abundances span 30 orders of magnitude (10⁻³⁰ to 1)
- Linear loss dominated by most abundant species
- Log-space loss treats all species equally
- Directly minimizes error in dex (astrophysical standard)

**Implementation**:
```python
# Normalized space (already log-transformed)
loss = torch.abs(torch.log10(pred_normalized + eps) - torch.log10(target_normalized + eps))
loss = loss.mean()
```

**Weighting**: Optional per-species weighting (not used in final model)

#### Optimization

**Optimizer**: Adam
- Learning rate: 5×10⁻⁴ (initial)
- Weight decay: 1×10⁻⁵ (L2 regularization)
- Beta1: 0.9, Beta2: 0.999 (default)

**Learning Rate Schedule**: ReduceLROnPlateau
- Mode: minimize validation loss
- Factor: 0.5 (halve LR)
- Patience: 10 epochs
- Min LR: 1×10⁻⁶

**Gradient Clipping**: 5.0
- Prevents exploding gradients
- Stabilizes training

**Batch Size**: 512
- Balance between memory and gradient stability
- Empirically optimal for this architecture

**Epochs**: 200
- Sufficient for convergence
- Early stopping not needed (validation loss continues improving)

#### Training Process

1. **Initialization**: Xavier uniform for linear layers
2. **Forward Pass**: Encode → Dynamics → Decode
3. **Loss Computation**: Log-ratio loss in normalized space
4. **Backward Pass**: Compute gradients
5. **Gradient Clipping**: Clip gradients to max norm 5.0
6. **Optimizer Step**: Update weights
7. **Validation**: Evaluate on validation set every epoch
8. **LR Scheduling**: Reduce LR if validation plateaus
9. **Checkpointing**: Save best model (lowest validation loss)

**Training Time**: ~2-3 hours for 4000K–4800K samples (CPU, 200 epochs)

### Normalization Strategy

#### Input Normalization

**Temperature**: `T_norm = T_K / 4000`
- Rationale: Upper bound of typical atmospheres
- Range: [0.17, 0.75] for 750-3000 K

**Pressure**: `P_norm = log₁₀(P_bar) / 10`
- Rationale: Pressure spans many orders of magnitude
- Range: [-1.0, 0.5] for 10⁻¹⁰ to 10⁵ bar

**Abundances**: `abund_norm = (abund_dex - 12) / 10`
- Rationale: Solar hydrogen reference (12.0), typical variation span
- Range: [-0.9, 0.9] for typical element variations

#### Target Normalization

**Method**: Log-space normalization
- `y_log = log₁₀(y_linear + ε)` where ε = 10⁻³⁰
- `y_norm = (y_log - log_mean) / log_std` (computed from training data)

**Rationale**:
- Abundances span 30 orders of magnitude
- Log-space normalization centers and scales data
- Prevents numerical overflow/underflow

**Denormalization**:
```python
y_norm → y_log → y_linear = 10^(y_log) - ε
```

### Hyperparameter Optimization

See [Performance Metrics](#performance-metrics) section for detailed hyperparameter studies.

**Key Findings**:
- Latent dimension: 192 (optimal)
- Layer width: 512 (optimal)
- Depth: 3 layers (optimal)
- Activation: SiLU (best performance)
- Loss: Log-ratio (best for abundance prediction)

---

## Diagnostics

### Overview

The diagnostics suite (`chemcalculations.diagnostics`) provides comprehensive validation of model performance, generating publication-ready plots and detailed error analysis.

### Running Diagnostics

```bash
# Set environment variables (example for x4800, the best model)
export CSV_PATH="data/datasets/all_gas_fastchem_x4800.csv"
export BEST_MODULE="results/runs/runs_autoencoder_x4800_optimal_retrained/best_model.py"
export OUT_DIR="results/runs/runs_autoencoder_x4800_optimal_retrained/diagnostics"

# Run diagnostics
python -m chemcalculations.diagnostics
# or: chemcalc-diagnostics
```

### Diagnostic Outputs

#### 1. Global Metrics (`global_metrics.txt`)

**Metrics Computed**:
- **Linear MAE**: Mean absolute error in linear space
- **Linear R²**: Variance explained in linear space
- **Log MAE**: Mean absolute error in log₁₀ space (dex)
- **Log R²**: Variance explained in log space (primary metric)

**Example Output**:
```
Linear MAE:  8.635e+18
Linear R²:   0.957
Log MAE:     1.368e-02 dex
Log R²:      0.999565
Test samples: 240,000
Species:     33
```

#### 2. Per-Species Metrics (`per_species_errors.csv`)

**Columns**:
- `species`: Species name
- `MAE`: Mean absolute error (linear space)
- `R2`: R² score for this species
- `max_abundance`: Maximum abundance in test set
- `mean_abundance`: Mean abundance in test set

**Usage**: Identifies which species are predicted most/least accurately

#### 3. Parity Plots

**Top-10 Species Parity** (`parity_top10.png`):
- Individual parity plots for 10 most abundant species
- Log-log scale with 1:1 line
- ±10% error bands
- Density coloring (KDE) for large datasets

**Overall Parity** (`parity_overall.png`):
- All species pooled together
- Log-log scale
- Density heatmap or hexbin
- 1:1 line and ±10% error bands
- Publication-ready figure

#### 4. Error Analysis Plots

**MAE per Species** (`MAE_per_species.png`):
- Horizontal bar chart showing error for each species
- Red bars = above global average
- Blue bars = below global average
- Identifies problematic species

**Residuals vs Observed** (`residual_vs_observed.png`):
- Hexbin plot of residuals vs true abundance
- Identifies systematic biases
- Shows error distribution across abundance range

**Error Distribution** (`error_distribution.png`):
- Two histograms: linear and log space residuals
- Shows error distribution shape
- Identifies outliers and bias

#### 5. Worst Samples (`worst100_samples.csv`)

**Purpose**: Identify conditions where model fails

**Columns**: All input features + predicted/true abundances + errors

**Usage**: Analyze failure modes, identify edge cases

### Diagnostic Interpretation

#### Good Performance Indicators:
- ✅ Log R² > 0.999 (excellent)
- ✅ Log MAE < 0.02 dex (very good)
- ✅ Parity plots show tight 1:1 correlation
- ✅ Residuals centered around zero
- ✅ No systematic biases in error plots

#### Warning Signs:
- ⚠️ Log R² < 0.99 (may need more data)
- ⚠️ Log MAE > 0.05 dex (may need architecture changes)
- ⚠️ Systematic bias in residuals (check normalization)
- ⚠️ Certain species consistently wrong (may need more training data for those species)

### Best Model Diagnostics

**x4800_improved** (recommended):
- Log R²: 0.9999 (99.99% variance explained)
- Log MAE: 0.0039 dex
- All species: R² > 0.99
- No systematic biases observed
- Error distribution: Normal, centered at zero

**x4800_optimal_retrained** (baseline):
- Log R²: 0.9996, Log MAE: 0.0137 dex

**Diagnostic Files Location**: `results/runs/runs_autoencoder_x4800_improved/diagnostics/`

---

## Technical Details

### Normalization Philosophy

**Why these specific constants?**

| Constant | Value | Rationale |
|----------|-------|-----------|
| TEMP_DIVISOR | 4000 | Upper bound of typical atmospheres (3000K) + margin |
| INPUT_LOG_SCALE | 10 | Brings log₁₀(P) to ~[-1, 0.5] range for 10⁻¹⁰ to 10⁵ bar |
| ABUND_OFFSET | 12 | Solar hydrogen reference (standard astrophysical notation) |
| ABUND_SCALE | 10 | Typical element variation span (±1 dex around solar) |
| TARGET_LOG_SCALE | 30 | Abundance range spans ~30 orders of magnitude (10⁻³⁰ to 1) |

**Benefits**:
- **No data dependencies**: Unlike StandardScaler, doesn't require training data statistics
- **Physical meaning**: Based on astrophysical scales, not arbitrary
- **Reproducibility**: Same normalization across all datasets
- **Low variance**: Normalized features have similar scales
- **Interpretability**: Normalized values have physical meaning

### Error Handling

**Philosophy**: Transparency over hiding problems

**Approach**:
- Detects non-finite values (NaN, Inf) in inputs/targets during data loading
- Logs per-column counts and example row indices
- **Drops** problematic rows (doesn't sanitize to avoid hiding issues)
- Reports how many rows dropped and why
- FastChem failures filtered during merge step

**Result**: No silent failures, easier debugging, clean training data

### Reproducibility

**Random Seeds**: Fixed seed (42) for:
- Data splitting (train/val/test)
- Weight initialization
- Data shuffling

**Deterministic Training**: 
- `torch.use_deterministic_algorithms(True)` where possible
- Fixed CUDA seed if using GPU

**Config Files**: All hyperparameters stored in JSON configs
- Architecture parameters
- Training parameters
- Normalization constants

**Result**: Same config → same results (within numerical precision)

---

## Computational Requirements

### Minimum Requirements
- **CPU**: Modern multi-core (training takes ~24 minutes)
- **RAM**: 4 GB (dataset + model fit in memory)
- **Storage**: ~50 MB (data + model)
- **Python**: 3.9+

### Recommended for Production
- **CPU**: Recent Intel/AMD or Apple Silicon
- **GPU**: Optional — Apple Silicon MPS gives ~6× over CPU (measured: 1.17M samples/sec on M1 Max)
- **RAM**: 8 GB (comfortable for large batches)

### Dependencies
```
torch >= 2.0
numpy >= 1.20
pandas >= 1.3
scikit-learn >= 1.0
matplotlib >= 3.5 (for diagnostics)
scipy >= 1.7 (optional, for KDE density plots)
```

---

## Acknowledgments

**Scientific Community**:
- FastChem team (Stock, Kitzmann, Patzer) for the original equilibrium solver
- UCSC Exoplanet Group for computational resources and scientific guidance
- Isaac Malsky for the PyTorch implementation and key architectural insights

**Technical Infrastructure**:
- PyTorch team for deep learning framework
- Python scientific stack (NumPy, Pandas, Scikit-learn, Matplotlib)

---

## License

MIT License

Copyright (c) 2025 Yashnil Mohanty, Xi Zhang, Isaac Malsky

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so.

---

## Future Work

**Near-term** (achievable with current architecture):
1. Expand to 30 input features (include Fe, Ti, Mg, Si, Na, Ca, etc.)
2. Increase to 50 output species (still focused, but more coverage)
3. GPU optimization for ultra-fast batch inference
4. Ensemble models for uncertainty estimation

**Long-term** (research directions):
1. Condensed-phase species (clouds, hazes, rainout)
2. Non-equilibrium chemistry (kinetics, photochemistry)
3. Physics-informed neural networks (enforce mass conservation)
4. Active learning (sample T-P-composition space adaptively)
5. Integration with radiative transfer (end-to-end differentiable atmospheres)

### Suggested Next Steps to Improve the Model

| Priority | Direction | Rationale |
|----------|-----------|-----------|
| **1** | **Targeted oversampling** | Hot Jupiter low-P (10⁻⁶ bar) and C/O > 1.5 show highest error; add more training samples in these regions to reduce Log MAE from 0.27 → <0.15 dex |
| **2** | **Physics-informed loss** | Add soft mass-conservation penalty: `L_total = L_log_ratio + λ·‖Σᵢ nᵢ − Σᵢ n̂ᵢ‖` to enforce physical consistency |
| **3** | **Uncertainty quantification** | Train a small ensemble (3–5 models) or use MC dropout; output mean + std for each species to support Bayesian retrievals |
| **4** | **Knowledge distillation** | Train a smaller student model (e.g., 1M params) to match x4800 outputs; faster inference, smaller footprint for deployment |
| **5** | **Curriculum learning** | Train first on easy conditions (solar, mid T-P), then fine-tune on hard regions (low P, extreme C/O); may improve generalization |
| **6** | **Quantization (INT8)** | Post-training quantization for 2–4× smaller model and faster CPU inference; validate accuracy drop is acceptable |

---

## Summary

**FastChem ML Emulator** solves the computational bottleneck in atmospheric modeling:

- **Problem**: FastChem too slow (~1.3 ms/call measured, engine reuse) for modern applications requiring millions to billions of evaluations
- **Solution**: Neural network emulator with FlowMapAutoencoder architecture, optimized for CPU and GPU inference
- **Result**: ~1,500× faster on GPU (~250× on CPU) with excellent accuracy (Log R² ≈ 0.9999, Log MAE ≈ 0.00391 dex, MFAE ≈ 0.0105)
- **Impact**: Enables JWST/HST retrievals, 3D GCMs, and population studies that were previously infeasible
- **Scalability**: Performance improves with dataset size (tested 800K–4800K, achieving 45% Log MAE improvement with clear asymptotic plateau at ~3200K)
- **Improvement**: x4800_improved (AdamW + train-only norm) achieves ~71% lower Log MAE and ~41% lower MFAE than x4800_optimal_retrained

**Current status**: Study complete, production-ready, validated, and recommended for all use cases.

**Best model**: **x4800_improved** (4.8M samples) — metrics in [Performance Metrics](#performance-metrics); artifacts in `results/runs/runs_autoencoder_x4800_improved/` and `plots/comparison_metrics.csv`

**Get started**: 
```bash
# Train best model (x4800_improved)
cd src && python train_autoencoder_improved.py --config ../configs/x4800_improved.json

# Or use pre-trained model
python -c "import sys; sys.path.insert(0, 'results/runs/runs_autoencoder_x4800_improved'); from best_model import load_model; model = load_model()"
```
