Does this look good?


B=2, T=3, D=16, dtype=torch.float32, include_lse_grad=True
    grad_x: max_abs=2.384186e-07, max_rel=1.212530e-05
    grad_q: max_abs=2.384186e-07, max_rel=2.491422e-06
   grad_y0: max_abs=3.725290e-09, max_rel=1.876972e-07
   grad_l0: max_abs=5.960464e-08, max_rel=1.618189e-07
PASS

B=2, T=3, D=16, dtype=torch.float32, include_lse_grad=False
    grad_x: max_abs=8.344650e-07, max_rel=3.468968e-06
    grad_q: max_abs=1.192093e-06, max_rel=9.939597e-06
   grad_y0: max_abs=2.384186e-07, max_rel=3.748511e-07
   grad_l0: max_abs=3.576279e-07, max_rel=3.195289e-07
PASS

B=2, T=3, D=32, dtype=torch.float32, include_lse_grad=True
    grad_x: max_abs=1.907349e-06, max_rel=1.661387e-05
    grad_q: max_abs=9.536743e-07, max_rel=7.756099e-06
   grad_y0: max_abs=4.768372e-07, max_rel=5.349277e-07
   grad_l0: max_abs=3.576279e-07, max_rel=1.382988e-05
PASS

B=2, T=3, D=32, dtype=torch.float32, include_lse_grad=False
    grad_x: max_abs=9.536743e-07, max_rel=4.482382e-06
    grad_q: max_abs=1.907349e-06, max_rel=7.872198e-06
   grad_y0: max_abs=5.960464e-08, max_rel=2.358896e-07
   grad_l0: max_abs=7.152557e-07, max_rel=6.606296e-07
PASS

B=2, T=3, D=64, dtype=torch.float32, include_lse_grad=True
    grad_x: max_abs=1.668930e-06, max_rel=8.514135e-06
    grad_q: max_abs=1.370907e-06, max_rel=7.046842e-06
   grad_y0: max_abs=7.152557e-07, max_rel=8.837702e-07
   grad_l0: max_abs=4.768372e-07, max_rel=8.177361e-07
PASS

B=2, T=3, D=64, dtype=torch.float32, include_lse_grad=False
    grad_x: max_abs=7.629395e-06, max_rel=3.870065e-06
    grad_q: max_abs=4.768372e-06, max_rel=6.393007e-06
   grad_y0: max_abs=5.960464e-07, max_rel=9.151361e-07
   grad_l0: max_abs=1.907349e-06, max_rel=5.720813e-06
PASS

B=2, T=3, D=128, dtype=torch.float32, include_lse_grad=True
    grad_x: max_abs=7.629395e-06, max_rel=2.756481e-05
    grad_q: max_abs=8.583069e-06, max_rel=2.624603e-05
   grad_y0: max_abs=2.384186e-07, max_rel=3.616927e-07
   grad_l0: max_abs=2.503395e-06, max_rel=1.942588e-06
PASS

B=2, T=3, D=128, dtype=torch.float32, include_lse_grad=False
    grad_x: max_abs=7.152557e-06, max_rel=1.226199e-05
    grad_q: max_abs=5.722046e-06, max_rel=2.085558e-05
   grad_y0: max_abs=9.536743e-07, max_rel=2.533109e-06
   grad_l0: max_abs=2.384186e-06, max_rel=2.191141e-06
PASS

B=2, T=3, D=256, dtype=torch.float32, include_lse_grad=True
    grad_x: max_abs=2.145767e-06, max_rel=8.211349e-04
    grad_q: max_abs=2.861023e-06, max_rel=4.158281e-05
   grad_y0: max_abs=1.005828e-07, max_rel=2.002195e-06
   grad_l0: max_abs=6.854534e-07, max_rel=1.991570e-06
PASS

B=2, T=3, D=256, dtype=torch.float32, include_lse_grad=False
    grad_x: max_abs=7.629395e-06, max_rel=3.663079e-05
    grad_q: max_abs=9.536743e-06, max_rel=4.279816e-03
   grad_y0: max_abs=1.072884e-06, max_rel=1.012024e-06
   grad_l0: max_abs=2.145767e-06, max_rel=3.578054e-06
