Metadata-Version: 2.4
Name: datarax
Version: 0.1.1
Summary: Datarax: A high-performance, NNX-based data pipeline framework for JAX
Project-URL: Bug Tracker, https://github.com/avitai/datarax/issues
Project-URL: Documentation, https://datarax.readthedocs.io
Project-URL: Source, https://github.com/avitai/datarax
Author: Mahdi Shafiei
License: MIT License
        
        Copyright (c) 2025 Mahdi Shafiei
        
        Permission is hereby granted, free of charge, to any person obtaining a copy
        of this software and associated documentation files (the "Software"), to deal
        in the Software without restriction, including without limitation the rights
        to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
        copies of the Software, and to permit persons to whom the Software is
        furnished to do so, subject to the following conditions:
        
        The above copyright notice and this permission notice shall be included in all
        copies or substantial portions of the Software.
        
        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        SOFTWARE.
License-File: LICENSE
Keywords: data-loading,flax,jax,machine-learning,nnx,workshop
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Education
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: MacOS
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Operating System :: POSIX :: Linux
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Topic :: Software Development
Classifier: Topic :: Software Development :: Libraries
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.11
Requires-Dist: absl-py>=1
Requires-Dist: antlr4-python3-runtime==4.9.3
Requires-Dist: array-record>=0.4
Requires-Dist: astunparse>=1.6
Requires-Dist: beartype>=0.14.1
Requires-Dist: chex>=0.1.7
Requires-Dist: cloudpickle>=2
Requires-Dist: decorator>=5
Requires-Dist: dm-tree>=0.1.8
Requires-Dist: flax>=0.12.0
Requires-Dist: grain>=0.1
Requires-Dist: jax>=0.6.1
Requires-Dist: jaxtyping>=0.2.20
Requires-Dist: matplotlib
Requires-Dist: optax>=0.1.4
Requires-Dist: orbax-checkpoint>=0.11.10
Requires-Dist: pre-commit>=4.3.0
Requires-Dist: protobuf>=5.28.0
Requires-Dist: psutil>=5.9.0
Requires-Dist: pyarrow>=12
Requires-Dist: ruff>=0.1.5
Requires-Dist: scipy>=1.10
Requires-Dist: tensorflow-datasets>=4.9.2
Requires-Dist: tensorflow-metadata>=1.15.0
Requires-Dist: tensorflow>=2.20.0
Requires-Dist: tomli-w>=1
Requires-Dist: tomli>=2; python_version < '3.11'
Provides-Extra: all
Requires-Dist: beartype>=0.14.1; extra == 'all'
Requires-Dist: build>=1.0.3; extra == 'all'
Requires-Dist: coverage>=7; extra == 'all'
Requires-Dist: datasets>=2.14.0; extra == 'all'
Requires-Dist: google-cloud-aiplatform>=1.38.0; extra == 'all'
Requires-Dist: google-cloud-storage>=2.14.0; extra == 'all'
Requires-Dist: grain-nightly; (sys_platform == 'linux') and extra == 'all'
Requires-Dist: griffe>=1.7.3; extra == 'all'
Requires-Dist: hippogriffe>=0.2; extra == 'all'
Requires-Dist: ipykernel>=6.29.5; extra == 'all'
Requires-Dist: jax[cuda12]>=0.6.1; (sys_platform == 'linux') and extra == 'all'
Requires-Dist: jaxlib>=0.6.1; (sys_platform == 'linux') and extra == 'all'
Requires-Dist: jupytext>=1.16.0; extra == 'all'
Requires-Dist: librosa>=0.10.0; extra == 'all'
Requires-Dist: mkdocs-gen-files>=0.5.0; extra == 'all'
Requires-Dist: mkdocs-include-exclude-files>=0.1; extra == 'all'
Requires-Dist: mkdocs-ipynb>=0.1; extra == 'all'
Requires-Dist: mkdocs-jupyter>=0.24; extra == 'all'
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'all'
Requires-Dist: mkdocs-material>=9.6.7; extra == 'all'
Requires-Dist: mkdocs>=1.6.1; extra == 'all'
Requires-Dist: mkdocstrings-python>=1.1.2; extra == 'all'
Requires-Dist: mkdocstrings>=0.28.3; extra == 'all'
Requires-Dist: pillow>=10.0.0; extra == 'all'
Requires-Dist: pydeps>=1.12.17; extra == 'all'
Requires-Dist: pymdown-extensions>=10.14.3; extra == 'all'
Requires-Dist: pyright>=1.1.336; extra == 'all'
Requires-Dist: pytest-asyncio>=0.23; extra == 'all'
Requires-Dist: pytest-benchmark>=4; extra == 'all'
Requires-Dist: pytest-cov>=6.1.1; extra == 'all'
Requires-Dist: pytest-env>=1.0.1; extra == 'all'
Requires-Dist: pytest-json-report>=1.5.0; extra == 'all'
Requires-Dist: pytest-timeout>=2.1; extra == 'all'
Requires-Dist: pytest-xdist>=3.6; extra == 'all'
Requires-Dist: pytest>=8.3.5; extra == 'all'
Requires-Dist: python-dotenv>=1; extra == 'all'
Requires-Dist: ruff>=0.1.5; extra == 'all'
Requires-Dist: shellcheck-py>=0.10.0.1; extra == 'all'
Requires-Dist: soundfile>=0.12.1; extra == 'all'
Requires-Dist: tensorflow-datasets>=4.9.2; extra == 'all'
Requires-Dist: twine>=4.0.2; extra == 'all'
Provides-Extra: all-cpu
Requires-Dist: beartype>=0.14.1; extra == 'all-cpu'
Requires-Dist: build>=1.0.3; extra == 'all-cpu'
Requires-Dist: coverage>=7; extra == 'all-cpu'
Requires-Dist: datasets>=2.14.0; extra == 'all-cpu'
Requires-Dist: google-cloud-aiplatform>=1.38.0; extra == 'all-cpu'
Requires-Dist: google-cloud-storage>=2.14.0; extra == 'all-cpu'
Requires-Dist: grain-nightly; (sys_platform == 'linux') and extra == 'all-cpu'
Requires-Dist: griffe>=1.7.3; extra == 'all-cpu'
Requires-Dist: hippogriffe>=0.2; extra == 'all-cpu'
Requires-Dist: ipykernel>=6.29.5; extra == 'all-cpu'
Requires-Dist: jupytext>=1.16.0; extra == 'all-cpu'
Requires-Dist: librosa>=0.10.0; extra == 'all-cpu'
Requires-Dist: mkdocs-gen-files>=0.5.0; extra == 'all-cpu'
Requires-Dist: mkdocs-include-exclude-files>=0.1; extra == 'all-cpu'
Requires-Dist: mkdocs-ipynb>=0.1; extra == 'all-cpu'
Requires-Dist: mkdocs-jupyter>=0.24; extra == 'all-cpu'
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'all-cpu'
Requires-Dist: mkdocs-material>=9.6.7; extra == 'all-cpu'
Requires-Dist: mkdocs>=1.6.1; extra == 'all-cpu'
Requires-Dist: mkdocstrings-python>=1.1.2; extra == 'all-cpu'
Requires-Dist: mkdocstrings>=0.28.3; extra == 'all-cpu'
Requires-Dist: pillow>=10.0.0; extra == 'all-cpu'
Requires-Dist: pydeps>=1.12.17; extra == 'all-cpu'
Requires-Dist: pymdown-extensions>=10.14.3; extra == 'all-cpu'
Requires-Dist: pyright>=1.1.336; extra == 'all-cpu'
Requires-Dist: pytest-asyncio>=0.23; extra == 'all-cpu'
Requires-Dist: pytest-benchmark>=4; extra == 'all-cpu'
Requires-Dist: pytest-cov>=6.1.1; extra == 'all-cpu'
Requires-Dist: pytest-env>=1.0.1; extra == 'all-cpu'
Requires-Dist: pytest-json-report>=1.5.0; extra == 'all-cpu'
Requires-Dist: pytest-timeout>=2.1; extra == 'all-cpu'
Requires-Dist: pytest-xdist>=3.6; extra == 'all-cpu'
Requires-Dist: pytest>=8.3.5; extra == 'all-cpu'
Requires-Dist: python-dotenv>=1; extra == 'all-cpu'
Requires-Dist: ruff>=0.1.5; extra == 'all-cpu'
Requires-Dist: shellcheck-py>=0.10.0.1; extra == 'all-cpu'
Requires-Dist: soundfile>=0.12.1; extra == 'all-cpu'
Requires-Dist: tensorflow-datasets>=4.9.2; extra == 'all-cpu'
Requires-Dist: twine>=4.0.2; extra == 'all-cpu'
Provides-Extra: all-macos
Requires-Dist: beartype>=0.14.1; extra == 'all-macos'
Requires-Dist: build>=1.0.3; extra == 'all-macos'
Requires-Dist: coverage>=7; extra == 'all-macos'
Requires-Dist: datasets>=2.14.0; extra == 'all-macos'
Requires-Dist: google-cloud-aiplatform>=1.38.0; extra == 'all-macos'
Requires-Dist: google-cloud-storage>=2.14.0; extra == 'all-macos'
Requires-Dist: grain-nightly; (sys_platform == 'linux') and extra == 'all-macos'
Requires-Dist: griffe>=1.7.3; extra == 'all-macos'
Requires-Dist: hippogriffe>=0.2; extra == 'all-macos'
Requires-Dist: ipykernel>=6.29.5; extra == 'all-macos'
Requires-Dist: jax-metal>=0.1.0; (sys_platform == 'darwin') and extra == 'all-macos'
Requires-Dist: jupytext>=1.16.0; extra == 'all-macos'
Requires-Dist: librosa>=0.10.0; extra == 'all-macos'
Requires-Dist: mkdocs-gen-files>=0.5.0; extra == 'all-macos'
Requires-Dist: mkdocs-include-exclude-files>=0.1; extra == 'all-macos'
Requires-Dist: mkdocs-ipynb>=0.1; extra == 'all-macos'
Requires-Dist: mkdocs-jupyter>=0.24; extra == 'all-macos'
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'all-macos'
Requires-Dist: mkdocs-material>=9.6.7; extra == 'all-macos'
Requires-Dist: mkdocs>=1.6.1; extra == 'all-macos'
Requires-Dist: mkdocstrings-python>=1.1.2; extra == 'all-macos'
Requires-Dist: mkdocstrings>=0.28.3; extra == 'all-macos'
Requires-Dist: pillow>=10.0.0; extra == 'all-macos'
Requires-Dist: pydeps>=1.12.17; extra == 'all-macos'
Requires-Dist: pymdown-extensions>=10.14.3; extra == 'all-macos'
Requires-Dist: pyright>=1.1.336; extra == 'all-macos'
Requires-Dist: pytest-asyncio>=0.23; extra == 'all-macos'
Requires-Dist: pytest-benchmark>=4; extra == 'all-macos'
Requires-Dist: pytest-cov>=6.1.1; extra == 'all-macos'
Requires-Dist: pytest-env>=1.0.1; extra == 'all-macos'
Requires-Dist: pytest-json-report>=1.5.0; extra == 'all-macos'
Requires-Dist: pytest-timeout>=2.1; extra == 'all-macos'
Requires-Dist: pytest-xdist>=3.6; extra == 'all-macos'
Requires-Dist: pytest>=8.3.5; extra == 'all-macos'
Requires-Dist: python-dotenv>=1; extra == 'all-macos'
Requires-Dist: ruff>=0.1.5; extra == 'all-macos'
Requires-Dist: shellcheck-py>=0.10.0.1; extra == 'all-macos'
Requires-Dist: soundfile>=0.12.1; extra == 'all-macos'
Requires-Dist: tensorflow-datasets>=4.9.2; extra == 'all-macos'
Requires-Dist: twine>=4.0.2; extra == 'all-macos'
Provides-Extra: benchmark
Requires-Dist: datasets>=2.14; extra == 'benchmark'
Requires-Dist: deeplake>=4.0; extra == 'benchmark'
Requires-Dist: jax-dataloader>=0.1; extra == 'benchmark'
Requires-Dist: litdata>=0.2; extra == 'benchmark'
Requires-Dist: mosaicml-streaming>=0.7; extra == 'benchmark'
Requires-Dist: nvidia-dali-cuda120>=1.0; (sys_platform == 'linux') and extra == 'benchmark'
Requires-Dist: ray[data]>=2.9; extra == 'benchmark'
Requires-Dist: rich>=13.0; extra == 'benchmark'
Requires-Dist: spdl>=0.1; extra == 'benchmark'
Requires-Dist: tabulate>=0.9; extra == 'benchmark'
Requires-Dist: tensorflow-datasets>=4.9; extra == 'benchmark'
Requires-Dist: torch>=2.0; extra == 'benchmark'
Requires-Dist: torchvision>=0.15; extra == 'benchmark'
Requires-Dist: webdataset>=0.2; extra == 'benchmark'
Provides-Extra: cuda-dev
Requires-Dist: build>=1.0.3; extra == 'cuda-dev'
Requires-Dist: coverage>=7; extra == 'cuda-dev'
Requires-Dist: google-cloud-aiplatform>=1.38.0; extra == 'cuda-dev'
Requires-Dist: google-cloud-storage>=2.14.0; extra == 'cuda-dev'
Requires-Dist: ipykernel>=6.29.5; extra == 'cuda-dev'
Requires-Dist: jax[cuda12]>=0.6.1; (sys_platform == 'linux') and extra == 'cuda-dev'
Requires-Dist: jaxlib>=0.6.1; (sys_platform == 'linux') and extra == 'cuda-dev'
Requires-Dist: pydeps>=1.12.17; extra == 'cuda-dev'
Requires-Dist: pyright>=1.1.336; extra == 'cuda-dev'
Requires-Dist: pytest-asyncio>=0.23; extra == 'cuda-dev'
Requires-Dist: pytest-benchmark>=4; extra == 'cuda-dev'
Requires-Dist: pytest-cov>=6.1.1; extra == 'cuda-dev'
Requires-Dist: pytest-env>=1.0.1; extra == 'cuda-dev'
Requires-Dist: pytest-json-report>=1.5.0; extra == 'cuda-dev'
Requires-Dist: pytest-timeout>=2.1; extra == 'cuda-dev'
Requires-Dist: pytest-xdist>=3.6; extra == 'cuda-dev'
Requires-Dist: pytest>=8.3.5; extra == 'cuda-dev'
Requires-Dist: python-dotenv>=1; extra == 'cuda-dev'
Requires-Dist: ruff>=0.1.5; extra == 'cuda-dev'
Requires-Dist: shellcheck-py>=0.10.0.1; extra == 'cuda-dev'
Requires-Dist: twine>=4.0.2; extra == 'cuda-dev'
Provides-Extra: data
Requires-Dist: datasets>=2.14.0; extra == 'data'
Requires-Dist: grain-nightly; (sys_platform == 'linux') and extra == 'data'
Requires-Dist: librosa>=0.10.0; extra == 'data'
Requires-Dist: pillow>=10.0.0; extra == 'data'
Requires-Dist: soundfile>=0.12.1; extra == 'data'
Requires-Dist: tensorflow-datasets>=4.9.2; extra == 'data'
Provides-Extra: dev
Requires-Dist: build>=1.0.3; extra == 'dev'
Requires-Dist: coverage>=7; extra == 'dev'
Requires-Dist: google-cloud-aiplatform>=1.38.0; extra == 'dev'
Requires-Dist: google-cloud-storage>=2.14.0; extra == 'dev'
Requires-Dist: ipykernel>=6.29.5; extra == 'dev'
Requires-Dist: pydeps>=1.12.17; extra == 'dev'
Requires-Dist: pyright>=1.1.336; extra == 'dev'
Requires-Dist: pytest-asyncio>=0.23; extra == 'dev'
Requires-Dist: pytest-benchmark>=4; extra == 'dev'
Requires-Dist: pytest-cov>=6.1.1; extra == 'dev'
Requires-Dist: pytest-env>=1.0.1; extra == 'dev'
Requires-Dist: pytest-json-report>=1.5.0; extra == 'dev'
Requires-Dist: pytest-timeout>=2.1; extra == 'dev'
Requires-Dist: pytest-xdist>=3.6; extra == 'dev'
Requires-Dist: pytest>=8.3.5; extra == 'dev'
Requires-Dist: python-dotenv>=1; extra == 'dev'
Requires-Dist: ruff>=0.1.5; extra == 'dev'
Requires-Dist: shellcheck-py>=0.10.0.1; extra == 'dev'
Requires-Dist: twine>=4.0.2; extra == 'dev'
Provides-Extra: docs
Requires-Dist: griffe>=1.7.3; extra == 'docs'
Requires-Dist: hippogriffe>=0.2; extra == 'docs'
Requires-Dist: jupytext>=1.16.0; extra == 'docs'
Requires-Dist: mkdocs-gen-files>=0.5.0; extra == 'docs'
Requires-Dist: mkdocs-include-exclude-files>=0.1; extra == 'docs'
Requires-Dist: mkdocs-ipynb>=0.1; extra == 'docs'
Requires-Dist: mkdocs-jupyter>=0.24; extra == 'docs'
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
Requires-Dist: mkdocs-material>=9.6.7; extra == 'docs'
Requires-Dist: mkdocs>=1.6.1; extra == 'docs'
Requires-Dist: mkdocstrings-python>=1.1.2; extra == 'docs'
Requires-Dist: mkdocstrings>=0.28.3; extra == 'docs'
Requires-Dist: pymdown-extensions>=10.14.3; extra == 'docs'
Provides-Extra: examples
Requires-Dist: apache-beam<2.65,>=2.50; extra == 'examples'
Requires-Dist: h5py>=3.0; extra == 'examples'
Requires-Dist: keras>=3.0; extra == 'examples'
Requires-Dist: torchcrepe; extra == 'examples'
Provides-Extra: gpu
Requires-Dist: jax[cuda12]>=0.6.1; (sys_platform == 'linux') and extra == 'gpu'
Requires-Dist: jaxlib>=0.6.1; (sys_platform == 'linux') and extra == 'gpu'
Provides-Extra: metal
Requires-Dist: jax-metal>=0.1.0; (sys_platform == 'darwin') and extra == 'metal'
Provides-Extra: test
Requires-Dist: beartype>=0.14.1; extra == 'test'
Requires-Dist: coverage>=7; extra == 'test'
Requires-Dist: pytest-asyncio>=0.23; extra == 'test'
Requires-Dist: pytest-benchmark>=4; extra == 'test'
Requires-Dist: pytest-cov>=6.1.1; extra == 'test'
Requires-Dist: pytest-env>=1.0.1; extra == 'test'
Requires-Dist: pytest-timeout>=2.1; extra == 'test'
Requires-Dist: pytest-xdist>=3.6; extra == 'test'
Requires-Dist: pytest>=8.3.5; extra == 'test'
Description-Content-Type: text/markdown

