Directory structure:
└── distortion-residual/
    ├── README.md
    ├── examples/
    │   └── basic_usage.py
    ├── src/
    │   └── distortion_residual/
    │       ├── __init__.py
    │       └── drl.py
    └── tests/
        └── test_drl.py

================================================
FILE: README.md
================================================
# distortion-residual

A differentiable **Distortion Residual Level (DRL)** metric for PyTorch.

DRL measures the nonlinear distortion introduced by audio processors
(limiters, compressors, saturators, etc.) using the **nulling method**:

1. **Level-match** the reference to the processed signal via least-squares
   projection, cancelling any linear gain difference.
2. **Subtract** the matched reference from the processed signal to isolate
   the distortion residual.
3. **Measure** the power ratio of the residual to the signal:

$$
\text{DRL} = 10 \log_{10} \frac{\lVert d \rVert^2}{\lVert \hat{g}\,x \rVert^2},
\qquad
\hat{g} = \frac{\langle x, y \rangle}{\langle x, x \rangle},
\qquad
d = y - \hat{g}\,x
$$

Every operation is differentiable, so DRL can be used directly as a **loss
function** for gradient-based optimisation of audio processing parameters.

## Installation

```bash
pip install distortion-residual
```

Or from source:

```bash
git clone https://github.com/agrathwohl/distortion-residual.git
cd distortion-residual
pip install -e .
```

### Optional: audio file I/O

```bash
pip install "distortion-residual[audio]"
```

## Quick start

```python
import torch
from distortion_residual import DRL

drl = DRL(sample_rate=44100)

reference = torch.randn(44100)              # 1 s of audio
processed = torch.clamp(reference, -0.5, 0.5)  # hard-clip at -6 dBFS

result = drl(reference, processed)
print(result["total_drl_db"])    # e.g. tensor(-18.42)
print(result["total_drl_percent"])  # e.g. tensor(12.0)
```

### As a loss function

```python
gain = torch.tensor(1.0, requires_grad=True)
processed = torch.tanh(reference * gain)

result = drl(reference, processed)
loss = result["total_drl_db"]
loss.backward()
print(gain.grad)  # gradient flows through
```

### Band-wise analysis

By default, DRL is decomposed into three frequency bands (20-200 Hz,
200-2000 Hz, 2000-20000 Hz). You can customise or disable this:

```python
# Custom bands
drl = DRL(sample_rate=44100, frequency_bands=[(100, 1000), (1000, 10000)])

# Broadband only (faster, no FIR filtering)
drl = DRL(sample_rate=44100, frequency_bands=None)
```

## Output dictionary

`DRL.forward()` returns a dict with:

| Key                 | Type                | Description                        |
| ------------------- | ------------------- | ---------------------------------- |
| `total_drl_db`      | `Tensor` (scalar)   | Broadband DRL in dB                |
| `total_drl_percent` | `Tensor` (scalar)   | DRL as a percentage                |
| `band_drl_db`       | `dict[str, Tensor]` | Per-band DRL in dB                 |
| `band_drl_percent`  | `dict[str, Tensor]` | Per-band DRL as percentage         |
| `residual`          | `Tensor`            | The distortion residual signal     |
| `residual_rms`      | `Tensor` (scalar)   | RMS of the residual                |
| `signal_rms`        | `Tensor` (scalar)   | RMS of the level-matched reference |

## How it works

The level-matching step (`g_hat = <x,y>/<x,x>`) projects out any linear gain component, so DRL is **invariant to makeup gain**. Only the nonlinear distortion component remains in the residual.

This makes DRL ideal for optimising dynamics processors: the loss function
measures what the processor _does to the waveform shape_, not how loud
it makes the output.

### Gradient properties

- **Through the residual**: linear subtraction, gradients pass directly.
- **Through level matching**: quotient rule on the inner-product ratio.
- **Through band filters**: FIR convolution is a linear operation.