PASS

B=4, T=5, D=32, dtype=torch.float16, include_lse_grad=True
    grad_x: max_abs=9.832382e-04, max_rel=8.211740e-03
    grad_q: max_abs=1.672745e-03, max_rel=4.105079e-04
   grad_y0: max_abs=9.577274e-04, max_rel=4.872094e-04
   grad_l0: max_abs=2.843142e-04, max_rel=4.130559e-04
PASS

B=4, T=5, D=64, dtype=torch.float16, include_lse_grad=True
    grad_x: max_abs=1.874447e-03, max_rel=2.341424e+00
    grad_q: max_abs=1.676083e-03, max_rel=4.407155e-04
   grad_y0: max_abs=9.620190e-04, max_rel=3.592997e-02
   grad_l0: max_abs=3.920794e-04, max_rel=3.966756e-04
PASS

B=4, T=5, D=128, dtype=torch.float16, include_lse_grad=True
    grad_x: max_abs=3.638268e-03, max_rel=2.797804e+00
    grad_q: max_abs=3.841400e-03, max_rel=4.031767e-04
   grad_y0: max_abs=8.685589e-04, max_rel=2.459664e+00
   grad_l0: max_abs=4.856586e-04, max_rel=4.439930e-01
PASS

import torch
import triton
import triton.language as tl

@triton.jit
def phase_2_online_softmax_merge_intrablock_backward_kernel(
    intrablock_partial_sum_ptr,
    pseudo_query_ptr,
    prev_interblock_normalized_output_ptr,
    prev_interblock_lse_ptr,
    grad_merged_output_ptr,
    grad_merged_lse_ptr,
    grad_intrablock_partial_sum_ptr,
    grad_pseudo_query_ptr,
    grad_prev_interblock_normalized_output_ptr,
    grad_prev_interblock_lse_ptr,
    eps,
    HIDDEN_DIM: tl.constexpr,
):
    batch_seq_idx = tl.program_id(0)
    hidden_dim_range = tl.arange(0, HIDDEN_DIM)

    x = tl.load(
        intrablock_partial_sum_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range
    ).to(tl.float32)

    q = tl.load(
        pseudo_query_ptr + hidden_dim_range,
        eviction_policy="evict_last",
    ).to(tl.float32)

    y0 = tl.load(
        prev_interblock_normalized_output_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range
    ).to(tl.float32)

    l0 = tl.load(prev_interblock_lse_ptr + batch_seq_idx).to(tl.float32)

    grad_y = tl.load(
        grad_merged_output_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range
    ).to(tl.float32)

    grad_l = tl.load(grad_merged_lse_ptr + batch_seq_idx).to(tl.float32)

    squared_norm_sum = tl.sum(x * x)
    inverse_rms_norm = tl.rsqrt(squared_norm_sum / float(HIDDEN_DIM) + eps)

    dot_xq = tl.sum(x * q)
    l1 = dot_xq * inverse_rms_norm

    merged_max = tl.maximum(l0, l1)
    w0 = tl.exp(l0 - merged_max)
    w1 = tl.exp(l1 - merged_max)
    exp_sum = w0 + w1

    alpha = w0 / exp_sum
    beta = w1 / exp_sum

    grad_y0 = alpha * grad_y
    grad_x_from_value = beta * grad_y

    dot_grad_y_y0_minus_x = tl.sum(grad_y * (y0 - x))

    grad_l0 = alpha * grad_l + alpha * beta * dot_grad_y_y0_minus_x
    grad_l1 = beta * grad_l - alpha * beta * dot_grad_y_y0_minus_x

    inv_rms_cubed = inverse_rms_norm * inverse_rms_norm * inverse_rms_norm

    grad_x_from_logit = grad_l1 * (
        inverse_rms_norm * q
        - dot_xq * inv_rms_cubed * x / float(HIDDEN_DIM)
    )

    grad_q = grad_l1 * inverse_rms_norm * x
    grad_x = grad_x_from_value + grad_x_from_logit

    tl.atomic_add(
        grad_intrablock_partial_sum_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range,
        grad_x,
        sem="relaxed",
    )

    tl.atomic_add(
        grad_pseudo_query_ptr + hidden_dim_range,
        grad_q,
        sem="relaxed",
    )

    tl.store(
        grad_prev_interblock_normalized_output_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range,
        grad_y0,
    )

    tl.store(
        grad_prev_interblock_lse_ptr + batch_seq_idx,
        grad_l0,
    )


