Explain Like I'm 5
We built the same tiny robot brain three times โ once with PyTorch LEGO, once with JAX (NNX) LEGO, and once with JAX (Linen) LEGO. ๐งฑ
- ๐ Same number of bricks: 235,148 in every build.
- ๐ฏ Same skill after training: โ95% on handwritten digits.
- โก JAX finished faster on a laptop CPU thanks to its compiler.
The point isn't who wins. The point is that Yat layers behave the same everywhere โ so pick the framework you already know.
The setup
We trained a two-hidden-layer Yat MLP on MNIST in four NMN
backends: nmn.torch, nmn.nnx,
nmn.linen, and nmn.mlx. The architecture is
identical across all four:
Flatten(28ร28) โ YatNMN(256) โ YatNMN(128) โ Linear(10)
No ReLU. No GELU. The non-linearity is baked into the โต-Product itself. Optimizer: AdamW, lr = 3 ร 10โปโด, batch = 128, 3 epochs, seed = 0. Hardware: Apple Silicon CPU, fp32. The full scripts are in the repo:
src/nmn/torch/examples/vision/mnist.pysrc/nmn/nnx/examples/vision/mnist.pysrc/nmn/linen/examples/mnist.pysrc/nmn/mlx/examples/mnist.py
Equivalent (unrun-on-py3.14) scripts are shipped for
nmn.tf and nmn.keras; see the
Reproducibility note below.
Results
Each row is a real run (not "expected"). JSON receipts are checked into
the repo under .context/<framework>_mnist.json.
| Framework | Params | Test acc | Test loss | Total (s) | s / epoch |
|---|---|---|---|---|---|
| PyTorch | 235,148 | 94.97 % | 0.1584 | 4.67 | โ 1.41 |
| Flax NNX | 235,148 | 95.39 % | 0.1477 | 1.68 | โ 0.52 |
| Flax Linen | 235,148 | 95.28 % | 0.1538 | 1.26 | โ 0.38 |
| MLX (Metal GPU) | 235,148 | 96.28 % | 0.1220 | 1.62 | โ 0.54 |
Three things to notice.
- The parameter counts are byte-identical (235,148). Four completely independent implementations agree on the model size โ a small but real cross-framework consistency check.
- Final test accuracies fall in a 1.31 pp window (94.97 โ 96.28). The remaining gap is the same kind of seed / shuffle-order / device-rounding noise you'd see comparing PyTorch to PyTorch with different seeds. There is no framework that "doesn't work" with Yat layers.
-
JAX (NNX & Linen) is โ 3ร faster on CPU here โ
same model, same data, same optimizer. The difference is the
XLA JIT pipeline; Yat's element-wise + reduce structure is
friendly to it. On GPU/TPU the gap closes (and PyTorch's
torch.compilerecovers most of it). -
MLX on Metal lands a real receipt: 96.28 %
in 1.62 s on the Apple Silicon GPU โ slightly faster than the
XLA-CPU runs and the best accuracy of the four. The
cpudevice run is bit-identical to the GPU run (96.28 % both ways), so MLX's backend isn't doing anything funny โ the Yat trajectory simply lands on a slightly better local minimum at this seed. We'll take it.
Loss curves
Train-loss trajectories per epoch (numbers from the receipt JSONs):
| Framework | Epoch 0 | Epoch 1 | Epoch 2 |
|---|---|---|---|
| PyTorch | 0.9561 | 0.2391 | 0.1818 |
| Flax NNX | 0.8841 | 0.2249 | 0.1710 |
| Flax Linen | 0.8967 | 0.2386 | 0.1808 |
| MLX (Metal GPU) | 0.4490 | 0.1630 | 0.1271 |
The shape is the same โ fast first-epoch drop (Yat's naturally bounded gradients mean you can use a small LR without losing the first-epoch transient), then a quieter pair of refinement epochs.
Why the numbers agree
Each backend implements the same โต-Product:
$$ \mathrm{yat}(x, w) \;=\; \frac{(x \cdot w)^2}{\lVert x - w \rVert^2 + \varepsilon} $$
Different frameworks, but the op decomposes into
matmul โ square โ divide โ broadcast-add ฮต
in all of them. Our cross-framework parity tests
(tests/test_cross_framework_parity*.py) show
< 10โปโถ max element-wise difference in fp32 โ so identical
initialization + identical mini-batches would give bit-similar
losses up to floating-point reduction order. The 0.4-pp accuracy
spread above is dominated by per-framework PRNG seeding and
shuffle order, not by any Yat-side difference.
What about TensorFlow and Keras?
We ship the same MNIST script for nmn.tf and
nmn.keras at
src/nmn/tf/examples/mnist.py and
src/nmn/keras/examples/mnist.py. They were
not run for this post because TensorFlow has no
wheel for the CPython 3.14 we tested on โ see
tensorflow.org/install
for supported Python versions. The scripts mirror the
PyTorch / NNX / Linen ones exactly, so on Python 3.11/3.12
you should see numbers in the same ballpark; please
open
an issue with your run if it lands meaningfully outside it.
Reproducibility
From a fresh checkout:
# PyTorch
PYTHONPATH=src python -m nmn.torch.examples.vision.mnist \
--epochs 3 --report .context/torch_mnist.json
# Flax NNX
PYTHONPATH=src python -m nmn.nnx.examples.vision.mnist \
--epochs 3 --report .context/nnx_mnist.json
# Flax Linen
PYTHONPATH=src python -m nmn.linen.examples.mnist \
--epochs 3 --report .context/linen_mnist.json
# MLX (Apple Silicon, --device gpu or cpu)
PYTHONPATH=src python -m nmn.mlx.examples.mnist \
--epochs 3 --device gpu --report .context/mlx_mnist.json
# TensorFlow (needs tensorflow + tensorflow-datasets)
PYTHONPATH=src python -m nmn.tf.examples.mnist \
--epochs 3 --report .context/tf_mnist.json
# Keras 3 (set KERAS_BACKEND first)
KERAS_BACKEND=jax PYTHONPATH=src python -m nmn.keras.examples.mnist \
--epochs 3 --report .context/keras_mnist.json
Every run drops a JSON receipt with the full per-epoch history, so the numbers above are auditable, not narrated.
Takeaway
Yat is framework-portable in practice, not just in theory. The implementation parity tests said it should be; this post is the experimental receipt. Pick the framework that fits your stack โ the model you build will train the same.