Metadata-Version: 2.3
Name: jaxace
Version: 0.3.0
Summary: JAX implementation of AbstractCosmologicalEmulators.jl interface
Author: marcobonici
Author-email: bonici.marco@gmail.com
Requires-Python: >=3.10,<4.0
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Requires-Dist: diffrax (>=0.7.0,<0.8.0)
Requires-Dist: flax (>=0.10.0,<0.11.0)
Requires-Dist: interpax (>=0.3.9,<0.4.0)
Requires-Dist: jax (>=0.4.30,<0.5.0)
Requires-Dist: jaxlib (>=0.4.30,<0.5.0)
Requires-Dist: numpy (>=2.2.6,<3.0.0)
Requires-Dist: quadax (>=0.2.9,<0.3.0)
Description-Content-Type: text/markdown

# jaxace

[![Tests](https://github.com/CosmologicalEmulators/jaxace/actions/workflows/tests.yml/badge.svg)](https://github.com/CosmologicalEmulators/jaxace/actions/workflows/tests.yml)
[![Documentation](https://img.shields.io/badge/docs-stable-blue)](https://cosmologicalemulators.github.io/jaxace/stable/)
[![Documentation](https://img.shields.io/badge/docs-dev-blue)](https://cosmologicalemulators.github.io/jaxace/dev/)
[![codecov](https://codecov.io/gh/CosmologicalEmulators/jaxace/graph/badge.svg?token=8DGPCJR8KX)](https://codecov.io/gh/CosmologicalEmulators/jaxace)

JAX/Flax implementation of cosmological emulators with automatic JIT compilation.

## Installation

```bash
pip install -e .
```

## Usage

```python
import jaxace
import jax.numpy as jnp

# Cosmology
cosmo = jaxace.w0waCDMCosmology(
    ln10As=3.044, ns=0.9649, h=0.6736,
    omega_b=0.02237, omega_c=0.1200,
    m_nu=0.06, w0=-1.0, wa=0.0
)

# Background functions
z = jnp.array([0.0, 0.5, 1.0])
growth = jaxace.D_z_from_cosmo(z, cosmo)
distance = jaxace.r_z_from_cosmo(z, cosmo)

# Neural network emulator
emulator = jaxace.init_emulator(nn_dict, weights, jaxace.FlaxEmulator)
output = emulator(input_data)  # Auto-JIT + batch detection
```

## Features

- Background cosmology (growth, distances, Hubble)
- Neural network emulators with auto-JIT
- Massive neutrinos and dark energy support
- Full JAX integration (grad, vmap, jit)

