==============================================================
Phase 6 — test_edge_cases.py pytest output artifact
UTC timestamp: 2026-05-15T06:36:53Z
CUDA available: True
Device: NVIDIA RTX 2000 Ada Generation Laptop GPU
triton: 3.6.0
torch_structured: installed
==============================================================

### Run 1: uv run pytest tests/test_edge_cases.py -v  (non-slow)

============================= test session starts ==============================
platform linux -- Python 3.13.12, pytest-9.0.3, pluggy-1.5.0 -- /home/claroche/miniconda3/bin/python3
cachedir: .pytest_cache
rootdir: /home/claroche/gru-triton/.claude/worktrees/agent-a70ebacd575f69c84
configfile: pyproject.toml
plugins: xdist-3.8.0, anyio-4.10.0
collecting ... collected 85 items

tests/test_edge_cases.py::test_t0_b0_raises_valueerror[T-shape0-reference] PASSED [  1%]
tests/test_edge_cases.py::test_t0_b0_raises_valueerror[T-shape0-dense_triton] PASSED [  2%]
tests/test_edge_cases.py::test_t0_b0_raises_valueerror[T-shape0-diagonal_triton] PASSED [  3%]
tests/test_edge_cases.py::test_t0_b0_raises_valueerror[T-shape0-monarch_triton] PASSED [  4%]
tests/test_edge_cases.py::test_t0_b0_raises_valueerror[T-shape0-butterfly_triton] PASSED [  5%]
tests/test_edge_cases.py::test_t0_b0_raises_valueerror[T-shape0-circulant] PASSED [  7%]
tests/test_edge_cases.py::test_t0_b0_raises_valueerror[T-shape0-ldr] PASSED [  8%]
tests/test_edge_cases.py::test_t0_b0_raises_valueerror[B-shape1-reference] PASSED [  9%]
tests/test_edge_cases.py::test_t0_b0_raises_valueerror[B-shape1-dense_triton] PASSED [ 10%]
tests/test_edge_cases.py::test_t0_b0_raises_valueerror[B-shape1-diagonal_triton] PASSED [ 11%]
tests/test_edge_cases.py::test_t0_b0_raises_valueerror[B-shape1-monarch_triton] PASSED [ 12%]
tests/test_edge_cases.py::test_t0_b0_raises_valueerror[B-shape1-butterfly_triton] PASSED [ 14%]
tests/test_edge_cases.py::test_t0_b0_raises_valueerror[B-shape1-circulant] PASSED [ 15%]
tests/test_edge_cases.py::test_t0_b0_raises_valueerror[B-shape1-ldr] PASSED [ 16%]
tests/test_edge_cases.py::test_t1_forward_parity[1-4-8-reference] PASSED [ 17%]
tests/test_edge_cases.py::test_t1_forward_parity[1-4-8-dense_triton] PASSED [ 18%]
tests/test_edge_cases.py::test_t1_forward_parity[1-4-8-diagonal_triton] PASSED [ 20%]
tests/test_edge_cases.py::test_t1_forward_parity[1-4-8-monarch_triton] PASSED [ 21%]
tests/test_edge_cases.py::test_t1_forward_parity[1-4-8-butterfly_triton] PASSED [ 22%]
tests/test_edge_cases.py::test_t1_forward_parity[1-4-8-circulant] PASSED [ 23%]
tests/test_edge_cases.py::test_t1_forward_parity[1-4-8-ldr] PASSED       [ 24%]
tests/test_edge_cases.py::test_t1_backward_parity[1-4-8-reference] PASSED [ 25%]
tests/test_edge_cases.py::test_t1_backward_parity[1-4-8-dense_triton] PASSED [ 27%]
tests/test_edge_cases.py::test_t1_backward_parity[1-4-8-diagonal_triton] PASSED [ 28%]
tests/test_edge_cases.py::test_t1_backward_parity[1-4-8-monarch_triton] SKIPPED [ 29%]
tests/test_edge_cases.py::test_t1_backward_parity[1-4-8-butterfly_triton] PASSED [ 30%]
tests/test_edge_cases.py::test_t1_backward_parity[1-4-8-circulant] PASSED [ 31%]
tests/test_edge_cases.py::test_t1_backward_parity[1-4-8-ldr] PASSED      [ 32%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-1-reference] PASSED [ 34%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-1-dense_triton] PASSED [ 35%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-1-diagonal_triton] PASSED [ 36%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-1-monarch_triton] PASSED [ 37%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-1-butterfly_triton] SKIPPED [ 38%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-1-circulant] PASSED [ 40%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-1-ldr] PASSED       [ 41%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-2-reference] PASSED [ 42%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-2-dense_triton] PASSED [ 43%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-2-diagonal_triton] PASSED [ 44%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-2-monarch_triton] PASSED [ 45%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-2-butterfly_triton] PASSED [ 47%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-2-circulant] PASSED [ 48%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-2-ldr] PASSED       [ 49%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-8-reference] PASSED [ 50%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-8-dense_triton] PASSED [ 51%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-8-diagonal_triton] PASSED [ 52%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-8-monarch_triton] PASSED [ 54%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-8-butterfly_triton] PASSED [ 55%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-8-circulant] PASSED [ 56%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-1-8-ldr] PASSED       [ 57%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-1-reference] PASSED [ 58%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-1-dense_triton] PASSED [ 60%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-1-diagonal_triton] PASSED [ 61%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-1-monarch_triton] PASSED [ 62%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-1-butterfly_triton] SKIPPED [ 63%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-1-circulant] PASSED [ 64%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-1-ldr] PASSED       [ 65%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-2-reference] PASSED [ 67%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-2-dense_triton] PASSED [ 68%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-2-diagonal_triton] PASSED [ 69%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-2-monarch_triton] PASSED [ 70%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-2-butterfly_triton] PASSED [ 71%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-2-circulant] PASSED [ 72%]
tests/test_edge_cases.py::test_b1_small_h_parity[8-4-2-ldr] PASSED       [ 74%]
tests/test_edge_cases.py::test_butterfly_h1_raises_valueerror PASSED     [ 75%]
tests/test_edge_cases.py::test_butterfly_partial_batch_tile[1] PASSED    [ 76%]
tests/test_edge_cases.py::test_butterfly_partial_batch_tile[3] PASSED    [ 77%]
tests/test_edge_cases.py::test_butterfly_partial_batch_tile[5] PASSED    [ 78%]
tests/test_edge_cases.py::test_butterfly_partial_batch_tile[7] FAILED    [ 80%]
tests/test_edge_cases.py::test_butterfly_partial_batch_tile[9] FAILED    [ 81%]
tests/test_edge_cases.py::test_butterfly_partial_batch_tile[17] FAILED   [ 82%]
tests/test_edge_cases.py::test_butterfly_partial_batch_tile[33] FAILED   [ 83%]
tests/test_edge_cases.py::test_long_t_drift[512-reference] PASSED        [ 84%]
tests/test_edge_cases.py::test_long_t_drift[512-dense_triton] PASSED     [ 85%]
tests/test_edge_cases.py::test_long_t_drift[512-diagonal_triton] PASSED  [ 87%]
tests/test_edge_cases.py::test_long_t_drift[512-monarch_triton] PASSED   [ 88%]
tests/test_edge_cases.py::test_long_t_drift[512-butterfly_triton] PASSED [ 89%]
tests/test_edge_cases.py::test_long_t_drift[512-circulant] PASSED        [ 90%]
tests/test_edge_cases.py::test_long_t_drift[512-ldr] PASSED              [ 91%]
tests/test_edge_cases.py::test_long_t_drift[1024-reference] PASSED       [ 92%]
tests/test_edge_cases.py::test_long_t_drift[1024-dense_triton] PASSED    [ 94%]
tests/test_edge_cases.py::test_long_t_drift[1024-diagonal_triton] PASSED [ 95%]
tests/test_edge_cases.py::test_long_t_drift[1024-monarch_triton] PASSED  [ 96%]
tests/test_edge_cases.py::test_long_t_drift[1024-butterfly_triton] PASSED [ 97%]
tests/test_edge_cases.py::test_long_t_drift[1024-circulant] PASSED       [ 98%]
tests/test_edge_cases.py::test_long_t_drift[1024-ldr] PASSED             [100%]

