import math
import random
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
import os

os.environ["TRITON_PRINT_AUTOTUNING"] = "1"

DEVICE = "cuda"
DTYPE = torch.bfloat16

L = 32
BLOCK_SIZE = 8
NUM_BLOCKS = math.ceil(L / BLOCK_SIZE) + 1

B, T, D = 32, 1024, 512
BT = B * T

EPS = torch.finfo(torch.float32).eps

autotune_configs = [
    triton.Config({}, num_warps=num_warps, num_stages=num_stages)
    for num_warps in [1, 2, 4, 8, 16]
    for num_stages in [1, 2, 3, 4]
]


@triton.autotune(
    configs=autotune_configs,
    key=["NUM_SOURCE_BLOCKS", "HIDDEN_DIM", "NUM_QUERIES_PER_BLOCK", "PADDED_SRC"],
)
@triton.jit
def phase_1_batched_interblock_attention_kernel(
    block_representations_ptr,
    pseudo_queries_ptr,
    softmax_normalized_output_ptr,
    lse_ptr,
    eps,
    NUM_SOURCE_BLOCKS: tl.constexpr,
    BT: tl.constexpr,
    HIDDEN_DIM: tl.constexpr,
    NUM_QUERIES_PER_BLOCK: tl.constexpr,
    PADDED_SRC: tl.constexpr,
):
    batch_seq_idx = tl.program_id(0)

    source_block_range = tl.arange(0, PADDED_SRC)[:, None]
    hidden_dim_range = tl.arange(0, HIDDEN_DIM)[None, :]
    valid_block_mask_2d = source_block_range < NUM_SOURCE_BLOCKS

    valid_block_mask_1d = tl.arange(0, PADDED_SRC) < NUM_SOURCE_BLOCKS

    source_block_values = tl.load(
        block_representations_ptr
        + source_block_range * (BT * HIDDEN_DIM)
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range,
        mask=valid_block_mask_2d,
        other=0.0,
    ).to(tl.float32)

    squared_norm_sum = tl.sum(source_block_values * source_block_values, axis=1)
    inverse_rms_norm = tl.rsqrt(squared_norm_sum / float(HIDDEN_DIM) + eps)

    hidden_dim_range_1d = tl.arange(0, HIDDEN_DIM)

    for layer_offset in tl.static_range(NUM_QUERIES_PER_BLOCK):
        pseudo_query_vector = tl.load(
            pseudo_queries_ptr + layer_offset * HIDDEN_DIM + hidden_dim_range,
            eviction_policy="evict_last",
        ).to(tl.float32)

        attention_logits = (
            tl.sum(source_block_values * pseudo_query_vector, axis=1) * inverse_rms_norm
        )
        attention_logits = tl.where(
            valid_block_mask_1d, attention_logits, float("-inf")
        )

        max_attention_logit = tl.max(attention_logits)
        exp_attention_logits = tl.exp(attention_logits - max_attention_logit)
        exp_sum = tl.sum(exp_attention_logits)

        unnormalized_output = tl.sum(
            exp_attention_logits[:, None] * source_block_values, axis=0
        )
        normalized_output = (unnormalized_output / exp_sum).to(tl.bfloat16)

        tl.store(
            softmax_normalized_output_ptr
            + layer_offset * BT * HIDDEN_DIM
            + batch_seq_idx * HIDDEN_DIM
            + hidden_dim_range_1d,
            normalized_output,
        )
        tl.store(
            lse_ptr + layer_offset * BT + batch_seq_idx,
            max_attention_logit + tl.log(exp_sum),
        )


def phase_1_batched_interblock_attention(
    block_representations,
    pseudo_queries,
    softmax_outputs,
    lses,
    eps=None,
):
    NUM_QUERIES = pseudo_queries.shape[0]
    NUM_SOURCE_BLOCKS = block_representations.shape[0]

    if eps is None:
        eps = EPS

    phase_1_batched_interblock_attention_kernel[(BT,)](
        block_representations,
        pseudo_queries,
        softmax_outputs,
        lses,
        eps,
        NUM_SOURCE_BLOCKS,
        BT,
        D,
        NUM_QUERIES,
        triton.next_power_of_2(NUM_SOURCE_BLOCKS),
    )


@triton.autotune(
    configs=autotune_configs,
    key=["HIDDEN_DIM"],
    restore_value=[
        "interblock_normalized_output_ptr",
    ],
)
@triton.jit
def phase_2_online_softmax_merge_intrablock_kernel(
    intrablock_partial_sum_ptr,
    pseudo_query_ptr,
    interblock_normalized_output_ptr,
    interblock_lse_ptr,
    eps,
    HIDDEN_DIM: tl.constexpr,
):
    batch_seq_idx = tl.program_id(0)
    hidden_dim_range = tl.arange(0, HIDDEN_DIM)

    intrablock_partial_sum = tl.load(
        intrablock_partial_sum_ptr + batch_seq_idx * HIDDEN_DIM + hidden_dim_range
    ).to(tl.float32)
    pseudo_query_vector = tl.load(
        pseudo_query_ptr + hidden_dim_range, eviction_policy="evict_last"
    ).to(tl.float32)

    interblock_lse = tl.load(interblock_lse_ptr + batch_seq_idx)
    interblock_normalized_output = tl.load(
        interblock_normalized_output_ptr + batch_seq_idx * HIDDEN_DIM + hidden_dim_range
    ).to(tl.float32)

    squared_norm_sum = tl.sum(intrablock_partial_sum * intrablock_partial_sum)
    inverse_rms_norm = tl.rsqrt(squared_norm_sum / float(HIDDEN_DIM) + eps)

    intrablock_logit = (
        tl.sum(intrablock_partial_sum * pseudo_query_vector) * inverse_rms_norm
    )
    merged_max = tl.maximum(interblock_lse, intrablock_logit)
    interblock_weight = tl.exp(interblock_lse - merged_max)
    intrablock_weight = tl.exp(intrablock_logit - merged_max)
    exp_sum = interblock_weight + intrablock_weight
    merged_output = (
        interblock_weight * interblock_normalized_output
        + intrablock_weight * intrablock_partial_sum
    ) / exp_sum

    tl.store(
        interblock_normalized_output_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range,
        merged_output.to(tl.bfloat16),
    )


def phase_2_online_softmax_merge_intrablock(
    intrablock_partial_sum,
    pseudo_query,
    interblock_normalized_output,
    interblock_lse,
    eps=None,
):
    if eps is None:
        eps = EPS

    phase_2_online_softmax_merge_intrablock_kernel[(BT,)](
        intrablock_partial_sum,
        pseudo_query,
        interblock_normalized_output,
        interblock_lse,
        eps,
        D,
    )