def phase_2_online_softmax_merge_intrablock_backward(
    intrablock_partial_sum,
    pseudo_query,
    prev_interblock_normalized_output,
    prev_interblock_lse,
    grad_merged_output,
    grad_merged_lse,
    grad_intrablock_partial_sum,
    grad_pseudo_query,
    grad_prev_interblock_normalized_output,
    grad_prev_interblock_lse,
    eps=None,
):
    if eps is None:
        eps = torch.finfo(torch.float32).eps

    if grad_merged_lse is None:
        grad_merged_lse = torch.zeros_like(prev_interblock_lse)

    phase_2_online_softmax_merge_intrablock_backward_kernel[(B * T,)](
        intrablock_partial_sum,
        pseudo_query,
        prev_interblock_normalized_output,
        prev_interblock_lse,
        grad_merged_output,
        grad_merged_lse,
        grad_intrablock_partial_sum,
        grad_pseudo_query,
        grad_prev_interblock_normalized_output,
        grad_prev_interblock_lse,
        eps,
        D,
    )

def reference_forward(
    x,
    q,
    y0,
    l0,
    eps=torch.finfo(torch.float32).eps,
):
    x_f = x.float()
    q_f = q.float()
    y0_f = y0.float()
    l0_f = l0.float()

    inv_rms = torch.rsqrt(x_f.square().mean(dim=-1) + eps)
    l1 = (x_f * q_f).sum(dim=-1) * inv_rms

    merged_max = torch.maximum(l0_f, l1)
    w0 = torch.exp(l0_f - merged_max)
    w1 = torch.exp(l1 - merged_max)
    exp_sum = w0 + w1

    alpha = w0 / exp_sum
    beta = w1 / exp_sum

    merged_y = alpha[..., None] * y0_f + beta[..., None] * x_f
    merged_lse = merged_max + torch.log(exp_sum)

    return merged_y, merged_lse


def run_backward_check(
    B=4,
    T=8,
    D=64,
    dtype=torch.float32,
    eps=torch.finfo(torch.float32).eps,
    include_lse_grad=True,
    seed=0,
    atol=2e-4,
    rtol=2e-4,
):
    torch.manual_seed(seed)
    device = "cuda"
    N = B * T

    x = torch.randn(N, D, device=device, dtype=dtype, requires_grad=True)
    q = torch.randn(D, device=device, dtype=dtype, requires_grad=True)
    y0 = torch.randn(N, D, device=device, dtype=dtype, requires_grad=True)
    l0 = torch.randn(N, device=device, dtype=dtype, requires_grad=True)

    grad_y = torch.randn(N, D, device=device, dtype=torch.float32)
    grad_l = (
        torch.randn(N, device=device, dtype=torch.float32)
        if include_lse_grad
        else torch.zeros(N, device=device, dtype=torch.float32)
    )

    ref_y, ref_lse = reference_forward(x, q, y0, l0, eps=eps)
    loss = (ref_y * grad_y).sum() + (ref_lse * grad_l).sum()
    loss.backward()

    ref_grad_x = x.grad.detach().clone()
    ref_grad_q = q.grad.detach().clone()
    ref_grad_y0 = y0.grad.detach().clone()
    ref_grad_l0 = l0.grad.detach().clone()

    tri_grad_x = torch.zeros_like(x, dtype=torch.float32)
    tri_grad_q = torch.zeros_like(q, dtype=torch.float32)
    tri_grad_y0 = torch.empty_like(y0, dtype=torch.float32)
    tri_grad_l0 = torch.empty_like(l0, dtype=torch.float32)

    phase_2_online_softmax_merge_intrablock_backward_kernel[(N,)](
        x.detach(),
        q.detach(),
        y0.detach(),
        l0.detach(),
        grad_y,
        grad_l,
        tri_grad_x,
        tri_grad_q,
        tri_grad_y0,
        tri_grad_l0,
        eps,
        D,
    )

    torch.cuda.synchronize()

    def report(name, got, expected):
        max_abs = (got.float() - expected.float()).abs().max().item()
        max_rel = (
            (got.float() - expected.float()).abs()
            / expected.float().abs().clamp_min(1e-8)
        ).max().item()
        print(f"{name:>10}: max_abs={max_abs:.6e}, max_rel={max_rel:.6e}")
        torch.testing.assert_close(
            got.float(),
            expected.float(),
            atol=atol,
            rtol=rtol,
        )

    print(f"\nB={B}, T={T}, D={D}, dtype={dtype}, include_lse_grad={include_lse_grad}")

    report("grad_x", tri_grad_x, ref_grad_x)
    report("grad_q", tri_grad_q, ref_grad_q)
    report("grad_y0", tri_grad_y0, ref_grad_y0)
    report("grad_l0", tri_grad_l0, ref_grad_l0)

    print("PASS")


