Metadata-Version: 2.4
Name: zmlx
Version: 0.7.0
Summary: ZMLX: Metal-kernel toolkit and optimization lab for MLX on Apple Silicon. Fused MoE decode (+5-26%), custom GPU kernels in one line, 70+ kernel catalog.
Project-URL: Homepage, https://github.com/Hmbown/ZMLX
Project-URL: Repository, https://github.com/Hmbown/ZMLX
Project-URL: Documentation, https://github.com/Hmbown/ZMLX#readme
Project-URL: Issues, https://github.com/Hmbown/ZMLX/issues
Project-URL: Changelog, https://github.com/Hmbown/ZMLX/blob/main/CHANGELOG.md
Author: Hunter Bown
License: MIT
License-File: LICENSE
Keywords: apple-silicon,autograd,jit,kernels,metal,mlx,zmlx
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries
Requires-Python: >=3.10
Requires-Dist: mlx>=0.30.0
Provides-Extra: dev
Requires-Dist: mypy>=1.10.0; extra == 'dev'
Requires-Dist: numpy>=1.24.0; extra == 'dev'
Requires-Dist: pytest>=8.0.0; extra == 'dev'
Requires-Dist: pyyaml>=6.0; extra == 'dev'
Requires-Dist: ruff>=0.5.0; extra == 'dev'
Requires-Dist: types-pyyaml>=6.0; extra == 'dev'
Provides-Extra: train
Requires-Dist: huggingface-hub>=0.20.0; extra == 'train'
Requires-Dist: mlx-lm>=0.25.0; extra == 'train'
Requires-Dist: pyyaml>=6.0; extra == 'train'
Requires-Dist: transformers>=4.40.0; extra == 'train'
Description-Content-Type: text/markdown

# ZMLX - Triton-style kernels for Apple Silicon

