import torch
import torch.nn.functional as F
import triton

torch.manual_seed(0)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False


@torch.compile(mode="max-autotune-no-cudagraphs")
def phase_1_fn(query, value):
    S, D = query.shape
    N, B, T, _ = value.shape

    logits = (F.rms_norm(value, (D,)).reshape(-1, D) @ query.T).view(N, B, T, S)

    max_logits = logits.amax(dim=0)
    exp_weights = torch.exp(logits - max_logits.unsqueeze(0))
    o_weighted_sum = (exp_weights.unsqueeze(-1) * value.unsqueeze(3)).sum(dim=0)

    max_logits = max_logits.permute(2, 0, 1)
    o_weighted_sum = o_weighted_sum.permute(2, 0, 1, 3)
    l_exp_sum = exp_weights.sum(dim=0).permute(2, 0, 1)
    h = o_weighted_sum[0] / l_exp_sum[0][..., None]
    return max_logits, o_weighted_sum, l_exp_sum, h


def phase_1_compiled_ref(query, value):
    max_logits, o_weighted_sum, l_exp_sum, h = phase_1_fn(query, value)

    out = o_weighted_sum / l_exp_sum[..., None]
    lse = max_logits + torch.log(l_exp_sum)

    return out, lse, h


def phase_1_fp32_math_ref(query, value, eps, cast_output_to_bf16=False):
    """
    Matches the math implemented by the Triton kernel:
      logits are accumulated in fp32 from BF16-rounded inputs.
    """
    S, D = query.shape
    N, B, T, _ = value.shape

    q = query.float()
    v = value.float()

    inv_rms = torch.rsqrt(v.square().mean(dim=-1) + eps)  # [N, B, T]
    normed_v = v * inv_rms[..., None]

    logits = (normed_v.reshape(-1, D) @ q.T).view(N, B, T, S)

    max_logits = logits.amax(dim=0)  # [B, T, S]
    exp_weights = torch.exp(logits - max_logits.unsqueeze(0))
    exp_sum = exp_weights.sum(dim=0)  # [B, T, S]

    weighted_sum = (exp_weights.unsqueeze(-1) * v.unsqueeze(3)).sum(dim=0)
    out = weighted_sum / exp_sum.unsqueeze(-1)  # [B, T, S, D]

    out = out.permute(2, 0, 1, 3).contiguous()  # [S, B, T, D]
    lse = (max_logits + torch.log(exp_sum)).permute(2, 0, 1).contiguous()

    if cast_output_to_bf16:
        out = out.to(torch.bfloat16)

    return out, lse


def phase_1_triton_forward(value, query, eps):
    N, B, T, D = value.shape
    S, Dq = query.shape
    assert Dq == D

    out = torch.empty((S, B, T, D), device=value.device, dtype=torch.bfloat16)
    lse = torch.empty((S, B, T), device=value.device, dtype=torch.float32)

    BT = B * T

    phase_1_batched_interblock_attention_kernel[(BT,)](
        value,
        query,
        out,
        lse,
        eps,
        N,
        BT,
        D,
        S,
        triton.next_power_of_2(N),
    )

    return out, lse


def phase_1_triton_backward(
    value,
    query,
    lse,
    grad_out,
    grad_lse,
    eps,
):
    N, B, T, D = value.shape
    S, Dq = query.shape
    assert Dq == D

    grad_value = torch.zeros((N, B, T, D), device=value.device, dtype=torch.float32)
    grad_query = torch.zeros((S, D), device=query.device, dtype=torch.float32)

    BT = B * T

    phase_1_batched_interblock_attention_backward_kernel[(BT,)](
        value,
        query,
        lse,
        grad_out,
        grad_lse,
        grad_value,
        grad_query,
        eps,
        N,
        BT,
        D,
        S,
        triton.next_power_of_2(N),
    )

    return grad_value, grad_query


def assert_close(name, actual, expected, *, atol, rtol, fail=True):
    actual_f = actual.float()
    expected_f = expected.float()

    abs_err = (actual_f - expected_f).abs()
    rel_err = abs_err / expected_f.abs().clamp_min(1e-6)

    max_abs = abs_err.max().item()
    max_rel = rel_err.max().item()
    ok = torch.allclose(actual_f, expected_f, atol=atol, rtol=rtol)

    print(f"{name:34s} ok={ok}  max_abs={max_abs:.6g}  max_rel={max_rel:.6g}")

    if not ok:
        idx = abs_err.argmax()
        print("  actual  :", actual_f.flatten()[idx].item())
        print("  expected:", expected_f.flatten()[idx].item())

        if fail:
            raise AssertionError(f"{name} mismatch")