def run_many_backward_checks():
    for dtype in [torch.float32]:
        for D in [16, 32, 64, 128, 256]:
            run_backward_check(B=2, T=3, D=D, dtype=dtype, include_lse_grad=True, seed=123)
            run_backward_check(B=2, T=3, D=D, dtype=dtype, include_lse_grad=False, seed=456)

    for D in [32, 64, 128]:
        run_backward_check(
            B=4,
            T=5,
            D=D,
            dtype=torch.float16,
            include_lse_grad=True,
            seed=789,
            atol=2e-3,
            rtol=2e-3,
        )


run_many_backward_checks()

Verbatim copy of the above:
Does this look good?


B=2, T=3, D=16, dtype=torch.float32, include_lse_grad=True
    grad_x: max_abs=2.384186e-07, max_rel=1.212530e-05
    grad_q: max_abs=2.384186e-07, max_rel=2.491422e-06
   grad_y0: max_abs=3.725290e-09, max_rel=1.876972e-07
   grad_l0: max_abs=5.960464e-08, max_rel=1.618189e-07
PASS

B=2, T=3, D=16, dtype=torch.float32, include_lse_grad=False
    grad_x: max_abs=8.344650e-07, max_rel=3.468968e-06
    grad_q: max_abs=1.192093e-06, max_rel=9.939597e-06
   grad_y0: max_abs=2.384186e-07, max_rel=3.748511e-07
   grad_l0: max_abs=3.576279e-07, max_rel=3.195289e-07
PASS

B=2, T=3, D=32, dtype=torch.float32, include_lse_grad=True
    grad_x: max_abs=1.907349e-06, max_rel=1.661387e-05
    grad_q: max_abs=9.536743e-07, max_rel=7.756099e-06
   grad_y0: max_abs=4.768372e-07, max_rel=5.349277e-07
   grad_l0: max_abs=3.576279e-07, max_rel=1.382988e-05
PASS

B=2, T=3, D=32, dtype=torch.float32, include_lse_grad=False
    grad_x: max_abs=9.536743e-07, max_rel=4.482382e-06
    grad_q: max_abs=1.907349e-06, max_rel=7.872198e-06
   grad_y0: max_abs=5.960464e-08, max_rel=2.358896e-07
   grad_l0: max_abs=7.152557e-07, max_rel=6.606296e-07
PASS

B=2, T=3, D=64, dtype=torch.float32, include_lse_grad=True
    grad_x: max_abs=1.668930e-06, max_rel=8.514135e-06
    grad_q: max_abs=1.370907e-06, max_rel=7.046842e-06
   grad_y0: max_abs=7.152557e-07, max_rel=8.837702e-07
   grad_l0: max_abs=4.768372e-07, max_rel=8.177361e-07