=================================== FAILURES ===================================
_____________________ test_butterfly_partial_batch_tile[7] _____________________

B = 7

    @cuda_only
    @pytest.mark.parametrize("B", [1, 3, 5, 7, 9, 17, 33])
    def test_butterfly_partial_batch_tile(B: int) -> None:
        """Butterfly partial-last-batch-tile sweep at (T=16, H=512)
        (EDG-02, ROADMAP SC#2).
    
        The CONCERNS.md-suggested butterfly sweep covering the
        ``B % BLOCK_B != 0`` partial-last-tile corner — the butterfly OOB fix
        ``d8218d4`` shipped WITHOUT a regression test. B=1 is the extreme;
        odd B values exercise non-aligned final tiles.
    
        BATCH-INVARIANCE CONTRACT (the binding correctness statement): a
        correct kernel produces bit-identical output for identical per-batch
        inputs, REGARDLESS of the total batch count B. This test replicates
        a single B=1 input across all B batch slots and asserts the Triton
        kernel returns the same per-batch output everywhere. This assertion
        is TF32-INDEPENDENT — every batch runs the exact same arithmetic, so
        any divergence is a genuine batch-tiling / partial-tile correctness
        bug, not numerical drift.
    
        FINDING (bd gru-triton-c2a, D-04): the butterfly Triton kernel FAILS
        this contract at H=512 — replicated-input batches diverge by up to
        ~3e-2 on a B-index-dependent subset of batch slots, while the
        per-step reference path is batch-invariant to ~5e-4 (TF32 only). The
        root cause is in the butterfly persistent kernel's batch-tiling /
        twiddle-stage indexing (``scan_butterfly.py``). A deep kernel fix is
        required (D-06 — accepted phase enlargement); per the Task-3
        CONTEXT-BUDGET handoff this failing test + the bd issue land as the
        recoverable Commit-A artifact, and the kernel fix is handed off.
        NO ``@pytest.mark.xfail`` — this test stays RED until the kernel fix
        (Commit B) lands.
        """
        pytest.importorskip("triton")
        pytest.importorskip("torch_structured")
        torch.manual_seed(0)
        torch.set_float32_matmul_precision("high")
        device = torch.device("cuda")
        T, H = 16, 512
        in_size = H
    
        layer = _make_layer("butterfly_triton", in_size, H).to(device)
    
        # One B=1 input, then the SAME input replicated across B batch slots.
        x1 = torch.randn(T, 1, in_size, device=device)
        out1, _ = layer(x1.clone())
        xN = x1.repeat(1, B, 1)
        outN, hTN = layer(xN.clone())
    
        assert torch.isfinite(outN).all(), f"butterfly out non-finite (B={B})"
        assert tuple(outN.shape) == (T, B, H), (
            f"butterfly out shape {tuple(outN.shape)} != {(T, B, H)} (B={B})"
        )
        # Batch-invariance: every batch slot got identical input, so every
        # slot MUST produce identical output. Per-batch deviation localizes a
        # partial-tile bug to a specific pid_b. TF32-independent — the bound
        # is correctness-tight (1e-5), not a numerical-noise tolerance.
        dev_per_b = [
            (outN[:, b] - out1[:, 0]).abs().max().item() for b in range(B)
        ]
        worst = max(dev_per_b)
