Metadata-Version: 2.4
Name: optimuon
Version: 0.1.1
Summary: A performance-optimized Muon optimizer with foreach support, auto-routing, and composite optimizer patterns.
Project-URL: Repository, https://github.com/emaballarin/optimuon
Project-URL: Issues, https://github.com/emaballarin/optimuon/issues
Author-email: Emanuele Ballarin <emanuele@ballarin.cc>
License-Expression: MIT
License-File: LICENSE
Keywords: deep-learning,muon,newton-schulz,optimizer,pytorch
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3.14
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Typing :: Typed
Requires-Python: >=3.11
Requires-Dist: torch>=2.7
Description-Content-Type: text/markdown

# optimuon

A performance-optimized [Muon](https://kellerjordan.github.io/posts/muon/) optimizer for PyTorch.

**Features:**

- **Foreach-native**: uses `torch._foreach_*` ops for momentum, weight decay, and parameter updates.
- **Batched Newton-Schulz**: groups matrices by shape for parallel orthogonalization.
- **Auto-parameter routing**: automatically partitions model parameters into Muon-eligible (≥2D hidden weights) and auxiliary (embeddings, heads, norms, biases).
- **Composite optimizer**: `CompositeMuon` combines Muon with any arbitrary auxiliary optimizer (not just AdamW).
- **Three LR modes**: Keller Jordan's `"original"` (with aspect-ratio scaling), Moonshot AI's `"match_rms_adamw"`, and `"none"` (no scaling).
- **Momentum conventions**: `"ema"` (`m = beta*m + (1-beta)*g`, default) and `"classical"` (`m = beta*m + g`).
- **Corrections**: MARS, cautious updates, cautious weight decay, NorMuon, gradient/update clipping (all toggleable).
- **Weight normalization**: optional Frobenius-norm clamping to `sqrt(fan_out)` (from KJ's original Muon).
- **Half-precision momentum**: optional lower-precision momentum buffers for memory savings.
- **Polar Express**: optimal per-step Newton-Schulz coefficients (default).
- **Distributed**: `torch.distributed` gradient sharding via `all_gather`.

## Installation

```bash
uv pip install git+https://github.com/emaballarin/optimuon
```

## Quick start

### Standalone Muon (manual parameter selection)

```python
from optimuon import Muon

# Muon for ≥2D hidden weight matrices only
muon = Muon(muon_params, lr=0.02, momentum=0.95, weight_decay=0.01)

# Separate AdamW for everything else
import torch
adamw = torch.optim.AdamW(other_params, lr=3e-4)

# Training loop
for batch in dataloader:
    loss = model(batch).loss
    loss.backward()
    muon.step()
    adamw.step()
    muon.zero_grad()
    adamw.zero_grad()
```

### CompositeMuon with auto-routing (recommended)

```python
from optimuon import CompositeMuon

optimizer = CompositeMuon(
    model,
    muon_lr=0.02,
    muon_kwargs={"weight_decay": 0.01, "foreach": True},
    aux_optimizer_class=torch.optim.AdamW,
    aux_optimizer_kwargs={"lr": 3e-4, "betas": (0.9, 0.95), "weight_decay": 0.01},
    verbose=True,
)

for batch in dataloader:
    loss = model(batch).loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
```

### With corrections

```python
from optimuon import CompositeMuon

optimizer = CompositeMuon(
    model,
    muon_lr=0.02,
    muon_kwargs={
        "weight_decay": 0.01,
        "mars": True,           # MARS gradient correction
        "cautious": True,       # cautious update masking
        "grad_clip": 1.0,       # gradient norm clipping
        "weight_norm": True,    # Frobenius-norm clamping
    },
    aux_optimizer_class=torch.optim.AdamW,
    aux_optimizer_kwargs={"lr": 3e-4},
)
```

### With a custom auxiliary optimizer

```python
from optimuon import CompositeMuon

optimizer = CompositeMuon(
    model,
    muon_lr=0.02,
    aux_optimizer_factory=lambda param_groups: SomeExoticOptimizer(param_groups, lr=1e-3),
)
```

### Manual routing utilities

```python
from optimuon import partition_params

result = partition_params(model)
print(f"Muon: {result.muon_names}")
print(f"Aux:  {result.aux_names}")
```

## References

- Keller Jordan et al., [Muon: An optimizer for hidden layers in neural networks](https://kellerjordan.github.io/posts/muon/) (2024)
- Huizhuo Yuan et al., [MARS: Unleashing the Power of Variance Reduction for Training Large Models](https://arxiv.org/abs/2411.10438) (2024)
- Kaizhao Liang et al., [Cautious Optimizers: Improving Training with One Line of Code](https://arxiv.org/abs/2411.16085) (2024)
- Moonshot AI, [Muon is Scalable for LLM Training](https://arxiv.org/abs/2502.16982) (2025)
- Essential AI, [Practical Efficiency of Muon for Pretraining](https://arxiv.org/abs/2505.02222) (2025)
- Noah Amsel et al., [The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm](https://arxiv.org/abs/2505.16932) (2025)
- Zichong Li et al., [NorMuon: Making Muon more efficient and scalable](https://arxiv.org/abs/2510.05491) (2025)
- Lizhang Chen et al., [Cautious Weight Decay](https://arxiv.org/abs/2510.12402) (2025)

## License

MIT