PASS

B=2, T=3, D=64, dtype=torch.float32, include_lse_grad=False
    grad_x: max_abs=7.629395e-06, max_rel=3.870065e-06
    grad_q: max_abs=4.768372e-06, max_rel=6.393007e-06
   grad_y0: max_abs=5.960464e-07, max_rel=9.151361e-07
   grad_l0: max_abs=1.907349e-06, max_rel=5.720813e-06
PASS

B=2, T=3, D=128, dtype=torch.float32, include_lse_grad=True
    grad_x: max_abs=7.629395e-06, max_rel=2.756481e-05
    grad_q: max_abs=8.583069e-06, max_rel=2.624603e-05
   grad_y0: max_abs=2.384186e-07, max_rel=3.616927e-07
   grad_l0: max_abs=2.503395e-06, max_rel=1.942588e-06
PASS

B=2, T=3, D=128, dtype=torch.float32, include_lse_grad=False
    grad_x: max_abs=7.152557e-06, max_rel=1.226199e-05
    grad_q: max_abs=5.722046e-06, max_rel=2.085558e-05
   grad_y0: max_abs=9.536743e-07, max_rel=2.533109e-06
   grad_l0: max_abs=2.384186e-06, max_rel=2.191141e-06
PASS

B=2, T=3, D=256, dtype=torch.float32, include_lse_grad=True
    grad_x: max_abs=2.145767e-06, max_rel=8.211349e-04
    grad_q: max_abs=2.861023e-06, max_rel=4.158281e-05
   grad_y0: max_abs=1.005828e-07, max_rel=2.002195e-06
   grad_l0: max_abs=6.854534e-07, max_rel=1.991570e-06
PASS

B=2, T=3, D=256, dtype=torch.float32, include_lse_grad=False
    grad_x: max_abs=7.629395e-06, max_rel=3.663079e-05
    grad_q: max_abs=9.536743e-06, max_rel=4.279816e-03
   grad_y0: max_abs=1.072884e-06, max_rel=1.012024e-06
   grad_l0: max_abs=2.145767e-06, max_rel=3.578054e-06
PASS

B=4, T=5, D=32, dtype=torch.float16, include_lse_grad=True
    grad_x: max_abs=9.832382e-04, max_rel=8.211740e-03
    grad_q: max_abs=1.672745e-03, max_rel=4.105079e-04
   grad_y0: max_abs=9.577274e-04, max_rel=4.872094e-04
   grad_l0: max_abs=2.843142e-04, max_rel=4.130559e-04
PASS

B=4, T=5, D=64, dtype=torch.float16, include_lse_grad=True
    grad_x: max_abs=1.874447e-03, max_rel=2.341424e+00
    grad_q: max_abs=1.676083e-03, max_rel=4.407155e-04
   grad_y0: max_abs=9.620190e-04, max_rel=3.592997e-02
   grad_l0: max_abs=3.920794e-04, max_rel=3.966756e-04
PASS

B=4, T=5, D=128, dtype=torch.float16, include_lse_grad=True
    grad_x: max_abs=3.638268e-03, max_rel=2.797804e+00
    grad_q: max_abs=3.841400e-03, max_rel=4.031767e-04
   grad_y0: max_abs=8.685589e-04, max_rel=2.459664e+00
   grad_l0: max_abs=4.856586e-04, max_rel=4.439930e-01
PASS

import torch
import triton
import triton.language as tl