>       assert worst < 1e-5, (
            f"butterfly batch-invariance violated: identical per-batch inputs "
            f"produced different outputs, worst abs dev {worst:.4e} (B={B}); "
            f"per-batch dev={[f'{d:.2e}' for d in dev_per_b]} — batch-tiling "
            f"correctness bug in scan_butterfly.py (bd gru-triton-c2a)"
        )
E       AssertionError: butterfly batch-invariance violated: identical per-batch inputs produced different outputs, worst abs dev 2.2821e-02 (B=7); per-batch dev=['2.13e-02', '2.12e-02', '2.28e-02', '2.28e-02', '2.28e-02', '2.28e-02', '2.28e-02'] — batch-tiling correctness bug in scan_butterfly.py (bd gru-triton-c2a)
E       assert 0.022820889949798584 < 1e-05

tests/test_edge_cases.py:580: AssertionError
_____________________ test_butterfly_partial_batch_tile[9] _____________________

B = 9

    @cuda_only
    @pytest.mark.parametrize("B", [1, 3, 5, 7, 9, 17, 33])
    def test_butterfly_partial_batch_tile(B: int) -> None:
        """Butterfly partial-last-batch-tile sweep at (T=16, H=512)
        (EDG-02, ROADMAP SC#2).
    
        The CONCERNS.md-suggested butterfly sweep covering the
        ``B % BLOCK_B != 0`` partial-last-tile corner — the butterfly OOB fix
        ``d8218d4`` shipped WITHOUT a regression test. B=1 is the extreme;
        odd B values exercise non-aligned final tiles.
    
        BATCH-INVARIANCE CONTRACT (the binding correctness statement): a
        correct kernel produces bit-identical output for identical per-batch
        inputs, REGARDLESS of the total batch count B. This test replicates
        a single B=1 input across all B batch slots and asserts the Triton
        kernel returns the same per-batch output everywhere. This assertion
        is TF32-INDEPENDENT — every batch runs the exact same arithmetic, so
        any divergence is a genuine batch-tiling / partial-tile correctness
        bug, not numerical drift.
    
        FINDING (bd gru-triton-c2a, D-04): the butterfly Triton kernel FAILS
        this contract at H=512 — replicated-input batches diverge by up to
        ~3e-2 on a B-index-dependent subset of batch slots, while the
        per-step reference path is batch-invariant to ~5e-4 (TF32 only). The
        root cause is in the butterfly persistent kernel's batch-tiling /
        twiddle-stage indexing (``scan_butterfly.py``). A deep kernel fix is
        required (D-06 — accepted phase enlargement); per the Task-3
        CONTEXT-BUDGET handoff this failing test + the bd issue land as the
        recoverable Commit-A artifact, and the kernel fix is handed off.
        NO ``@pytest.mark.xfail`` — this test stays RED until the kernel fix
        (Commit B) lands.
        """
        pytest.importorskip("triton")
        pytest.importorskip("torch_structured")
        torch.manual_seed(0)
        torch.set_float32_matmul_precision("high")
        device = torch.device("cuda")
        T, H = 16, 512
        in_size = H
    
        layer = _make_layer("butterfly_triton", in_size, H).to(device)
    
        # One B=1 input, then the SAME input replicated across B batch slots.
        x1 = torch.randn(T, 1, in_size, device=device)
        out1, _ = layer(x1.clone())
        xN = x1.repeat(1, B, 1)
        outN, hTN = layer(xN.clone())
    
        assert torch.isfinite(outN).all(), f"butterfly out non-finite (B={B})"
        assert tuple(outN.shape) == (T, B, H), (
            f"butterfly out shape {tuple(outN.shape)} != {(T, B, H)} (B={B})"
        )
        # Batch-invariance: every batch slot got identical input, so every
        # slot MUST produce identical output. Per-batch deviation localizes a
        # partial-tile bug to a specific pid_b. TF32-independent — the bound
        # is correctness-tight (1e-5), not a numerical-noise tolerance.
        dev_per_b = [
            (outN[:, b] - out1[:, 0]).abs().max().item() for b in range(B)
        ]
        worst = max(dev_per_b)
