
phase 1 backward microbench (B=32, T=1024, D=512, BT=32768, BLOCK_SIZE=8)
zero_each_run=False

case                                        triton full  triton main   triton red    torch bwd
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_kernel,
with key as (1, 512, 8, 1, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32'),
finished after 11.17s,
best config selected: num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None;
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_backward_kernel,
with key as (1, 512, 8, 1, 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32'),
finished after 21.57s,
best config selected: num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None;
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 16, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 16, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 16, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Triton autotuning for function phase_1_reduce_grad_pseudo_queries_kernel,
with key as (32768, 512, 8, 'torch.float32', 'torch.float32'),
finished after 2.80s,
best config selected: BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None;

/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py:321: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/_inductor/select_algorithm.py:3464: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  current_size = base.storage().size()
Autotune Choices Stats:
{"num_choices": 15, "num_triton_choices": 14, "best_kernel": "mm", "best_time": 0.11673600226640701, "best_triton_pos": 1, "best_triton_time": 0.1228799968957901, "best_triton_kernel": "triton_mm_10", "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4"}
AUTOTUNE mm(32768x512, 512x8)
strides: [512, 1], [1, 512]
dtypes: torch.float32, torch.float32
  mm 0.1167 ms 100.0% 
  triton_mm_10 0.1229 ms 95.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_13 0.1249 ms 93.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_7 0.1618 ms 72.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_8 0.1649 ms 70.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_5 0.1669 ms 69.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_6 0.1690 ms 69.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_4 0.1731 ms 67.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_9 0.1792 ms 65.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_11 0.1925 ms 60.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.6621 seconds and 1.0095 seconds precompiling for 15 choices
Autotune Choices Stats:
{"num_choices": 7, "num_triton_choices": 0, "best_kernel": "mm", "best_time": 0.10956799983978271}
AUTOTUNE mm(512x32768, 32768x8)
strides: [1, 512], [8, 1]
dtypes: torch.float32, torch.float32
  mm 0.1096 ms 100.0% 
  decompose_k_mm_64_split_5 0.3820 ms 28.7% k_split=64
  decompose_k_mm_16_split_3 0.4987 ms 22.0% k_split=16
  decompose_k_mm_32_split_4 0.5007 ms 21.9% k_split=32
  decompose_k_mm_8_split_2 0.9810 ms 11.2% k_split=8
  decompose_k_mm_4_split_1 1.9446 ms 5.6% k_split=4
  decompose_k_mm_2_split_0 3.8738 ms 2.8% k_split=2
SingleProcess AUTOTUNE benchmarking takes 3.8155 seconds and 0.0004 seconds precompiling for 7 choices
Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "triton_mm_19", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4", "best_time": 0.053247999399900436, "best_triton_pos": 0}
AUTOTUNE mm(32768x8, 8x512)
strides: [8, 1], [512, 1]
dtypes: torch.float32, torch.float32
  triton_mm_19 0.0532 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_21 0.0532 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_20 0.0543 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_22 0.0543 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_23 0.0543 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_24 0.0543 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_25 0.0543 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_26 0.0543 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_29 0.0563 ms 94.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_30 0.0563 ms 94.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.3493 seconds and 0.0003 seconds precompiling for 18 choices

block_start=0  src=1  q=8  grad_lse=True          1.314        0.926        0.385        2.774
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_kernel,
with key as (2, 512, 8, 2, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32'),
finished after 14.55s,
best config selected: num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None;
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_backward_kernel,
with key as (2, 512, 8, 2, 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32'),
finished after 34.44s,
best config selected: num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None;

Autotune Choices Stats:
{"num_choices": 15, "num_triton_choices": 14, "best_kernel": "triton_mm_41", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4", "best_time": 0.19660800695419312, "best_triton_pos": 0}
AUTOTUNE mm(65536x512, 512x8)
strides: [512, 1], [1, 512]
dtypes: torch.float32, torch.float32
  triton_mm_41 0.1966 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  mm 0.2007 ms 98.0% 
  triton_mm_44 0.2028 ms 97.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_38 0.3144 ms 62.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_42 0.3144 ms 62.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_39 0.3164 ms 62.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_43 0.3185 ms 61.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=8
  triton_mm_40 0.3461 ms 56.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_37 0.3645 ms 53.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_36 0.3656 ms 53.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.5788 seconds and 0.7068 seconds precompiling for 15 choices
Autotune Choices Stats:
{"num_choices": 6, "num_triton_choices": 0, "best_kernel": "mm", "best_time": 0.1730560064315796}
AUTOTUNE mm(512x65536, 65536x8)
strides: [1, 512], [8, 1]
dtypes: torch.float32, torch.float32
  mm 0.1731 ms 100.0% 
  decompose_k_mm_128_split_9 0.5581 ms 31.0% k_split=128
  decompose_k_mm_256_split_10 0.6369 ms 27.2% k_split=256
  decompose_k_mm_64_split_8 0.6625 ms 26.1% k_split=64
  decompose_k_mm_16_split_6 0.9810 ms 17.6% k_split=16
  decompose_k_mm_32_split_7 0.9830 ms 17.6% k_split=32
SingleProcess AUTOTUNE benchmarking takes 3.2613 seconds and 0.0004 seconds precompiling for 6 choices
Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "triton_mm_51", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8", "best_time": 0.09932799637317657, "best_triton_pos": 0}
AUTOTUNE mm(65536x8, 8x512)
strides: [8, 1], [512, 1]
dtypes: torch.float32, torch.float32
  triton_mm_51 0.0993 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_52 0.0993 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_56 0.0993 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_57 0.0993 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_53 0.1004 ms 99.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_54 0.1004 ms 99.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_55 0.1004 ms 99.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_50 0.1014 ms 98.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_59 0.1024 ms 97.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  mm 0.1034 ms 96.0% 
SingleProcess AUTOTUNE benchmarking takes 0.4724 seconds and 0.0004 seconds precompiling for 18 choices

block_start=8  src=2  q=8  grad_lse=True          1.475        1.081        0.387        4.582
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_kernel,
with key as (3, 512, 8, 4, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32'),
finished after 17.48s,
best config selected: num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None;
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_backward_kernel,
with key as (3, 512, 8, 4, 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32'),
finished after 51.11s,
best config selected: num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None;
block_start=16 src=3  q=8  grad_lse=True          1.655        1.270        0.384        5.636
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_kernel,
with key as (4, 512, 8, 4, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32'),
finished after 17.77s,
best config selected: num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None;
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_backward_kernel,
with key as (4, 512, 8, 4, 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32'),
finished after 56.88s,
best config selected: num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None;
block_start=24 src=4  q=8  grad_lse=True          1.932        1.537        0.388        6.781
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_kernel,
with key as (5, 512, 1, 8, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.float32'),
finished after 8.32s,
best config selected: num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None;
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 1, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 2, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 4, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 8, num_ctas: 1, num_stages: 4, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 3, maxnreg: None
Autotuning kernel phase_1_batched_interblock_attention_backward_kernel with config num_warps: 16, num_ctas: 1, num_stages: 4, maxnreg: None
Triton autotuning for function phase_1_batched_interblock_attention_backward_kernel,
with key as (5, 512, 1, 8, 'torch.bfloat16', 'torch.bfloat16', 'torch.float32', 'torch.bfloat16', 'torch.float32', 'torch.float32', 'torch.float32'),
finished after 13.45s,
best config selected: num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None;
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 16, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 16, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 16, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase_1_reduce_grad_pseudo_queries_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Triton autotuning for function phase_1_reduce_grad_pseudo_queries_kernel,
with key as (32768, 512, 1, 'torch.float32', 'torch.float32'),
finished after 2.61s,
best config selected: BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None;

Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "triton_mm_69", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4", "best_time": 0.22630399465560913, "best_triton_pos": 0}
AUTOTUNE mm(163840x1, 1x512)
strides: [1, 0], [512, 1]
dtypes: torch.float32, torch.float32
  triton_mm_69 0.2263 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_74 0.2263 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_67 0.2273 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_68 0.2273 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_72 0.2273 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_73 0.2273 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_70 0.2284 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_71 0.2284 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_76 0.2304 ms 98.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_77 0.2314 ms 97.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.7127 seconds and 0.9022 seconds precompiling for 18 choices

final src=5  q=1  grad_lse=False                  0.750        0.695        0.056        2.317


phase 1 backward sweep only
Includes final phase-1 backward plus all per-block phase-1 backward calls.
triton zero_each_run=True
triton phase1 backward sweep: 7.360 ms
torch  phase1 backward sweep: 21.884 ms