@triton.jit
def phase_2_online_softmax_merge_intrablock_backward_kernel(
    intrablock_partial_sum_ptr,
    pseudo_query_ptr,
    prev_interblock_normalized_output_ptr,
    prev_interblock_lse_ptr,
    grad_merged_output_ptr,
    grad_merged_lse_ptr,
    grad_intrablock_partial_sum_ptr,
    grad_pseudo_query_ptr,
    grad_prev_interblock_normalized_output_ptr,
    grad_prev_interblock_lse_ptr,
    eps,
    HIDDEN_DIM: tl.constexpr,
):
    batch_seq_idx = tl.program_id(0)
    hidden_dim_range = tl.arange(0, HIDDEN_DIM)

    x = tl.load(
        intrablock_partial_sum_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range
    ).to(tl.float32)

    q = tl.load(
        pseudo_query_ptr + hidden_dim_range,
        eviction_policy="evict_last",
    ).to(tl.float32)

    y0 = tl.load(
        prev_interblock_normalized_output_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range
    ).to(tl.float32)

    l0 = tl.load(prev_interblock_lse_ptr + batch_seq_idx).to(tl.float32)

    grad_y = tl.load(
        grad_merged_output_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range
    ).to(tl.float32)

    grad_l = tl.load(grad_merged_lse_ptr + batch_seq_idx).to(tl.float32)

    squared_norm_sum = tl.sum(x * x)
    inverse_rms_norm = tl.rsqrt(squared_norm_sum / float(HIDDEN_DIM) + eps)

    dot_xq = tl.sum(x * q)
    l1 = dot_xq * inverse_rms_norm

    merged_max = tl.maximum(l0, l1)
    w0 = tl.exp(l0 - merged_max)
    w1 = tl.exp(l1 - merged_max)
    exp_sum = w0 + w1

    alpha = w0 / exp_sum
    beta = w1 / exp_sum

    grad_y0 = alpha * grad_y
    grad_x_from_value = beta * grad_y

    dot_grad_y_y0_minus_x = tl.sum(grad_y * (y0 - x))

    grad_l0 = alpha * grad_l + alpha * beta * dot_grad_y_y0_minus_x
    grad_l1 = beta * grad_l - alpha * beta * dot_grad_y_y0_minus_x

    inv_rms_cubed = inverse_rms_norm * inverse_rms_norm * inverse_rms_norm

    grad_x_from_logit = grad_l1 * (
        inverse_rms_norm * q
        - dot_xq * inv_rms_cubed * x / float(HIDDEN_DIM)
    )

    grad_q = grad_l1 * inverse_rms_norm * x
    grad_x = grad_x_from_value + grad_x_from_logit

    tl.atomic_add(
        grad_intrablock_partial_sum_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range,
        grad_x,
        sem="relaxed",
    )

    tl.atomic_add(
        grad_pseudo_query_ptr + hidden_dim_range,
        grad_q,
        sem="relaxed",
    )

    tl.store(
        grad_prev_interblock_normalized_output_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range,
        grad_y0,
    )

    tl.store(
        grad_prev_interblock_lse_ptr + batch_seq_idx,
        grad_l0,
    )


def phase_2_online_softmax_merge_intrablock_backward(
    intrablock_partial_sum,
    pseudo_query,
    prev_interblock_normalized_output,
    prev_interblock_lse,
    grad_merged_output,
    grad_merged_lse,
    grad_intrablock_partial_sum,
    grad_pseudo_query,
    grad_prev_interblock_normalized_output,
    grad_prev_interblock_lse,
    eps=None,
):
    if eps is None:
        eps = torch.finfo(torch.float32).eps

    if grad_merged_lse is None:
        grad_merged_lse = torch.zeros_like(prev_interblock_lse)

    phase_2_online_softmax_merge_intrablock_backward_kernel[(B * T,)](
        intrablock_partial_sum,
        pseudo_query,
        prev_interblock_normalized_output,
        prev_interblock_lse,
        grad_merged_output,
        grad_merged_lse,
        grad_intrablock_partial_sum,
        grad_pseudo_query,
        grad_prev_interblock_normalized_output,
        grad_prev_interblock_lse,
        eps,
        D,
    )

