Metadata-Version: 2.4
Name: mps-spectro
Version: 0.3.0
Summary: Fast torch-compatible STFT and ISTFT on Apple MPS via custom Metal kernels
Author-email: ssmall256 <ssmall256@users.noreply.github.com>
License-Expression: MIT
Project-URL: Homepage, https://github.com/ssmall256/mps-spectro
Project-URL: Repository, https://github.com/ssmall256/mps-spectro
Project-URL: Issues, https://github.com/ssmall256/mps-spectro/issues
Keywords: pytorch,stft,istft,spectral,audio,dsp,apple-silicon,mps,metal
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: Operating System :: MacOS
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Multimedia :: Sound/Audio :: Analysis
Classifier: Topic :: Scientific/Engineering
Requires-Python: >=3.12
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.10.0
Provides-Extra: dev
Requires-Dist: pytest>=9.0; extra == "dev"
Requires-Dist: torchaudio>=2.10.0; extra == "dev"
Dynamic: license-file

# mps-spectro

Drop-in `torch.stft` / `torch.istft` replacements for Apple Silicon, plus mel frontends built on top of them — **1.4–3x faster** on MPS via custom Metal kernels.

```python
# before
spec = torch.stft(x, n_fft=2048, hop_length=512, window=window, center=True, return_complex=True)
y = torch.istft(spec, n_fft=2048, hop_length=512, window=window, center=True, length=T)

# after
from mps_spectro import stft, istft

spec = stft(x, n_fft=2048, hop_length=512)
y = istft(spec, n_fft=2048, hop_length=512, center=True, length=T)
```

Log-mel frontend for AMT / ASR-style models:

```python
from mps_spectro import LogMelSpectrogramTransform

frontend = LogMelSpectrogramTransform(
    sample_rate=16000,
    n_fft=2048,
    hop_length=512,
    n_mels=256,
    f_min=30.0,
    f_max=8000.0,
    pad_mode="constant",
    power=1.0,
    norm="slaney",
    mel_scale="htk",
    log_amin=1e-5,
    log_mode="clamp",
)

mel = frontend(x)
```

Dynamic pitch/model frontends with per-call `keyshift` / `speed`, including support for externally supplied mel filterbanks:

```python
from mps_spectro import DynamicMelSpectrogramTransform

frontend = DynamicMelSpectrogramTransform(
    sample_rate=16000,
    n_fft=1024,
    hop_length=160,
    win_length=1024,
    output_scale="log",
    log_amin=1e-5,
    mel_basis=external_mel_basis,  # optional [n_mels, n_freqs]
)

mel = frontend(x, keyshift=3, speed=1.2)
```

This `0.3.0` line expands `mps-spectro` from fast STFT/iSTFT plus fixed mel frontends into a broader shared spectral frontend package:

- standard mel frontends for log, linear, dB, and compat-style outputs
- dynamic frontends for pitch models with per-call `keyshift` / `speed`
- optional external mel filterbank injection for exact project parity
- parity-oriented dynamic STFT mode when exact legacy wrapper behavior matters more than the lowest-level fast path