def verify_phase_1_once(
    *,
    N,
    B_,
    T_,
    D_,
    S,
    dtype=torch.bfloat16,
):
    device = "cuda"

    # Matches F.rms_norm(value, (D,)) with eps=None for BF16 input.
    eps = torch.finfo(dtype).eps

    value = torch.randn(N, B_, T_, D_, device=device, dtype=dtype)
    query = torch.randn(S, D_, device=device, dtype=dtype)

    # Warm up compiled reference.
    for _ in range(3):
        phase_1_compiled_ref(query, value)

    torch.cuda.synchronize()

    tri_out, tri_lse = phase_1_triton_forward(value, query, eps)
    torch.cuda.synchronize()

    fp32_out_bf16, fp32_lse = phase_1_fp32_math_ref(
        query,
        value,
        eps,
        cast_output_to_bf16=True,
    )

    compiled_out, compiled_lse, compiled_h = phase_1_compiled_ref(query, value)

    print("\nStrict check against FP32 math reference")
    assert_close(
        "forward out vs fp32_ref",
        tri_out,
        fp32_out_bf16,
        atol=2.5e-2,
        rtol=2.5e-2,
    )
    assert_close(
        "forward lse vs fp32_ref",
        tri_lse,
        fp32_lse,
        atol=5.0e-2,
        rtol=5.0e-2,
    )

    print("\nLoose check against compiled BF16 reference")
    assert_close(
        "forward out vs compiled_ref",
        tri_out,
        compiled_out,
        atol=8.0e-2,
        rtol=8.0e-2,
    )
    assert_close(
        "forward query0 vs compiled h",
        tri_out[0],
        compiled_h,
        atol=8.0e-2,
        rtol=8.0e-2,
    )
    assert_close(
        "forward lse vs compiled_ref",
        tri_lse,
        compiled_lse,
        atol=8.0e-2,
        rtol=8.0e-2,
    )

    # Backward reference in FP32.
    value_ref = value.detach().float().requires_grad_(True)
    query_ref = query.detach().float().requires_grad_(True)

    ref_out, ref_lse = phase_1_fp32_math_ref(
        query_ref,
        value_ref,
        eps,
        cast_output_to_bf16=False,
    )

    grad_out = torch.randn_like(ref_out)
    grad_lse = torch.randn_like(ref_lse)

    ref_loss = (ref_out * grad_out).sum() + (ref_lse * grad_lse).sum()
    ref_loss.backward()

    # Use ref_lse here to isolate the backward kernel formula from small
    # forward-lse approximation differences.
    tri_grad_value, tri_grad_query = phase_1_triton_backward(
        value,
        query,
        ref_lse.detach(),
        grad_out,
        grad_lse,
        eps,
    )
    torch.cuda.synchronize()

    print("\nBackward check against FP32 math reference")
    assert_close(
        "grad value vs fp32_ref",
        tri_grad_value,
        value_ref.grad,
        atol=7.5e-2,
        rtol=7.5e-2,
    )
    assert_close(
        "grad query vs fp32_ref",
        tri_grad_query,
        query_ref.grad,
        atol=7.5e-2,
        rtol=7.5e-2,
    )

    print("\nphase_1 verifier passed")


def verify_phase_1_many():
    test_cases = [
        dict(N=1, B_=1, T_=8, D_=32, S=1),
        dict(N=2, B_=2, T_=11, D_=64, S=2),
        dict(N=5, B_=2, T_=17, D_=64, S=4),
        dict(N=9, B_=1, T_=23, D_=128, S=8),
        dict(N=10, B_=2, T_=19, D_=128, S=8),
    ]

    for cfg in test_cases:
        print(f"\n=== cfg={cfg} ===")
        verify_phase_1_once(**cfg)


verify_phase_1_many()


=== cfg={'N': 1, 'B_': 1, 'T_': 8, 'D_': 32, 'S': 1} ===

Strict check against FP32 math reference
forward out vs fp32_ref            ok=True  max_abs=0  max_rel=0
forward lse vs fp32_ref            ok=True  max_abs=9.53674e-07  max_rel=2.75492e-07

Loose check against compiled BF16 reference
forward out vs compiled_ref        ok=True  max_abs=0  max_rel=0
forward query0 vs compiled h       ok=True  max_abs=0  max_rel=0
forward lse vs compiled_ref        ok=True  max_abs=0.0330162  max_rel=0.00650462

Backward check against FP32 math reference
grad value vs fp32_ref             ok=True  max_abs=6.22869e-06  max_rel=0.000823695
grad query vs fp32_ref             ok=True  max_abs=1.45435e-05  max_rel=5.49677e-05