All paths are fully differentiable. No straight-through estimators or
surrogate gradients required.

## Development

```bash
git clone https://github.com/agrathwohl/distortion-residual.git
cd distortion-residual
uv sync --extra dev
uv run pytest
```

## License

MIT



================================================
FILE: examples/basic_usage.py
================================================
#!/usr/bin/env python
"""Basic usage of distortion_residual.DRL."""

import math

import torch

from distortion_residual import DRL


def main():
    sr = 44100
    t = torch.linspace(0, 1, sr)
    reference = torch.sin(2 * math.pi * 1000 * t)

    # --- Example 1: hard clipping -------------------------------------------
    clipped = torch.clamp(reference, -0.5, 0.5)

    drl = DRL(sample_rate=sr)
    result = drl(reference, clipped)

    print("=== Hard clip at -6 dBFS ===")
    print(f"  DRL:  {result['total_drl_db'].item():.2f} dB")
    print(f"  DRL:  {result['total_drl_percent'].item():.2f} %")
    for band, val in result["band_drl_db"].items():
        print(f"  Band {band} Hz: {val.item():.2f} dB")

    # --- Example 2: use as a loss function -----------------------------------
    gain = torch.tensor(1.0, requires_grad=True)
    processed = torch.tanh(reference * gain)  # soft saturator

    result = drl(reference, processed)
    loss = result["total_drl_db"]
    loss.backward()

    print("\n=== Gradient through soft-clip ===")
    print(f"  DRL:      {loss.item():.2f} dB")
    print(f"  dL/dgain: {gain.grad.item():.6f}")

    # --- Example 3: broadband only (no band decomposition) -------------------
    drl_fast = DRL(sample_rate=sr, frequency_bands=None)
    result = drl_fast(reference, clipped)
    print(f"\n=== Broadband only ===")
    print(f"  DRL: {result['total_drl_db'].item():.2f} dB")


if __name__ == "__main__":
    main()



================================================
FILE: src/distortion_residual/__init__.py
================================================
"""
Distortion Residual Level (DRL) -- a differentiable audio distortion metric.

>>> from distortion_residual import DRL
>>> drl = DRL(sample_rate=44100)
>>> result = drl(reference, processed)
>>> result['total_drl_db'].backward()
"""

from distortion_residual.drl import DRL, design_fir_bandpass

__all__ = ["DRL", "design_fir_bandpass"]
__version__ = "0.1.0"



================================================
FILE: src/distortion_residual/drl.py
================================================
"""
Differentiable Distortion Residual Level (DRL) metric for PyTorch.

The DRL measures distortion introduced by nonlinear audio processing
using the nulling method:

1. Level-match: scale reference to match processed level (least squares)
2. Subtract: residual = processed - scaled_reference
3. Measure: DRL = 10*log10(||residual||^2 / ||scaled_reference||^2)

All operations maintain gradient flow for use as a loss function.

Accepted input shapes:
    - ``(T,)``       — single mono signal
    - ``(C, T)``     — single multichannel signal (level-match per channel)
    - ``(B, C, T)``  — batch of multichannel signals

For a batch of mono signals, use ``(B, 1, T)``.
"""

import math
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


def design_fir_bandpass(
    sample_rate: int,
    low_freq: float,
    high_freq: float,
    num_taps: int = 255,
) -> torch.Tensor:
    """
    Design a linear-phase FIR bandpass filter using windowed sinc method.

    Pure PyTorch implementation (no scipy dependency).

    Args:
        sample_rate: Sample rate in Hz.
        low_freq: Lower cutoff frequency in Hz.
        high_freq: Upper cutoff frequency in Hz.
        num_taps: Number of filter taps (odd number recommended).

    Returns:
        FIR filter coefficients as torch.Tensor.
    """
    nyquist = sample_rate / 2
    low_norm = max(0.001, min(low_freq / nyquist, 0.999))
    high_norm = max(low_norm + 0.001, min(high_freq / nyquist, 0.999))

    n = torch.arange(num_taps, dtype=torch.float32) - (num_taps - 1) / 2
    n_safe = torch.where(n == 0, torch.ones_like(n) * 1e-10, n)

    lp_high = high_norm * torch.sinc(high_norm * n_safe)
    lp_low = low_norm * torch.sinc(low_norm * n_safe)
    bp = lp_high - lp_low

    center = num_taps // 2
    bp[center] = high_norm - low_norm

    window = torch.blackman_window(num_taps, dtype=torch.float32)
    bp = bp * window
    bp = bp / bp.sum()

    return bp