Drop-in compatible with [python-audio-separator](https://github.com/karaokenerds/python-audio-separator) (MDX, Roformer, Demucs) — **1.4x faster STFT** and **2x faster iSTFT** on stereo 44.1 kHz audio with no model changes. See [benchmarks](#stftistft-in-audio-separator-workloads) below.

## Install

```bash
pip install mps-spectro
```

## Features

- PyTorch-compatible STFT/ISTFT semantics (same parameters as `torch.stft` / `torch.istft`)
- PyTorch-native mel frontends built on top of the same spectral core
- Dynamic mel/spectrogram frontends for pitch models with `keyshift`, `speed`, and optional external mel bases
- Fused overlap-add with optimized Metal compute shaders
- Autograd support with custom Metal backward kernels
- `torch.compile` compatible (`aot_eager` backend) via `torch.library` custom ops
- Pure Python — no C++ build step, no Xcode CLI tools

## Validated downstream use cases

The current package surface has been benchmarked and parity-checked in several real consumer projects:

- `mamba_amt`: log-mel frontend replacement on MPS
- `python-audio-separator`: shared STFT/iSTFT compatibility layer
- `LinkSeg`: compat mel frontend replacing project-local frontend code
- `SongFormer-mps`: shared dB mel frontends for MusicFM and MuQ
- `RVMPE`: dynamic mel frontend with per-call `keyshift` / `speed`
- `torchfcpe`: dynamic spectrogram path for the MPS mel frontend patch

The most important takeaway is that `mps-spectro` now covers both:

- fixed frontend replacements for torchaudio-style mel paths
- dynamic frontend building blocks for pitch models that previously needed project-local MPS STFT patches

### Autograd

Both `stft` and `istft` support PyTorch autograd when inputs have `requires_grad=True`:

```python
x = torch.randn(4, 16000, device="mps", requires_grad=True)

spec = stft(x, n_fft=1024, hop_length=256)
y = istft(spec, n_fft=1024, hop_length=256, center=True, length=16000)

loss = y.pow(2).sum()
loss.backward()
print(x.grad.shape)  # torch.Size([4, 16000])
```

When `requires_grad=False` (the default), zero overhead -- the original Metal kernel path is used directly. Backward passes use custom Metal kernels for GPU-accelerated gradient computation. Window gradients are not computed (returns `None`) since windows are almost always frozen in practice.

### torch.compile

Custom ops are registered via `torch.library` with Meta (FakeTensor) kernels, so `torch.compile` can trace through both forward and backward:

```python
@torch.compile(backend="aot_eager")
def f(x):
    return stft(x, n_fft=2048, hop_length=512)

f(torch.randn(4, 160000, device="mps"))  # works
```

### ISTFT extras

`istft` also supports:

- `torch_like=True` -- raise on NOLA violations like `torch.istft`
- `safety="auto"|"always"|"off"` -- NOLA envelope safety checking
- `kernel_dtype="float32"|"float16"|"mixed"` -- Metal kernel precision
- `kernel_layout="auto"|"native"|"transposed"` -- memory layout selection

## Benchmarks

Apple M4 Max, macOS 26.3, PyTorch 2.10.0, 20 iterations (5 warmup).

### STFT Forward

| Config | torch MPS | mps_spectro | Speedup |
|---|--:|--:|--:|
| B=4 T=160k nfft=1024 | 0.51 ms | 0.35 ms | 1.5x |
| B=4 T=160k nfft=2048 | 0.53 ms | 0.31 ms | 1.7x |
| B=8 T=160k nfft=1024 | 0.78 ms | 0.46 ms | 1.7x |
| B=4 T=1.3M nfft=1024 | 1.93 ms | 1.38 ms | 1.4x |

### ISTFT Forward

| Config | torch MPS | mps_spectro | Speedup |
|---|--:|--:|--:|
| B=4 T=160k nfft=1024 | 1.10 ms | 0.34 ms | 3.2x |
| B=8 T=160k nfft=1024 | 1.70 ms | 0.63 ms | 2.7x |
| B=4 T=1.3M nfft=1024 | 6.01 ms | 2.30 ms | 2.6x |
| B=1 T=1.3M nfft=1024 | 1.76 ms | 0.61 ms | 2.9x |

### STFT Forward + Backward

| Config | torch MPS | mps_spectro | Speedup |
|---|--:|--:|--:|
| B=4 T=160k nfft=1024 | 1.51 ms | 1.05 ms | 1.4x |
| B=8 T=160k nfft=1024 | 2.96 ms | 2.08 ms | 1.4x |
| B=4 T=1.3M nfft=1024 | 12.75 ms | 9.73 ms | 1.3x |
| B=1 T=1.3M nfft=1024 | 2.95 ms | 2.16 ms | 1.4x |

### ISTFT Forward + Backward

| Config | torch MPS | mps_spectro | Speedup |
|---|--:|--:|--:|
| B=4 T=160k nfft=1024 | 1.91 ms | 0.98 ms | 1.9x |
| B=8 T=160k nfft=1024 | 2.95 ms | 1.62 ms | 1.8x |
| B=4 T=1.3M nfft=1024 | 9.95 ms | 5.71 ms | 1.7x |
| B=1 T=1.3M nfft=1024 | 2.95 ms | 1.56 ms | 1.9x |

### Roundtrip (STFT -> ISTFT) Forward + Backward

| Config | torch MPS | mps_spectro | Speedup |
|---|--:|--:|--:|
| B=4 T=160k nfft=1024 | 2.52 ms | 1.47 ms | 1.7x |
| B=8 T=160k nfft=1024 | 4.71 ms | 2.55 ms | 1.8x |
| B=4 T=1.3M nfft=1024 | 18.42 ms | 11.07 ms | 1.7x |
| B=1 T=1.3M nfft=1024 | 4.60 ms | 2.39 ms | 1.9x |

To reproduce:
- Full suite: `python scripts/benchmark.py`
- Dispatch overhead profile: `python scripts/benchmark.py --dispatch-profile`

### STFT/iSTFT in audio-separator workloads

[python-audio-separator](https://github.com/karaokenerds/python-audio-separator) uses `torch.stft`/`torch.istft` in its MDX, Roformer, and Demucs model pipelines. We swapped in `mps_spectro` via a [compatibility layer](https://github.com/karaokenerds/python-audio-separator/blob/main/audio_separator/separator/stft_compat.py) and measured the STFT/iSTFT portion of each pipeline with two real stereo 44.1 kHz tracks (267s and 195s). Apple M4 Max, PyTorch 2.10.0, 20 iterations, 5 warmup, 5s cooldown.

| Model config | STFT speedup | iSTFT speedup |
|---|--:|--:|
| MDX (n_fft=2048, hop=512) | **1.40x** | **2.03x** |
| Roformer (n_fft=2048, hop=512) | **1.40x** | **2.01x** |
| Demucs (n_fft=4096, hop=1024) | **1.28x** | **1.87x** |

Note: total separation wall time is dominated by model inference, so E2E speedup is modest. The gains above apply to the STFT/iSTFT calls themselves.

To reproduce: `python scripts/benchmark_audio_separator.py`

**Numerical parity.** Output stems are perceptually identical — maximum float32 difference per sample is ≤ 1.83 × 10⁻⁴ (≤ 6 int16 LSBs) across all architectures:

| Model | Max abs diff (float32) | SNR (dB) | Int16 sample match |
|---|--:|--:|--:|
| BS-Roformer-SW (6-stem) | 3.05e-05 | 91 – 100 | ≥ 99.98% |
| Mel-Roformer Karaoke | 3.05e-05 | 89 – 91 | ≥ 99.84% |
| MDX-NET Inst HQ 5 | 1.83e-04 | 55 – 64 | ≥ 99%\* |
| hdemucs_mmi (shifts=0) | 4.27e-04 | 44 – 52 | ≥ 71% |

\* MDX int16 diffs are symmetric ±1 LSB rounding noise with zero bias and max ±6 LSBs.

### Log-mel frontend in `mamba_amt`

On the `mamba_amt` log-mel frontend configuration (`16 kHz`, `n_fft=2048`, `hop=512`, `n_mels=256`, `pad_mode="constant"`, `power=1.0`, `norm="slaney"`, `mel_scale="htk"`), the new `LogMelSpectrogramTransform` was about `2.44x` faster than `torchaudio.transforms.MelSpectrogram` on MPS while staying numerically tight:

- torchaudio median: `0.00397 s`
- `mps-spectro` median: `0.00163 s`
- speedup: `2.44x`
- max abs diff: `1.14e-4`
- mean abs diff: `6.33e-6`

### Dynamic frontends in `RVMPE` and `torchfcpe`

The new dynamic frontend APIs were validated against the prior project-local MPS paths:

- `RVMPE` dynamic mel frontend:
  - old median: `1.436 ms`
  - new shared path: `1.182 ms`
  - speedup: `1.21x`
  - parity: max abs `4.77e-07`, mean abs `1.90e-08`

- `torchfcpe` dynamic spectrogram path:
  - old median: `3.347 ms`
  - new shared path: `3.249 ms`
  - speedup: `1.03x`
  - parity on realistic mel-style positive filterbanks stayed effectively exact, with max abs `3.58e-07` and mean abs `1.79e-08` after log compression

## Using MLX instead of PyTorch?

See [mlx-spectro](https://github.com/ssmall256/mlx-spectro) — same idea, built natively on MLX with even faster kernels (2–8x vs torch).

## How it works

Metal shader source is compiled at runtime via `torch.mps.compile_shader` (pure Python, no C++ build step).

1. **STFT**: a tiled Metal kernel loads overlapping signal chunks into threadgroup shared memory (~3x data reuse for typical n_fft/hop ratios), applies reflect-padding and windowing in one pass, then `torch.fft.rfft` for the FFT
2. **ISTFT**: `torch.fft.irfft` on MPS, then a fused Metal kernel for synthesis-window multiply + overlap-add + envelope normalization

## Requirements

- macOS with Apple Silicon (MPS)
- Python 3.12+
- PyTorch 2.10+

## Tests

```bash
pip install -e ".[dev]"
pytest
```

## License

MIT
