Metadata-Version: 2.4
Name: zmlx
Version: 0.4.2
Summary: ZMLX: Triton for Apple Silicon. Write custom Metal GPU kernels for MLX with one-line ergonomics, automatic gradients, and built-in autotuning.
Project-URL: Homepage, https://github.com/Hmbown/ZMLX
Project-URL: Repository, https://github.com/Hmbown/ZMLX
Project-URL: Documentation, https://github.com/Hmbown/ZMLX#readme
Project-URL: Issues, https://github.com/Hmbown/ZMLX/issues
Project-URL: Changelog, https://github.com/Hmbown/ZMLX/blob/main/CHANGELOG.md
Author: Hunter Bown
License: MIT
License-File: LICENSE
Keywords: apple-silicon,autograd,jit,kernels,metal,mlx,zmlx
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries
Requires-Python: >=3.10
Requires-Dist: mlx>=0.30.0
Provides-Extra: dev
Requires-Dist: mypy>=1.10.0; extra == 'dev'
Requires-Dist: numpy>=1.24.0; extra == 'dev'
Requires-Dist: pytest>=8.0.0; extra == 'dev'
Requires-Dist: pyyaml>=6.0; extra == 'dev'
Requires-Dist: ruff>=0.5.0; extra == 'dev'
Requires-Dist: types-pyyaml>=6.0; extra == 'dev'
Provides-Extra: train
Requires-Dist: huggingface-hub>=0.20.0; extra == 'train'
Requires-Dist: mlx-lm>=0.25.0; extra == 'train'
Requires-Dist: pyyaml>=6.0; extra == 'train'
Requires-Dist: transformers>=4.40.0; extra == 'train'
Description-Content-Type: text/markdown

# ZMLX — Triton for Apple Silicon