class DRL(nn.Module):
    """
    Differentiable Distortion Residual Level (DRL) metric.

    Measures distortion introduced by nonlinear audio processing (limiters,
    compressors, saturators, etc.) by comparing a reference signal to a
    processed signal using the nulling method.

    The level-matching step cancels any linear gain difference, so DRL
    measures only the nonlinear distortion component.

    Supports batch and multichannel inputs. Level-matching and DRL are
    computed **per channel**; ``total_drl_db`` is the mean across all
    batch elements and channels.

    Accepted shapes:
        - ``(T,)``       — single mono signal
        - ``(C, T)``     — multichannel (e.g. stereo), one item
        - ``(B, C, T)``  — batch of multichannel signals

    For a batch of mono signals use ``(B, 1, T)``.

    Args:
        sample_rate: Audio sample rate in Hz.
        frequency_bands: List of (low_hz, high_hz) tuples for band analysis.
            Default: ``[(20, 200), (200, 2000), (2000, 20000)]``.
            Pass ``None`` for broadband-only (no band decomposition).
        num_filter_taps: FIR filter length (higher = sharper cutoff).
        device: Torch device. Auto-detects CUDA if not specified.

    Example::

        drl = DRL(sample_rate=44100)
        result = drl(reference, processed)
        loss = result['total_drl_db']
        loss.backward()  # gradients flow through
    """

    DEFAULT_BANDS: List[Tuple[float, float]] = [
        (20, 200),
        (200, 2000),
        (2000, 20000),
    ]

    def __init__(
        self,
        sample_rate: int = 44100,
        frequency_bands: Optional[List[Tuple[float, float]]] = DEFAULT_BANDS,
        num_filter_taps: int = 255,
        device: Optional[torch.device] = None,
    ):
        super().__init__()

        self.sample_rate = sample_rate
        self.num_filter_taps = num_filter_taps

        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._device = device

        self.frequency_bands = frequency_bands if frequency_bands is not None else []

        self._filter_kernels: Dict[Tuple[float, float], str] = {}
        for low, high in self.frequency_bands:
            kernel = design_fir_bandpass(sample_rate, low, high, num_filter_taps)
            band_name = f"filter_{int(low)}_{int(high)}"
            self.register_buffer(band_name, kernel.unsqueeze(0).unsqueeze(0))
            self._filter_kernels[(low, high)] = band_name

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    @staticmethod
    def _to_3d(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
        """Normalise input to ``(B, C, T)`` and return original ndim."""
        ndim = x.dim()
        if ndim == 1:
            return x.unsqueeze(0).unsqueeze(0), ndim  # (1, 1, T)
        if ndim == 2:
            return x.unsqueeze(0), ndim  # (1, C, T)
        if ndim == 3:
            return x, ndim
        raise ValueError(f"Expected 1-3D input, got {ndim}D")

    def _apply_bandpass(
        self,
        signal: torch.Tensor,
        low: float,
        high: float,
    ) -> torch.Tensor:
        """Apply bandpass filter to ``(B, C, T)`` signal."""
        band_name = self._filter_kernels[(low, high)]
        kernel = getattr(self, band_name)  # (1, 1, num_taps)
        kernel = kernel.to(device=signal.device, dtype=signal.dtype)

        B, C, T = signal.shape
        x = signal.reshape(B * C, 1, T)
        padding = self.num_filter_taps // 2
        filtered = F.conv1d(x, kernel, padding=padding)
        return filtered.reshape(B, C, -1)

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    @staticmethod
    def match_levels(
        reference: torch.Tensor,
        processed: torch.Tensor,
        eps: float = 1e-10,
    ) -> torch.Tensor:
        """
        Scale reference to match processed level via least-squares projection.

        Computes ``g_hat = <x, y> / <x, x>`` **per channel** (last dim is
        time) and returns ``g_hat * x``.

        Works on any shape ``(..., T)``.

        Args:
            reference: Reference (unprocessed) signal.
            processed: Processed signal (same shape as *reference*).
            eps: Numerical stability constant.

        Returns:
            Level-matched reference signal (same shape as input).
        """
        numerator = (reference * processed).sum(dim=-1, keepdim=True)
        denominator = (reference * reference).sum(dim=-1, keepdim=True) + eps
        scale = numerator / denominator
        return reference * scale

    def compute_drl(
        self,
        reference: torch.Tensor,
        processed: torch.Tensor,
        eps: float = 1e-10,
    ) -> Dict[str, torch.Tensor]:
        """
        Compute DRL using the nulling method.

        Args:
            reference: Original signal — ``(T,)``, ``(C, T)``, or ``(B, C, T)``.
            processed: Processed signal (same shape as *reference*).
            eps: Numerical stability constant.

        Returns:
            Dictionary with keys:

            - ``total_drl_db``: Mean DRL in dB (scalar).
            - ``total_drl_percent``: Mean DRL as percentage (scalar).
            - ``channel_drl_db``: Per-channel DRL in dB.
            - ``channel_drl_percent``: Per-channel DRL as percentage.
            - ``band_drl_db``: Per-band mean DRL in dB (dict).
            - ``band_drl_percent``: Per-band mean DRL as percentage (dict).
            - ``residual``: Distortion residual signal.
            - ``residual_rms``: Per-channel residual RMS.
            - ``signal_rms``: Per-channel level-matched reference RMS.

            Shapes of per-channel tensors match the input rank: scalar for
            1-D input, ``(C,)`` for 2-D, ``(B, C)`` for 3-D.
        """
        ref, ndim = self._to_3d(reference)
        proc, _ = self._to_3d(processed)

        # Truncate to common length
        min_len = min(ref.shape[-1], proc.shape[-1])
        ref = ref[..., :min_len]
        proc = proc[..., :min_len]

        # Level-match per channel  — ref_scaled is (B, C, T)
        ref_scaled = self.match_levels(ref, proc, eps)
        residual = proc - ref_scaled

        # Power per channel — (B, C)
        residual_power = torch.mean(residual**2, dim=-1)
        signal_power = torch.mean(ref_scaled**2, dim=-1) + eps

        drl_ratio = residual_power / signal_power  # (B, C)
        channel_drl_db = 10 * torch.log10(drl_ratio + eps)
        channel_drl_pct = 100 * torch.sqrt(drl_ratio)

        # Scalar totals (mean over batch and channels)
        total_drl_db = channel_drl_db.mean()
        total_drl_pct = channel_drl_pct.mean()

        # Band analysis — scalars (mean over B, C)
        band_drl_db: Dict[str, torch.Tensor] = {}
        band_drl_percent: Dict[str, torch.Tensor] = {}

        for low, high in self.frequency_bands:
            res_band = self._apply_bandpass(residual, low, high)
            ref_band = self._apply_bandpass(ref_scaled, low, high)

            b_res_pow = torch.mean(res_band**2, dim=-1)  # (B, C)
            b_sig_pow = torch.mean(ref_band**2, dim=-1) + eps
            b_ratio = b_res_pow / b_sig_pow

            bname = f"{int(low)}_{int(high)}"
            band_drl_db[bname] = (10 * torch.log10(b_ratio + eps)).mean()
            band_drl_percent[bname] = (100 * torch.sqrt(b_ratio)).mean()

        # Collapse per-channel tensors to match input rank
        def _squeeze_bc(t: torch.Tensor) -> torch.Tensor:
            """Squeeze a (B, C) tensor back to match input ndim."""
            if ndim == 1:
                return t.squeeze()  # scalar
            if ndim == 2:
                return t.squeeze(0)  # (C,)
            return t  # (B, C)

        def _squeeze_bct(t: torch.Tensor) -> torch.Tensor:
            """Squeeze a (B, C, T) tensor back to match input ndim."""
            if ndim == 1:
                return t.squeeze(0).squeeze(0)  # (T,)
            if ndim == 2:
                return t.squeeze(0)  # (C, T)
            return t  # (B, C, T)

        return {
            "total_drl_db": total_drl_db,
            "total_drl_percent": total_drl_pct,
            "channel_drl_db": _squeeze_bc(channel_drl_db),
            "channel_drl_percent": _squeeze_bc(channel_drl_pct),
            "band_drl_db": band_drl_db,
            "band_drl_percent": band_drl_percent,
            "residual": _squeeze_bct(residual),
            "residual_rms": _squeeze_bc(torch.sqrt(residual_power)),
            "signal_rms": _squeeze_bc(torch.sqrt(signal_power)),
        }

    def forward(
        self,
        reference: torch.Tensor,
        processed: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        """
        Compute DRL between reference and processed signals.

        Args:
            reference: Original signal before processing.
            processed: Signal after processing.

        Returns:
            Dictionary with DRL measurements (see :meth:`compute_drl`).
        """
        return self.compute_drl(reference, processed)



================================================
FILE: tests/test_drl.py
================================================
"""Tests for distortion_residual.DRL."""

import math

import pytest
import torch

from distortion_residual import DRL, design_fir_bandpass


# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------

@pytest.fixture
def drl():
    return DRL(sample_rate=44100)


@pytest.fixture
def drl_no_bands():
    return DRL(sample_rate=44100, frequency_bands=None)


@pytest.fixture
def sine_1k():
    """1 kHz sine at 0 dBFS, 1 second."""
    sr = 44100
    t = torch.linspace(0, 1, sr, dtype=torch.float32)
    return torch.sin(2 * math.pi * 1000 * t)


# ---------------------------------------------------------------------------
# design_fir_bandpass
# ---------------------------------------------------------------------------

class TestDesignFirBandpass:
    def test_output_shape(self):
        h = design_fir_bandpass(44100, 200, 2000, num_taps=255)
        assert h.shape == (255,)

    def test_unity_dc_rejection(self):
        """A bandpass that excludes DC should have ~0 DC gain."""
        h = design_fir_bandpass(44100, 200, 2000, num_taps=255)
        assert h.dtype == torch.float32

    def test_clamps_frequencies(self):
        """Edge-case frequencies should not crash."""
        h = design_fir_bandpass(44100, 0, 22050, num_taps=127)
        assert h.shape == (127,)


# ---------------------------------------------------------------------------
# DRL — identity / gain-only
# ---------------------------------------------------------------------------

class TestDRLIdentity:
    def test_identity_gives_neg_inf(self, drl, sine_1k):
        """Identical signals -> DRL approaches -inf (no distortion)."""
        result = drl(sine_1k, sine_1k)
        assert result["total_drl_db"].item() < -80

    def test_linear_gain_invisible(self, drl, sine_1k):
        """A pure gain change should be cancelled by level matching."""
        gained = sine_1k * 2.0
        result = drl(sine_1k, gained)
        assert result["total_drl_db"].item() < -80

    def test_attenuation_invisible(self, drl, sine_1k):
        gained = sine_1k * 0.25
        result = drl(sine_1k, gained)
        assert result["total_drl_db"].item() < -80

    def test_polarity_flip_invisible(self, drl, sine_1k):
        """Polarity inversion is a linear gain of -1; DRL should cancel it."""
        flipped = sine_1k * -1.0
        result = drl(sine_1k, flipped)
        assert result["total_drl_db"].item() < -80


# ---------------------------------------------------------------------------
# DRL — nonlinear processing
# ---------------------------------------------------------------------------

class TestDRLNonlinear:
    def test_hard_clip_detected(self, drl, sine_1k):
        """Hard clipping should produce measurable DRL."""
        clipped = torch.clamp(sine_1k, -0.5, 0.5)
        result = drl(sine_1k, clipped)
        assert result["total_drl_db"].item() > -60  # clearly non-zero

    def test_soft_clip_less_than_hard(self, drl, sine_1k):
        """Soft clipping should produce less distortion than hard."""
        hard = torch.clamp(sine_1k, -0.5, 0.5)
        soft = torch.tanh(sine_1k * 2) / 2  # roughly same headroom
        r_hard = drl(sine_1k, hard)
        r_soft = drl(sine_1k, soft)
        assert r_soft["total_drl_db"].item() < r_hard["total_drl_db"].item()

    def test_more_clipping_higher_drl(self, drl, sine_1k):
        """Heavier clipping -> higher (less negative) DRL."""
        mild = torch.clamp(sine_1k, -0.8, 0.8)
        heavy = torch.clamp(sine_1k, -0.3, 0.3)
        r_mild = drl(sine_1k, mild)
        r_heavy = drl(sine_1k, heavy)
        assert r_heavy["total_drl_db"].item() > r_mild["total_drl_db"].item()


# ---------------------------------------------------------------------------
# DRL — gradient flow
# ---------------------------------------------------------------------------

class TestGradientFlow:
    def test_grad_through_processed(self, drl_no_bands):
        ref = torch.randn(4410)
        proc = torch.randn(4410, requires_grad=True)
        result = drl_no_bands(ref, proc)
        result["total_drl_db"].backward()
        assert proc.grad is not None
        assert torch.isfinite(proc.grad).all()

    def test_grad_through_gain_parameter(self, drl_no_bands):
        """Gradient should flow through a gain applied to the processed signal."""
        ref = torch.randn(4410)
        gain = torch.tensor(1.0, requires_grad=True)
        proc = ref * gain + torch.randn(4410) * 0.01  # tiny distortion
        result = drl_no_bands(ref, proc)
        result["total_drl_db"].backward()
        assert gain.grad is not None
        assert torch.isfinite(gain.grad).all()

    def test_grad_batched(self, drl_no_bands):
        """Gradients should flow through batched (B, C, T) input."""
        ref = torch.randn(2, 1, 4410)
        proc = torch.randn(2, 1, 4410, requires_grad=True)
        result = drl_no_bands(ref, proc)
        result["total_drl_db"].backward()
        assert proc.grad is not None
        assert proc.grad.shape == (2, 1, 4410)
        assert torch.isfinite(proc.grad).all()


# ---------------------------------------------------------------------------
# DRL — band analysis
# ---------------------------------------------------------------------------

class TestBandAnalysis:
    def test_bands_present(self, drl, sine_1k):
        clipped = torch.clamp(sine_1k, -0.5, 0.5)
        result = drl(sine_1k, clipped)
        assert "20_200" in result["band_drl_db"]
        assert "200_2000" in result["band_drl_db"]
        assert "2000_20000" in result["band_drl_db"]

    def test_no_bands_when_disabled(self, drl_no_bands, sine_1k):
        clipped = torch.clamp(sine_1k, -0.5, 0.5)
        result = drl_no_bands(sine_1k, clipped)
        assert result["band_drl_db"] == {}

    def test_custom_bands(self, sine_1k):
        custom = DRL(sample_rate=44100, frequency_bands=[(100, 1000)])
        clipped = torch.clamp(sine_1k, -0.5, 0.5)
        result = custom(sine_1k, clipped)
        assert "100_1000" in result["band_drl_db"]


# ---------------------------------------------------------------------------
# DRL — batch and multichannel
# ---------------------------------------------------------------------------

class TestBatchMultichannel:
    def test_stereo_input(self, drl_no_bands):
        """2D input (C, T) should return per-channel DRL."""
        ref = torch.randn(2, 4410)
        proc = torch.clamp(ref, -0.5, 0.5)
        result = drl_no_bands(ref, proc)
        assert result["total_drl_db"].dim() == 0  # scalar
        assert result["channel_drl_db"].shape == (2,)
        assert result["residual"].shape == (2, 4410)

    def test_batched_mono(self, drl_no_bands):
        """(B, 1, T) batch of mono signals."""
        ref = torch.randn(4, 1, 4410)
        proc = torch.clamp(ref, -0.5, 0.5)
        result = drl_no_bands(ref, proc)
        assert result["total_drl_db"].dim() == 0  # scalar
        assert result["channel_drl_db"].shape == (4, 1)
        assert result["residual"].shape == (4, 1, 4410)

    def test_batched_stereo(self, drl_no_bands):
        """(B, C, T) batch of stereo signals."""
        ref = torch.randn(3, 2, 4410)
        proc = torch.clamp(ref, -0.5, 0.5)
        result = drl_no_bands(ref, proc)
        assert result["total_drl_db"].dim() == 0
        assert result["channel_drl_db"].shape == (3, 2)
        assert result["residual"].shape == (3, 2, 4410)
        assert result["residual_rms"].shape == (3, 2)

    def test_mono_backward_compat(self, drl_no_bands):
        """1D input should produce scalar per-channel values (backward compat)."""
        ref = torch.randn(4410)
        proc = torch.clamp(ref, -0.5, 0.5)
        result = drl_no_bands(ref, proc)
        assert result["total_drl_db"].dim() == 0
        assert result["channel_drl_db"].dim() == 0
        assert result["residual"].dim() == 1

    def test_per_channel_level_matching(self, drl_no_bands):
        """Each channel should be level-matched independently."""
        ref = torch.randn(2, 4410)
        # Channel 0: gain of 2, channel 1: gain of 0.5 (both linear -> low DRL)
        proc = ref.clone()
        proc[0] *= 2.0
        proc[1] *= 0.5
        result = drl_no_bands(ref, proc)
        # Both channels are pure gain -> both should have very low DRL
        assert result["channel_drl_db"][0].item() < -80
        assert result["channel_drl_db"][1].item() < -80

    def test_batched_bands(self, drl):
        """Band analysis should work with batched input."""
        ref = torch.randn(2, 1, 44100)
        proc = torch.clamp(ref, -0.5, 0.5)
        result = drl(ref, proc)
        assert "20_200" in result["band_drl_db"]
        # Band values should be scalars (mean over B, C)
        assert result["band_drl_db"]["20_200"].dim() == 0


# ---------------------------------------------------------------------------
# DRL — output dict completeness
# ---------------------------------------------------------------------------

class TestOutputDict:
    def test_keys(self, drl, sine_1k):
        clipped = torch.clamp(sine_1k, -0.5, 0.5)
        result = drl(sine_1k, clipped)
        expected = {
            "total_drl_db", "total_drl_percent",
            "channel_drl_db", "channel_drl_percent",
            "band_drl_db", "band_drl_percent",
            "residual", "residual_rms", "signal_rms",
        }
        assert expected == set(result.keys())

    def test_percent_positive(self, drl, sine_1k):
        clipped = torch.clamp(sine_1k, -0.5, 0.5)
        result = drl(sine_1k, clipped)
        assert result["total_drl_percent"].item() > 0