>       assert worst < 1e-5, (
            f"butterfly batch-invariance violated: identical per-batch inputs "
            f"produced different outputs, worst abs dev {worst:.4e} (B={B}); "
            f"per-batch dev={[f'{d:.2e}' for d in dev_per_b]} — batch-tiling "
            f"correctness bug in scan_butterfly.py (bd gru-triton-c2a)"
        )
E       AssertionError: butterfly batch-invariance violated: identical per-batch inputs produced different outputs, worst abs dev 6.1999e-02 (B=9); per-batch dev=['0.00e+00', '0.00e+00', '0.00e+00', '0.00e+00', '6.20e-02', '6.20e-02', '6.20e-02', '6.20e-02', '0.00e+00'] — batch-tiling correctness bug in scan_butterfly.py (bd gru-triton-c2a)
E       assert 0.06199939548969269 < 1e-05

tests/test_edge_cases.py:580: AssertionError
____________________ test_butterfly_partial_batch_tile[17] _____________________

B = 17

    @cuda_only
    @pytest.mark.parametrize("B", [1, 3, 5, 7, 9, 17, 33])
    def test_butterfly_partial_batch_tile(B: int) -> None:
        """Butterfly partial-last-batch-tile sweep at (T=16, H=512)
        (EDG-02, ROADMAP SC#2).
    
        The CONCERNS.md-suggested butterfly sweep covering the
        ``B % BLOCK_B != 0`` partial-last-tile corner — the butterfly OOB fix
        ``d8218d4`` shipped WITHOUT a regression test. B=1 is the extreme;
        odd B values exercise non-aligned final tiles.
    
        BATCH-INVARIANCE CONTRACT (the binding correctness statement): a
        correct kernel produces bit-identical output for identical per-batch
        inputs, REGARDLESS of the total batch count B. This test replicates
        a single B=1 input across all B batch slots and asserts the Triton
        kernel returns the same per-batch output everywhere. This assertion
        is TF32-INDEPENDENT — every batch runs the exact same arithmetic, so
        any divergence is a genuine batch-tiling / partial-tile correctness
        bug, not numerical drift.
    
        FINDING (bd gru-triton-c2a, D-04): the butterfly Triton kernel FAILS
        this contract at H=512 — replicated-input batches diverge by up to
        ~3e-2 on a B-index-dependent subset of batch slots, while the
        per-step reference path is batch-invariant to ~5e-4 (TF32 only). The
        root cause is in the butterfly persistent kernel's batch-tiling /
        twiddle-stage indexing (``scan_butterfly.py``). A deep kernel fix is
        required (D-06 — accepted phase enlargement); per the Task-3
        CONTEXT-BUDGET handoff this failing test + the bd issue land as the
        recoverable Commit-A artifact, and the kernel fix is handed off.
        NO ``@pytest.mark.xfail`` — this test stays RED until the kernel fix
        (Commit B) lands.
        """
        pytest.importorskip("triton")
        pytest.importorskip("torch_structured")
        torch.manual_seed(0)
        torch.set_float32_matmul_precision("high")
        device = torch.device("cuda")
        T, H = 16, 512
        in_size = H
    
        layer = _make_layer("butterfly_triton", in_size, H).to(device)
    
        # One B=1 input, then the SAME input replicated across B batch slots.
        x1 = torch.randn(T, 1, in_size, device=device)
        out1, _ = layer(x1.clone())
        xN = x1.repeat(1, B, 1)
        outN, hTN = layer(xN.clone())
    
        assert torch.isfinite(outN).all(), f"butterfly out non-finite (B={B})"
        assert tuple(outN.shape) == (T, B, H), (
            f"butterfly out shape {tuple(outN.shape)} != {(T, B, H)} (B={B})"
        )
        # Batch-invariance: every batch slot got identical input, so every
        # slot MUST produce identical output. Per-batch deviation localizes a
        # partial-tile bug to a specific pid_b. TF32-independent — the bound
        # is correctness-tight (1e-5), not a numerical-noise tolerance.
        dev_per_b = [
            (outN[:, b] - out1[:, 0]).abs().max().item() for b in range(B)
        ]
        worst = max(dev_per_b)