[![PyPI](https://img.shields.io/pypi/v/zmlx.svg)](https://pypi.org/project/zmlx/)
[![Python 3.10+](https://img.shields.io/badge/python-3.10%2B-blue.svg)](https://www.python.org/downloads/)
[![License: MIT](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE)
[![Platform: macOS Apple Silicon](https://img.shields.io/badge/platform-macOS%20Apple%20Silicon-lightgrey.svg)](https://github.com/ml-explore/mlx)

**A Metal kernel toolkit and upstream incubation lab for [MLX](https://github.com/ml-explore/mlx)** — author custom GPU kernels in Python, apply fused model patches, and prototype C++ Metal primitives for upstream MLX.

Toolkit highlights (available via `pip install zmlx` on stock MLX):
- SwiGLU 2.0x, Dropout 7.5x, Top-K 3.7x in op-level microbenchmarks
- 70+ kernel catalog, autograd, benchmarking utilities, and model patching

> **Validated benchmarks** (M4 Max 36 GB, MLX 0.30.4.dev, Jan 31, 2026):
>
> | Result | Measurement | Notes |
> |:--|:--|:--|
> | **+5-8% decode** on Qwen3-30B-A3B-4bit (MoE, E=128, K=8) | 119-122 tok/s vs 113-114 | Requires local MLX build with `gather_qmm_swiglu` |
> | **Neutral (1.01x)** on Qwen3-4B-4bit (dense) | 125.3 vs 124.4 tok/s | Safe on dense models |
> | **+4% per MoE layer** on LFM2-8B-A1B-4bit (MoE, E=32, K=4) | 293 us -> 282 us at M=1 decode | Kernel-level bench; E2E too noisy to report |
>
> Fused SwiGLU requires a [local MLX build](#optimization-lab) with `gather_qmm_swiglu`. On stock MLX, gating+combine-only can be neutral to slightly negative (Qwen3-30B measured 0.95–0.98x), so use `smart_patch` or a fused build for MoE models.

```bash
pip install zmlx
```

**MoE patching example:**

```python
import mlx_lm
from zmlx.patch import patch

model, tokenizer = mlx_lm.load("mlx-community/Qwen3-30B-A3B-4bit")
patch(model)  # best on MoE with fused MLX; use smart_patch on stock MLX
```

**Custom kernel example:**

```python
from zmlx.api import elementwise
import mlx.core as mx

# Math formula → compiled Metal kernel → runs on GPU
mish = elementwise("x * tanh(log(1 + exp(x)))", name="mish")
y = mish(mx.random.normal((1024,)))
```

---

## What's New in v0.7.0

### Fused expert SwiGLU (`gather_qmm_swiglu`)

ZMLX prototypes fused C++ Metal primitives in a local MLX fork. The first: `gather_qmm_swiglu` — fuses gate projection + up projection + SwiGLU activation into a single kernel launch, reading the input tensor once instead of twice.

- **+5-8% decode** on Qwen3-30B-A3B-4bit (119-122 vs 113-114 tok/s, E2E)
- **+4% per MoE layer** on LFM2-8B-A1B-4bit (293 us → 282 us at M=1 decode, kernel-level)
- **Neutral (1.01x)** on Qwen3-4B-4bit (dense, E2E)
- Auto-enabled by `patch(model)` when available; falls back to two-pass otherwise
- Threshold-guarded: fused kernel used for M<=32 (decode/small prefill), falls back at large M where it is slower
- Added kernel-level + E2E benchmark scripts plus correctness tests

Requires building MLX from the local fork (`mlx_local/`). See [Optimization Lab](#optimization-lab) for details. Upstream PR to MLX planned.

### Optimization lab

ZMLX is evolving into two things:

1. **A Metal kernel toolkit** (stable) — `elementwise()`, `reduce()`, `map_reduce()`, autograd, model patching. This is what `pip install zmlx` gives you.
2. **An MLX optimization lab** (experimental) — prototyping fused C++ Metal primitives that should eventually be upstreamed to MLX. These require a local MLX build and live in `mlx_local/`.

The lab exists because some fusions (quantized matmul + activation) need access to MLX's internal SIMD helpers (`qdot`, `load_vector`, `QuantizedBlockLoader`) which aren't exposed through the public `metal_kernel` API. The plan is: prototype here, validate with benchmarks, upstream to MLX, then ZMLX auto-detects the primitives when they land.

**Current lab work:**
- `gather_qmm_swiglu` — fused expert gate+up+SwiGLU (done, working, benchmarked)
- `add_rms_norm` — fused residual add + RMSNorm (planned, benefits all models)
- `gather_qmm_combine` — fused down projection + weighted expert sum (planned, MoE-specific)

### Previous highlights (v0.6.x)

- SIMD-group top-k gating, bias-aware fused gating, dynamic `num_experts_per_tok`
- `topk_gating_softmax(x, k)` kernel (fused Metal for k<=8, full-softmax + expert bias)
- LFM2, GPT-OSS, GLM-4, Mixtral model support
- `mode` parameter, validated benchmarks, `router` attribute support

### Previous highlights (v0.4-0.5)

- MoE patch (fused gating + combine), high-level API, JIT compiler
- Smart patching, training pipeline, 70+ kernel catalog

---

## Why ZMLX?

When you need a custom GPU op on Apple Silicon, your options today are:
1. Write raw Metal source strings, manage caching, figure out threadgroups, wire up autodiff manually
2. Use ZMLX

ZMLX wraps `mx.fast.metal_kernel` and `mx.custom_function` to provide **Triton-like ergonomics**:

- **One-line kernel authoring** - define elementwise, reduction, and map-reduce ops from C expressions
- **Automatic gradients** - custom VJP backward passes (themselves Metal kernels) via `mx.custom_function`
- **Define-once caching** - kernels compile once, reused by source hash + config
- **Autotuning** - threadgroup size search with persistent caching
- **Testing & benchmarking** - verify against reference ops, compare timings side-by-side
- **Model patching** - swap MLX layers for fused ZMLX kernels with `patch(model)`

---

## Install

**Requirements**: macOS (Apple Silicon), Python >= 3.10, mlx >= 0.30.0

```bash
pip install zmlx
```

From source (development):

```bash
git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"
```

---

## Quick Start

### 1. Custom elementwise kernel

```python
from zmlx.api import elementwise
import mlx.core as mx

# Non-differentiable - just forward pass
fast_exp = elementwise("metal::exp(x)", name="fast_exp")
y = fast_exp(mx.random.normal((1024,)))

# Differentiable - with custom VJP
from zmlx import msl

silu = elementwise(
    "kk_silu(x)",
    name="my_silu",
    grad_expr="g * (s + x * s * ((T)1 - s))",
    grad_prelude="T s = kk_sigmoid(x);",
    use_output=False,
    header=msl.DEFAULT_HEADER,
)
gx = mx.grad(lambda z: silu(z).sum())(mx.random.normal((1024,)))
```

### 2. Custom reduction

```python
from zmlx.api import reduce
import mlx.core as mx

my_sum = reduce(init="0.0f", update="acc + v", name="row_sum")
y = my_sum(mx.random.normal((8, 1024)))  # shape (8,)
```

### 3. Two-pass map-reduce (softmax pattern)

```python
from zmlx.api import map_reduce
import mlx.core as mx

my_softmax = map_reduce(
    pass1={"init": "-INFINITY", "update": "max(acc1, x)", "reduce": "max(a, b)"},
    pass2={"init": "0.0f", "update": "acc2 + exp(x - s1)", "reduce": "a + b"},
    write="exp(x - s1) / s2",
    name="my_softmax",
)
y = my_softmax(mx.random.normal((8, 1024)))
```

### 4. Test and benchmark your kernel

```python
import zmlx
import mlx.core as mx

# Verify correctness
zmlx.testing.assert_matches(
    my_softmax, lambda x: mx.softmax(x, axis=-1),
    shapes=[(8, 1024), (32, 4096)],
)

# Benchmark
zmlx.bench.compare(
    {"ZMLX": my_softmax, "MLX": lambda x: mx.softmax(x, axis=-1)},
    shapes=[(1024, 4096), (4096, 4096)],
)
```

### 5. Lower-level building blocks

```python
from zmlx import autograd, elementwise, msl
import mlx.core as mx

# Unary kernel (no gradient)
exp_kern = elementwise.unary(
    name="kk_exp", expr="metal::exp(x)",
    compute_dtype=mx.float32, header=msl.DEFAULT_HEADER,
)

# Binary kernel with custom VJP
mul_op = autograd.binary_from_expr(
    name="safe_mul", fwd_expr="a * b",
    vjp_lhs_expr="g * b", vjp_rhs_expr="g * a",
    compute_dtype=mx.float32,
)
```

---

## Kernel Catalog

ZMLX includes 70+ kernels organized by domain. Some are genuinely useful for custom workloads (loss, GLU fusions, bit ops, MoE gating). Others are **reference implementations** showing codegen patterns - correct but not faster than MLX built-ins for standard transformer shapes.

Full reference: [`docs/KERNELS.md`](docs/KERNELS.md).

| Module | Highlights |
|:---|:---|
| `loss` | `softmax_cross_entropy` - memory-efficient fused loss |
| `transformer` | `swiglu`, `geglu`, `rmsnorm_residual` (with full weight gradients), `dropout` - genuine fusions |
| `bits` | `pack_bits`, `unpack_bits` - no MLX equivalent |
| `moe` | `topk_gating_softmax`, `moe_dispatch`, `moe_combine` - fused expert routing (k ≤ 8 fused, bias-aware) |
| `quant` | FP8 (E4M3/E5M2), NF4, int8, int4 dequantization - real bit-manipulation kernels |
| `optimizers` | `adamw_step` - fused AdamW parameter update in a single kernel |
| `scan` | `cumsum_lastdim` - differentiable prefix sum |
| `norms` | `rmsnorm`, `layernorm` - parallel reduction. All norms compute in float32 internally |
| `softmax` | `softmax_lastdim` - map-reduce codegen showcase |
| `rope` | `apply_rope`, `apply_rope_interleaved`, `apply_gqa_rope` |
| `linear` | Reference fused-linear patterns (naive matmul, not for production) |

---

## Architecture

Three-layer design. Full details: [`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md).

1. **Metal kernel infrastructure** - `MetalKernel` wrapper, in-process cache, stats tracking
2. **Code generation & helpers** - MSL templates, elementwise/autograd/rowwise APIs, autotuning
3. **Kernel catalog** - domain modules built on layers 1 and 2

---

## Benchmarks

Three levels of measurement:
1. **Isolated primitives** — `bench_gather_qmm_swiglu.py`
2. **Single MoE layer** — `bench_moe_layer.py` (500 iters, p50 median, `mx.synchronize()`)
3. **E2E model decode** — `bench_moe_e2e.py`

### Op-level (B=16, S=1024, D=1024, float16, M4 Max)

Run `python benchmarks/microbench.py` to reproduce on your hardware.

| Operation | MLX | ZMLX | Speedup |
|:--|--:|--:|:--|
| **SwiGLU** | 0.87 ms | **0.43 ms** | **2.0x** |
| **Dropout** | 3.08 ms | **0.41 ms** | **7.5x** |
| **Top-K** | 1.81 ms | **0.49 ms** | **3.7x** |
| **Gather-Add** | 0.55 ms | **0.42 ms** | **1.3x** |
| Softmax | 0.45 ms | 0.44 ms | ~1.0x |
| RMSNorm | 0.51 ms | 0.54 ms | 0.95x |
| MoE gating | 0.35 ms | 0.37 ms | 0.94x |
| Sum | 0.22 ms | 0.37 ms | 0.58x |
| CumSum | 0.32 ms | 0.62 ms | 0.52x |

ZMLX is most effective for **fused operations** that MLX does not provide as single ops (SwiGLU, fused-RNG dropout, fused gather-add). MLX built-ins (`mx.fast.rms_norm`, `mx.softmax`, reductions) are already highly optimized and remain the preferred choice for standard transformer shapes.

### Model-level inference (E2E)

All baselines are **unmodified `mlx_lm`**. ZMLX rows add `patch(model)`. Same model weights, same quantization, same prompt. Benchmarks on M4 Max 36 GB, MLX 0.30.4.dev, Jan 31, 2026 unless noted.

#### MoE models

**Qwen3-30B-A3B-4bit** (MoE, 48 layers, E=128, K=8) — `python benchmarks/bench_moe_e2e.py`  
3 runs x 200 tokens, repeated 3 times.

| Config | Decode (tok/s) | vs Baseline |
|:--|--:|:--|
| Baseline (`mlx_lm`) | 113-114 | — |
| `patch(model)` — gating + combine only | 109-111 | 0.95-0.98x |
| `patch(model)` — with fused SwiGLU | 119-122 | **+5-8%** |

> Fused SwiGLU requires a local MLX build. Gating+combine alone measured below baseline on this model.

**LFM2-8B-A1B-4bit** (MoE, 24 layers, E=32, K=4) — E2E is too noisy at ~20k tok/s.  
Use the kernel-level MoE layer benchmark below as the authoritative measurement.

#### Dense models (neutral — expected)

**Qwen3-4B-4bit** (dense, 36 layers) — `python benchmarks/bench_moe_e2e.py`  
3 runs x 1000 tokens.

| Config | Decode (tok/s) | vs Baseline |
|:--|--:|:--|
| Baseline (`mlx_lm`) | 124.4 | — |
| `patch(model)` | 125.3 | **1.01x (neutral)** |

> Dense decode is bandwidth-bound; patches are safe but not expected to help.

**Overnight suites (sequential, cache on external drive):**

```bash
python benchmarks/bench_moe_suite.py \
  --model-list benchmarks/moe_models.txt \
  --cache-dir /Volumes/VIXinSSD/TEST \
  --runs 3 --max-tokens 200 --resume
```

### Kernel-level MoE layer timing (authoritative for fast models)

`python benchmarks/bench_moe_layer.py` isolates a single MoE layer forward pass, timed with `mx.synchronize()` brackets. 500 iterations, p50 median. Per-measurement stdev: 4-10%.

**LFM2-8B-A1B-4bit** (E=32, K=4, hidden=2048, intermediate=1792):

| seq_len | Baseline | Gating+combine | Fused SwiGLU | Speedup |
|:--|--:|--:|--:|:--|
| 1 (decode) | 293 us | 288 us (1.02x) | 282 us | **1.04x** |
| 4 | 531 us | 500 us (1.06x) | 498 us | **1.07x** |
| 16 | 998 us | 943 us (1.06x) | 949 us | 1.05x |
| 64 | 2291 us | 2480 us (0.92x) | 2452 us | 0.93x |

**Qwen3-30B-A3B-4bit** (E=128, K=8, hidden=2048, intermediate=1024):

| seq_len | Baseline | Gating+combine | Fused SwiGLU | Speedup |
|:--|--:|--:|--:|:--|
| 1 (decode) | 613 us | 621 us (0.99x) | 613 us | **1.00x** |
| 4 | 808 us | 724 us (1.12x) | 756 us | **1.07x** |
| 16 | 1269 us | 1264 us (1.00x) | 1284 us | 0.99x |
| 64 | 3565 us | 3557 us (1.00x) | 3508 us | 1.02x |

**Key finding:** fused SwiGLU saves ~4% per MoE layer on LFM2 at decode, but is neutral on Qwen3-30B at the single-layer level. The +5-8% E2E gain on Qwen3-30B comes from compound system effects: 48 MoE layers x one eliminated kernel dispatch is ~240 us saved per token (~2.7% of 8.85 ms/token), plus reduced intermediate memory pressure and fewer graph nodes.

### Kernel-level primitive microbenchmarks

`python benchmarks/bench_gather_qmm_swiglu.py` — fused vs naive (2x `gather_qmm` + SwiGLU):

| Config | Naive | Fused | Speedup |
|:--|--:|--:|:--|
| M=1, K=512, N=512 | 170 us | 133 us | **1.28x** |
| M=1, K=2048, N=1024 | 144 us | 129 us | **1.11x** |
| M=1, K=2048, N=2048 | 148 us | 134 us | **1.10x** |
| M=16, K=2048, N=1024 | 243 us | 229 us | 1.06x |
| M=64, K=2048, N=1024 | 240 us | 534 us | 0.45x |

> Fused kernel improves small M (decode) and is slower at large M (prefill). The patch auto-selects: fused for M<=32, two-pass fallback otherwise.

### Next steps (roadmap)

- **`add_rms_norm`** — fused residual add + RMSNorm in one kernel (benefits all models, ~20 line Metal diff)
- **`gather_qmm_combine`** — fused down projection + weighted expert sum (MoE-specific, eliminates intermediate tensor)
- **Upstream `gather_qmm_swiglu` to MLX** — PR with benchmarks so everyone benefits via `pip install mlx`
- **Per-device autotune profiles** (better defaults by chip family)

**When do patches help?**
- **MoE Models (4-bit)**: Best case with fused SwiGLU on a local MLX build. Qwen3-30B gets **+5-8%** E2E; LFM2 shows **+4% per MoE layer** at decode (E2E too noisy to report). On stock MLX, gating+combine can be slower — use `smart_patch` or exclude `moe_mlp`.
- **MoE Models (8-bit)**: Fused SwiGLU is less impactful; benchmark before enabling.
- **MoE Models (pre-computed gating)**: GLM-4, DeepSeek-V3 — neutral. Gate is already `@mx.compile`-optimized.
- **Dense Models (any size)**: Neutral. Decode is bandwidth-bound; Qwen3-4B is 1.01x.

```python
from zmlx.patch import smart_patch
import mlx.core as mx

# Auto-benchmark each pattern, keep only what helps
sample = mx.array([tokenizer.encode("Hello")])
model = smart_patch(model, sample)
```

Or use mode/presets if you know your workload:

```python
from zmlx.patch import patch

patch(model)                    # inference default; use smart_patch on stock MLX for MoE
patch(model, mode="training")   # training: adds norm fusions for backward pass savings

# Or explicit presets for full control:
from zmlx.patch import ALL_PATTERNS, FUSED_ACTIVATIONS, TRAINING_RECOMMENDED
patch(model, patterns=FUSED_ACTIVATIONS)       # same as default
patch(model, patterns=TRAINING_RECOMMENDED)    # same as mode="training"
patch(model, patterns=ALL_PATTERNS)            # WARNING: can be slower on inference
```

### Smart patching

`smart_patch` applies each candidate pattern, benchmarks the model's forward pass, and **automatically reverts patterns that make things slower**. It supports custom forward functions for realistic benchmarks:

```python
from zmlx.patch import smart_patch

# Basic: benchmark raw forward pass
model = smart_patch(model, sample_input)

# Advanced: benchmark with actual generation
def gen_fn(model, sample):
    return mlx_lm.generate(model, tokenizer, prompt="Hello", max_tokens=20)

model = smart_patch(model, sample, forward_fn=gen_fn, threshold=0.99)

# Result includes per-pattern speedups
result = model._zmlx_patch_result
print(result.benchmarks)    # {'swiglu_mlp': 1.012, 'residual_norm': 0.971}
print(result.summary())     # what was kept and why
```

### Autotuning

Replacement modules support `threadgroup="auto"` to search for the best threadgroup size on first invocation:

```python
from zmlx.patch import patch
patch(model, threadgroup="auto")  # autotunes each kernel on first call
```

The `map_reduce()` API also supports autotuning:

```python
from zmlx.api import map_reduce
my_softmax = map_reduce(..., threadgroup="auto")  # autotunes per-shape
```

### Where ZMLX genuinely helps

- **MoE model inference** — best results with fused expert SwiGLU (`gather_qmm_swiglu`, local MLX build). Qwen3-30B gets **+5-8%** E2E; LFM2 shows **+4% per layer** at decode. On stock MLX, use `smart_patch` to avoid gating+combine slowdowns. Supports Qwen3-MoE, LFM2, Mixtral, GPT-OSS.
- **Prototyping MLX-level optimizations** — ZMLX's optimization lab incubates C++ Metal primitives. Prove value with benchmarks here, then upstream to MLX for everyone.
- **Custom ops that MLX doesn't have** — SwiGLU, GeGLU, fused dropout, fused MoE gating, bit packing
- **Training** — fused `softmax_cross_entropy` loss, correct weight gradients for `rmsnorm_residual`
- **Authoring new kernels** — `elementwise()`, `reduce()`, `map_reduce()` APIs: math formula to compiled Metal kernel in one line
- **Quantization** — FP8 (E4M3/E5M2), NF4, int8, int4 dequantization with real bit-manipulation kernels

### MoE performance notes

- **Expert matmuls are the bottleneck** — for M=1 decode, MoE layers do 3 gather_qmm calls per layer (gate, up, down). Fusing gate+up+SwiGLU into one kernel (`gather_qmm_swiglu`) saves one full read of the input tensor.
- **E2E gains are a compound effect** — on Qwen3-30B, 48 MoE layers x one eliminated dispatch saves ~240 us per token (~2.7% of 8.85 ms/token), plus reduced intermediate memory pressure and fewer graph nodes.
- **Fast models need kernel-level timing** — LFM2 runs at ~20k tok/s, so E2E variance is 20-40%. The kernel-level MoE layer benchmark is the reliable measurement (+4% per layer at decode).

### Where ZMLX won't help

- **Dense model inference** — batch-1 decode is dominated by weight reads and bandwidth-bound. `patch(model)` is generally neutral (Qwen3-4B: 1.01x).
- **Replacing MLX built-in norms/softmax** — `mx.fast.rms_norm`, `mx.softmax`, `mx.fast.scaled_dot_product_attention` are Apple-optimized fast paths. Custom kernels add dispatch overhead.

---

## Optimization Lab

ZMLX includes a local MLX fork (`mlx_local/`) where we prototype fused C++ Metal primitives that need access to MLX's internal quantized matmul infrastructure. These are experimental and require building MLX from source.

### What's in the lab

| Primitive | Status | What it does |
|:--|:--|:--|
| `gather_qmm_swiglu` | Working, benchmarked | Fused gate+up+SwiGLU for MoE experts |
| `add_rms_norm` | Planned | Fused residual add + RMSNorm |
| `gather_qmm_combine` | Planned | Fused down projection + weighted expert sum |

### Building the local MLX fork

```bash
# From the ZMLX repo root
CMAKE_ARGS='-DMETAL_CPP_URL=file:///path/to/ZMLX/mlx_local/third_party/metal-cpp_26.zip \
  -DMLX_METAL_MODULE_CACHE=/tmp/clang_module_cache \
  -DFETCHCONTENT_SOURCE_DIR_JSON=/path/to/ZMLX/mlx_local/third_party/json \
  -DFETCHCONTENT_SOURCE_DIR_FMT=/path/to/ZMLX/mlx_local/third_party/fmt \
  -DFETCHCONTENT_SOURCE_DIR_NANOBIND=/path/to/ZMLX/mlx_local/third_party/nanobind \
  -DMLX_BUILD_GGUF=OFF' \
pip install -e mlx_local --no-build-isolation
```

ZMLX auto-detects the fused primitives at runtime. If you're on stock MLX, `patch(model)` uses the standard two-pass path — for MoE models, prefer `smart_patch` to avoid gating+combine slowdowns.

### The plan

Prototype here, validate with benchmarks, upstream to MLX via PR. Once primitives land in stock MLX, ZMLX auto-detects them and everyone benefits via `pip install mlx`. ZMLX is the incubator, MLX is the distribution.

---

## Precision

All ZMLX Metal kernels compute internally in **float32** regardless of input dtype. The `compute_dtype` parameter accepted by many kernel functions is **deprecated** and will be removed in a future release. Passing a non-None value will emit a `DeprecationWarning`.

---

## Documentation

- [`docs/QUICKSTART.md`](docs/QUICKSTART.md) - 5-minute tutorial
- [`docs/COOKBOOK.md`](docs/COOKBOOK.md) - Recipes for common patterns
- [`docs/KERNELS.md`](docs/KERNELS.md) - Complete kernel catalog reference
- [`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md) - Design philosophy

---

## Contributing

See [`CONTRIBUTING.md`](CONTRIBUTING.md) for setup, testing, and conventions.

---

## License

MIT. See [`LICENSE`](LICENSE).
