โ† Back to Theory Benchmark

Cross-Framework MNIST

Same Yat MLP. Three frameworks. Real numbers, not vibes.

๐Ÿง’

Explain Like I'm 5

We built the same tiny robot brain three times โ€” once with PyTorch LEGO, once with JAX (NNX) LEGO, and once with JAX (Linen) LEGO. ๐Ÿงฑ

  • ๐Ÿ“ Same number of bricks: 235,148 in every build.
  • ๐ŸŽฏ Same skill after training: โ‰ˆ95% on handwritten digits.
  • โšก JAX finished faster on a laptop CPU thanks to its compiler.

The point isn't who wins. The point is that Yat layers behave the same everywhere โ€” so pick the framework you already know.

The setup

We trained a two-hidden-layer Yat MLP on MNIST in four NMN backends: nmn.torch, nmn.nnx, nmn.linen, and nmn.mlx. The architecture is identical across all four:

Flatten(28ร—28) โ†’ YatNMN(256) โ†’ YatNMN(128) โ†’ Linear(10)

No ReLU. No GELU. The non-linearity is baked into the โตŸ-Product itself. Optimizer: AdamW, lr = 3 ร— 10โปโด, batch = 128, 3 epochs, seed = 0. Hardware: Apple Silicon CPU, fp32. The full scripts are in the repo:

  • src/nmn/torch/examples/vision/mnist.py
  • src/nmn/nnx/examples/vision/mnist.py
  • src/nmn/linen/examples/mnist.py
  • src/nmn/mlx/examples/mnist.py

Equivalent (unrun-on-py3.14) scripts are shipped for nmn.tf and nmn.keras; see the Reproducibility note below.

Results

Each row is a real run (not "expected"). JSON receipts are checked into the repo under .context/<framework>_mnist.json.

Framework Params Test acc Test loss Total (s) s / epoch
PyTorch 235,148 94.97 % 0.1584 4.67 โ‰ˆ 1.41
Flax NNX 235,148 95.39 % 0.1477 1.68 โ‰ˆ 0.52
Flax Linen 235,148 95.28 % 0.1538 1.26 โ‰ˆ 0.38
MLX (Metal GPU) 235,148 96.28 % 0.1220 1.62 โ‰ˆ 0.54

Three things to notice.

  1. The parameter counts are byte-identical (235,148). Four completely independent implementations agree on the model size โ€” a small but real cross-framework consistency check.
  2. Final test accuracies fall in a 1.31 pp window (94.97 โ€“ 96.28). The remaining gap is the same kind of seed / shuffle-order / device-rounding noise you'd see comparing PyTorch to PyTorch with different seeds. There is no framework that "doesn't work" with Yat layers.
  3. JAX (NNX & Linen) is โ‰ˆ 3ร— faster on CPU here โ€” same model, same data, same optimizer. The difference is the XLA JIT pipeline; Yat's element-wise + reduce structure is friendly to it. On GPU/TPU the gap closes (and PyTorch's torch.compile recovers most of it).
  4. MLX on Metal lands a real receipt: 96.28 % in 1.62 s on the Apple Silicon GPU โ€” slightly faster than the XLA-CPU runs and the best accuracy of the four. The cpu device run is bit-identical to the GPU run (96.28 % both ways), so MLX's backend isn't doing anything funny โ€” the Yat trajectory simply lands on a slightly better local minimum at this seed. We'll take it.

Loss curves

Train-loss trajectories per epoch (numbers from the receipt JSONs):

Framework Epoch 0 Epoch 1 Epoch 2
PyTorch 0.9561 0.2391 0.1818
Flax NNX 0.8841 0.2249 0.1710
Flax Linen 0.8967 0.2386 0.1808
MLX (Metal GPU) 0.4490 0.1630 0.1271

The shape is the same โ€” fast first-epoch drop (Yat's naturally bounded gradients mean you can use a small LR without losing the first-epoch transient), then a quieter pair of refinement epochs.

Why the numbers agree

Each backend implements the same โตŸ-Product:

$$ \mathrm{yat}(x, w) \;=\; \frac{(x \cdot w)^2}{\lVert x - w \rVert^2 + \varepsilon} $$

Different frameworks, but the op decomposes into matmul โ†’ square โ†’ divide โ†’ broadcast-add ฮต in all of them. Our cross-framework parity tests (tests/test_cross_framework_parity*.py) show < 10โปโถ max element-wise difference in fp32 โ€” so identical initialization + identical mini-batches would give bit-similar losses up to floating-point reduction order. The 0.4-pp accuracy spread above is dominated by per-framework PRNG seeding and shuffle order, not by any Yat-side difference.

What about TensorFlow and Keras?

We ship the same MNIST script for nmn.tf and nmn.keras at src/nmn/tf/examples/mnist.py and src/nmn/keras/examples/mnist.py. They were not run for this post because TensorFlow has no wheel for the CPython 3.14 we tested on โ€” see tensorflow.org/install for supported Python versions. The scripts mirror the PyTorch / NNX / Linen ones exactly, so on Python 3.11/3.12 you should see numbers in the same ballpark; please open an issue with your run if it lands meaningfully outside it.

Reproducibility

From a fresh checkout:

# PyTorch
PYTHONPATH=src python -m nmn.torch.examples.vision.mnist \
    --epochs 3 --report .context/torch_mnist.json

# Flax NNX
PYTHONPATH=src python -m nmn.nnx.examples.vision.mnist \
    --epochs 3 --report .context/nnx_mnist.json

# Flax Linen
PYTHONPATH=src python -m nmn.linen.examples.mnist \
    --epochs 3 --report .context/linen_mnist.json

# MLX (Apple Silicon, --device gpu or cpu)
PYTHONPATH=src python -m nmn.mlx.examples.mnist \
    --epochs 3 --device gpu --report .context/mlx_mnist.json

# TensorFlow (needs tensorflow + tensorflow-datasets)
PYTHONPATH=src python -m nmn.tf.examples.mnist \
    --epochs 3 --report .context/tf_mnist.json

# Keras 3 (set KERAS_BACKEND first)
KERAS_BACKEND=jax PYTHONPATH=src python -m nmn.keras.examples.mnist \
    --epochs 3 --report .context/keras_mnist.json

Every run drops a JSON receipt with the full per-epoch history, so the numbers above are auditable, not narrated.

Takeaway

Yat is framework-portable in practice, not just in theory. The implementation parity tests said it should be; this post is the experimental receipt. Pick the framework that fits your stack โ€” the model you build will train the same.