Metadata-Version: 2.3
Name: spectrux
Version: 0.0.1
Summary: A JAX-only neural-network library with a PyTorch-shaped eager surface and Flax-NNX-style graph/state underneath.
Keywords: Deep Learning,Machine Learning,Neural Networks,JAX,XLA
Author: Erfan Zare Chavoshi
Author-email: Erfan Zare Chavoshi <Erfanzare810@gmail.com>
License: Apache-2.0
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Requires-Dist: jax>=0.9.2
Requires-Dist: jaxlib>=0.9.2
Requires-Dist: numpy>=1.26
Requires-Dist: treescope>=0.1.7
Requires-Dist: optax>=0.2.8 ; extra == 'contrib'
Requires-Dist: jax[cuda13]>=0.9.2 ; extra == 'cuda'
Requires-Dist: ruff>=0.6 ; extra == 'dev'
Requires-Dist: pytest>=7 ; extra == 'dev'
Requires-Dist: sphinx>=7 ; extra == 'docs'
Requires-Dist: sphinx-book-theme>=1.1 ; extra == 'docs'
Requires-Dist: sphinx-design>=0.6 ; extra == 'docs'
Requires-Dist: sphinx-autodoc-typehints>=2 ; extra == 'docs'
Requires-Dist: myst-nb>=1 ; extra == 'docs'
Requires-Dist: pytest>=7 ; extra == 'test'
Requires-Dist: jax[tpu]>=0.9.2 ; extra == 'tpu'
Requires-Python: >=3.11, <3.14
Project-URL: Homepage, https://github.com/erfanzar/spectrux
Project-URL: Repository, https://github.com/erfanzar/spectrux
Provides-Extra: contrib
Provides-Extra: cuda
Provides-Extra: dev
Provides-Extra: docs
Provides-Extra: test
Provides-Extra: tpu
Description-Content-Type: text/markdown

# spectrux

**True MPMD pipeline parallelism for JAX — with an eager module API.**