[![PyPI](https://img.shields.io/pypi/v/zmlx.svg)](https://pypi.org/project/zmlx/)
[![Python 3.10+](https://img.shields.io/badge/python-3.10%2B-blue.svg)](https://www.python.org/downloads/)
[![License: MIT](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE)
[![Platform: macOS Apple Silicon](https://img.shields.io/badge/platform-macOS%20Apple%20Silicon-lightgrey.svg)](https://github.com/ml-explore/mlx)

**The Triton-like toolkit for [MLX](https://github.com/ml-explore/mlx)** — write custom Metal GPU kernels from Python with one-line ergonomics, automatic gradients, and built-in autotuning. No raw Metal, no manual threadgroups, no boilerplate.

> **+33% decode** on Qwen3-32B (dense) and **+51% prompt / +36% decode** on Qwen3-30B-A3B (MoE) with fused kernel patches. [Benchmarks](#model-level-inference)

```bash
pip install zmlx
```

**Speed up your model in 3 lines:**

```python
import mlx_lm
from zmlx.patch import patch

model, tokenizer = mlx_lm.load("mlx-community/Qwen3-30B-A3B-4bit")
patch(model)  # +51% prompt, +36% decode — done
```

**Or write custom GPU kernels in one line:**

```python
from zmlx.api import elementwise
import mlx.core as mx

# Math formula → compiled Metal kernel → runs on GPU
mish = elementwise("x * tanh(log(1 + exp(x)))", name="mish")
y = mish(mx.random.normal((1024,)))
```

---

## What's New in v0.4.0

- **1.33x decode on 32B dense, 1.51x prompt / 1.36x decode on 30B MoE** — fused residual+RMSNorm and fused MoE gating patches ([benchmarks](#model-level-inference))
- **MoE patch** — fused `top2_gating_softmax` + `moe_combine` eliminates multiple memory round-trips in expert routing
- **High-level API** — `elementwise()`, `reduce()`, `map_reduce()` for kernel authoring in one line
- **JIT compiler** — `@jit` decorator compiles Python scalar expressions to Metal
- **Testing & benchmarking** — `zmlx.testing.assert_matches()`, `zmlx.bench.compare()` for correctness verification and side-by-side timing
- **Profiling** — `zmlx.profile.time_kernel()`, `dump_msl()`, `kernel_stats()` for introspection
- **Training pipeline** — `zmlx train` CLI for LoRA fine-tuning with ZMLX patches
- **Smart patching** — `smart_patch()` auto-benchmarks each pattern and keeps only what helps
- **Fused AdamW** — single-kernel optimizer step reducing memory bandwidth
- **Paged attention** — `zmlx.nn.PagedAttention` for high-throughput serving
- **70+ kernel catalog** — activations, norms, RoPE, attention, MoE, quantization, loss, bit ops

---

## Why ZMLX?

When you need a custom GPU op on Apple Silicon, your options today are:
1. Write raw Metal source strings, manage caching, figure out threadgroups, wire up autodiff manually
2. Use ZMLX

ZMLX wraps `mx.fast.metal_kernel` and `mx.custom_function` to provide **Triton-like ergonomics**:

- **One-line kernel authoring** — define elementwise, reduction, and map-reduce ops from C expressions
- **Automatic gradients** — custom VJP backward passes (themselves Metal kernels) via `mx.custom_function`
- **Define-once caching** — kernels compile once, reused by source hash + config
- **Autotuning** — threadgroup size search with persistent caching
- **Testing & benchmarking** — verify against reference ops, compare timings side-by-side
- **Model patching** — swap MLX layers for fused ZMLX kernels with `patch(model)`

---

## Install

**Requirements**: macOS (Apple Silicon), Python >= 3.10, mlx >= 0.30.0

```bash
pip install zmlx
```

From source (development):

```bash
git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"
```

---

## Quick Start

### 1. Custom elementwise kernel

```python
from zmlx.api import elementwise
import mlx.core as mx

# Non-differentiable — just forward pass
fast_exp = elementwise("metal::exp(x)", name="fast_exp")
y = fast_exp(mx.random.normal((1024,)))

# Differentiable — with custom VJP
from zmlx import msl

silu = elementwise(
    "kk_silu(x)",
    name="my_silu",
    grad_expr="g * (s + x * s * ((T)1 - s))",
    grad_prelude="T s = kk_sigmoid(x);",
    use_output=False,
    header=msl.DEFAULT_HEADER,
)
gx = mx.grad(lambda z: silu(z).sum())(mx.random.normal((1024,)))
```

### 2. Custom reduction

```python
from zmlx.api import reduce
import mlx.core as mx

my_sum = reduce(init="0.0f", update="acc + v", name="row_sum")
y = my_sum(mx.random.normal((8, 1024)))  # shape (8,)
```

### 3. Two-pass map-reduce (softmax pattern)

```python
from zmlx.api import map_reduce
import mlx.core as mx

my_softmax = map_reduce(
    pass1={"init": "-INFINITY", "update": "max(acc1, x)", "reduce": "max(a, b)"},
    pass2={"init": "0.0f", "update": "acc2 + exp(x - s1)", "reduce": "a + b"},
    write="exp(x - s1) / s2",
    name="my_softmax",
)
y = my_softmax(mx.random.normal((8, 1024)))
```

### 4. Test and benchmark your kernel

```python
import zmlx
import mlx.core as mx

# Verify correctness
zmlx.testing.assert_matches(
    my_softmax, lambda x: mx.softmax(x, axis=-1),
    shapes=[(8, 1024), (32, 4096)],
)

# Benchmark
zmlx.bench.compare(
    {"ZMLX": my_softmax, "MLX": lambda x: mx.softmax(x, axis=-1)},
    shapes=[(1024, 4096), (4096, 4096)],
)
```

### 5. Lower-level building blocks

```python
from zmlx import autograd, elementwise, msl
import mlx.core as mx

# Unary kernel (no gradient)
exp_kern = elementwise.unary(
    name="kk_exp", expr="metal::exp(x)",
    compute_dtype=mx.float32, header=msl.DEFAULT_HEADER,
)

# Binary kernel with custom VJP
mul_op = autograd.binary_from_expr(
    name="safe_mul", fwd_expr="a * b",
    vjp_lhs_expr="g * b", vjp_rhs_expr="g * a",
    compute_dtype=mx.float32,
)
```

---

## Kernel Catalog

ZMLX includes 70+ kernels organized by domain. Some are genuinely useful for custom workloads (loss, GLU fusions, bit ops, MoE gating). Others are **reference implementations** showing codegen patterns — correct but not faster than MLX built-ins for standard transformer shapes.

Full reference: [`docs/KERNELS.md`](docs/KERNELS.md).

| Module | Highlights |
|:---|:---|
| `loss` | `softmax_cross_entropy` — memory-efficient fused loss |
| `transformer` | `swiglu`, `geglu`, `rmsnorm_residual` (with full weight gradients), `dropout` — genuine fusions |
| `bits` | `pack_bits`, `unpack_bits` — no MLX equivalent |
| `moe` | `top2_gating_softmax`, `moe_dispatch`, `moe_combine` — fused expert routing (+36% decode on 30B MoE) |
| `quant` | FP8 (E4M3/E5M2), NF4, int8, int4 dequantization — real bit-manipulation kernels |
| `optimizers` | `adamw_step` — fused AdamW parameter update in a single kernel |
| `scan` | `cumsum_lastdim` — differentiable prefix sum |
| `norms` | `rmsnorm`, `layernorm` — parallel reduction. All norms compute in float32 internally |
| `softmax` | `softmax_lastdim` — map-reduce codegen showcase |
| `rope` | `apply_rope`, `apply_rope_interleaved`, `apply_gqa_rope` |
| `linear` | Reference fused-linear patterns (naive matmul, not for production) |

---

## Architecture

Three-layer design. Full details: [`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md).

1. **Metal kernel infrastructure** — `MetalKernel` wrapper, in-process cache, stats tracking
2. **Code generation & helpers** — MSL templates, elementwise/autograd/rowwise APIs, autotuning
3. **Kernel catalog** — domain modules built on layers 1 and 2

---

## Benchmarks

### Op-level (B=16, S=1024, D=1024, float32, M4 Max)

Run `python benchmarks/microbench.py` to reproduce on your hardware.

| Operation | MLX | ZMLX | Speedup |
|:--|--:|--:|:--|
| **SwiGLU** | 0.85 ms | **0.40 ms** | **2.1x** |
| **Dropout** | 3.12 ms | **0.38 ms** | **8.2x** |
| **Top-K** | 1.82 ms | **0.49 ms** | **3.7x** |
| **Gather-Add** | 0.54 ms | **0.41 ms** | **1.3x** |
| Softmax | 0.36 ms | 0.41 ms | 0.90x |
| RMSNorm | 0.37 ms | 0.41 ms | 0.90x |
| Sum | 0.19 ms | 0.36 ms | 0.53x |
| CumSum | 0.30 ms | 0.59 ms | 0.51x |

ZMLX wins big on **fused operations** that MLX doesn't provide as single ops (SwiGLU, fused-RNG dropout, fused gather-add). MLX's built-in operations (`mx.fast.rms_norm`, `mx.softmax`, reductions) are already highly optimized and should not be replaced.

### Model-level inference

LLM inference is **memory-bandwidth-bound**: fused kernels shine on large models where each saved memory round-trip matters. The effect scales with model size — small models see no benefit, large models see significant speedups.

**Qwen3-32B-4bit (64 layers, ~18 GB)** — M4 Max, 36 GB

| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|:--|--:|--:|:--|
| Baseline (MLX) | 107 | 13.5 | — |
| ZMLX fused activations | 108 | 14.2 | 1.01x / **1.05x** |
| ZMLX all patches | **127** | **18.0** | **1.19x / 1.33x** |

> **+33% decode throughput** on a 32B model — 64 layers of fused residual+RMSNorm, each saving a full memory round-trip.

**Qwen3-30B-A3B-4bit (MoE, 48 layers, 3B active/30B total)** — `python benchmarks/inference_benchmark.py --models qwen3-30b-a3b --selective`

| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|:--|--:|--:|:--|
| Baseline (MLX) | 1,083 | 116 | — |
| ZMLX fused activations | **1,635** | **158** | **1.51x / 1.36x** |

> **+36% decode throughput** on MoE models — fused gating (`top2_gating_softmax`) and combine (`moe_combine`) kernels eliminate multiple memory round-trips in the expert routing path.

**Qwen3-8B-4bit (32 layers, ~5 GB)** — `python benchmarks/inference_benchmark.py --models qwen3-8b --selective`

| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|:--|--:|--:|:--|
| Baseline (MLX) | 676 | 75 | — |
| ZMLX fused activations | 675 | 76 | 1.00x / 1.01x |

**Llama-3.2-1B-Instruct-4bit (16 layers, ~0.8 GB)** — `python benchmarks/llama_benchmark.py`

| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|:--|--:|--:|:--|
| Baseline (MLX) | 3,913 | 377 | — |
| ZMLX fused activations | 3,804 | 378 | 0.97x / 1.00x |
| ZMLX all patches | 3,705 | 366 | 0.95x / 0.97x |

**When do patches help?**
- **Large Dense Models (8B+)**: Use **all patches**. Bandwidth-bound, so fused residual+norm saves real throughput.
- **MoE Models**: Use **fused activations** (`--selective`). The `moe_mlp` patch provides a massive +36% boost.
- **Small Models (< 3B)**: Neutral. Overhead often outweighs fusion gains. Use `smart_patch` to be sure.

```python
from zmlx.patch import smart_patch
import mlx.core as mx

# Auto-benchmark each pattern, keep only what helps
sample = mx.array([tokenizer.encode("Hello")])
model = smart_patch(model, sample)
```

Or use presets if you know your workload:

```python
from zmlx.patch import patch, FUSED_ACTIVATIONS, TRAINING_RECOMMENDED
patch(model)                                   # large models (8B+): all patches
patch(model, patterns=FUSED_ACTIVATIONS)       # MoE/small: activations + MoE gating
patch(model, patterns=TRAINING_RECOMMENDED)    # training: activations + norms
```

### Smart patching

`smart_patch` applies each candidate pattern, benchmarks the model's forward pass, and **automatically reverts patterns that make things slower**. It supports custom forward functions for realistic benchmarks:

```python
from zmlx.patch import smart_patch

# Basic: benchmark raw forward pass
model = smart_patch(model, sample_input)

# Advanced: benchmark with actual generation
def gen_fn(model, sample):
    return mlx_lm.generate(model, tokenizer, prompt="Hello", max_tokens=20)

model = smart_patch(model, sample, forward_fn=gen_fn, threshold=0.99)

# Result includes per-pattern speedups
result = model._zmlx_patch_result
print(result.benchmarks)    # {'swiglu_mlp': 1.012, 'residual_norm': 0.971}
print(result.summary())     # what was kept and why
```

### Autotuning

Replacement modules support `threadgroup="auto"` to search for the best threadgroup size on first invocation:

```python
from zmlx.patch import patch
patch(model, threadgroup="auto")  # autotunes each kernel on first call
```

The `map_reduce()` API also supports autotuning:

```python
from zmlx.api import map_reduce
my_softmax = map_reduce(..., threadgroup="auto")  # autotunes per-shape
```

### Where ZMLX genuinely helps

- **Large model inference** — 1.33x decode on 32B dense (fused residual+norm), 1.51x prompt / 1.36x decode on 30B MoE (fused gating+combine)
- **Custom ops that MLX doesn't have** — SwiGLU, GeGLU, fused dropout, fused MoE gating, bit packing
- **Training** — fused `softmax_cross_entropy` loss, correct weight gradients for `rmsnorm_residual`
- **Authoring new kernels** — the `elementwise()`, `reduce()`, and `map_reduce()` APIs let you go from math formula to compiled Metal kernel in one line
- **Quantization** — FP8 (E4M3/E5M2), NF4, int8, int4 dequantization with real bit-manipulation kernels

---

## Precision

All ZMLX Metal kernels compute internally in **float32** regardless of input dtype. The `compute_dtype` parameter accepted by many kernel functions is **deprecated** and will be removed in a future release. Passing a non-None value will emit a `DeprecationWarning`.

---

## Documentation

- [`docs/QUICKSTART.md`](docs/QUICKSTART.md) — 5-minute tutorial
- [`docs/COOKBOOK.md`](docs/COOKBOOK.md) — Recipes for common patterns
- [`docs/KERNELS.md`](docs/KERNELS.md) — Complete kernel catalog reference
- [`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md) — Design philosophy

---

## Contributing

See [`CONTRIBUTING.md`](CONTRIBUTING.md) for setup, testing, and conventions.

---

## License

MIT. See [`LICENSE`](LICENSE).
