Metadata-Version: 2.4
Name: mamba3-jax
Version: 1.0.1
Summary: Pure JAX/Flax NNX implementation of the Mamba3 state-space model with state caching, weight loading, and time-series forecasting.
Author: Cosmo Santoni
License: Apache-2.0
Project-URL: Homepage, https://github.com/jax-state-spaces/mamba3-jax
Project-URL: Repository, https://github.com/jax-state-spaces/mamba3-jax
Project-URL: Issues, https://github.com/jax-state-spaces/mamba3-jax/issues
Keywords: mamba3,jax,flax,nnx,state-space-model,ssm,language-modeling,time-series,sequence-modeling,caching
Classifier: Development Status :: 5 - Production/Stable
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Developers
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.11
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: flax>=0.12.0
Requires-Dist: jax>=0.8.0
Requires-Dist: jaxlib>=0.8.0
Requires-Dist: optax>=0.2.6
Provides-Extra: dev
Requires-Dist: pytest>=9.0.1; extra == "dev"
Requires-Dist: absl-py>=2.0.0; extra == "dev"
Requires-Dist: numpy>=1.26.0; extra == "dev"
Provides-Extra: pretrained
Requires-Dist: safetensors>=0.4.0; extra == "pretrained"
Requires-Dist: huggingface_hub>=0.20.0; extra == "pretrained"
Requires-Dist: transformers>=4.40.0; extra == "pretrained"
Provides-Extra: reference
Requires-Dist: torch>=2.4.0; extra == "reference"
Requires-Dist: mamba-ssm>=2.3.2; extra == "reference"
Dynamic: license-file

# Mamba3-JAX