@triton.autotune(
    configs=autotune_configs,
    key=["NUM_SOURCE_BLOCKS", "HIDDEN_DIM", "NUM_QUERIES_PER_BLOCK", "PADDED_SRC"],
    restore_value=[
        "grad_block_representations_accumulator_ptr",
    ],
)
@triton.jit
def phase_1_batched_interblock_attention_backward_kernel(
    block_representations_ptr,
    pseudo_queries_ptr,
    lse_ptr,
    grad_softmax_normalized_output_ptr,
    grad_lse_ptr,
    grad_block_representations_accumulator_ptr,
    grad_pseudo_queries_partial_ptr,
    eps,
    NUM_SOURCE_BLOCKS: tl.constexpr,
    BT: tl.constexpr,
    HIDDEN_DIM: tl.constexpr,
    NUM_QUERIES_PER_BLOCK: tl.constexpr,
    PADDED_SRC: tl.constexpr,
    HAS_GRAD_LSE: tl.constexpr,
):
    batch_seq_idx = tl.program_id(0)

    source_block_range = tl.arange(0, PADDED_SRC)[:, None]
    source_block_range_1d = tl.arange(0, PADDED_SRC)

    hidden_dim_range = tl.arange(0, HIDDEN_DIM)[None, :]
    hidden_dim_range_1d = tl.arange(0, HIDDEN_DIM)

    valid_block_mask_2d = source_block_range < NUM_SOURCE_BLOCKS
    valid_block_mask_1d = source_block_range_1d < NUM_SOURCE_BLOCKS

    source_block_values = tl.load(
        block_representations_ptr
        + source_block_range * (BT * HIDDEN_DIM)
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range,
        mask=valid_block_mask_2d,
        other=0.0,
    ).to(tl.float32)

    squared_norm_sum = tl.sum(source_block_values * source_block_values, axis=1)
    inverse_rms_norm = tl.rsqrt(squared_norm_sum / float(HIDDEN_DIM) + eps)
    inverse_rms_norm_cubed = inverse_rms_norm * inverse_rms_norm * inverse_rms_norm

    for layer_offset in tl.static_range(NUM_QUERIES_PER_BLOCK):
        pseudo_query_vector = tl.load(
            pseudo_queries_ptr + layer_offset * HIDDEN_DIM + hidden_dim_range,
            eviction_policy="evict_last",
        ).to(tl.float32)

        grad_attention_output = tl.load(
            grad_softmax_normalized_output_ptr
            + layer_offset * BT * HIDDEN_DIM
            + batch_seq_idx * HIDDEN_DIM
            + hidden_dim_range_1d,
        ).to(tl.float32)

        if HAS_GRAD_LSE:
            grad_logsumexp = tl.load(
                grad_lse_ptr + layer_offset * BT + batch_seq_idx
            ).to(tl.float32)
        else:
            grad_logsumexp = 0.0

        forward_logsumexp = tl.load(lse_ptr + layer_offset * BT + batch_seq_idx).to(
            tl.float32
        )

        pseudo_query_source_dot = tl.sum(
            source_block_values * pseudo_query_vector,
            axis=1,
        )

        attention_logits = pseudo_query_source_dot * inverse_rms_norm
        attention_logits = tl.where(
            valid_block_mask_1d,
            attention_logits,
            float("-inf"),
        )

        softmax_probabilities = tl.exp(attention_logits - forward_logsumexp)

        grad_output_dot_source_values = tl.sum(
            source_block_values * grad_attention_output[None, :],
            axis=1,
        )

        grad_output_dot_expected_value = tl.sum(
            softmax_probabilities * grad_output_dot_source_values,
            axis=0,
        )

        grad_attention_logits = softmax_probabilities * (
            grad_logsumexp
            + grad_output_dot_source_values
            - grad_output_dot_expected_value
        )

        grad_source_from_value_path = (
            softmax_probabilities[:, None] * grad_attention_output[None, :]
        )

        grad_source_from_logit_path = grad_attention_logits[:, None] * (
            inverse_rms_norm[:, None] * pseudo_query_vector
            - pseudo_query_source_dot[:, None]
            * inverse_rms_norm_cubed[:, None]
            * source_block_values
            / float(HIDDEN_DIM)
        )

        grad_source_block_values = (
            grad_source_from_value_path + grad_source_from_logit_path
        )

        grad_source_block_values = tl.where(
            valid_block_mask_2d,
            grad_source_block_values,
            0.0,
        )

        grad_pseudo_query = tl.sum(
            grad_attention_logits[:, None]
            * inverse_rms_norm[:, None]
            * source_block_values,
            axis=0,
        )

        tl.atomic_add(
            grad_block_representations_accumulator_ptr
            + source_block_range * (BT * HIDDEN_DIM)
            + batch_seq_idx * HIDDEN_DIM
            + hidden_dim_range,
            grad_source_block_values,
            mask=valid_block_mask_2d,
            sem="relaxed",
        )

        tl.store(
            grad_pseudo_queries_partial_ptr
            + layer_offset * BT * HIDDEN_DIM
            + batch_seq_idx * HIDDEN_DIM
            + hidden_dim_range_1d,
            grad_pseudo_query,
        )


reduce_configs = [
    triton.Config(
        {
            "BLOCK_BATCH_SEQ": block_batch_seq,
            "BLOCK_HIDDEN": block_hidden,
        },
        num_warps=num_warps,
        num_stages=1,
    )
    for block_batch_seq in [64, 128, 256]
    for block_hidden in [16, 32]
    for num_warps in [4, 8]
]


@triton.autotune(
    configs=reduce_configs,
    key=["NUM_BATCH_SEQ", "HIDDEN_DIM", "NUM_QUERIES_PER_BLOCK"],
    restore_value=[
        "grad_pseudo_queries_accumulator_ptr",
    ],
)
@triton.jit
def phase_1_reduce_grad_pseudo_queries_kernel(
    grad_pseudo_queries_partial_ptr,
    grad_pseudo_queries_accumulator_ptr,
    NUM_BATCH_SEQ: tl.constexpr,
    HIDDEN_DIM: tl.constexpr,
    NUM_QUERIES_PER_BLOCK: tl.constexpr,
    BLOCK_BATCH_SEQ: tl.constexpr,
    BLOCK_HIDDEN: tl.constexpr,
):
    batch_seq_block_idx = tl.program_id(0)
    query_idx = tl.program_id(1)
    hidden_block_idx = tl.program_id(2)

    batch_seq_offsets = batch_seq_block_idx * BLOCK_BATCH_SEQ + tl.arange(
        0, BLOCK_BATCH_SEQ
    )

    hidden_offsets = hidden_block_idx * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN)

    grad_tile = tl.load(
        grad_pseudo_queries_partial_ptr
        + query_idx * NUM_BATCH_SEQ * HIDDEN_DIM
        + batch_seq_offsets[:, None] * HIDDEN_DIM
        + hidden_offsets[None, :],
        mask=(
            (batch_seq_offsets[:, None] < NUM_BATCH_SEQ)
            & (hidden_offsets[None, :] < HIDDEN_DIM)
            & (query_idx < NUM_QUERIES_PER_BLOCK)
        ),
        other=0.0,
    ).to(tl.float32)

    grad_reduced = tl.sum(grad_tile, axis=0)

    tl.atomic_add(
        grad_pseudo_queries_accumulator_ptr + query_idx * HIDDEN_DIM + hidden_offsets,
        grad_reduced,
        mask=((hidden_offsets < HIDDEN_DIM) & (query_idx < NUM_QUERIES_PER_BLOCK)),
        sem="relaxed",
    )


def phase_1_batched_interblock_attention_backward(
    block_representations,
    pseudo_queries,
    lses,
    grad_softmax_outputs,
    grad_lses,
    grad_block_representations,
    grad_pseudo_queries,
    grad_pseudo_queries_partial,
    eps=None,
):
    NUM_QUERIES = pseudo_queries.shape[0]
    NUM_SOURCE_BLOCKS = block_representations.shape[0]

    if eps is None:
        eps = EPS

    has_grad_lses = grad_lses is not None
    if grad_lses is None:
        grad_lses = lses

    phase_1_batched_interblock_attention_backward_kernel[(BT,)](
        block_representations,
        pseudo_queries,
        lses,
        grad_softmax_outputs,
        grad_lses,
        grad_block_representations,
        grad_pseudo_queries_partial,
        eps,
        NUM_SOURCE_BLOCKS,
        BT,
        D,
        NUM_QUERIES,
        triton.next_power_of_2(NUM_SOURCE_BLOCKS),
        has_grad_lses,
    )

    phase_1_reduce_grad_pseudo_queries_kernel[
        lambda META: (
            triton.cdiv(BT, META["BLOCK_BATCH_SEQ"]),
            NUM_QUERIES,
            triton.cdiv(D, META["BLOCK_HIDDEN"]),
        )
    ](
        grad_pseudo_queries_partial,
        grad_pseudo_queries,
        BT,
        D,
        NUM_QUERIES,
    )