>       assert worst < 1e-5, (
            f"butterfly batch-invariance violated: identical per-batch inputs "
            f"produced different outputs, worst abs dev {worst:.4e} (B={B}); "
            f"per-batch dev={[f'{d:.2e}' for d in dev_per_b]} — batch-tiling "
            f"correctness bug in scan_butterfly.py (bd gru-triton-c2a)"
        )
E       AssertionError: butterfly batch-invariance violated: identical per-batch inputs produced different outputs, worst abs dev 1.5028e-01 (B=17); per-batch dev=['1.04e-02', '1.04e-02', '1.02e-02', '1.02e-02', '2.67e-02', '2.63e-02', '2.62e-02', '2.62e-02', '0.00e+00', '0.00e+00', '0.00e+00', '0.00e+00', '1.50e-01', '1.50e-01', '1.49e-01', '1.49e-01', '0.00e+00'] — batch-tiling correctness bug in scan_butterfly.py (bd gru-triton-c2a)
E       assert 0.15028199553489685 < 1e-05

tests/test_edge_cases.py:580: AssertionError
____________________ test_butterfly_partial_batch_tile[33] _____________________

B = 33

    @cuda_only
    @pytest.mark.parametrize("B", [1, 3, 5, 7, 9, 17, 33])
    def test_butterfly_partial_batch_tile(B: int) -> None:
        """Butterfly partial-last-batch-tile sweep at (T=16, H=512)
        (EDG-02, ROADMAP SC#2).
    
        The CONCERNS.md-suggested butterfly sweep covering the
        ``B % BLOCK_B != 0`` partial-last-tile corner — the butterfly OOB fix
        ``d8218d4`` shipped WITHOUT a regression test. B=1 is the extreme;
        odd B values exercise non-aligned final tiles.
    
        BATCH-INVARIANCE CONTRACT (the binding correctness statement): a
        correct kernel produces bit-identical output for identical per-batch
        inputs, REGARDLESS of the total batch count B. This test replicates
        a single B=1 input across all B batch slots and asserts the Triton
        kernel returns the same per-batch output everywhere. This assertion
        is TF32-INDEPENDENT — every batch runs the exact same arithmetic, so
        any divergence is a genuine batch-tiling / partial-tile correctness
        bug, not numerical drift.
    
        FINDING (bd gru-triton-c2a, D-04): the butterfly Triton kernel FAILS
        this contract at H=512 — replicated-input batches diverge by up to
        ~3e-2 on a B-index-dependent subset of batch slots, while the
        per-step reference path is batch-invariant to ~5e-4 (TF32 only). The
        root cause is in the butterfly persistent kernel's batch-tiling /
        twiddle-stage indexing (``scan_butterfly.py``). A deep kernel fix is
        required (D-06 — accepted phase enlargement); per the Task-3
        CONTEXT-BUDGET handoff this failing test + the bd issue land as the
        recoverable Commit-A artifact, and the kernel fix is handed off.
        NO ``@pytest.mark.xfail`` — this test stays RED until the kernel fix
        (Commit B) lands.
        """
        pytest.importorskip("triton")
        pytest.importorskip("torch_structured")
        torch.manual_seed(0)
        torch.set_float32_matmul_precision("high")
        device = torch.device("cuda")
        T, H = 16, 512
        in_size = H
    
        layer = _make_layer("butterfly_triton", in_size, H).to(device)
    
        # One B=1 input, then the SAME input replicated across B batch slots.
        x1 = torch.randn(T, 1, in_size, device=device)
        out1, _ = layer(x1.clone())
        xN = x1.repeat(1, B, 1)
        outN, hTN = layer(xN.clone())
    
        assert torch.isfinite(outN).all(), f"butterfly out non-finite (B={B})"
        assert tuple(outN.shape) == (T, B, H), (
            f"butterfly out shape {tuple(outN.shape)} != {(T, B, H)} (B={B})"
        )
        # Batch-invariance: every batch slot got identical input, so every
        # slot MUST produce identical output. Per-batch deviation localizes a
        # partial-tile bug to a specific pid_b. TF32-independent — the bound
        # is correctness-tight (1e-5), not a numerical-noise tolerance.
        dev_per_b = [
            (outN[:, b] - out1[:, 0]).abs().max().item() for b in range(B)
        ]
        worst = max(dev_per_b)