[![PyPI](https://img.shields.io/pypi/v/mamba3-jax)](https://pypi.org/project/mamba3-jax/)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE)

A pure JAX/Flax NNX implementation of the Mamba3 state-space model with SSM state caching, weight loading from the canonical state-spaces layout, causal language modeling, and time-series forecasting.

This is the standalone PyPI package for the Mamba3 implementation authored by [Cosmo Santoni](https://cosmosantoni.com), following the same pure-JAX, kernel-free approach as its sibling [mamba2-jax](https://github.com/jax-state-spaces/mamba2-jax).

*Independent community reimplementation — not affiliated with, endorsed by, or connected to the original Mamba authors or [state-spaces/mamba](https://github.com/state-spaces/mamba).*

## Supported Models

`Mamba3ForCausalLM` and `Mamba3Forecaster` across every canonical configuration: SISO and MIMO (`is_mimo`, `mimo_rank`), `ngroups`, `is_outproj_norm`, `rope_fraction ∈ {0.5, 1.0}`, and packed variable-length sequences (`cu_seqlens`). No pretrained Mamba-3 checkpoints are public; the architecture, config, and weight loader match the canonical `state-spaces/mamba` (Mamba-3) layout.

## Features

- **Flax NNX** modules (no legacy `init`/`apply` ceremony)
- **SSM + rotary + trapezoidal state caching** for O(n) autoregressive generation (no convolution state)
- **MIMO** rank-`R` recurrence, plus `ngroups`, `is_outproj_norm`, `rope_fraction` — every canonical config flag
- **Variable-length / packed sequences** (`cu_seqlens`) for padding-free batched training
- **Causal language modeling** (`Mamba3ForCausalLM`) with tied or untied embeddings
- **Time-series forecasting** (`Mamba3Forecaster`)
- **Weight loading** from the canonical state-spaces/mamba (Mamba-3) layout
- **Kernel-verified parity** against the reference Triton (SISO) and TileLang (MIMO) implementations
- Fully compatible with `jax.jit`, `jax.grad`, `jax.vmap`
- Runs on **CPU, GPU (CUDA), and TPU**

## State Space Caching

The SSM state cache enables O(n) autoregressive generation instead of O(n²) re-computation. Mamba-3 carries a fixed-size per-layer state — the cumulative rotary angle, the SSM state, and the previous `B`/`x` required by the trapezoidal recurrence (no convolution state) — so each decode step is O(1) in sequence length. Cached and non-cached generation produce identical tokens.

## Installation

### From PyPI

```bash
pip install mamba3-jax
```

### From source

```bash
git clone https://github.com/jax-state-spaces/mamba3-jax.git
cd mamba3-jax
pip install -e ".[dev]"
```

For GPU or TPU support, install the appropriate JAX backend as described in the [JAX installation guide](https://jax.readthedocs.io/en/latest/installation.html).

## Usage

### Causal Language Model

```python
import jax.numpy as jnp
from flax import nnx
from mamba3_jax import Mamba3Config, Mamba3ForCausalLM

cfg = Mamba3Config(vocab_size=1024, hidden_size=256, num_hidden_layers=4,
                   state_size=64, head_dim=32, chunk_size=64)
model = Mamba3ForCausalLM(cfg, rngs=nnx.Rngs(0))

input_ids = jnp.ones((2, 64), dtype=jnp.int32)
outputs = model(input_ids, labels=input_ids)

print(outputs["logits"].shape)  # (2, 64, 1024)
print(float(outputs["loss"]))
```

### Cached generation

```python
import jax.numpy as jnp
from flax import nnx
from mamba3_jax import Mamba3ForCausalLM, Mamba3Config

cfg = Mamba3Config.tiny()
model = Mamba3ForCausalLM(cfg, rngs=nnx.Rngs(0))

# Prefill
prompt = jnp.array([[1, 2, 3]], dtype=jnp.int32)
out = model(prompt)
cache = out["cache"]

# Decode with cache (O(1) per step)
next_token = jnp.array([[4]], dtype=jnp.int32)
out = model(next_token, cache=cache)
print(out["logits"].shape)  # (1, 1, vocab_size)
```

### Time-series forecasting

```python
import jax.numpy as jnp
from mamba3_jax import Mamba3Forecaster, create_random_forecaster

model = create_random_forecaster(input_dim=10, d_model=256, n_layers=4,
                                 output_dim=1, forecast_horizon=24)

x = jnp.ones((8, 100, 10))  # (batch, seq_len, features)
y = model(x)
print(y.shape)  # (8, 24, 1)
```

## Numerical Parity

Verified directly against the canonical `state-spaces/mamba` kernels — Triton for SISO, TileLang for MIMO — with identical random weights, across `ngroups ∈ {1, 2}`, `is_outproj_norm ∈ {False, True}`, and `rope_fraction ∈ {0.5, 1.0}`, matching to ~1–2% mean relative (the Triton SISO kernel casts its SSD core to bf16; MIMO runs fp32 throughout). The varlen path is verified against the kernel and via `varlen(packed) == concat(per-sequence)`. Reproduce on any CUDA GPU with [`scripts/verify_mimo_kernel.py`](scripts/verify_mimo_kernel.py).

## Project Structure

```
mamba3-jax/
├── mamba3_jax/
│   ├── __init__.py          # Public API
│   ├── modeling.py          # Config, SSD kernels, all model classes
│   └── params.py            # Weight loading & parameter utilities
├── tests/
│   ├── test_mamba3.py       # Comprehensive test suite
│   ├── run_model.py         # Generation demo script
│   └── artifacts/           # Golden parity data & generator
├── scripts/
│   └── verify_mimo_kernel.py    # GPU parity check against the canonical kernels
├── LICENSE                  # Apache 2.0
├── pyproject.toml
└── README.md
```

## Contributing

Contributions are welcome! Areas where help is particularly valuable:

- Performance optimization and profiling
- Test coverage expansion
- Bug reports and feature requests

Please open an issue or submit a pull request on GitHub.

## Acknowledgments

**Original Mamba Authors:**
- Aakash Lahoti, Kevin Y. Li, Berlin Chen, Caitlin Wang, Aviv Bick, J. Zico Kolter, Tri Dao, and Albert Gu for the Mamba3 architecture and the original `mamba_ssm` implementation
- The State Spaces team for advancing SSM research

**JAX Ecosystem:**
- The JAX, Flax, and Optax teams at Google for the excellent frameworks

## License

This project is licensed under the [Apache License 2.0](LICENSE).

## Citation

If you use this implementation in your research, please cite the original Mamba3 paper and this JAX implementation:

```bibtex
@misc{lahoti2026mamba3,
  title={Mamba-3: Improved Sequence Modeling using State Space Principles},
  author={Aakash Lahoti and Kevin Y. Li and Berlin Chen and Caitlin Wang and Aviv Bick and J. Zico Kolter and Tri Dao and Albert Gu},
  year={2026},
  eprint={2603.15569},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2603.15569}
}

@software{mamba3jax,
  author  = {Cosmo Santoni},
  title   = {mamba3-jax: Pure JAX Implementation of Mamba3},
  year    = {2026},
  url     = {https://github.com/jax-state-spaces/mamba3-jax}
}
```