[**Quick Start**](#quick-start) |
[**Installation**](#installation) |
[**MPMD Runtime**](#mpmd-runtime) |
[**Examples**](#examples) |
[**Docs**](docs/)

---

Spectrux is a JAX-native neural network library built around **true MPMD
pipeline parallelism**. Each physical rank compiles and runs its own XLA
program — no shared `shard_map` HLO, no SPMD-same-shape constraint.
Heterogeneous stages (embed → blocks → head), multiple schedules
(GPipe, 1F1B, ZeroBubble, Interleaved, DualPipe), and a unified
`spx.run()` entry point that dispatches to SPMD or MPMD from the same
training script.

The module API is eager and debuggable — subclass `Module`, override
`forward`, call `model(x)` — but every `Module` is a JAX pytree, so
`jax.jit`, `jax.grad`, and `jax.tree.map` work directly.

---

## Why spectrux?

| Capability               | spectrux                                | JAX + manual        | other JAX frameworks       |
| ------------------------ | --------------------------------------- | ------------------- | -------------------------- |
| **True MPMD**            | built-in (`sxcall`, `sxjit`)            | hand-rolled         | SPMD-only                  |
| **Heterogeneous stages** | native (different class/shape per rank) | fragile             | not supported              |
| **Pipeline schedules**   | 9 schedules (GPipe→DualPipeV)           | hand-rolled         | limited                    |
| **Unified runtime**      | `spx.run(mesh)` → SPMD or MPMD          | separate code paths | separate APIs              |
| **Eager modules**        | `model(x)` + pytree-native              | functional only     | functional or ref-tracking |
| **Dispatch overhead**    | ~150 µs                                 | ~N/A                | ~300–2000 µs               |

*Dispatch overhead measured on a tiny 2-layer CPU transformer. See [benchmarks](#benchmarks).*

---

## Quick Start

```bash
pip install spectrux
```

### Single-device eager training

```python
import jax.numpy as jnp
import spectrux as spx
from spectrux import nn

class MLP(spx.Module):
    def __init__(self, d, h, o, *, rngs):
        super().__init__()
        self.fc1 = nn.Linear(d, h, rngs=rngs)
        self.fc2 = nn.Linear(h, o, rngs=rngs)

    def forward(self, x):
        return self.fc2(nn.gelu(self.fc1(x)))

model = MLP(16, 64, 4, rngs=spx.Rngs(0))

@spx.jit
@spx.value_and_grad
def loss_fn(m, x, y):
    return ((m(x) - y) ** 2).mean()

loss, grads = loss_fn(model, jnp.ones((8, 16)), jnp.zeros((8, 4)))
```

### Marker-based MPMD — split a function into per-rank programs

```python
from spectrux.runtime.mpmd import sxjit, sxstage_iter
from spectrux.pipeline import Std1F1B

# Define a multi-stage forward with explicit stage boundaries
@sxjit(schedule=Std1F1B(microbatches=8))
def pipeline_fwd(model, x):
    x = model.embed(x)                # stage 0
    x = sxstage_iter(x)               # boundary — split here
    x = model.blocks[0](x)            # stage 1
    x = sxstage_iter(x)               # boundary — split here
    x = model.blocks[1](x)            # stage 2
    x = sxstage_iter(x)               # boundary — split here
    return model.head(x)              # stage 3

# sxjit traces the function, splits the jaxpr at sxstage_iter markers,
# and compiles one XLA executable per stage/rank. Each rank only sees
# its own sub-graph — true MPMD, not SPMD.
output = pipeline_fwd(model, x)
```

### MPMD pipeline training — one call, multiple devices

```python
from spectrux.sharding import logical_axis_rules

# Create a 4-stage pipeline mesh
mesh = spx.create_mesh(axis_dims=(2, 1, -1, 1, 1, 1), mpmd_axis="pp")

with logical_axis_rules(FSDP_TP_RULES), mesh:
    # MPMD: each rank gets its own compiled stage
    loss, grads = spx.run(
        model, inputs=x, targets=y,
        mesh=mesh, mode="train", loss_fn=ce_loss, microbatches=8,
    )

# Drop mpmd_axis → same code runs under pure SPMD pjit
# Same model, same script, different mesh — no code changes.
```

### Deferred initialization — infer shapes at runtime

```python
model = nn.Sequential(
    nn.Linear(None, 256, rngs=rngs),   # in_features inferred from first call
    nn.ReLU(),
    nn.Linear(256, 10, rngs=rngs),
)
y = model(jnp.zeros((8, 128)))         # weight shapes resolved here
```

---

## MPMD Runtime

Spectrux implements **true MPMD**: each physical rank compiles and executes
its own distinct JAX program. This is not SPMD-with-barriers — stages can
have different classes, different parameter shapes, and different I/O shapes.

### `spx.run` — unified entry point

`spx.run` routes to SPMD (`pjit`) or MPMD (`sxcall`) based on the mesh:

```python
spx.run(
    model,
    inputs=x,                # microbatched along leading axis
    targets=y,
    mesh=mesh,               # SpxMesh — mpmd_axis decides the path
    mode="train",            # "forward" | "train"
    loss_fn=ce_loss,
    microbatches=8,
)
```

| Mesh type        | What happens                                             |
| ---------------- | -------------------------------------------------------- |
| `mpmd_axis=None` | Pure SPMD — `pjit` with FSDP/TP via logical axis rules   |
| `mpmd_axis="pp"` | True MPMD — auto-split into stages, per-rank compilation |

### Lower-level primitives

For full control, drop below `spx.run`:

| Primitive                     | Purpose                                                                                                  |
| ----------------------------- | -------------------------------------------------------------------------------------------------------- |
| `sxcall`                      | Execute a `PipelineSequential` under a schedule — Python dispatch loop over pre-built per-rank callables |
| `sxjit`                       | Decorator: trace a function, split at `sxstage_iter` markers, compile one XLA executable per stage/rank  |
| `sxgrad` / `sxvalue_and_grad` | Schedule-faithful gradients of an `sxjit` function                                                       |
| `treduce`                     | Schedule-driven microbatch reduction primitive — binds a body + schedule into the traced jaxpr           |
| `sxstage_iter`                | Marker primitive for stage boundaries inside `sxjit`                                                     |

### Supported schedules

| Schedule        | Type    | Key trait                                                                 |
| --------------- | ------- | ------------------------------------------------------------------------- |
| `GPipe`         | Flat    | All forwards, then all backwards. Simple, high memory.                    |
| `Std1F1B`       | Flat    | Standard 1-forward-1-backward. Peak memory `O(n_stages)`.                 |
| `ZeroBubbleH1`  | Flat    | Splits BWD into input-grad + weight-grad; weight-grad fills bubble slots. |
| `InterleavedH1` | Virtual | Each rank owns `v` non-contiguous stages. Bubble shrinks by `v`.          |
| `DualPipeV`     | Virtual | V-shaped bidirectional pipeline (DeepSeek-style).                         |
| `KimiK2`        | Virtual | Interleaved with extra warmup (Moonshot K2 design).                       |

### Heterogeneous stages

No same-shape constraint. Stages can be completely different:

```python
model = PipelineSequential(
    EmbedStage(vocab, d, rngs=rngs),       # (B, S) int → (B, S, d)
    BlockStage(d, rngs=rngs),               # (B, S, d) → (B, S, d)
    BlockStage(d, rngs=rngs),               # (B, S, d) → (B, S, d)
    HeadStage(d, vocab, rngs=rngs),        # (B, S, d) → (B, S, vocab)
)
```

Auto-splitting is available for homogeneous stacks: `spx.run` detects
`model.blocks: ModuleList` and slices it evenly across pipeline stages.

---

## Installation

```bash
pip install spectrux

# Optional extras
pip install "spectrux[contrib]"   # optax integration
pip install "spectrux[cuda]"      # CUDA jaxlib
pip install "spectrux[tpu]"       # TPU jaxlib
```

From source:

```bash
uv sync --extra dev --extra test --extra contrib
```

Requires Python 3.11+ and JAX >= 0.9.2.

---

## Features

### MPMD Pipeline Parallelism

True multi-program multi-data execution with per-rank compilation,
schedule-faithful dispatch, and heterogeneous stage support.

### Module-aware JAX transforms

| Transform                                                      | What it does                                           |
| -------------------------------------------------------------- | ------------------------------------------------------ |
| `spx.jit`                                                      | `mutable=` selector declares writable collections      |
| `spx.grad` / `spx.value_and_grad`                              | `wrt=` selector picks the differentiated subset        |
| `spx.vmap`                                                     | module states passed with `in_axes=None` automatically |
| `spx.scan` / `spx.remat_scan`                                  | module-aware loops, optionally checkpointed            |
| `spx.remat`                                                    | gradient checkpointing                                 |
| `spx.cond` / `spx.switch` / `spx.while_loop` / `spx.fori_loop` | control flow over module state                         |

### Selector DSL

One composable predicate for every "subset of the model" API:

```python
spx.grad(loss, wrt="params")                     # by collection name
spx.grad(loss, wrt=nn.LoraParameter)             # by Variable class
spx.grad(loss, wrt=spx.path_contains("attn"))    # by path glob

sel = spx.select().at_instances_of(nn.Linear).of_type(spx.Parameter) - spx.path_contains("head")
trainable, frozen = sel.partition_state(model, state)
```

### Sharding & SPMD

Annotate variables with logical axis names and let the mesh decide:

```python
w = spx.Parameter(
    jnp.zeros((256, 256)),
    sharding=spx.Sharding(("data", "model")),
    axis_names=("in", "out"),
)
```

### LoRA Fine-Tuning

```python
base = nn.Linear(768, 768, rngs=spx.Rngs(0))
model = nn.wrap_lora(base, rank=8, alpha=16, rngs=spx.Rngs(1))

@spx.jit
@spx.grad(wrt="lora")
def step(m, x, y):
    return ((m(x) - y) ** 2).mean()
```

### FP8 Training

Delayed-scaling per-tensor FP8 with rolling amax history:

```python
@spx.jit(mutable="fp8_meta")
def step(m, x, y):
    def loss(m, x, y):
        return ((m(x) - y) ** 2).mean()
    return spx.grad(loss)(m, x, y)
```

### Explicit graph / state seam

```python
gdef, state = spx.export(model)          # immutable GraphDef + State dict
model2 = spx.bind(gdef, state)           # reconstruct, skips __init__
spx.update(model, state)                 # in-place state patch
clone = spx.clone(model)                 # deep-copy via export+bind
```

### Inspection

```python
spx.inspect.summary(model, jnp.zeros((1, 128)))
spx.inspect.count_params(model)
spx.inspect.count_bytes(model)
spx.inspect.tabulate(model, example_input)
```

---

## Examples

See [`examples/`](examples/) for runnable scripts:

| Folder                                                          | Topic                                                               |
| --------------------------------------------------------------- | ------------------------------------------------------------------- |
| [`01_basics/`](examples/01_basics/)                             | Modules, training loops, export/bind, optimizers                    |
| [`02_implementation_guide/`](examples/02_implementation_guide/) | Llama 3, Qwen 2, GPT-2, ViT, custom transformer                     |
| [`03_transformations/`](examples/03_transformations/)           | jit, grad, vmap, remat, scan, fori_loop                             |
| [`04_surgery/`](examples/04_surgery/)                           | Selectors, LoRA injection, FP8, freezing, swapping                  |
| [`05_shardings/`](examples/05_shardings/)                       | FSDP, TP, hybrid sharding, logical axis rules                       |
| [`06_spmd_scheduled/`](examples/06_spmd_scheduled/)             | Pipeline runtime with all schedules                                 |
| [`07_mpmd/`](examples/07_mpmd/)                                 | Real MPMD pipeline via `spx.run` — train, forward, decode, 3-D mesh |

```bash
python -m examples.01_basics.02_training_loop
python -m examples.02_implementation_guide.01_llama3
python -m examples.07_mpmd.01_train_homogeneous
```

---

## Design

1. **True MPMD first** — `sxcall` compiles one XLA program per physical rank.
   Stages can differ in class, shape, and parameters. No SPMD-same-shape
   constraint.
2. **Unified runtime** — `spx.run` dispatches on the mesh. Same model, same
   script; change the mesh and you change the parallelism strategy.
3. **Schedule-faithful execution** — the dispatch loop walks the schedule grid
   exactly as planned. No hidden reordering, no implicit fusing.
4. **Modules are JAX pytrees** — flatten/unflatten through `export`/`bind`.
   `jax.jit`, `jax.tree.map`, `jax.value_and_grad` work directly.
5. **State lives in `Variable` cells** — `Parameter`, `Buffer`, custom
   subclasses. Each tags its collection (`params`, `buffers`, `lora`,
   `fp8_meta`, ...).
6. **One filter DSL everywhere** — `Selector` serves `grad(wrt=...)`,
   `jit(mutable=...)`, `Optimizer(wrt=...)`, `freeze(...)`,
   `iter_variables(select=...)`.

---

## Benchmarks

```bash
python -m benchmarks.bench --cases all --device cpu
python -m benchmarks.bench --cases large --device tpu     # 1.21B transformer
```

Results land in `benchmarks/results/latest.{json,md}` with per-case
`spectrux_ms / nnx_ms` ratios.

On a tiny CPU dispatch-bound benchmark (2-layer / d=64 / batch-2 transformer),
spectrux runs at **1.83×** the speed of flax.nnx; on d=48 it hits **2.0×**.
On compute-bound workloads (TPU 8B) the Python gap shrinks but stays positive.

---

## Testing

```bash
pytest tests/ -q
pytest tests/test_conformance.py
```

---

## Status

`v0.0.1` — alpha. API may still move; pin the version if you depend on
behavioural stability.

---

## License

Apache-2.0.