@triton.autotune(
    configs=autotune_configs,
    key=["HIDDEN_DIM"],
    restore_value=[
        "grad_intrablock_partial_sum_accumulator_ptr",
    ],
)
@triton.jit
def phase_2_online_softmax_merge_intrablock_backward_kernel(
    intrablock_partial_sum_ptr,
    pseudo_query_ptr,
    phase1_interblock_normalized_output_ptr,
    phase1_interblock_logsumexp_ptr,
    grad_merged_attention_output_ptr,
    grad_intrablock_partial_sum_accumulator_ptr,
    grad_pseudo_query_partial_ptr,
    grad_phase1_interblock_normalized_output_ptr,
    grad_phase1_interblock_logsumexp_ptr,
    eps,
    HIDDEN_DIM: tl.constexpr,
):
    batch_seq_idx = tl.program_id(0)
    hidden_dim_range = tl.arange(0, HIDDEN_DIM)

    intrablock_partial_sum = tl.load(
        intrablock_partial_sum_ptr + batch_seq_idx * HIDDEN_DIM + hidden_dim_range
    ).to(tl.float32)

    pseudo_query = tl.load(
        pseudo_query_ptr + hidden_dim_range,
        eviction_policy="evict_last",
    ).to(tl.float32)

    phase1_interblock_normalized_output = tl.load(
        phase1_interblock_normalized_output_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range
    ).to(tl.float32)

    phase1_interblock_logsumexp = tl.load(
        phase1_interblock_logsumexp_ptr + batch_seq_idx
    ).to(tl.float32)

    grad_merged_attention_output = tl.load(
        grad_merged_attention_output_ptr + batch_seq_idx * HIDDEN_DIM + hidden_dim_range
    ).to(tl.float32)

    intrablock_partial_sum_squared_norm = tl.sum(
        intrablock_partial_sum * intrablock_partial_sum
    )
    intrablock_inverse_rms_norm = tl.rsqrt(
        intrablock_partial_sum_squared_norm / float(HIDDEN_DIM) + eps
    )

    pseudo_query_intrablock_dot = tl.sum(intrablock_partial_sum * pseudo_query)
    phase2_intrablock_logit = pseudo_query_intrablock_dot * intrablock_inverse_rms_norm

    online_softmax_shift = tl.maximum(
        phase1_interblock_logsumexp,
        phase2_intrablock_logit,
    )
    phase1_partition_weight = tl.exp(phase1_interblock_logsumexp - online_softmax_shift)
    phase2_partition_weight = tl.exp(phase2_intrablock_logit - online_softmax_shift)
    merged_partition_weight_sum = phase1_partition_weight + phase2_partition_weight

    phase1_merge_probability = phase1_partition_weight / merged_partition_weight_sum
    phase2_merge_probability = phase2_partition_weight / merged_partition_weight_sum

    grad_phase1_interblock_normalized_output = (
        phase1_merge_probability * grad_merged_attention_output
    )
    grad_intrablock_partial_sum_from_value_path = (
        phase2_merge_probability * grad_merged_attention_output
    )

    grad_output_dot_interblock_minus_intrablock = tl.sum(
        grad_merged_attention_output
        * (phase1_interblock_normalized_output - intrablock_partial_sum)
    )

    merge_probability_product = phase1_merge_probability * phase2_merge_probability

    grad_phase1_interblock_logsumexp = (
        merge_probability_product * grad_output_dot_interblock_minus_intrablock
    )

    grad_phase2_intrablock_logit = (
        -merge_probability_product * grad_output_dot_interblock_minus_intrablock
    )

    intrablock_inverse_rms_norm_cubed = (
        intrablock_inverse_rms_norm
        * intrablock_inverse_rms_norm
        * intrablock_inverse_rms_norm
    )

    grad_intrablock_partial_sum_from_logit_path = grad_phase2_intrablock_logit * (
        intrablock_inverse_rms_norm * pseudo_query
        - pseudo_query_intrablock_dot
        * intrablock_inverse_rms_norm_cubed
        * intrablock_partial_sum
        / float(HIDDEN_DIM)
    )

    grad_pseudo_query = (
        grad_phase2_intrablock_logit
        * intrablock_inverse_rms_norm
        * intrablock_partial_sum
    )
    grad_intrablock_partial_sum = (
        grad_intrablock_partial_sum_from_value_path
        + grad_intrablock_partial_sum_from_logit_path
    )

    grad_intrablock_ptr = (
        grad_intrablock_partial_sum_accumulator_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range
    )

    tl.store(
        grad_intrablock_ptr,
        tl.load(grad_intrablock_ptr).to(tl.float32) + grad_intrablock_partial_sum,
    )

    tl.store(
        grad_pseudo_query_partial_ptr + batch_seq_idx * HIDDEN_DIM + hidden_dim_range,
        grad_pseudo_query,
    )

    tl.store(
        grad_phase1_interblock_normalized_output_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range,
        grad_phase1_interblock_normalized_output,
    )

    tl.store(
        grad_phase1_interblock_logsumexp_ptr + batch_seq_idx,
        grad_phase1_interblock_logsumexp,
    )


@triton.autotune(
    configs=reduce_configs,
    key=["NUM_BATCH_SEQ", "HIDDEN_DIM"],
    restore_value=[
        "grad_pseudo_query_accumulator_ptr",
    ],
)
@triton.jit
def phase_2_reduce_grad_pseudo_query_kernel(
    grad_pseudo_query_partial_ptr,
    grad_pseudo_query_accumulator_ptr,
    NUM_BATCH_SEQ: tl.constexpr,
    HIDDEN_DIM: tl.constexpr,
    BLOCK_BATCH_SEQ: tl.constexpr,
    BLOCK_HIDDEN: tl.constexpr,
):
    batch_seq_block_idx = tl.program_id(0)
    hidden_block_idx = tl.program_id(1)

    batch_seq_offsets = batch_seq_block_idx * BLOCK_BATCH_SEQ + tl.arange(
        0, BLOCK_BATCH_SEQ
    )
    hidden_offsets = hidden_block_idx * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN)

    grad_tile = tl.load(
        grad_pseudo_query_partial_ptr
        + batch_seq_offsets[:, None] * HIDDEN_DIM
        + hidden_offsets[None, :],
        mask=(
            (batch_seq_offsets[:, None] < NUM_BATCH_SEQ)
            & (hidden_offsets[None, :] < HIDDEN_DIM)
        ),
        other=0.0,
    ).to(tl.float32)

    grad_reduced = tl.sum(grad_tile, axis=0)

    tl.atomic_add(
        grad_pseudo_query_accumulator_ptr + hidden_offsets,
        grad_reduced,
        mask=hidden_offsets < HIDDEN_DIM,
        sem="relaxed",
    )


def phase_2_online_softmax_merge_intrablock_backward(
    intrablock_partial_sum,
    pseudo_query,
    phase1_interblock_normalized_output,
    phase1_interblock_logsumexp,
    grad_merged_attention_output,
    grad_intrablock_partial_sum,
    grad_pseudo_query,
    grad_phase1_interblock_normalized_output,
    grad_phase1_interblock_logsumexp,
    grad_pseudo_query_partial,
    eps=None,
):
    if eps is None:
        eps = EPS

    phase_2_online_softmax_merge_intrablock_backward_kernel[(BT,)](
        intrablock_partial_sum,
        pseudo_query,
        phase1_interblock_normalized_output,
        phase1_interblock_logsumexp,
        grad_merged_attention_output,
        grad_intrablock_partial_sum,
        grad_pseudo_query_partial,
        grad_phase1_interblock_normalized_output,
        grad_phase1_interblock_logsumexp,
        eps,
        D,
    )

    phase_2_reduce_grad_pseudo_query_kernel[
        lambda META: (
            triton.cdiv(BT, META["BLOCK_BATCH_SEQ"]),
            triton.cdiv(D, META["BLOCK_HIDDEN"]),
        )
    ](
        grad_pseudo_query_partial,
        grad_pseudo_query,
        BT,
        D,
    )


class BlockwiseAttentionFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, pseudo_queries, layers, eps, *flat_layer_params):
        block_representations = torch.empty(
            NUM_BLOCKS,
            B,
            T,
            D,
            device=DEVICE,
            dtype=inputs.dtype,
        )
        block_representations[0].copy_(inputs)

        block_attn_out_scratch = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            D,
            device=DEVICE,
            dtype=torch.bfloat16,
        )

        block_lse_scratch = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            device=DEVICE,
            dtype=torch.float32,
        )

        for block_start in range(0, L, BLOCK_SIZE):
            curr_block_idx = block_start // BLOCK_SIZE + 1
            num_queries = min(BLOCK_SIZE, L - block_start)

            block_attn_out = block_attn_out_scratch[:num_queries]
            block_lse = block_lse_scratch[:num_queries]

            phase_1_batched_interblock_attention(
                block_representations[:curr_block_idx],
                pseudo_queries[block_start : block_start + num_queries],
                block_attn_out,
                block_lse,
                eps=eps,
            )

            curr_block = block_representations[curr_block_idx]

            for query_offset in range(num_queries):
                i = block_start + query_offset

                if query_offset != 0:
                    phase_2_online_softmax_merge_intrablock(
                        curr_block,
                        pseudo_queries[i],
                        block_attn_out[query_offset],
                        block_lse[query_offset],
                        eps=eps,
                    )

                update = layers[i](block_attn_out[query_offset])

                if query_offset == 0:
                    curr_block.copy_(update)
                else:
                    curr_block.add_(update)

        final_out = torch.empty(
            B,
            T,
            D,
            device=DEVICE,
            dtype=inputs.dtype,
        )

        final_lse_scratch = torch.empty(
            1,
            B,
            T,
            device=DEVICE,
            dtype=torch.float32,
        )

        phase_1_batched_interblock_attention(
            block_representations,
            pseudo_queries[-1:],
            final_out.unsqueeze(0),
            final_lse_scratch,
            eps=eps,
        )

        ctx.save_for_backward(
            block_representations,
            pseudo_queries,
        )
        ctx.layers = layers
        ctx.eps = eps
        ctx.num_layer_params = len(flat_layer_params)

        return final_out

    @staticmethod
    def backward(ctx, *grad_outputs):
        grad_output = grad_outputs[0]
        if grad_output is None:
            return (None, None, None, None, *([None] * ctx.num_layer_params))

        block_representations, pseudo_queries = ctx.saved_tensors
        layers = ctx.layers
        eps = ctx.eps

        device = block_representations.device
        block_dtype = block_representations.dtype
        attn_dtype = torch.bfloat16

        grad_output = grad_output.contiguous()

        layer_param_groups = [tuple(layer.parameters()) for layer in layers]
        flat_layer_params = [p for group in layer_param_groups for p in group]

        param_offsets = []
        offset = 0
        for group in layer_param_groups:
            param_offsets.append(offset)
            offset += len(group)

        grad_flat_layer_params = [
            torch.zeros_like(p, dtype=torch.float32) if p.requires_grad else None
            for p in flat_layer_params
        ]

        grad_block_representations = torch.zeros_like(
            block_representations,
            dtype=torch.float32,
        )
        grad_pseudo_queries = torch.zeros_like(
            pseudo_queries,
            dtype=torch.float32,
        )

        grad_pseudo_queries_partial = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            D,
            device=device,
            dtype=torch.float32,
        )

        grad_phase2_pseudo_query_partial = torch.empty(
            B,
            T,
            D,
            device=device,
            dtype=torch.float32,
        )

        def run_layer_backward(layer_idx, layer_input_buf, grad_update_f32):
            params_i = layer_param_groups[layer_idx]
            active_param_indices = [
                j for j, p in enumerate(params_i) if p.requires_grad
            ]
            active_params = [params_i[j] for j in active_param_indices]

            with torch.enable_grad():
                layer_input = layer_input_buf.detach().requires_grad_(True)
                update = layers[layer_idx](layer_input)

                grad_results = torch.autograd.grad(
                    outputs=update,
                    inputs=(layer_input, *active_params),
                    grad_outputs=grad_update_f32.to(dtype=update.dtype),
                    retain_graph=False,
                    create_graph=False,
                    allow_unused=False,
                )

            grad_layer_input = grad_results[0]
            if grad_layer_input is None:
                grad_layer_input_f32 = torch.zeros_like(
                    layer_input_buf,
                    dtype=torch.float32,
                )
            else:
                grad_layer_input_f32 = grad_layer_input.to(torch.float32).contiguous()

            base = param_offsets[layer_idx]
            for local_idx, param_grad in zip(active_param_indices, grad_results[1:]):
                if param_grad is not None:
                    grad_flat_layer_params[base + local_idx].add_(
                        param_grad.to(torch.float32)
                    )

            return grad_layer_input_f32

        final_out_recomputed = torch.empty(
            1,
            B,
            T,
            D,
            device=device,
            dtype=attn_dtype,
        )
        final_lse = torch.empty(
            1,
            B,
            T,
            device=device,
            dtype=torch.float32,
        )

        with torch.no_grad():
            phase_1_batched_interblock_attention(
                block_representations,
                pseudo_queries[-1:],
                final_out_recomputed,
                final_lse,
                eps=eps,
            )

        phase_1_batched_interblock_attention_backward(
            block_representations,
            pseudo_queries[-1:],
            final_lse,
            grad_output.unsqueeze(0),
            None,
            grad_block_representations,
            grad_pseudo_queries[-1:],
            grad_pseudo_queries_partial[:1],
            eps=eps,
        )

        block_phase1_out_scratch = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            D,
            device=device,
            dtype=attn_dtype,
        )
        block_lse_scratch = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            device=device,
            dtype=torch.float32,
        )

        grad_block_phase1_out_scratch = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            D,
            device=device,
            dtype=torch.float32,
        )
        grad_block_lse_scratch = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            device=device,
            dtype=torch.float32,
        )

        intrablock_partial_before_scratch = torch.empty(
            max(BLOCK_SIZE - 1, 1),
            B,
            T,
            D,
            device=device,
            dtype=block_dtype,
        )

        partial_recompute = torch.empty(
            B,
            T,
            D,
            device=device,
            dtype=block_dtype,
        )

        layer_input_tmp = torch.empty(
            B,
            T,
            D,
            device=device,
            dtype=attn_dtype,
        )

        grad_curr_partial = torch.empty(
            B,
            T,
            D,
            device=device,
            dtype=torch.float32,
        )
        grad_prev_partial = torch.empty_like(grad_curr_partial)

        last_block_start = ((L - 1) // BLOCK_SIZE) * BLOCK_SIZE

        for block_start in range(last_block_start, -1, -BLOCK_SIZE):
            curr_block_idx = block_start // BLOCK_SIZE + 1
            num_queries = min(BLOCK_SIZE, L - block_start)

            phase1_out = block_phase1_out_scratch[:num_queries]
            phase1_lse = block_lse_scratch[:num_queries]

            grad_phase1_out = grad_block_phase1_out_scratch[:num_queries]
            grad_phase1_lse = grad_block_lse_scratch[:num_queries]

            grad_phase1_out.zero_()
            grad_phase1_lse.zero_()

            with torch.no_grad():
                phase_1_batched_interblock_attention(
                    block_representations[:curr_block_idx],
                    pseudo_queries[block_start : block_start + num_queries],
                    phase1_out,
                    phase1_lse,
                    eps=eps,
                )

                for query_offset in range(num_queries):
                    layer_idx = block_start + query_offset

                    layer_input_tmp.copy_(phase1_out[query_offset])

                    if query_offset != 0:
                        intrablock_partial_before_scratch[query_offset - 1].copy_(
                            partial_recompute
                        )

                        phase_2_online_softmax_merge_intrablock(
                            intrablock_partial_before_scratch[query_offset - 1],
                            pseudo_queries[layer_idx],
                            layer_input_tmp,
                            phase1_lse[query_offset],
                            eps=eps,
                        )

                    update = layers[layer_idx](layer_input_tmp)

                    if query_offset == 0:
                        partial_recompute.copy_(update)
                    else:
                        partial_recompute.add_(update)

            grad_curr_partial.copy_(grad_block_representations[curr_block_idx])

            for query_offset in range(num_queries - 1, -1, -1):
                layer_idx = block_start + query_offset

                with torch.no_grad():
                    layer_input_tmp.copy_(phase1_out[query_offset])

                    if query_offset != 0:
                        phase_2_online_softmax_merge_intrablock(
                            intrablock_partial_before_scratch[query_offset - 1],
                            pseudo_queries[layer_idx],
                            layer_input_tmp,
                            phase1_lse[query_offset],
                            eps=eps,
                        )

                grad_layer_input = run_layer_backward(
                    layer_idx,
                    layer_input_tmp,
                    grad_curr_partial,
                )

                if query_offset == 0:
                    grad_phase1_out[query_offset].copy_(grad_layer_input)
                else:
                    grad_prev_partial.copy_(grad_curr_partial)

                    phase_2_online_softmax_merge_intrablock_backward(
                        intrablock_partial_before_scratch[query_offset - 1],
                        pseudo_queries[layer_idx],
                        phase1_out[query_offset],
                        phase1_lse[query_offset],
                        grad_layer_input,
                        grad_prev_partial,
                        grad_pseudo_queries[layer_idx],
                        grad_phase1_out[query_offset],
                        grad_phase1_lse[query_offset],
                        grad_phase2_pseudo_query_partial,
                        eps=eps,
                    )

                    grad_curr_partial, grad_prev_partial = (
                        grad_prev_partial,
                        grad_curr_partial,
                    )

            phase_1_batched_interblock_attention_backward(
                block_representations[:curr_block_idx],
                pseudo_queries[block_start : block_start + num_queries],
                phase1_lse,
                grad_phase1_out,
                grad_phase1_lse,
                grad_block_representations[:curr_block_idx],
                grad_pseudo_queries[block_start : block_start + num_queries],
                grad_pseudo_queries_partial[:num_queries],
                eps=eps,
            )

        grad_inputs = (
            grad_block_representations[0].to(block_dtype)
            if ctx.needs_input_grad[0]
            else None
        )

        grad_pseudo_queries_out = (
            grad_pseudo_queries.to(pseudo_queries.dtype)
            if ctx.needs_input_grad[1]
            else None
        )

        grad_flat_layer_params_out = []
        for j, (param, grad_param) in enumerate(
            zip(flat_layer_params, grad_flat_layer_params)
        ):
            needs_grad = ctx.needs_input_grad[4 + j]
            if not needs_grad or grad_param is None:
                grad_flat_layer_params_out.append(None)
            else:
                grad_flat_layer_params_out.append(grad_param.to(param.dtype))

        return (
            grad_inputs,
            grad_pseudo_queries_out,
            None,
            None,
            *grad_flat_layer_params_out,
        )


def production_forward(inputs, pseudo_queries, layers, eps=None):
    if eps is None:
        eps = EPS

    flat_layer_params = tuple(p for layer in layers for p in layer.parameters())

    return BlockwiseAttentionFunction.apply(
        inputs,
        pseudo_queries,
        layers,
        eps,
        *flat_layer_params,
    )


@torch.compile(mode="max-autotune-no-cudagraphs")
def naive_attention_residual(pseudo_query, values):
    keys = F.rms_norm(values, (values.shape[-1],), eps=EPS)

    logits = torch.einsum("d, n b t d -> n b t", pseudo_query, keys)
    logits = logits - logits.max(dim=0, keepdim=True).values

    return torch.einsum(
        "n b t, n b t d -> b t d",
        logits.softmax(0),
        values,
    ).to(DTYPE)


def paper_forward(inputs, pseudo_queries, layers):
    inputs = inputs.to(torch.float32)
    pseudo_queries = pseudo_queries.to(torch.float32)

    blocks = [inputs]

    for i in range(len(layers)):
        outputs = naive_attention_residual(
            pseudo_queries[i],
            torch.stack(blocks, dim=0),
        )

        update = layers[i](outputs)

        if i % BLOCK_SIZE == 0:
            blocks.append(update)
        else:
            blocks[-1] = blocks[-1] + update

    return naive_attention_residual(
        pseudo_queries[-1],
        torch.stack(blocks, dim=0),
    )


@torch.compile(mode="max-autotune-no-cudagraphs")
def phase_1_fn(query, value):
    query = query.to(torch.float32)
    value = value.to(torch.float32)

    D_ = value.shape[-1]

    squared_norm_sum = (value * value).sum(dim=-1)
    inverse_rms_norm = torch.rsqrt(squared_norm_sum / float(D_) + EPS)
    raw_dot = torch.einsum("nbtd,sd->nbts", value, query)
    logits = raw_dot * inverse_rms_norm.unsqueeze(-1)

    max_logits = logits.amax(dim=0)
    exp_weights = torch.exp(logits - max_logits.unsqueeze(0))
    exp_sum = exp_weights.sum(dim=0)

    weighted_sum = (exp_weights.unsqueeze(-1) * value.unsqueeze(3)).sum(dim=0)
    normalized = (weighted_sum / exp_sum[..., None]).permute(2, 0, 1, 3).contiguous()

    lse = (max_logits + torch.log(exp_sum)).permute(2, 0, 1).contiguous()

    h = normalized[0]
    return lse, normalized.to(torch.bfloat16), h


@torch.compile(mode="max-autotune-no-cudagraphs")
def phase_2_fn(current_block_values, query_vector, prev_lse, prev_normalized):
    query_vector_f32 = query_vector.to(torch.float32)
    prev_normalized_f32 = prev_normalized.to(torch.float32)

    current_block_values_f32 = current_block_values.to(torch.float32)

    squared_norm_sum = (current_block_values_f32 * current_block_values_f32).sum(dim=-1)

    inverse_rms_norm = torch.rsqrt(
        squared_norm_sum / current_block_values_f32.shape[-1] + EPS
    )

    current_logit = (current_block_values_f32 @ query_vector_f32) * inverse_rms_norm

    merged_max = torch.maximum(prev_lse, current_logit)
    interblock_weight = torch.exp(prev_lse - merged_max)
    intrablock_weight = torch.exp(current_logit - merged_max)

    out = (
        interblock_weight[..., None] * prev_normalized_f32
        + intrablock_weight[..., None] * current_block_values_f32
    ) / (interblock_weight + intrablock_weight)[..., None]

    return out.to(torch.bfloat16)


def torch_compile_phases_forward(inputs, query_w, layers):
    blocks = [inputs]

    for i in range(len(layers)):
        offset = i % BLOCK_SIZE

        if offset == 0:
            values = torch.stack(blocks, dim=0)

            lse, normalized, h = phase_1_fn(query_w[i : i + BLOCK_SIZE], values)
            blocks.append(layers[i](h.to(inputs.dtype)))
        else:
            h = phase_2_fn(
                blocks[-1],
                query_w[i],
                lse[offset],
                normalized[offset],
            )

            blocks[-1] = blocks[-1] + layers[i](h.to(inputs.dtype))

    _, _, h = phase_1_fn(query_w[-1:], torch.stack(blocks, dim=0))
    return h.to(inputs.dtype)


class SwiGLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm = nn.RMSNorm(D, device=DEVICE, dtype=DTYPE, eps=EPS)
        self.linear1 = nn.Linear(D, D * 2, bias=False, device=DEVICE, dtype=DTYPE)
        self.linear2 = nn.Linear(D, D, bias=False, device=DEVICE, dtype=DTYPE)

    def forward(self, x):
        h1, gate = self.linear1(self.norm(x)).chunk(2, dim=-1)
        return self.linear2(F.silu(gate) * h1)


class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


def grad_targets(inputs, pseudo_queries, layers):
    params = tuple(p for layer in layers for p in layer.parameters() if p.requires_grad)
    return (inputs, pseudo_queries, *params)


def bench_fwd_bwd(fn, inputs, pseudo_queries, layers, grad_out, warmup=3, runs=10):
    targets = grad_targets(inputs, pseudo_queries, layers)

    for _ in range(warmup):
        out = fn(inputs, pseudo_queries, layers)
        torch.autograd.grad(
            outputs=out,
            inputs=targets,
            grad_outputs=grad_out,
            retain_graph=False,
            create_graph=False,
            allow_unused=False,
        )

    torch.cuda.synchronize()
    t0 = time.perf_counter()

    for _ in range(runs):
        out = fn(inputs, pseudo_queries, layers)
        torch.autograd.grad(
            outputs=out,
            inputs=targets,
            grad_outputs=grad_out,
            retain_graph=False,
            create_graph=False,
            allow_unused=True,
        )

    torch.cuda.synchronize()

    return (time.perf_counter() - t0) / runs * 1000


def collect_grads(fn, inputs, pseudo_queries, layers, grad_out):
    targets = grad_targets(inputs, pseudo_queries, layers)

    out = fn(inputs, pseudo_queries, layers)

    grads = torch.autograd.grad(
        outputs=out,
        inputs=targets,
        grad_outputs=grad_out,
        retain_graph=False,
        create_graph=False,
        allow_unused=False,
    )

    grads = [grad.detach().to(torch.float32) for grad in grads]
    return out.detach(), grads


def compare_grads(
    ref_name, ref_fn, test_name, test_fn, inputs, pseudo_queries, layers, grad_out
):
    ref_out, ref_grads = collect_grads(ref_fn, inputs, pseudo_queries, layers, grad_out)
    test_out, test_grads = collect_grads(
        test_fn, inputs, pseudo_queries, layers, grad_out
    )

    out_abs = (ref_out.to(torch.float32) - test_out.to(torch.float32)).abs()
    print(
        f"{test_name} vs {ref_name} output: "
        f"mean_abs={out_abs.mean()}, max_abs={out_abs.max()}"
    )

    for idx, (rg, tg) in enumerate(zip(ref_grads, test_grads)):
        if rg is None or tg is None:
            print(
                f"{test_name} grad[{idx}] vs {ref_name}: "
                f"None mismatch: ref_is_none={rg is None}, test_is_none={tg is None}"
            )
            continue

        diff = (rg - tg).abs()
        rel = diff / (rg.abs() + 1e-3)

        norm_rel = (rg - tg).norm() / (rg.norm() + 1e-12)

        rg_abs_avg = rg.abs().mean()
        tg_abs_avg = tg.abs().mean()

        print(
            f"{test_name} grad[{idx}] vs {ref_name}: "
            f"mean_abs={diff.mean()}, max_abs={diff.max()}, "
            f"mean_rel={rel.mean()}, max_rel={rel.max()}, "
            f"norm_rel={norm_rel}, "
            f"ref_abs_avg={rg_abs_avg}, test_abs_avg={tg_abs_avg}"
        )


def bench_backward_only(
    fn, inputs, pseudo_queries, layers, grad_out, warmup=3, runs=10
):
    targets = grad_targets(inputs, pseudo_queries, layers)

    for _ in range(warmup):
        out = fn(inputs, pseudo_queries, layers)
        torch.cuda.synchronize()

        torch.autograd.grad(
            outputs=out,
            inputs=targets,
            grad_outputs=grad_out,
            retain_graph=False,
            create_graph=False,
            allow_unused=False,
        )
        torch.cuda.synchronize()

    total = 0.0

    for _ in range(runs):
        out = fn(inputs, pseudo_queries, layers)
        torch.cuda.synchronize()

        t0 = time.perf_counter()
        torch.autograd.grad(
            outputs=out,
            inputs=targets,
            grad_outputs=grad_out,
            retain_graph=False,
            create_graph=False,
            allow_unused=True,
        )
        torch.cuda.synchronize()

        total += time.perf_counter() - t0

    return total / runs * 1000


def print_bench_group(title, args):
    print(title)
    for name, func in funcs_to_bench:
        fwd_bwd = bench_fwd_bwd(func, *args, grad_out)
        bwd = bench_backward_only(func, *args, grad_out)
        print(f"{name} fwd+bwd:  {fwd_bwd:.3f} ms")
        print(f"{name} bwd-only: {bwd:.3f} ms")
    print()


for i in range(0):
    inputs = torch.randn(
        B,
        T,
        D,
        device=DEVICE,
        dtype=DTYPE,
        requires_grad=True,
    )

    layers_swiglu = [SwiGLU() for _ in range(L)]
    layers_identity = [Identity() for _ in range(L)]

    pseudo_queries_zeros = torch.zeros(
        L + 1,
        D,
        device=DEVICE,
        dtype=DTYPE,
        requires_grad=True,
    )

    pseudo_queries_randn = torch.randn(
        L + 1,
        D,
        device=DEVICE,
        dtype=DTYPE,
        requires_grad=True,
    ) / math.sqrt(D)

    grad_out = torch.randn(
        B,
        T,
        D,
        device=DEVICE,
        dtype=DTYPE,
    )

    args_identity = (inputs, pseudo_queries_randn, layers_swiglu)

    args_swiglu_zeros = (inputs, pseudo_queries_zeros, layers_swiglu)
    args_swiglu_randn = (inputs, pseudo_queries_randn, layers_swiglu)

    funcs_to_bench = [
        ("torch_compile_phases_forward", torch_compile_phases_forward),
        # ("production_forward", production_forward),
        ("paper_forward", paper_forward),
    ]

    random.shuffle(funcs_to_bench)

    print_bench_group("identity / randn queries", args_identity)
    print_bench_group("swiglu / zero queries", args_swiglu_zeros)
    print_bench_group("swiglu / randn queries", args_swiglu_randn)

    # compare_grads(
    #     "paper_forward",
    #     paper_forward,
    #     "production_forward",
    #     production_forward,
    #     *args_identity,
    #     grad_out,
    # )

    # compare_grads(
    #     "paper_forward",
    #     paper_forward,
    #     "torch_compile_phases_forward",
    #     torch_compile_phases_forward,
    #     *args_identity,
    #     grad_out,
    # )

def phase_1_torch_backward_once(
    block_representations,
    pseudo_queries,
    grad_softmax_outputs,
    grad_lses,
):
    values = block_representations.detach().requires_grad_(True)
    queries = pseudo_queries.detach().requires_grad_(True)

    lse, normalized, _ = phase_1_fn(queries, values)

    if grad_lses is None:
        outputs = (normalized,)
        grad_outputs = (grad_softmax_outputs.to(normalized.dtype),)
    else:
        outputs = (normalized, lse)
        grad_outputs = (
            grad_softmax_outputs.to(normalized.dtype),
            grad_lses,
        )

    grad_values, grad_queries = torch.autograd.grad(
        outputs=outputs,
        inputs=(values, queries),
        grad_outputs=grad_outputs,
        retain_graph=False,
        create_graph=False,
        allow_unused=False,
    )

    return grad_values, grad_queries


def _time_cuda_ms(fn, warmup=5, runs=20):
    for _ in range(warmup):
        fn()

    torch.cuda.synchronize()
    t0 = time.perf_counter()

    for _ in range(runs):
        fn()

    torch.cuda.synchronize()
    return (time.perf_counter() - t0) * 1000.0 / runs


@torch.no_grad()
def make_phase1_backward_case(
    num_source_blocks,
    num_queries,
    *,
    has_grad_lse,
    grad_out_dtype=torch.float32,
    device=DEVICE,
    dtype=DTYPE,
):
    block_representations = torch.randn(
        num_source_blocks, B, T, D, device=device, dtype=dtype
    )

    pseudo_queries = torch.randn(
        num_queries, D, device=device, dtype=dtype
    ) / math.sqrt(D)

    phase1_out = torch.empty(
        num_queries, B, T, D, device=device, dtype=dtype
    )

    lses = torch.empty(
        num_queries, B, T, device=device, dtype=torch.float32
    )

    phase_1_batched_interblock_attention(
        block_representations,
        pseudo_queries,
        phase1_out,
        lses,
        eps=EPS,
    )

    grad_softmax_outputs = torch.randn(
        num_queries, B, T, D, device=device, dtype=grad_out_dtype
    )

    grad_lses = (
        torch.randn(num_queries, B, T, device=device, dtype=torch.float32)
        if has_grad_lse
        else None
    )

    grad_block_representations = torch.zeros(
        num_source_blocks, B, T, D, device=device, dtype=torch.float32
    )

    grad_pseudo_queries = torch.zeros(
        num_queries, D, device=device, dtype=torch.float32
    )

    grad_pseudo_queries_partial = torch.zeros(
        num_queries, B, T, D, device=device, dtype=torch.float32
    )

    torch.cuda.synchronize()

    return {
        "num_source_blocks": num_source_blocks,
        "num_queries": num_queries,
        "has_grad_lse": has_grad_lse,
        "block_representations": block_representations,
        "pseudo_queries": pseudo_queries,
        "lses": lses,
        "grad_softmax_outputs": grad_softmax_outputs,
        "grad_lses": grad_lses,
        "grad_block_representations": grad_block_representations,
        "grad_pseudo_queries": grad_pseudo_queries,
        "grad_pseudo_queries_partial": grad_pseudo_queries_partial,
    }


def _launch_phase1_backward_main(case):
    grad_lses_arg = case["grad_lses"]
    if grad_lses_arg is None:
        grad_lses_arg = case["lses"]

    phase_1_batched_interblock_attention_backward_kernel[(BT,)](
        case["block_representations"],
        case["pseudo_queries"],
        case["lses"],
        case["grad_softmax_outputs"],
        grad_lses_arg,
        case["grad_block_representations"],
        case["grad_pseudo_queries_partial"],
        EPS,
        case["num_source_blocks"],
        BT,
        D,
        case["num_queries"],
        triton.next_power_of_2(case["num_source_blocks"]),
        case["has_grad_lse"],
    )


def _launch_phase1_backward_reduce(case):
    phase_1_reduce_grad_pseudo_queries_kernel[
        lambda META: (
            triton.cdiv(BT, META["BLOCK_BATCH_SEQ"]),
            case["num_queries"],
            triton.cdiv(D, META["BLOCK_HIDDEN"]),
        )
    ](
        case["grad_pseudo_queries_partial"],
        case["grad_pseudo_queries"],
        BT,
        D,
        case["num_queries"],
    )


def _launch_phase1_backward_full(case):
    phase_1_batched_interblock_attention_backward(
        case["block_representations"],
        case["pseudo_queries"],
        case["lses"],
        case["grad_softmax_outputs"],
        case["grad_lses"],
        case["grad_block_representations"],
        case["grad_pseudo_queries"],
        case["grad_pseudo_queries_partial"],
        eps=EPS,
    )


def bench_phase1_backward_case(
    case,
    *,
    mode,
    warmup=5,
    runs=20,
    zero_each_run=False,
):
    if mode == "full":
        launcher = _launch_phase1_backward_full
    elif mode == "main":
        launcher = _launch_phase1_backward_main
    elif mode == "reduce":
        launcher = _launch_phase1_backward_reduce
    elif mode == "torch":
        launcher = lambda c: phase_1_torch_backward_once(
            c["block_representations"],
            c["pseudo_queries"],
            c["grad_softmax_outputs"],
            c["grad_lses"],
        )
    else:
        raise ValueError(f"unknown mode: {mode}")

    def run_once():
        if zero_each_run:
            if mode in ("full", "main"):
                case["grad_block_representations"].zero_()
            if mode in ("full", "reduce"):
                case["grad_pseudo_queries"].zero_()

        launcher(case)

    return _time_cuda_ms(run_once, warmup=warmup, runs=runs)


def print_phase1_backward_microbench(
    *,
    warmup=5,
    runs=20,
    zero_each_run=False,
):
    print()
    print(
        "phase 1 backward microbench "
        f"(B={B}, T={T}, D={D}, BT={BT}, BLOCK_SIZE={BLOCK_SIZE})"
    )
    print(f"zero_each_run={zero_each_run}")
    print()

    print(
        f"{'case':<42} "
        f"{'triton full':>12} "
        f"{'triton main':>12} "
        f"{'triton red':>12} "
        f"{'torch bwd':>12}"
    )

    for block_start in range(0, L, BLOCK_SIZE):
        num_source_blocks = block_start // BLOCK_SIZE + 1
        num_queries = min(BLOCK_SIZE, L - block_start)

        label = (
            f"block_start={block_start:<2} "
            f"src={num_source_blocks:<2} "
            f"q={num_queries:<2} "
            f"grad_lse=True"
        )

        case = make_phase1_backward_case(
            num_source_blocks,
            num_queries,
            has_grad_lse=True,
            grad_out_dtype=torch.float32,
        )

        full_ms = bench_phase1_backward_case(
            case, mode="full", warmup=warmup, runs=runs, zero_each_run=zero_each_run
        )
        main_ms = bench_phase1_backward_case(
            case, mode="main", warmup=warmup, runs=runs, zero_each_run=zero_each_run
        )
        reduce_ms = bench_phase1_backward_case(
            case, mode="reduce", warmup=warmup, runs=runs, zero_each_run=zero_each_run
        )
        torch_ms = bench_phase1_backward_case(
            case, mode="torch", warmup=warmup, runs=runs, zero_each_run=False
        )

        print(
            f"{label:<42} "
            f"{full_ms:12.3f} "
            f"{main_ms:12.3f} "
            f"{reduce_ms:12.3f} "
            f"{torch_ms:12.3f}"
        )

    final_case = make_phase1_backward_case(
        NUM_BLOCKS,
        1,
        has_grad_lse=False,
        grad_out_dtype=DTYPE,
    )

    final_label = f"final src={NUM_BLOCKS:<2} q={1:<2} grad_lse=False"

    full_ms = bench_phase1_backward_case(
        final_case, mode="full", warmup=warmup, runs=runs, zero_each_run=zero_each_run
    )
    main_ms = bench_phase1_backward_case(
        final_case, mode="main", warmup=warmup, runs=runs, zero_each_run=zero_each_run
    )
    reduce_ms = bench_phase1_backward_case(
        final_case, mode="reduce", warmup=warmup, runs=runs, zero_each_run=zero_each_run
    )
    torch_ms = bench_phase1_backward_case(
        final_case, mode="torch", warmup=warmup, runs=runs, zero_each_run=False
    )

    print(
        f"{final_label:<42} "
        f"{full_ms:12.3f} "
        f"{main_ms:12.3f} "
        f"{reduce_ms:12.3f} "
        f"{torch_ms:12.3f}"
    )

    print()


@torch.no_grad()
def make_phase1_backward_sweep_case(
    *,
    block_grad_out_dtype=torch.float32,
    final_grad_out_dtype=DTYPE,
    device=DEVICE,
    dtype=DTYPE,
):
    block_representations = torch.randn(
        NUM_BLOCKS, B, T, D, device=device, dtype=dtype
    )

    pseudo_queries = torch.randn(
        L + 1, D, device=device, dtype=dtype
    ) / math.sqrt(D)

    phase1_out_scratch = torch.empty(
        BLOCK_SIZE, B, T, D, device=device, dtype=dtype
    )

    final_lse = torch.empty(
        1, B, T, device=device, dtype=torch.float32
    )

    phase_1_batched_interblock_attention(
        block_representations,
        pseudo_queries[-1:],
        phase1_out_scratch[:1],
        final_lse,
        eps=EPS,
    )

    block_specs = []

    for block_start in range(0, L, BLOCK_SIZE):
        curr_block_idx = block_start // BLOCK_SIZE + 1
        num_queries = min(BLOCK_SIZE, L - block_start)

        lse = torch.empty(
            num_queries, B, T, device=device, dtype=torch.float32
        )

        phase_1_batched_interblock_attention(
            block_representations[:curr_block_idx],
            pseudo_queries[block_start : block_start + num_queries],
            phase1_out_scratch[:num_queries],
            lse,
            eps=EPS,
        )

        block_specs.append(
            {
                "block_start": block_start,
                "curr_block_idx": curr_block_idx,
                "num_queries": num_queries,
                "lse": lse,
            }
        )

    grad_final_out = torch.randn(
        1, B, T, D, device=device, dtype=final_grad_out_dtype
    )

    grad_block_phase1_out = torch.randn(
        BLOCK_SIZE, B, T, D, device=device, dtype=block_grad_out_dtype
    )

    grad_block_phase1_lse = torch.randn(
        BLOCK_SIZE, B, T, device=device, dtype=torch.float32
    )

    grad_block_representations = torch.zeros(
        NUM_BLOCKS, B, T, D, device=device, dtype=torch.float32
    )

    grad_pseudo_queries = torch.zeros(
        L + 1, D, device=device, dtype=torch.float32
    )

    grad_pseudo_queries_partial = torch.zeros(
        BLOCK_SIZE, B, T, D, device=device, dtype=torch.float32
    )

    torch.cuda.synchronize()

    return {
        "block_representations": block_representations,
        "pseudo_queries": pseudo_queries,
        "final_lse": final_lse,
        "block_specs": block_specs,
        "grad_final_out": grad_final_out,
        "grad_block_phase1_out": grad_block_phase1_out,
        "grad_block_phase1_lse": grad_block_phase1_lse,
        "grad_block_representations": grad_block_representations,
        "grad_pseudo_queries": grad_pseudo_queries,
        "grad_pseudo_queries_partial": grad_pseudo_queries_partial,
    }


def run_phase1_backward_sweep_once(case, *, zero_each_run=True):
    if zero_each_run:
        case["grad_block_representations"].zero_()
        case["grad_pseudo_queries"].zero_()

    phase_1_batched_interblock_attention_backward(
        case["block_representations"],
        case["pseudo_queries"][-1:],
        case["final_lse"],
        case["grad_final_out"],
        None,
        case["grad_block_representations"],
        case["grad_pseudo_queries"][-1:],
        case["grad_pseudo_queries_partial"][:1],
        eps=EPS,
    )

    for spec in reversed(case["block_specs"]):
        block_start = spec["block_start"]
        curr_block_idx = spec["curr_block_idx"]
        num_queries = spec["num_queries"]

        phase_1_batched_interblock_attention_backward(
            case["block_representations"][:curr_block_idx],
            case["pseudo_queries"][block_start : block_start + num_queries],
            spec["lse"],
            case["grad_block_phase1_out"][:num_queries],
            case["grad_block_phase1_lse"][:num_queries],
            case["grad_block_representations"][:curr_block_idx],
            case["grad_pseudo_queries"][block_start : block_start + num_queries],
            case["grad_pseudo_queries_partial"][:num_queries],
            eps=EPS,
        )


def run_phase1_torch_backward_sweep_once(case):
    phase_1_torch_backward_once(
        case["block_representations"],
        case["pseudo_queries"][-1:],
        case["grad_final_out"],
        None,
    )

    for spec in reversed(case["block_specs"]):
        block_start = spec["block_start"]
        curr_block_idx = spec["curr_block_idx"]
        num_queries = spec["num_queries"]

        phase_1_torch_backward_once(
            case["block_representations"][:curr_block_idx],
            case["pseudo_queries"][block_start : block_start + num_queries],
            case["grad_block_phase1_out"][:num_queries],
            case["grad_block_phase1_lse"][:num_queries],
        )


def bench_phase1_backward_sweep(
    *,
    warmup=5,
    runs=20,
    zero_each_run=True,
):
    case = make_phase1_backward_sweep_case()

    def run_once():
        run_phase1_backward_sweep_once(case, zero_each_run=zero_each_run)

    return _time_cuda_ms(run_once, warmup=warmup, runs=runs)


def bench_phase1_torch_backward_sweep(
    *,
    warmup=5,
    runs=20,
):
    case = make_phase1_backward_sweep_case()

    def run_once():
        run_phase1_torch_backward_sweep_once(case)

    return _time_cuda_ms(run_once, warmup=warmup, runs=runs)


def print_phase1_backward_sweep_bench(
    *,
    warmup=5,
    runs=20,
    zero_each_run=True,
):
    triton_ms = bench_phase1_backward_sweep(
        warmup=warmup,
        runs=runs,
        zero_each_run=zero_each_run,
    )

    torch_ms = bench_phase1_torch_backward_sweep(
        warmup=warmup,
        runs=runs,
    )

    print()
    print("phase 1 backward sweep only")
    print(
        "Includes final phase-1 backward plus all per-block phase-1 backward calls."
    )
    print(f"triton zero_each_run={zero_each_run}")
    print(f"triton phase1 backward sweep: {triton_ms:.3f} ms")
    print(f"torch  phase1 backward sweep: {torch_ms:.3f} ms")
    print()


if __name__ == "__main__":
    torch.manual_seed(0)

    print_phase1_backward_microbench(
        warmup=5,
        runs=20,
        zero_each_run=False,
    )

    print_phase1_backward_sweep_bench(
        warmup=5,
        runs=20,
        zero_each_run=True,
    )



phase 1 backward microbench (B=32, T=1024, D=512, BT=32768, BLOCK_SIZE=8)
zero_each_run=False

case                                        triton full  triton main   triton red    torch bwd
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_kernel,
with key as (1, 512, 8, 1, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32'),
finished after 2.46s,
best config selected: num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None;
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_backward_kernel,
with key as (1, 512, 8, 1, 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32'),
finished after 2.71s,
best config selected: num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None;
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 16, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 16, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 16, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Triton autotuning for function phase_1_reduce_grad_pseudo_queries_kernel,
with key as (32768, 512, 8, 'torch.float32', 'torch.float32'),
finished after 1.51s,
best config selected: BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None;

/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py:321: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/_inductor/select_algorithm.py:3464: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  current_size = base.storage().size()
Autotune Choices Stats:
{"num_choices": 15, "num_triton_choices": 14, "best_kernel": "mm", "best_time": 0.11673600226640701, "best_triton_pos": 1, "best_triton_time": 0.1228799968957901, "best_triton_kernel": "triton_mm_10", "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4"}
AUTOTUNE mm(32768x512, 512x8)
strides: [512, 1], [1, 512]
dtypes: torch.float32, torch.float32
  mm 0.1167 ms 100.0% 
  triton_mm_10 0.1229 ms 95.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_13 0.1249 ms 93.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_7 0.1628 ms 71.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_8 0.1659 ms 70.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_5 0.1679 ms 69.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_6 0.1700 ms 68.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_4 0.1741 ms 67.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_9 0.1802 ms 64.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_11 0.1915 ms 61.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.5945 seconds and 0.9111 seconds precompiling for 15 choices
Autotune Choices Stats:
{"num_choices": 7, "num_triton_choices": 0, "best_kernel": "mm", "best_time": 0.10547199845314026}
AUTOTUNE mm(512x32768, 32768x8)
strides: [1, 512], [8, 1]
dtypes: torch.float32, torch.float32
  mm 0.1055 ms 100.0% 
  decompose_k_mm_64_split_5 0.3820 ms 27.6% k_split=64
  decompose_k_mm_16_split_3 0.4987 ms 21.1% k_split=16
  decompose_k_mm_32_split_4 0.5028 ms 21.0% k_split=32
  decompose_k_mm_8_split_2 0.9800 ms 10.8% k_split=8
  decompose_k_mm_4_split_1 1.9558 ms 5.4% k_split=4
  decompose_k_mm_2_split_0 3.8728 ms 2.7% k_split=2
SingleProcess AUTOTUNE benchmarking takes 3.7342 seconds and 0.0004 seconds precompiling for 7 choices
Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "triton_mm_20", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8", "best_time": 0.053247999399900436, "best_triton_pos": 0}
AUTOTUNE mm(32768x8, 8x512)
strides: [8, 1], [512, 1]
dtypes: torch.float32, torch.float32
  triton_mm_20 0.0532 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_21 0.0532 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_19 0.0543 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_22 0.0543 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_23 0.0543 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_24 0.0543 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_25 0.0543 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_26 0.0543 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_30 0.0563 ms 94.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  mm 0.0584 ms 91.2% 
SingleProcess AUTOTUNE benchmarking takes 0.3440 seconds and 0.0002 seconds precompiling for 18 choices

block_start=0  src=1  q=8  grad_lse=True          1.441        1.053        0.383        2.727
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_kernel,
with key as (2, 512, 8, 2, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32'),
finished after 16.83s,
best config selected: num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None;
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_backward_kernel,
with key as (2, 512, 8, 2, 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32'),
finished after 35.75s,
best config selected: num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None;

Autotune Choices Stats:
{"num_choices": 15, "num_triton_choices": 14, "best_kernel": "triton_mm_41", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4", "best_time": 0.19660800695419312, "best_triton_pos": 0}
AUTOTUNE mm(65536x512, 512x8)
strides: [512, 1], [1, 512]
dtypes: torch.float32, torch.float32
  triton_mm_41 0.1966 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  mm 0.2017 ms 97.5% 
  triton_mm_44 0.2028 ms 97.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_38 0.3144 ms 62.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_36 0.3154 ms 62.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_37 0.3154 ms 62.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_42 0.3154 ms 62.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_39 0.3164 ms 62.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_43 0.3174 ms 61.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=8
  triton_mm_34 0.3195 ms 61.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.5753 seconds and 0.8035 seconds precompiling for 15 choices
Autotune Choices Stats:
{"num_choices": 6, "num_triton_choices": 0, "best_kernel": "mm", "best_time": 0.16998399794101715}
AUTOTUNE mm(512x65536, 65536x8)
strides: [1, 512], [8, 1]
dtypes: torch.float32, torch.float32
  mm 0.1700 ms 100.0% 
  decompose_k_mm_128_split_9 0.6277 ms 27.1% k_split=128
  decompose_k_mm_256_split_10 0.6359 ms 26.7% k_split=256
  decompose_k_mm_64_split_8 0.7434 ms 22.9% k_split=64
  decompose_k_mm_32_split_7 0.9830 ms 17.3% k_split=32
  decompose_k_mm_16_split_6 0.9871 ms 17.2% k_split=16
SingleProcess AUTOTUNE benchmarking takes 3.8403 seconds and 0.0004 seconds precompiling for 6 choices
Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "triton_mm_51", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8", "best_time": 0.09932799637317657, "best_triton_pos": 0}
AUTOTUNE mm(65536x8, 8x512)
strides: [8, 1], [512, 1]
dtypes: torch.float32, torch.float32
  triton_mm_51 0.0993 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_52 0.0993 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_56 0.0993 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_57 0.0993 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_53 0.1004 ms 99.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_54 0.1004 ms 99.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_55 0.1004 ms 99.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_50 0.1014 ms 98.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_59 0.1024 ms 97.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  mm 0.1034 ms 96.0% 
SingleProcess AUTOTUNE benchmarking takes 0.4842 seconds and 0.0003 seconds precompiling for 18 choices

block_start=8  src=2  q=8  grad_lse=True          3.126        2.732        0.386        4.629
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_kernel,
with key as (3, 512, 8, 4, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32'),
finished after 19.02s,
best config selected: num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None;
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_backward_kernel,
with key as (3, 512, 8, 4, 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32'),
finished after 45.86s,
best config selected: num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None;
block_start=16 src=3  q=8  grad_lse=True          4.312        3.918        0.384        5.664
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_kernel,
with key as (4, 512, 8, 4, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32'),
finished after 19.05s,
best config selected: num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None;
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_backward_kernel,
with key as (4, 512, 8, 4, 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32'),
finished after 50.46s,
best config selected: num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None;
block_start=24 src=4  q=8  grad_lse=True          5.484        5.088        0.384        6.812
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_kernel,
with key as (5, 512, 1, 8, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32'),
finished after 8.60s,
best config selected: num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None;
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_backward_kernel,
with key as (5, 512, 1, 8, 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32'),
finished after 13.12s,
best config selected: num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None;
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 16, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 16, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 16, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Triton autotuning for function phase_1_reduce_grad_pseudo_queries_kernel,
with key as (32768, 512, 1, 'torch.float32', 'torch.float32'),
finished after 2.87s,
best config selected: BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None;

Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "triton_mm_68", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8", "best_time": 0.22630399465560913, "best_triton_pos": 0}
AUTOTUNE mm(163840x1, 1x512)
strides: [1, 0], [512, 1]
dtypes: torch.float32, torch.float32
  triton_mm_68 0.2263 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_74 0.2263 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_69 0.2273 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_70 0.2273 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_71 0.2273 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_72 0.2273 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_73 0.2273 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_67 0.2294 ms 98.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_76 0.2304 ms 98.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  mm 0.2324 ms 97.4% 
SingleProcess AUTOTUNE benchmarking takes 0.7192 seconds and 0.8252 seconds precompiling for 18 choices

final src=5  q=1  grad_lse=False                  1.084        1.023        0.056        2.378


phase 1 backward sweep only
Includes final phase-1 backward plus all per-block phase-1 backward calls.
triton zero_each_run=True
triton phase1 backward sweep: 15.630 ms
torch  phase1 backward sweep: 21.866 ms

