JAXSR: Interpretable Scientific Models from Data

Open-Source Symbolic Regression via Sparse Optimization with JAX

John R. Kitchin
Department of Chemical Engineering, Carnegie Mellon University

The Problem: Black-Box Models in Science

Modern ML models are powerful but opaque.

  • Neural networks fit data well, but scientists cannot inspect, interpret, or trust them
  • No closed-form expression means no mechanistic insight
  • Extrapolation outside training data is unreliable
  • Uncertainty quantification is ad hoc or absent
  • Physical constraints (non-negativity, monotonicity, thermodynamic consistency) cannot be enforced

Scientists need models they can read, trust, and reason about.

Why Not Existing Symbolic Regression?

Current tools leave critical gaps:

Need Genetic Programming (PySR) Commercial (ALAMO) JAXSR
Deterministic & reproducible No Yes Yes
No commercial solver needed Yes No Yes
Uncertainty quantification No No 5 methods
Physical constraints No Yes 8 types
Active learning No Yes 15 acq. functions
GPU acceleration Yes No Yes
Scikit-learn compatible No No Yes

JAXSR is the only open-source tool that combines all of these.

How JAXSR Works

Three-step approach: Build, Select, Validate

  1. Build a library of candidate basis functions

    • 15 families: polynomials, interactions, transcendentals, ratios, compositions, parametric...
    • Fluent API: .add_linear().add_polynomials(3).add_transcendental()
  2. Select the best sparse model via information criteria (AIC/BIC/AICc)

    • No hyperparameters to tune -- automatic complexity-accuracy tradeoff
    • Returns full Pareto front of models from simple to complex
  3. Validate with uncertainty quantification and constraint enforcement

    • OLS intervals, Bayesian model averaging, conformal prediction, bootstrap
    • Enforce physics: bounds, monotonicity, convexity, sign constraints

Example: Discovering Kinetic Rate Laws

Problem: Given reaction rate data, discover the rate expression.

True model: (power-law kinetics)

from jaxsr import SymbolicRegressor, BasisLibrary
basis = BasisLibrary(n_features=2).add_power_laws(max_degree=2)
model = SymbolicRegressor(basis, strategy="greedy_forward")
model.fit(X, y)

Discovered: | R^2 = 0.9999

Also recovers Arrhenius parameters to 0.13% error and Langmuir isotherms to 0.8% error -- from data alone.

Uncertainty Quantification: Five Methods, One API

Most SR tools give a point prediction. JAXSR provides calibrated uncertainty.

Method Assumption Guarantee Use case
OLS intervals Gaussian residuals Exact if true Standard regression
Bootstrap None (resampling) Approximate Non-Gaussian data
Conformal prediction Exchangeability Finite-sample Safety-critical
Bayesian Model Averaging Model prior Accounts for model uncertainty Multiple plausible models
Ensemble disagreement Pareto front Structural uncertainty Extrapolation detection

All accessible through a unified API: model.predict_interval(), model.predict_conformal(), model.predict_bma()

Enforcing Physical Laws

Unconstrained models can violate known physics. JAXSR enforces 8 constraint types.

Example: Gibbs-Duhem Thermodynamic Consistency

Constraint Level Max Residual Approach
Unconstrained 0.18 No enforcement
Refit only 0.16 Constrained refit
Iterative selection 0.007 Constraint-aware model search
Hard (null-space) 2.78 x 10^-11 Exact at machine precision

Physics-informed models extrapolate better and inspire greater confidence.

Active Learning: Smarter Experiments

Experiments are expensive. JAXSR includes 15 acquisition functions for sequential experimental design -- no Gaussian process surrogate needed.

  • Exploration: Where is model uncertainty highest?
  • Exploitation: Where does the model predict optimal performance?
  • Design-based: A-optimal, D-optimal for parameter estimation
  • Model discrimination: Which experiment best distinguishes competing models?

Langmuir isotherm case: 8 initial points + 12 actively selected points recovered to 0.8% error and to 2.6% error with calibrated confidence intervals.

All computed in closed form from OLS posterior -- fast enough for real-time lab use.

Beyond Regression

JAXSR extends to broader scientific modeling tasks:

Symbolic Classification

  • Discover interpretable decision boundaries as sparse logistic expressions
  • One-vs-rest multiclass with conformal prediction sets

ODE Discovery (SINDy-style)

  • Recover governing equations from time-series data
  • Example: Lotka-Volterra predator-prey dynamics from trajectory data

Response Surface Methodology

  • Full DOE integration: factorial, CCD, Box-Behnken designs
  • ANOVA decomposition, canonical analysis, optimization

Integrated Scientific Workflows

JAXSR is not just a library -- it is a workflow platform.

  • DOEStudy object: Design -> Experiment -> Fit -> Analyze -> Iterate, all serialized to JSON
  • CLI: jaxsr fit data.csv, jaxsr predict, jaxsr compare
  • Streamlit app: Interactive GUI for the full DOE cycle
  • Export: LaTeX, SymPy, NumPy callables, Excel/Word reports, JSON serialization
  • Scikit-learn integration: Works with cross_val_score, GridSearchCV, Pipeline

Designed for scientists who want to focus on science, not software.

Vision: Interpretable AI for Scientific Discovery

Where we are: A mature, open-source tool with 19,000 lines of code, 294 tests, and comprehensive documentation.

Where we want to go:

  • Scalability: Extend to high-dimensional problems (100+ features) with advanced screening
  • Multi-fidelity: Combine cheap computational data with expensive experimental data
  • Autonomous labs: Tight integration with robotic experiment platforms for closed-loop discovery
  • Community: Build an ecosystem of domain-specific basis libraries (catalysis, polymers, electrochemistry)
  • Education: Classroom-ready tools for teaching data-driven modeling with physical insight

Why Fund This Work?

Scientific impact:
Interpretable models accelerate discovery by revealing mechanisms, not just correlations.

Broader impacts:
Open-source, no commercial dependencies -- accessible to any research group worldwide.

Unique position:
No other open-source tool combines deterministic SR + UQ + constraints + active learning.

Proven foundation:
Working software with extensive test suite, documentation, and example applications in chemical kinetics, thermodynamics, heat transfer, and experimental design.

JAXSR makes the transition from "black-box prediction" to "interpretable scientific understanding" practical, rigorous, and accessible.

Thank You

JAXSR: Open-Source Symbolic Regression via Sparse Optimization with JAX

John R. Kitchin | Carnegie Mellon University
Department of Chemical Engineering

GitHub: github.com/jkitchin/jaxsr

Questions?