>       assert worst < 1e-5, (
            f"butterfly batch-invariance violated: identical per-batch inputs "
            f"produced different outputs, worst abs dev {worst:.4e} (B={B}); "
            f"per-batch dev={[f'{d:.2e}' for d in dev_per_b]} — batch-tiling "
            f"correctness bug in scan_butterfly.py (bd gru-triton-c2a)"
        )
E       AssertionError: butterfly batch-invariance violated: identical per-batch inputs produced different outputs, worst abs dev 1.2468e-01 (B=33); per-batch dev=['0.00e+00', '0.00e+00', '0.00e+00', '0.00e+00', '1.46e-02', '1.46e-02', '1.46e-02', '1.49e-02', '6.88e-03', '6.88e-03', '6.88e-03', '6.88e-03', '1.24e-01', '1.24e-01', '1.23e-01', '1.23e-01', '0.00e+00', '0.00e+00', '0.00e+00', '0.00e+00', '0.00e+00', '5.96e-03', '1.25e-01', '1.25e-01', '0.00e+00', '0.00e+00', '0.00e+00', '0.00e+00', '3.39e-02', '3.40e-02', '3.55e-02', '3.51e-02', '0.00e+00'] — batch-tiling correctness bug in scan_butterfly.py (bd gru-triton-c2a)
E       assert 0.12467984855175018 < 1e-05

