Metadata-Version: 2.4
Name: datarax
Version: 0.1.0.post1
Summary: Datarax: A high-performance, NNX-based data pipeline framework for JAX
Project-URL: Bug Tracker, https://github.com/avitai/datarax/issues
Project-URL: Documentation, https://datarax.readthedocs.io
Project-URL: Source, https://github.com/avitai/datarax
Author: Mahdi Shafiei
License: MIT License
        
        Copyright (c) 2025 Mahdi Shafiei
        
        Permission is hereby granted, free of charge, to any person obtaining a copy
        of this software and associated documentation files (the "Software"), to deal
        in the Software without restriction, including without limitation the rights
        to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
        copies of the Software, and to permit persons to whom the Software is
        furnished to do so, subject to the following conditions:
        
        The above copyright notice and this permission notice shall be included in all
        copies or substantial portions of the Software.
        
        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        SOFTWARE.
License-File: LICENSE
Keywords: data-loading,flax,jax,machine-learning,nnx,workshop
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Education
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: MacOS
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Operating System :: POSIX :: Linux
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Topic :: Software Development
Classifier: Topic :: Software Development :: Libraries
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.11
Requires-Dist: absl-py>=1
Requires-Dist: antlr4-python3-runtime==4.9.3
Requires-Dist: array-record>=0.4
Requires-Dist: astunparse>=1.6
Requires-Dist: beartype>=0.14.1
Requires-Dist: chex>=0.1.7
Requires-Dist: cloudpickle>=2
Requires-Dist: decorator>=5
Requires-Dist: dm-tree>=0.1.8
Requires-Dist: flax>=0.12.0
Requires-Dist: grain>=0.1
Requires-Dist: jax>=0.6.1
Requires-Dist: jaxtyping>=0.2.20
Requires-Dist: matplotlib
Requires-Dist: optax>=0.1.4
Requires-Dist: orbax-checkpoint>=0.11.10
Requires-Dist: pre-commit>=4.3.0
Requires-Dist: protobuf<5,>=4.21.12
Requires-Dist: psutil>=5.9.0
Requires-Dist: pyarrow>=12
Requires-Dist: ruff>=0.1.5
Requires-Dist: scipy>=1.10
Requires-Dist: tensorflow-datasets>=4.9.2
Requires-Dist: tensorflow-metadata>=1.15.0
Requires-Dist: tensorflow>=2.16.0
Requires-Dist: tomli-w>=1
Requires-Dist: tomli>=2; python_version < '3.11'
Provides-Extra: all
Requires-Dist: beartype>=0.14.1; extra == 'all'
Requires-Dist: build>=1.0.3; extra == 'all'
Requires-Dist: coverage>=7; extra == 'all'
Requires-Dist: datasets>=2.14.0; extra == 'all'
Requires-Dist: google-cloud-aiplatform>=1.38.0; extra == 'all'
Requires-Dist: google-cloud-storage>=2.14.0; extra == 'all'
Requires-Dist: grain-nightly; extra == 'all'
Requires-Dist: griffe>=1.7.3; extra == 'all'
Requires-Dist: hippogriffe>=0.2; extra == 'all'
Requires-Dist: ipykernel>=6.29.5; extra == 'all'
Requires-Dist: jax[cuda12]>=0.6.1; extra == 'all'
Requires-Dist: jaxlib>=0.6.1; extra == 'all'
Requires-Dist: jupytext>=1.16.0; extra == 'all'
Requires-Dist: librosa>=0.10.0; extra == 'all'
Requires-Dist: mkdocs-gen-files>=0.5.0; extra == 'all'
Requires-Dist: mkdocs-include-exclude-files>=0.1; extra == 'all'
Requires-Dist: mkdocs-ipynb>=0.1; extra == 'all'
Requires-Dist: mkdocs-jupyter>=0.24; extra == 'all'
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'all'
Requires-Dist: mkdocs-material>=9.6.7; extra == 'all'
Requires-Dist: mkdocs>=1.6.1; extra == 'all'
Requires-Dist: mkdocstrings-python>=1.1.2; extra == 'all'
Requires-Dist: mkdocstrings>=0.28.3; extra == 'all'
Requires-Dist: pillow>=10.0.0; extra == 'all'
Requires-Dist: pydeps>=1.12.17; extra == 'all'
Requires-Dist: pymdown-extensions>=10.14.3; extra == 'all'
Requires-Dist: pyright>=1.1.336; extra == 'all'
Requires-Dist: pytest-asyncio>=0.23; extra == 'all'
Requires-Dist: pytest-benchmark>=4; extra == 'all'
Requires-Dist: pytest-cov>=6.1.1; extra == 'all'
Requires-Dist: pytest-env>=1.0.1; extra == 'all'
Requires-Dist: pytest-json-report>=1.5.0; extra == 'all'
Requires-Dist: pytest-timeout>=2.1; extra == 'all'
Requires-Dist: pytest-xdist>=3.6; extra == 'all'
Requires-Dist: pytest>=8.3.5; extra == 'all'
Requires-Dist: python-dotenv>=1; extra == 'all'
Requires-Dist: ruff>=0.1.5; extra == 'all'
Requires-Dist: shellcheck-py>=0.10.0.1; extra == 'all'
Requires-Dist: soundfile>=0.12.1; extra == 'all'
Requires-Dist: tensorflow-datasets>=4.9.2; extra == 'all'
Requires-Dist: twine>=4.0.2; extra == 'all'
Provides-Extra: all-cpu
Requires-Dist: beartype>=0.14.1; extra == 'all-cpu'
Requires-Dist: build>=1.0.3; extra == 'all-cpu'
Requires-Dist: coverage>=7; extra == 'all-cpu'
Requires-Dist: datasets>=2.14.0; extra == 'all-cpu'
Requires-Dist: google-cloud-aiplatform>=1.38.0; extra == 'all-cpu'
Requires-Dist: google-cloud-storage>=2.14.0; extra == 'all-cpu'
Requires-Dist: grain-nightly; extra == 'all-cpu'
Requires-Dist: griffe>=1.7.3; extra == 'all-cpu'
Requires-Dist: hippogriffe>=0.2; extra == 'all-cpu'
Requires-Dist: ipykernel>=6.29.5; extra == 'all-cpu'
Requires-Dist: jupytext>=1.16.0; extra == 'all-cpu'
Requires-Dist: librosa>=0.10.0; extra == 'all-cpu'
Requires-Dist: mkdocs-gen-files>=0.5.0; extra == 'all-cpu'
Requires-Dist: mkdocs-include-exclude-files>=0.1; extra == 'all-cpu'
Requires-Dist: mkdocs-ipynb>=0.1; extra == 'all-cpu'
Requires-Dist: mkdocs-jupyter>=0.24; extra == 'all-cpu'
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'all-cpu'
Requires-Dist: mkdocs-material>=9.6.7; extra == 'all-cpu'
Requires-Dist: mkdocs>=1.6.1; extra == 'all-cpu'
Requires-Dist: mkdocstrings-python>=1.1.2; extra == 'all-cpu'
Requires-Dist: mkdocstrings>=0.28.3; extra == 'all-cpu'
Requires-Dist: pillow>=10.0.0; extra == 'all-cpu'
Requires-Dist: pydeps>=1.12.17; extra == 'all-cpu'
Requires-Dist: pymdown-extensions>=10.14.3; extra == 'all-cpu'
Requires-Dist: pyright>=1.1.336; extra == 'all-cpu'
Requires-Dist: pytest-asyncio>=0.23; extra == 'all-cpu'
Requires-Dist: pytest-benchmark>=4; extra == 'all-cpu'
Requires-Dist: pytest-cov>=6.1.1; extra == 'all-cpu'
Requires-Dist: pytest-env>=1.0.1; extra == 'all-cpu'
Requires-Dist: pytest-json-report>=1.5.0; extra == 'all-cpu'
Requires-Dist: pytest-timeout>=2.1; extra == 'all-cpu'
Requires-Dist: pytest-xdist>=3.6; extra == 'all-cpu'
Requires-Dist: pytest>=8.3.5; extra == 'all-cpu'
Requires-Dist: python-dotenv>=1; extra == 'all-cpu'
Requires-Dist: ruff>=0.1.5; extra == 'all-cpu'
Requires-Dist: shellcheck-py>=0.10.0.1; extra == 'all-cpu'
Requires-Dist: soundfile>=0.12.1; extra == 'all-cpu'
Requires-Dist: tensorflow-datasets>=4.9.2; extra == 'all-cpu'
Requires-Dist: twine>=4.0.2; extra == 'all-cpu'
Provides-Extra: all-macos
Requires-Dist: beartype>=0.14.1; extra == 'all-macos'
Requires-Dist: build>=1.0.3; extra == 'all-macos'
Requires-Dist: coverage>=7; extra == 'all-macos'
Requires-Dist: datasets>=2.14.0; extra == 'all-macos'
Requires-Dist: google-cloud-aiplatform>=1.38.0; extra == 'all-macos'
Requires-Dist: google-cloud-storage>=2.14.0; extra == 'all-macos'
Requires-Dist: grain-nightly; extra == 'all-macos'
Requires-Dist: griffe>=1.7.3; extra == 'all-macos'
Requires-Dist: hippogriffe>=0.2; extra == 'all-macos'
Requires-Dist: ipykernel>=6.29.5; extra == 'all-macos'
Requires-Dist: jax-metal>=0.1.0; extra == 'all-macos'
Requires-Dist: jupytext>=1.16.0; extra == 'all-macos'
Requires-Dist: librosa>=0.10.0; extra == 'all-macos'
Requires-Dist: mkdocs-gen-files>=0.5.0; extra == 'all-macos'
Requires-Dist: mkdocs-include-exclude-files>=0.1; extra == 'all-macos'
Requires-Dist: mkdocs-ipynb>=0.1; extra == 'all-macos'
Requires-Dist: mkdocs-jupyter>=0.24; extra == 'all-macos'
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'all-macos'
Requires-Dist: mkdocs-material>=9.6.7; extra == 'all-macos'
Requires-Dist: mkdocs>=1.6.1; extra == 'all-macos'
Requires-Dist: mkdocstrings-python>=1.1.2; extra == 'all-macos'
Requires-Dist: mkdocstrings>=0.28.3; extra == 'all-macos'
Requires-Dist: pillow>=10.0.0; extra == 'all-macos'
Requires-Dist: pydeps>=1.12.17; extra == 'all-macos'
Requires-Dist: pymdown-extensions>=10.14.3; extra == 'all-macos'
Requires-Dist: pyright>=1.1.336; extra == 'all-macos'
Requires-Dist: pytest-asyncio>=0.23; extra == 'all-macos'
Requires-Dist: pytest-benchmark>=4; extra == 'all-macos'
Requires-Dist: pytest-cov>=6.1.1; extra == 'all-macos'
Requires-Dist: pytest-env>=1.0.1; extra == 'all-macos'
Requires-Dist: pytest-json-report>=1.5.0; extra == 'all-macos'
Requires-Dist: pytest-timeout>=2.1; extra == 'all-macos'
Requires-Dist: pytest-xdist>=3.6; extra == 'all-macos'
Requires-Dist: pytest>=8.3.5; extra == 'all-macos'
Requires-Dist: python-dotenv>=1; extra == 'all-macos'
Requires-Dist: ruff>=0.1.5; extra == 'all-macos'
Requires-Dist: shellcheck-py>=0.10.0.1; extra == 'all-macos'
Requires-Dist: soundfile>=0.12.1; extra == 'all-macos'
Requires-Dist: tensorflow-datasets>=4.9.2; extra == 'all-macos'
Requires-Dist: twine>=4.0.2; extra == 'all-macos'
Provides-Extra: cuda-dev
Requires-Dist: build>=1.0.3; extra == 'cuda-dev'
Requires-Dist: coverage>=7; extra == 'cuda-dev'
Requires-Dist: google-cloud-aiplatform>=1.38.0; extra == 'cuda-dev'
Requires-Dist: google-cloud-storage>=2.14.0; extra == 'cuda-dev'
Requires-Dist: ipykernel>=6.29.5; extra == 'cuda-dev'
Requires-Dist: jax[cuda12]>=0.6.1; extra == 'cuda-dev'
Requires-Dist: jaxlib>=0.6.1; extra == 'cuda-dev'
Requires-Dist: pydeps>=1.12.17; extra == 'cuda-dev'
Requires-Dist: pyright>=1.1.336; extra == 'cuda-dev'
Requires-Dist: pytest-asyncio>=0.23; extra == 'cuda-dev'
Requires-Dist: pytest-benchmark>=4; extra == 'cuda-dev'
Requires-Dist: pytest-cov>=6.1.1; extra == 'cuda-dev'
Requires-Dist: pytest-env>=1.0.1; extra == 'cuda-dev'
Requires-Dist: pytest-json-report>=1.5.0; extra == 'cuda-dev'
Requires-Dist: pytest-timeout>=2.1; extra == 'cuda-dev'
Requires-Dist: pytest-xdist>=3.6; extra == 'cuda-dev'
Requires-Dist: pytest>=8.3.5; extra == 'cuda-dev'
Requires-Dist: python-dotenv>=1; extra == 'cuda-dev'
Requires-Dist: ruff>=0.1.5; extra == 'cuda-dev'
Requires-Dist: shellcheck-py>=0.10.0.1; extra == 'cuda-dev'
Requires-Dist: twine>=4.0.2; extra == 'cuda-dev'
Provides-Extra: data
Requires-Dist: datasets>=2.14.0; extra == 'data'
Requires-Dist: grain-nightly; extra == 'data'
Requires-Dist: librosa>=0.10.0; extra == 'data'
Requires-Dist: pillow>=10.0.0; extra == 'data'
Requires-Dist: soundfile>=0.12.1; extra == 'data'
Requires-Dist: tensorflow-datasets>=4.9.2; extra == 'data'
Provides-Extra: dev
Requires-Dist: build>=1.0.3; extra == 'dev'
Requires-Dist: coverage>=7; extra == 'dev'
Requires-Dist: google-cloud-aiplatform>=1.38.0; extra == 'dev'
Requires-Dist: google-cloud-storage>=2.14.0; extra == 'dev'
Requires-Dist: ipykernel>=6.29.5; extra == 'dev'
Requires-Dist: pydeps>=1.12.17; extra == 'dev'
Requires-Dist: pyright>=1.1.336; extra == 'dev'
Requires-Dist: pytest-asyncio>=0.23; extra == 'dev'
Requires-Dist: pytest-benchmark>=4; extra == 'dev'
Requires-Dist: pytest-cov>=6.1.1; extra == 'dev'
Requires-Dist: pytest-env>=1.0.1; extra == 'dev'
Requires-Dist: pytest-json-report>=1.5.0; extra == 'dev'
Requires-Dist: pytest-timeout>=2.1; extra == 'dev'
Requires-Dist: pytest-xdist>=3.6; extra == 'dev'
Requires-Dist: pytest>=8.3.5; extra == 'dev'
Requires-Dist: python-dotenv>=1; extra == 'dev'
Requires-Dist: ruff>=0.1.5; extra == 'dev'
Requires-Dist: shellcheck-py>=0.10.0.1; extra == 'dev'
Requires-Dist: twine>=4.0.2; extra == 'dev'
Provides-Extra: docs
Requires-Dist: griffe>=1.7.3; extra == 'docs'
Requires-Dist: hippogriffe>=0.2; extra == 'docs'
Requires-Dist: jupytext>=1.16.0; extra == 'docs'
Requires-Dist: mkdocs-gen-files>=0.5.0; extra == 'docs'
Requires-Dist: mkdocs-include-exclude-files>=0.1; extra == 'docs'
Requires-Dist: mkdocs-ipynb>=0.1; extra == 'docs'
Requires-Dist: mkdocs-jupyter>=0.24; extra == 'docs'
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
Requires-Dist: mkdocs-material>=9.6.7; extra == 'docs'
Requires-Dist: mkdocs>=1.6.1; extra == 'docs'
Requires-Dist: mkdocstrings-python>=1.1.2; extra == 'docs'
Requires-Dist: mkdocstrings>=0.28.3; extra == 'docs'
Requires-Dist: pymdown-extensions>=10.14.3; extra == 'docs'
Provides-Extra: gpu
Requires-Dist: jax[cuda12]>=0.6.1; extra == 'gpu'
Requires-Dist: jaxlib>=0.6.1; extra == 'gpu'
Provides-Extra: metal
Requires-Dist: jax-metal>=0.1.0; extra == 'metal'
Provides-Extra: test
Requires-Dist: beartype>=0.14.1; extra == 'test'
Requires-Dist: coverage>=7; extra == 'test'
Requires-Dist: pytest-asyncio>=0.23; extra == 'test'
Requires-Dist: pytest-benchmark>=4; extra == 'test'
Requires-Dist: pytest-cov>=6.1.1; extra == 'test'
Requires-Dist: pytest-env>=1.0.1; extra == 'test'
Requires-Dist: pytest-timeout>=2.1; extra == 'test'
Requires-Dist: pytest-xdist>=3.6; extra == 'test'
Requires-Dist: pytest>=8.3.5; extra == 'test'
Description-Content-Type: text/markdown