# Datarax: A Data Pipeline Framework for JAX

[![CI](https://github.com/avitai/datarax/actions/workflows/ci.yml/badge.svg)](https://github.com/avitai/datarax/actions/workflows/ci.yml)
[![Test Coverage](https://github.com/avitai/datarax/actions/workflows/test-coverage.yml/badge.svg)](https://github.com/avitai/datarax/actions/workflows/test-coverage.yml)
[![codecov](https://codecov.io/gh/avitai/datarax/branch/main/graph/badge.svg)](https://codecov.io/gh/avitai/datarax)
[![Build](https://github.com/avitai/datarax/actions/workflows/build-verification.yml/badge.svg)](https://github.com/avitai/datarax/actions/workflows/build-verification.yml)
[![Summary](https://github.com/avitai/datarax/actions/workflows/summary.yml/badge.svg)](https://github.com/avitai/datarax/actions/workflows/summary.yml)

[![Project Status: Active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)

---

> **Early Development - API Unstable**
>
> Datarax is in early development and undergoing rapid iteration.
> Breaking changes are expected. Pin to specific commits if stability is required.
> We recommend waiting for a stable release (v1.0) before using Datarax in production.

---

**Datarax** (*Data + Array/JAX*) is a high-performance, extensible data pipeline framework specifically engineered for JAX-based machine learning workflows. It leverages JAX's JIT compilation, automatic differentiation, and hardware acceleration to build efficient, scalable data loading, preprocessing, and augmentation pipelines on CPUs, GPUs, and TPUs.

## Key Features

- **JAX-Native Design:** All core components built on JAX's functional paradigm with Flax NNX module system for state management
- **High Performance:** JIT-compiled pipelines via XLA, with built-in profiling and roofline analysis
- **DAG Execution Engine:** Graph-based pipeline construction with branching, parallel execution, caching, and rebatching nodes
- **Scalability:** Multi-device and multi-host data distribution with device mesh sharding
- **Determinism:** Reproducible pipelines by default using Grain's Feistel cipher shuffling (O(1) memory)
- **Extensibility:** Custom data sources, operators, and augmentation strategies via composable NNX modules
- **Benchmarking Suite:** Comparative benchmarks against 12+ frameworks (Grain, tf.data, PyTorch DataLoader, DALI, Ray Data, and more)
- **Ecosystem Integration:** Works with Flax, Optax, Orbax, HuggingFace Datasets, and TensorFlow Datasets

## Why Datarax?

Datarax's differentiable pipeline architecture enables optimization paradigms that are impossible with traditional data loaders. Here are three real-world examples:

### Learned Augmentation Policy (10,000x Faster Search)
Traditional augmentation search (AutoAugment) requires 15,000 GPU-hours of RL. With datarax's differentiable operators, [DADA-style gradient-based search](examples/advanced/differentiable/01_dada_learned_augmentation_guide.py) achieves the same accuracy in **~0.1 GPU-hours** — because gradients flow through the augmentation pipeline.

### Task-Optimized Image Processing (+30% Detection Accuracy)
Camera ISPs are tuned for human perception, not AI tasks. Datarax's DAG executor lets you [build a differentiable ISP pipeline](examples/advanced/differentiable/02_learned_isp_guide.py) where detection loss backpropagates through every processing stage, automatically optimizing for **what the model actually needs**.

### Cross-Domain Extensibility (Audio Synthesis in 3 Operators)
Datarax isn't just for images. By implementing [3 custom operators for DDSP audio synthesis](examples/advanced/differentiable/03_ddsp_audio_synthesis_guide.py), you get a complete differentiable audio pipeline — with **100x less training data** than neural audio models — proving the framework extends to any domain.

> **Learn more**: [Differentiable Pipeline Examples](docs/examples/advanced/differentiable/)

## Installation

```bash
# Basic installation
pip install datarax

# With data loading support (HuggingFace, TFDS, audio/image libs)
pip install datarax[data]

# With GPU support (CUDA 12)
pip install datarax[gpu]

# Full development installation
pip install datarax[all]
```

### macOS / Apple Silicon

```bash
# macOS CPU mode (recommended)
pip install datarax[all-cpu]
JAX_PLATFORMS=cpu python your_script.py

# Metal GPU acceleration (experimental, M1/M2/M3+)
pip install jax-metal
JAX_PLATFORMS=metal python your_script.py
```

> **Note:** Metal GPU acceleration is community-tested. CI runs on macOS with CPU only.

## Quick Start

```python
import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx

from datarax import from_source
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator, ElementOperatorConfig
from datarax.sources import MemorySource, MemorySourceConfig
from datarax.typing import Element


def normalize(element: Element, key: jax.Array | None = None) -> Element:
    return element.update_data({"image": element.data["image"] / 255.0})


def augment(element: Element, key: jax.Array) -> Element:
    key1, _ = jax.random.split(key)
    flip = jax.random.bernoulli(key1, 0.5)
    new_image = jax.lax.cond(
        flip, lambda img: jnp.flip(img, axis=1), lambda img: img,
        element.data["image"],
    )
    return element.update_data({"image": new_image})


# Create in-memory data source
data = {
    "image": np.random.randint(0, 255, (1000, 28, 28, 1)).astype(np.float32),
    "label": np.random.randint(0, 10, (1000,)).astype(np.int32),
}
source = MemorySource(MemorySourceConfig(), data=data, rngs=nnx.Rngs(0))

# Build pipeline with DAG-based API
normalizer = ElementOperator(
    ElementOperatorConfig(stochastic=False), fn=normalize, rngs=nnx.Rngs(0),
)
augmenter = ElementOperator(
    ElementOperatorConfig(stochastic=True, stream_name="augmentations"),
    fn=augment, rngs=nnx.Rngs(42),
)

pipeline = (
    from_source(source, batch_size=32)
    >> OperatorNode(normalizer)
    >> OperatorNode(augmenter)
)

# Process batches
for i, batch in enumerate(pipeline):
    if i >= 3:
        break
    print(f"Batch {i}: images {batch['image'].shape}, labels {batch['label'].shape}")
```

### Advanced: Branching and Parallel DAGs

```python
from datarax.dag.nodes import OperatorNode, Merge, Branch

# Define additional operators
def invert(element: Element, key=None) -> Element:
    return element.update_data({"image": 1.0 - element.data["image"]})

inverter = ElementOperator(
    ElementOperatorConfig(stochastic=False), fn=invert, rngs=nnx.Rngs(0),
)

def is_high_contrast(element):
    return jnp.var(element.data["image"]) > 0.1

# Build a complex DAG:
# 1. Source -> Batching
# 2. Parallel: normalizer AND inverter (| creates a Parallel node)
# 3. Merge: average the two branches
# 4. Branch: conditional path based on image variance
complex_pipeline = (
    from_source(source, batch_size=32)
    >> (OperatorNode(normalizer) | OperatorNode(inverter))
    >> Merge("mean")
    >> Branch(
           condition=is_high_contrast,
           true_path=OperatorNode(augmenter),
           false_path=OperatorNode(normalizer),
       )
)
```

## Architecture

```
src/datarax/
  core/         # Base modules: DataSourceModule, OperatorModule, Element, Batcher, Sampler, Sharder
  dag/          # DAG executor and node system (source, operator, batch, cache, control flow)
  sources/      # MemorySource, TFDS (eager/streaming), HuggingFace (eager/streaming), ArrayRecord, MixedSource
  operators/    # ElementOperator, MapOperator, CompositeOperator, modality-specific (image, text)
    strategies/ # Sequential, Parallel, Branching, Ensemble, Merging execution strategies
  samplers/     # Sequential, Shuffle (Feistel cipher), Range, EpochAware samplers
  sharding/     # ArraySharder, JaxProcessSharder for multi-device distribution
  distributed/  # DeviceMesh, DataParallel for multi-host training
  batching/     # DefaultBatcher with buffer state management
  checkpoint/   # NNXCheckpointHandler with Orbax integration
  monitoring/   # Pipeline monitor, DAG monitor, reporters
  performance/  # Roofline analysis, XLA optimization utilities
  benchmarking/ # Profiler, comparative engine, regression guard, resource monitor
  control/      # Prefetcher for asynchronous data loading
  memory/       # Shared memory manager for multi-process data sharing
  config/       # TOML-based configuration system with schema validation
  cli/          # datarax and datarax-bench CLI entry points
  utils/        # PyTree utilities, external integration helpers
```

## Benchmarking

Datarax includes a benchmarking suite for competitive comparison against 12 data loading frameworks across 25 scenarios spanning vision, NLP, tabular, multimodal, I/O, distributed, and pipeline complexity workloads.

```bash
# Install benchmark dependencies (adds PyTorch, DALI, Ray, etc.)
pip install datarax[benchmark]

# Run benchmarks locally
datarax-bench run --platform cpu --profile ci_cpu --repetitions 5

# Run on cloud (SkyPilot)
sky launch benchmarks/sky/gpu-benchmark.yaml --env WANDB_API_KEY=$WANDB_API_KEY
```

Benchmark results are exported to W&B with charts, gap analysis, stability reports, and raw result artifacts. See [Benchmarking Guide](docs/benchmarks/index.md) for methodology and cloud deployment.

## Development Setup

Datarax uses `uv` as its package manager:

```bash
# Clone and setup
git clone https://github.com/avitai/datarax.git
cd datarax
pip install uv

# Automatic setup
./setup.sh && source activate.sh

# Or manual install
uv pip install -e ".[dev]"
```

### Running Tests

```bash
# CPU-only (most stable)
JAX_PLATFORMS=cpu python -m pytest

# Specific module
JAX_PLATFORMS=cpu python -m pytest tests/sources/test_memory_source.py
```

### Docker

```bash
# Build and run
docker build -t datarax:latest .
docker run --rm --gpus all datarax:latest python -c "import datarax, jax; print(jax.devices())"

# Benchmark images
docker build -f benchmarks/docker/Dockerfile.gpu -t datarax-bench:gpu .
```

See [Docker Guide](docs/contributing/docker.md) for full details.

## Documentation

- [Installation Guide](docs/getting_started/installation.md)
- [Quick Start](docs/getting_started/quick_start.md)
- [Core Concepts](docs/getting_started/core_concepts.md)
- [User Guide](docs/user_guide/)
- [API Reference](docs/api_reference/index.md)
- [Examples](docs/examples/overview.md)
- [Benchmarking](docs/benchmarks/index.md)
- [Contributing](docs/contributing/contributing_guide.md)
- [Docker](docs/contributing/docker.md)

## License

Datarax is licensed under the [MIT License](LICENSE).