tests/test_edge_cases.py:580: AssertionError
=========================== short test summary info ============================
SKIPPED [1] tests/test_triton_monarch_strict.py:678: F-04-VERIFIER-F (gru-triton-e0l): monarch bwd kernel cannot run on RTX 2000 Ada at blksz=2 (H=8, nb=4); SMEM OOM or tl.dot K<16 constraint
SKIPPED [2] tests/test_edge_cases.py:485: butterfly H=1 is rejected at construction (size-1 factorization undefined); see test_butterfly_h1_raises_valueerror
FAILED tests/test_edge_cases.py::test_butterfly_partial_batch_tile[7] - Asser...
FAILED tests/test_edge_cases.py::test_butterfly_partial_batch_tile[9] - Asser...
FAILED tests/test_edge_cases.py::test_butterfly_partial_batch_tile[17] - Asse...
FAILED tests/test_edge_cases.py::test_butterfly_partial_batch_tile[33] - Asse...
============= 4 failed, 78 passed, 3 skipped in 107.23s (0:01:47) ==============

### Run 2: uv run pytest tests/test_edge_cases.py -m slow -v  (slow tier)

============================= test session starts ==============================
platform linux -- Python 3.13.12, pytest-9.0.3, pluggy-1.5.0 -- /home/claroche/miniconda3/bin/python3
cachedir: .pytest_cache
rootdir: /home/claroche/gru-triton/.claude/worktrees/agent-a70ebacd575f69c84
configfile: pyproject.toml
plugins: xdist-3.8.0, anyio-4.10.0
collecting ... collected 85 items / 71 deselected / 14 selected

tests/test_edge_cases.py::test_long_t_drift[512-reference] PASSED        [  7%]
tests/test_edge_cases.py::test_long_t_drift[512-dense_triton] PASSED     [ 14%]
tests/test_edge_cases.py::test_long_t_drift[512-diagonal_triton] PASSED  [ 21%]
tests/test_edge_cases.py::test_long_t_drift[512-monarch_triton] PASSED   [ 28%]
tests/test_edge_cases.py::test_long_t_drift[512-butterfly_triton] PASSED [ 35%]
tests/test_edge_cases.py::test_long_t_drift[512-circulant] PASSED        [ 42%]
tests/test_edge_cases.py::test_long_t_drift[512-ldr] PASSED              [ 50%]
tests/test_edge_cases.py::test_long_t_drift[1024-reference] PASSED       [ 57%]
tests/test_edge_cases.py::test_long_t_drift[1024-dense_triton] PASSED    [ 64%]
tests/test_edge_cases.py::test_long_t_drift[1024-diagonal_triton] PASSED [ 71%]
tests/test_edge_cases.py::test_long_t_drift[1024-monarch_triton] PASSED  [ 78%]
tests/test_edge_cases.py::test_long_t_drift[1024-butterfly_triton] PASSED [ 85%]
tests/test_edge_cases.py::test_long_t_drift[1024-circulant] PASSED       [ 92%]
tests/test_edge_cases.py::test_long_t_drift[1024-ldr] PASSED             [100%]

================ 14 passed, 71 deselected in 124.57s (0:02:04) =================