# Datarax: A Data Pipeline Framework for JAX

[![CI](https://github.com/avitai/datarax/actions/workflows/ci.yml/badge.svg)](https://github.com/avitai/datarax/actions/workflows/ci.yml)
[![Test Coverage](https://github.com/avitai/datarax/actions/workflows/test-coverage.yml/badge.svg)](https://github.com/avitai/datarax/actions/workflows/test-coverage.yml)
[![codecov](https://codecov.io/gh/avitai/datarax/branch/main/graph/badge.svg)](https://codecov.io/gh/avitai/datarax)
[![Build](https://github.com/avitai/datarax/actions/workflows/build-verification.yml/badge.svg)](https://github.com/avitai/datarax/actions/workflows/build-verification.yml)
[![Summary](https://github.com/avitai/datarax/actions/workflows/summary.yml/badge.svg)](https://github.com/avitai/datarax/actions/workflows/summary.yml)

[![Project Status: Active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)

> **Note:** This project is in early development. API may change. Expect breaking changes.

**Datarax** (*Data + Array/JAX*) is a high-performance, extensible data pipeline framework specifically engineered for JAX-based machine learning workflows. It simplifies and accelerates the development of efficient and scalable data loading, preprocessing, and augmentation pipelines for JAX, leveraging the full potential of JAX's Just-In-Time (JIT) compilation, automatic differentiation, and hardware acceleration capabilities.

## Key Features

- **High Performance:** Leverages JAX's JIT compilation and XLA backend to achieve near-optimal data processing speeds on CPUs, GPUs, and TPUs.
- **JAX-Native Design:** All core components and operations are designed with JAX's functional programming paradigm and immutable data structures (PyTrees) in mind.
- **Scalability:** Supports efficient data loading and processing for large datasets and distributed training scenarios, including multi-host and multi-device setups.
- **Extensibility:** Easily define and integrate custom data sources, transformations, and augmentation operations.
- **Usability:** Provides a clear, intuitive Python API and a flexible configuration system (TOML-based) for defining and managing pipelines.
- **Determinism:** Pipeline runs are deterministic by default, crucial for reproducibility in research and production.
- **Caching Optimization:** Multiple caching strategies for performance improvement, including function caching, transformer caching, and checkpointing.
- **Complete Feature Set:** Supports common data pipeline operations including diverse data source handling, advanced transformations, data augmentation, batching, sharding, checkpointing, and caching.
- **Ecosystem Integration:** Facilitates smooth integration with other JAX libraries like Flax, Optax, and Orbax.

## Installation

Install Datarax using pip:

```bash
# Basic installation
pip install datarax

# With data loading support (HuggingFace, TFDS, Audio/Image libs)
pip install datarax[data]

# With GPU support (CUDA 12)
pip install datarax[gpu]

# Full installation with all optional dependencies
pip install datarax[all]
```

### macOS / Apple Silicon

Datarax supports macOS on both Intel and Apple Silicon (M1/M2/M3) processors.

```bash
# Install for macOS (CPU mode - recommended)
pip install datarax[all-cpu]

# Run with explicit CPU backend
JAX_PLATFORMS=cpu python your_script.py
```

**Metal GPU Acceleration (Experimental):** JAX supports Apple's Metal backend for GPU acceleration on M1+ chips:

```bash
pip install jax-metal
JAX_PLATFORMS=metal python your_script.py
```

> **Note:** Metal GPU acceleration is **community-tested**. CI runs on macOS with CPU only. If you can help validate Metal support, please open an issue with your results.

**Known Limitations:**
- **TensorFlow Datasets**: May hang on import on macOS ARM64. Datarax uses lazy imports to handle this.
- **JAX Profiler**: Tracing is automatically disabled on macOS due to TensorBoard compatibility.

## Development Setup

For contributors and developers, Datarax uses `uv` as its package manager:

```bash
# Install uv
pip install uv

# Clone and setup
git clone https://github.com/avitai/datarax.git
cd datarax

# Run automatic setup (creates venv & installs dependencies)
./setup.sh

# Activate the environment
source activate.sh

# Or install manually with uv
uv pip install -e ".[dev]"
```

See the [Development Environment Guide](docs/dev_guide.md) for detailed instructions on:

- Creating a virtual environment with `uv venv`
- Installing dependencies through `pyproject.toml`
- Running tests through pytest
- Building and packaging Datarax

## Current Status

Datarax is currently in active development with a **Flax NNX-based architecture** that provides robust state management and checkpointing:

### Core Architecture

- **NNX Module System**: All components built on `flax.nnx.Module` for robust state management
- **Integrated Checkpointing**: Seamless Orbax integration for stateful pipeline persistence
- **Type Safety**: Complete type annotations and runtime validation
- **Composability**: Modular design enabling flexible pipeline construction

### Implemented Components

- **Core NNX Modules**: DataraxModule, OperatorModule, StructuralModule
- **Data Sources**: MemorySource, TFDSSource, HFSource (inheriting from StructuralModule/DataSourceModule)
- **Operators**: Element-wise operators, MapOperator (inheriting from OperatorModule)
- **DAG Execution Engine**: `DAGExecutor` and `pipeline` API for constructing flexible, graph-based data processing flows.
- **Node System**: A rich set of nodes for building pipelines:
  - `DataSourceNode`: entry point for data
  - `OperatorNode`: for transformations
  - `BatchNode` & `RebatchNode`: for batching control
  - `ShuffleNode`, `CacheNode`: for data management
  - Control flow nodes: `Sequential`, `Parallel`, `Branch`, `Merge`
- **Stateful Components**:
  - MemorySourceModule/ArrayRecordSourceModule/HFSourceModule for diverse data ingestion
  - Range/ShuffleSamplerModule with reproducible random state
  - DefaultBatcherModule with buffer state management
  - ArraySharderModule for device-aware data distribution
- **External Integrations**:
  - HuggingFace Datasets with stateful iteration
  - TensorFlow Datasets with checkpoint support
- **Advanced Features**:
  - Complete caching strategies with state preservation
  - Differentiable rebatching (`DifferentiableRebatchImpl`)
  - NNXCheckpointHandler for production-grade checkpointing

Upcoming features include:

- Image transformation library with JAX-native operations
- Advanced sharding strategies for multi-device and multi-host scenarios
- Performance optimization suite with benchmarking tools
- Extended monitoring and metrics capabilities
- Additional external data source integrations

## Quick Start

Here's a simple example of using Datarax's DAG-based architecture to create a data pipeline for image classification:

```python
import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx

from datarax import from_source
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator, ElementOperatorConfig
from datarax.sources import MemorySource, MemorySourceConfig
from datarax.typing import Element


def normalize(element: Element, key: jax.Array | None = None) -> Element:
    """Normalize the image in the element."""
    # element.data is a dict-like PyTree containing the actual data arrays
    # Return a new Element with updated data (immutable update)
    # update_data provides a clean API for partial updates
    return element.update_data({"image": element.data["image"] / 255.0})


def apply_augmentation(element: Element, key: jax.Array) -> Element:
    """Apply a simple augmentation (flipping) to the image."""
    key1, _ = jax.random.split(key)
    flip_horizontal = jax.random.bernoulli(key1, 0.5)

    # Use jax.lax.cond for JAX-compatible conditional execution
    def flip_image(img):
        return jnp.flip(img, axis=1)

    def no_flip(img):
        return img

    new_image = jax.lax.cond(
        flip_horizontal,
        flip_image,
        no_flip,
        element.data["image"]
    )
    return element.update_data({"image": new_image})


# Create some dummy data (28x28 images)
num_samples = 1000
image_shape = (28, 28, 1)
data = {
    "image": np.random.randint(0, 255, (num_samples, *image_shape)).astype(np.float32),
    "label": np.random.randint(0, 10, (num_samples,)).astype(np.int32),
}

# Create the data source using config-based API
# MemorySource is a standard DataSource implementation for in-memory data
source_config = MemorySourceConfig()
source = MemorySource(source_config, data=data, rngs=nnx.Rngs(0))

# Create operators using the unified ElementOperator API
# Normalizer is deterministic (no random key needed)
normalizer_config = ElementOperatorConfig(stochastic=False)
normalizer = ElementOperator(normalizer_config, fn=normalize, rngs=nnx.Rngs(0))

# Augmenter is stochastic: requires stream_name for proper RNG management
augmenter_config = ElementOperatorConfig(stochastic=True, stream_name="augmentations")
augmenter = ElementOperator(augmenter_config, fn=apply_augmentation, rngs=nnx.Rngs(42))

# Create the data pipeline using the DAG-based API
# from_source() initializes the pipeline with:
# 1. The DataSourceNode (data loading)
# 2. A BatchNode (automatic batching)
# The >> operator chains transformation operators
pipeline = (
    from_source(source, batch_size=32)
    >> OperatorNode(normalizer)
    >> OperatorNode(augmenter)
)

# Alternative: Method Chaining
# You can also build the pipeline using the fluent .add() method:
# pipeline = (
#     from_source(source, batch_size=32)
#     .add(OperatorNode(normalizer))
#     .add(OperatorNode(augmenter))
# )

# Create an iterator and process batches
# The pipeline handles data streaming, batching, state management, and execution
for i, batch in enumerate(pipeline):
    if i >= 3:
        break

    # Get the shape and stats for each component in the batch
    # batch['key'] provides direct access to data arrays
    image_batch = batch["image"]
    label_batch = batch["label"]

    print(f"Batch {i}:")
    print(f"  Image shape: {image_batch.shape}")
    print(f"  Label batch size: {label_batch.shape[0]}")
    print(f"  Image min/max: {image_batch.min():.3f}/{image_batch.max():.3f}")

print("Pipeline processing completed!")
```

### Advanced Pipeline

For more complex workflows, Datarax supports branching and parallel execution:

```python
from datarax.dag.nodes import Parallel, Merge, Branch

# Define additional operators
def invert(element, key=None):
    return element.update_data({"image": 1.0 - element.data["image"]})

def is_high_contrast(element):
    # Condition: check if image variance is high
    return jnp.var(element.data["image"]) > 0.1

# Build a complex DAG:
# 1. Source -> Batching
# 2. Parallel: Normal version AND Inverted version
# 3. Merge: Average them (Simple Ensemble)
# 4. Branch: Apply extra noise ONLY if high contrast, otherwise normalize again
complex_pipeline = (
    from_source(source, batch_size=32)
    >> (OperatorNode(normalizer) | OperatorNode(invert))
    >> Merge("mean")
    >> Branch(
           condition=is_high_contrast,
           true_path=OperatorNode(augmenter),
           false_path=OperatorNode(normalizer)
       )
)
```

## Documentation

For complete documentation, please visit [datarax.readthedocs.io](https://datarax.readthedocs.io).

- [Installation Guide](https://datarax.readthedocs.io/en/latest/installation/)
- [User Guide](https://datarax.readthedocs.io/en/latest/user_guide/)
- [API Reference](https://datarax.readthedocs.io/en/latest/api_reference/)
- [Examples](https://datarax.readthedocs.io/en/latest/examples/)
- [Contributing](https://datarax.readthedocs.io/en/latest/contributing/)


## Testing

Datarax uses a complete test suite with support for both CPU and GPU testing:

- Tests are organized to mirror the `src/datarax` package structure for easier navigation
- All GitHub CI workflows run tests exclusively on CPU
- Local test runs automatically use GPU when available
- The testing infrastructure handles environment configuration to ensure consistency

To run tests:

```bash
# Run all tests using CPU only (most stable)
JAX_PLATFORMS=cpu python -m pytest

# Run specific test module
JAX_PLATFORMS=cpu python -m pytest tests/sources/test_memory_source.py
```

For more information on the test organization and how to run tests, see the [Testing Guide](tests/README.md).

## License

Datarax is licensed under the [MIT License](LICENSE).
