Metadata-Version: 2.4
Name: mlx-scan
Version: 0.0.2
Summary: Associative scan implementation with MLX.
Project-URL: Homepage, https://github.com/jaketae/mlx-scan
Project-URL: Repository, https://github.com/jaketae/mlx-scan
Project-URL: Issues, https://github.com/jaketae/mlx-scan/issues
Author: Jake Tae
License-Expression: MIT
License-File: LICENSE
Keywords: associative-scan,mlx,parallel-scan,prefix-sum,scan
Classifier: Development Status :: 2 - Pre-Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Requires-Dist: mlx>=0.31
Provides-Extra: bench
Requires-Dist: jax>=0.6; extra == 'bench'
Provides-Extra: dev
Requires-Dist: pytest>=8; extra == 'dev'
Requires-Dist: ruff>=0.11; extra == 'dev'
Description-Content-Type: text/markdown

# mlx-scan

Associative scan implementation with MLX.

## Quickstart

To install:

```bash
uv add mlx-scan       # uv
pip install mlx-scan  # pip
```

For local development:

```bash
uv sync --extra dev --extra bench
```

To test and lint:

```bash
uv run pytest
uv run ruff check .
```

## Associative Scan

A scan computes every prefix of a sequence:

```text
[a, b, c, d] -> [a, a (*) b, a (*) b (*) c, a (*) b (*) c (*) d]
```

If the binary operator `(*)` is associative, the prefixes can be computed with parallel-prefix algorithms such as [Hillis-Steele](https://rsim.cs.illinois.edu/arch/qual_papers/systems/3.pdf) or [Blelloch](https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf) scan. This matters for more than sums: [matrix products](https://docs.jax.dev/en/latest/_autosummary/jax.lax.associative_scan.html), [affine recurrences](https://arxiv.org/abs/1709.04057), and [sequence-model operations](https://arxiv.org/abs/2312.00752) can be expressed with scan-like computations.

## Example

The canonical example is a cumulative sum:

```python
import mlx.core as mx
from mlx_scan import associative_scan

x = mx.array([1, 2, 3, 4])
associative_scan(lambda a, b: a + b, x)
# array([1, 3, 6, 10], dtype=int32)
```

`associative_scan` accepts an MLX array or a pytree of MLX arrays.

## Performance

This implementation is useful as a generic, MLX-native associative scan, but it is not a replacement for a compiler-native scan primitive. Benchmarks were run on macOS 14.2.1 arm64 with Python 3.10.12, MLX 0.31.2, and JAX 0.6.2 CPU-only.

For addition, the baseline `mx.cumsum` is much faster than generic scan composition. On CPU, compiled median times in milliseconds:

| shape           | MLX Hillis-Steele | MLX Blelloch | MLX cumsum | JAX lax scan |
| --------------- | ----------------: | -----------: | ---------: | -----------: |
| 1024            |             0.182 |        0.360 |      0.023 |        0.010 |
| 65536           |             0.536 |        0.481 |      0.080 |        0.149 |
| 512x1024 axis=1 |             3.145 |        1.481 |      0.483 |        0.823 |

For an SSM/linear-RNN style affine recurrence, there is no `mx.cumsum` equivalent. Each scan element is a pair `(a, b)` representing `h -> a*h + b`, with associative composition:

```python
compose((a_left, b_left), (a_right, b_right)) = (
    a_right * a_left,
    a_right * b_left + b_right,
)
```

Compiled CPU median times in milliseconds:

| shape          | MLX Hillis-Steele | MLX Blelloch | JAX lax scan |
| -------------- | ----------------: | -----------: | -----------: |
| 1024           |             0.255 |        0.541 |        0.035 |
| 8192           |             0.345 |        0.929 |        0.106 |
| 128x256 axis=1 |             0.471 |        0.534 |        0.228 |

The main bottleneck is not plain Python loop overhead once compiled; it is the graph produced by expressing scan through repeated `slice`, `concat`, `stack`, `reshape`, and user-operator applications. MLX evaluates those composed graphs well enough to make the implementation usable, but it does not appear to apply the same scan-specific fusion/lowering that XLA applies to `jax.lax.associative_scan`.

For a Mamba-style prefill recurrence, adapted from the SSM loop in [`mlx-lm`'s `mamba.py`](https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/mamba.py), the isolated recurrence `state_t = decay_t * state_{t-1} + update_t` is a better fit for the work-efficient Blelloch scan. With shape `(batch=1, sequence, channels=256, state=16)`, compiled GPU median times in milliseconds:

| sequence |   loop | Hillis-Steele | Blelloch |
| -------- | -----: | ------------: | -------: |
| 128      |  2.654 |         0.888 |    0.995 |
| 512      |  8.469 |         1.973 |    0.899 |
| 1024     | 18.481 |         5.152 |    1.665 |
| 2048     | 43.396 |        10.206 |    1.898 |

The benchmark tables above were generated with:

```bash
uv run python benchmarks/bench_scan.py --compiled --op add --device cpu \
  --include-jax \
  --algorithm hillis_steele \
  --algorithm blelloch \
  --algorithm mlx_cumsum \
  --algorithm jax_lax_associative_scan \
  --runs 50 --warmup 10

uv run python benchmarks/bench_scan.py --compiled --op affine --device cpu \
  --include-jax \
  --algorithm hillis_steele \
  --algorithm blelloch \
  --algorithm jax_lax_associative_scan \
  --runs 50 --warmup 10

uv run python benchmarks/bench_mamba_prefill.py --compiled --device gpu \
  --runs 10 --warmup 3
```

## References

```bibtex
@article{HillisSteele1986,
  author = {Hillis, W. Daniel and Steele, Guy L., Jr.},
  title = {Data Parallel Algorithms},
  journal = {Communications of the ACM},
  volume = {29},
  number = {12},
  pages = {1170--1183},
  year = {1986},
  doi = {10.1145/7902.7903}
}
```

```bibtex
@incollection{Blelloch1993,
  author = {Blelloch, Guy E.},
  title = {Prefix Sums and Their Applications},
  booktitle = {Synthesis of Parallel Algorithms},
  publisher = {Morgan Kaufmann},
  year = {1993}
}
```

```bibtex
@inproceedings{MartinCundy2018,
  author = {Martin, Eric and Cundy, Chris},
  title = {Parallelizing Linear Recurrent Neural Nets Over Sequence Length},
  booktitle = {International Conference on Learning Representations},
  year = {2018}
}
```

```bibtex
@article{GuDao2023,
  author = {Gu, Albert and Dao, Tri},
  title = {Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
  journal = {arXiv preprint arXiv:2312.00752},
  year = {2023}
}
```