def reference_forward(
    x,
    q,
    y0,
    l0,
    eps=torch.finfo(torch.float32).eps,
):
    x_f = x.float()
    q_f = q.float()
    y0_f = y0.float()
    l0_f = l0.float()

    inv_rms = torch.rsqrt(x_f.square().mean(dim=-1) + eps)
    l1 = (x_f * q_f).sum(dim=-1) * inv_rms

    merged_max = torch.maximum(l0_f, l1)
    w0 = torch.exp(l0_f - merged_max)
    w1 = torch.exp(l1 - merged_max)
    exp_sum = w0 + w1

    alpha = w0 / exp_sum
    beta = w1 / exp_sum

    merged_y = alpha[..., None] * y0_f + beta[..., None] * x_f
    merged_lse = merged_max + torch.log(exp_sum)

    return merged_y, merged_lse


def run_backward_check(
    B=4,
    T=8,
    D=64,
    dtype=torch.float32,
    eps=torch.finfo(torch.float32).eps,
    include_lse_grad=True,
    seed=0,
    atol=2e-4,
    rtol=2e-4,
):
    torch.manual_seed(seed)
    device = "cuda"
    N = B * T

    x = torch.randn(N, D, device=device, dtype=dtype, requires_grad=True)
    q = torch.randn(D, device=device, dtype=dtype, requires_grad=True)
    y0 = torch.randn(N, D, device=device, dtype=dtype, requires_grad=True)
    l0 = torch.randn(N, device=device, dtype=dtype, requires_grad=True)

    grad_y = torch.randn(N, D, device=device, dtype=torch.float32)
    grad_l = (
        torch.randn(N, device=device, dtype=torch.float32)
        if include_lse_grad
        else torch.zeros(N, device=device, dtype=torch.float32)
    )

    ref_y, ref_lse = reference_forward(x, q, y0, l0, eps=eps)
    loss = (ref_y * grad_y).sum() + (ref_lse * grad_l).sum()
    loss.backward()

    ref_grad_x = x.grad.detach().clone()
    ref_grad_q = q.grad.detach().clone()
    ref_grad_y0 = y0.grad.detach().clone()
    ref_grad_l0 = l0.grad.detach().clone()

    tri_grad_x = torch.zeros_like(x, dtype=torch.float32)
    tri_grad_q = torch.zeros_like(q, dtype=torch.float32)
    tri_grad_y0 = torch.empty_like(y0, dtype=torch.float32)
    tri_grad_l0 = torch.empty_like(l0, dtype=torch.float32)

    phase_2_online_softmax_merge_intrablock_backward_kernel[(N,)](
        x.detach(),
        q.detach(),
        y0.detach(),
        l0.detach(),
        grad_y,
        grad_l,
        tri_grad_x,
        tri_grad_q,
        tri_grad_y0,
        tri_grad_l0,
        eps,
        D,
    )

    torch.cuda.synchronize()

    def report(name, got, expected):
        max_abs = (got.float() - expected.float()).abs().max().item()
        max_rel = (
            (got.float() - expected.float()).abs()
            / expected.float().abs().clamp_min(1e-8)
        ).max().item()
        print(f"{name:>10}: max_abs={max_abs:.6e}, max_rel={max_rel:.6e}")
        torch.testing.assert_close(
            got.float(),
            expected.float(),
            atol=atol,
            rtol=rtol,
        )

    print(f"\nB={B}, T={T}, D={D}, dtype={dtype}, include_lse_grad={include_lse_grad}")

    report("grad_x", tri_grad_x, ref_grad_x)
    report("grad_q", tri_grad_q, ref_grad_q)
    report("grad_y0", tri_grad_y0, ref_grad_y0)
    report("grad_l0", tri_grad_l0, ref_grad_l0)

    print("PASS")


def run_many_backward_checks():
    for dtype in [torch.float32]:
        for D in [16, 32, 64, 128, 256]:
            run_backward_check(B=2, T=3, D=D, dtype=dtype, include_lse_grad=True, seed=123)
            run_backward_check(B=2, T=3, D=D, dtype=dtype, include_lse_grad=False, seed=456)

    for D in [32, 64, 128]:
        run_backward_check(
            B=4,
            T=5,
            D=D,
            dtype=torch.float16,
            include_lse_grad=True,
            seed=789,
            atol=2e-3,
            rtol=2e-3,
        )


run_many_backward_checks()