phase_1 verifier passed

=== cfg={'N': 2, 'B_': 2, 'T_': 11, 'D_': 64, 'S': 2} ===

Strict check against FP32 math reference
forward out vs fp32_ref            ok=True  max_abs=0  max_rel=0
forward lse vs fp32_ref            ok=True  max_abs=1.90735e-06  max_rel=5.97271e-06

Loose check against compiled BF16 reference
forward out vs compiled_ref        ok=True  max_abs=0.046875  max_rel=0.985915
forward query0 vs compiled h       ok=True  max_abs=0.046875  max_rel=0.985915
forward lse vs compiled_ref        ok=True  max_abs=0.0903473  max_rel=0.121605

Backward check against FP32 math reference
grad value vs fp32_ref             ok=True  max_abs=3.67165e-05  max_rel=0.0205433
grad query vs fp32_ref             ok=True  max_abs=7.58171e-05  max_rel=0.000272084

phase_1 verifier passed

=== cfg={'N': 5, 'B_': 2, 'T_': 17, 'D_': 64, 'S': 4} ===

Strict check against FP32 math reference
forward out vs fp32_ref            ok=True  max_abs=0.00390625  max_rel=0.00641026
forward lse vs fp32_ref            ok=True  max_abs=2.86102e-06  max_rel=6.1498e-06

Loose check against compiled BF16 reference
forward out vs compiled_ref        ok=True  max_abs=0.0351562  max_rel=2.18902
forward query0 vs compiled h       ok=True  max_abs=0.03125  max_rel=1.29586
forward lse vs compiled_ref        ok=True  max_abs=0.136179  max_rel=0.0977561

Backward check against FP32 math reference
grad value vs fp32_ref             ok=True  max_abs=5.4121e-05  max_rel=0.035051
grad query vs fp32_ref             ok=True  max_abs=9.76324e-05  max_rel=0.00187211

phase_1 verifier passed

=== cfg={'N': 9, 'B_': 1, 'T_': 23, 'D_': 128, 'S': 8} ===

Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "mm", "best_time": 0.010239999741315842, "best_triton_pos": 1, "best_triton_time": 0.010239999741315842, "best_triton_kernel": "triton_mm_53", "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=2"}
AUTOTUNE mm(207x128, 128x8)
strides: [s31, 1], [1, s31]
dtypes: torch.bfloat16, torch.bfloat16
  mm 0.0102 ms 100.0% 
  triton_mm_53 0.0102 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=2
  triton_mm_56 0.0102 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_59 0.0102 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_62 0.0102 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_67 0.0113 ms 90.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_52 0.0123 ms 83.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=1, num_warps=2
  triton_mm_54 0.0123 ms 83.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=2
  triton_mm_55 0.0123 ms 83.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_57 0.0123 ms 83.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.3206 seconds and 0.8070 seconds precompiling for 18 choices


Strict check against FP32 math reference
forward out vs fp32_ref            ok=True  max_abs=0.00195312  max_rel=0.0177515
forward lse vs fp32_ref            ok=True  max_abs=3.8147e-06  max_rel=4.39347e-07

Loose check against compiled BF16 reference
forward out vs compiled_ref        ok=True  max_abs=0.0507812  max_rel=62.1097
forward query0 vs compiled h       ok=True  max_abs=0.046875  max_rel=33
forward lse vs compiled_ref        ok=True  max_abs=0.260536  max_rel=0.00976179

Backward check against FP32 math reference
grad value vs fp32_ref             ok=True  max_abs=0.000273705  max_rel=0.018092
grad query vs fp32_ref             ok=True  max_abs=0.000303268  max_rel=0.00276758

phase_1 verifier passed

=== cfg={'N': 10, 'B_': 2, 'T_': 19, 'D_': 128, 'S': 8} ===

Strict check against FP32 math reference
forward out vs fp32_ref            ok=True  max_abs=0.00390625  max_rel=0.00775194
forward lse vs fp32_ref            ok=True  max_abs=7.62939e-06  max_rel=3.55747e-06

Loose check against compiled BF16 reference
forward out vs compiled_ref        ok=True  max_abs=0.0859375  max_rel=152.081
forward query0 vs compiled h       ok=True  max_abs=0.046875  max_rel=12.042
forward lse vs compiled_ref        ok=True  max_abs=0.24152  max_rel=0.142159

Backward check against FP32 math reference
grad value vs fp32_ref             ok=True  max_abs=0.000224352  max_rel=0.166864
grad query vs fp32_ref             ok=True  max_abs=0.000386715  max_rel=0.0035152

phase_1 verifier passed

