Metadata-Version: 2.4
Name: rope-conformer
Version: 0.1.0
Summary: A clean, plug-and-play Conformer encoder with rotary positional embeddings.
Project-URL: Homepage, https://github.com/crlandsc/rope-conformer
Project-URL: Issues, https://github.com/crlandsc/rope-conformer/issues
Author: Christopher Landschoot
License-Expression: MIT
License-File: LICENSE
Keywords: attention,audio,conformer,pytorch,rope,rotary,speech,transformer
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Requires-Dist: einops>=0.8
Requires-Dist: rotary-embedding-torch>=0.6
Requires-Dist: torch>=2.3
Provides-Extra: dev
Requires-Dist: build>=1.2; extra == 'dev'
Requires-Dist: numpy>=1.24; extra == 'dev'
Requires-Dist: pytest>=7.4; extra == 'dev'
Provides-Extra: mps-sdpa
Requires-Dist: mps-sdpa>=0.2.0; extra == 'mps-sdpa'
Description-Content-Type: text/markdown

# rope-conformer

A clean, plug-and-play PyTorch Conformer encoder with rotary positional embeddings (RoPE).

```python
import torch
from rope_conformer import RoPEConformer

model = RoPEConformer(dim=256, depth=6, heads=8, dim_head=32)
x = torch.randn(2, 100, 256)        # [B, N, dim]
y = model(x)                        # [B, N, dim]
```

## Install

```bash
pip install rope-conformer
```

Optional Apple-silicon acceleration (experimental):

```bash
pip install "rope-conformer[mps-sdpa]"
```

## Usage with a padding mask

```python
mask = torch.zeros(2, 100, dtype=torch.bool)
mask[:, 80:] = True                 # last 20 positions padded
y = model(x, key_padding_mask=mask) # masked positions are ignored in attention
```

## API

| Argument | Default | Description |
|---|---|---|
| `dim` | — | Model channels (input and output, unless `output_dim` is set). |
| `depth` | — | Number of stacked Conformer blocks. |
| `dim_head` | `64` | Per-head dimension. |
| `heads` | `8` | Number of attention heads. |
| `ff_mult` | `4` | Feedforward expansion factor. |
| `conv_expansion_factor` | `2` | Pointwise expansion in the conv module. |
| `conv_kernel_size` | `31` | Depthwise conv kernel (Conformer paper default). |
| `attn_dropout` | `0.0` | Dropout inside attention (off by default). |
| `proj_dropout` | `0.1` | Dropout after attention output projection. |
| `ff_dropout` | `0.1` | Dropout in feedforward sublayers. |
| `conv_dropout` | `0.1` | Dropout at the end of the conv module. |
| `conv_causal` | `False` | Left-pad the depthwise conv (no future-frame leakage). |
| `conv_norm_type` | `"rms"` | `"rms"` (per-token, causal-safe), `"group"` or `"batch"` (cross-time stats; both fall back to Identity in causal mode). |
| `use_attn_gates` | `False` | Add per-head sigmoid output gates on attention. |
| `flash_attn` | `True` | Use `F.scaled_dot_product_attention`; `False` falls back to einsum. |
| `use_mps_sdpa` | `False` | Route attention through `mps-sdpa` (experimental, MPS only). |
| `norm_output` | `True` | Apply final RMSNorm before projection. |
| `output_dim` | `None` | Optional output projection to a different dimension. |

The forward signature is `model(x, key_padding_mask=None)`:

- `x`: `[B, N, dim]`
- `key_padding_mask`: `[B, N]` bool, `True` for padded positions (optional).

## Granular dropout

The four dropout knobs (`attn_dropout`, `proj_dropout`, `ff_dropout`, `conv_dropout`) target different stages so you can tune each independently. Attention dropout is off by default because zeroing entries before the softmax distorts the resulting probability distribution and creates a training/inference mismatch in the attention pattern; the other three knobs act on unnormalized intermediate features and don't have this issue, so they default to a small `0.1`.

## Causal use

Set `conv_causal=True` for a depthwise conv that only sees past frames. This handles the conv path; the self-attention path is *not* causally masked by this flag — pass your own causal `attn_mask` (or use a causal-aware downstream stack) if you need full causal behavior.

## Optional `mps-sdpa` integration (experimental)

PyTorch's MPS backend does not currently dispatch `scaled_dot_product_attention` to Apple's fused `MPSGraph.scaledDotProductAttention` op; it builds a naive matmul → softmax → matmul graph instead. The [`mps-sdpa`](https://github.com/crlandsc/mps-sdpa) package wraps the fused op directly, giving roughly 5–7× faster inference and 2–2.5× faster training on Apple silicon (M1+, macOS 15+). When installed via the `[mps-sdpa]` extra, attention can be routed through it per-instance:

```python
model = RoPEConformer(dim=256, depth=6, use_mps_sdpa=True)
```

If the `mps-sdpa` package is not installed, the flag silently no-ops and the standard SDPA path is used. The flag also has no effect on non-MPS devices.

## Acknowledgements

The architecture mirrors the modernized transformer conventions used in
[`lucidrains/BS-RoFormer`](https://github.com/lucidrains/BS-RoFormer) (RoPE,
RMSNorm, SDPA, GELU FFN, no biases on QKV/output projections), applied to the
Conformer block (Gulati et al., 2020).

## License

MIT.
