GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition
Torch: 2.10.0+cu128
CUDA: 12.8
Triton: 3.6.0
Shape: B=128, T=8192, BT=1048576, D=1024, dtype=torch.float16, accum_dtype=torch.float32

Compiling/autotuning old path...
Autotuning kernel phase2_bwd_old_kernel with config num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_old_kernel with config num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_old_kernel with config num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_old_kernel with config num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_old_kernel with config num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_old_kernel with config num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_old_kernel with config num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_old_kernel with config num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Triton autotuning for function phase2_bwd_old_kernel,
with key as (1024, 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32'),
finished after 5.19s,
best config selected: num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None;
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 64, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 128, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 64, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 128, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 64, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 128, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 512, BLOCK_HIDDEN: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 512, BLOCK_HIDDEN: 64, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 512, BLOCK_HIDDEN: 128, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 64, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 128, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 64, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 128, BLOCK_HIDDEN: 128, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 64, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 256, BLOCK_HIDDEN: 128, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 512, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 512, BLOCK_HIDDEN: 64, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel reduce_query_grad_old_kernel with config BLOCK_BATCH_SEQ: 512, BLOCK_HIDDEN: 128, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Triton autotuning for function reduce_query_grad_old_kernel,
with key as (1048576, 1024, 'torch.float32', 'torch.float32'),
finished after 4.80s,
best config selected: BLOCK_BATCH_SEQ: 64, BLOCK_HIDDEN: 32, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None;
Compiling/autotuning new path...
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 1, num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 1, num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 1, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 1, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 2, num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 2, num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 2, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 2, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 4, num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 4, num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 4, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 4, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 8, num_warps: 1, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 8, num_warps: 2, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 8, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 8, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 16, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 1, num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 1, num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 1, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 1, num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 2, num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 2, num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 2, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 2, num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 4, num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 4, num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 4, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 4, num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 8, num_warps: 1, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 8, num_warps: 2, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 8, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 8, num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 16, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 16, num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 16, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 16, num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 32, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 32, num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 64, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 64, num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 128, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None
Autotuning kernel phase2_bwd_fused_query_grad_kernel with config BLOCK_BT: 128, num_warps: 8, num_ctas: 1, num_stages: 2, maxnreg: None
Triton autotuning for function phase2_bwd_fused_query_grad_kernel,
with key as (1048576, 1024, 'torch.float16', 'torch.float16', 'torch.float16', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32', 'torch.float32'),
finished after 88.75s,
best config selected: BLOCK_BT: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None;

Correctness vs old
  grad_intrablock_partial_sum                   ok=True  max_abs=6.67572e-05 max_rel=2.12756
  grad_pseudo_query                             ok=True  max_abs=0.0220947 max_rel=0.00107678
  grad_phase1_interblock_normalized_output      ok=True  max_abs=3.99351e-06 max_rel=1.2009e-05
  grad_phase1_interblock_logsumexp              ok=True  max_abs=2.00272e-05 max_rel=0.017116

Benchmark
  old phase2_bwd + reduce: 20.3917 ms
  new fused autotuned:      14.5637 ms
  speedup:                  1.400x

Autotune cache selected configs internally.
Try changing B/T/D/DTYPE at the top and rerun the cell.

