=== MFA Shader [backwardDKV D=64 bq=16 bk=64 bd=16 m3=1 dtype=2] ===

// -*- Metal -*-
//===-- metal_simdgroup_event ---------------------------------------------===//
// Copyright (c) 2024 Philip Turner. See MIT LICENSE
//===----------------------------------------------------------------------===//

#ifndef __METAL_SIMDGROUP_EVENT
#define __METAL_SIMDGROUP_EVENT

#pragma METAL internals : enable
namespace metal
{
  enum class simdgroup_async_copy_clamp_mode {
    clamp_to_zero = 0,
    clamp_to_edge = 1
  };

  struct simdgroup_event {
    METAL_FUNC simdgroup_event() thread {}

    template <ushort dst_elements_per_row, ushort threadgroup_size, simdgroup_async_copy_clamp_mode clamp_mode = simdgroup_async_copy_clamp_mode::clamp_to_zero, typename T>
    METAL_FUNC void async_copy(
      // Description of the destination.
      threadgroup T *dst,
      ushort2 dst_tile_dimensions,

      // Description of the source.
      const device T *src,
      uint src_elements_per_row,
      ushort2 src_tile_dimensions,
      ushort tid,
      bool transpose_matrix = false
    ) thread {
      if (transpose_matrix) {
        src_tile_dimensions = src_tile_dimensions.yx;
        dst_tile_dimensions = dst_tile_dimensions.yx;
      }
      #pragma clang loop unroll(full)
      for (ushort i = tid; i < dst_tile_dimensions.y * dst_tile_dimensions.x; i += threadgroup_size) {
        const ushort x = i % dst_tile_dimensions.x;
        const ushort y = i / dst_tile_dimensions.x;
        dst[y * dst_elements_per_row + x] = y < src_tile_dimensions.y && x < src_tile_dimensions.x ? src[y * src_elements_per_row + x] : 0;
      }
    }

    template <ushort src_elements_per_row, ushort threadgroup_size, simdgroup_async_copy_clamp_mode clamp_mode = simdgroup_async_copy_clamp_mode::clamp_to_zero, typename T>
    METAL_FUNC void async_copy(
      // Description of the destination.
      device T *dst,
      uint dst_elements_per_row,
      ushort2 dst_tile_dimensions,

      // Description of the source.
      const threadgroup T *src,
      ushort2 src_tile_dimensions,
      ushort tid,
      bool transpose_matrix = false
    ) thread {
      if (transpose_matrix) {
        src_tile_dimensions = src_tile_dimensions.yx;
        dst_tile_dimensions = dst_tile_dimensions.yx;
      }
      #pragma clang loop unroll(full)
      for (ushort i = tid; i < dst_tile_dimensions.y * dst_tile_dimensions.x; i += threadgroup_size) {
        const ushort x = i % dst_tile_dimensions.x;
        const ushort y = i / dst_tile_dimensions.x;
        dst[y * dst_elements_per_row + x] = src[y * src_elements_per_row + x];
      }
    }

    template <ushort threadgroup_size, typename T>
    METAL_FUNC void async_copy(
      // Description of the destination.
      threadgroup T *dst,
      ushort dst_tile_dimensions,

      // Description of the source.
      const device T *src,
      ushort src_tile_dimensions,
      ushort tid
    ) thread {
      #pragma clang loop unroll(full)
      for (ushort i = tid; i < dst_tile_dimensions; i += threadgroup_size) {
        dst[i] = i < src_tile_dimensions ? src[i] : 0;
      }
    }

    template <ushort threadgroup_size, typename T>
    METAL_FUNC void async_copy(
      // Description of the destination.
      device T *dst,
      ushort dst_tile_dimensions,

      // Description of the source.
      const threadgroup T *src,
      ushort src_tile_dimensions,
      ushort tid
    ) thread {
      #pragma clang loop unroll(full)
      for (ushort i = tid; i < dst_tile_dimensions; i += threadgroup_size) {
        dst[i] = src[i];
      }
    }

    METAL_FUNC static void wait(int count, thread simdgroup_event *events) {
    }
  };
} // namespace metal
#pragma METAL internals : disable

#endif // __METAL_SIMDGROUP_EVENT



// -*- Metal -*-
//===-- metal_simdgroup_matrix_storage ------------------------------------===//
// Copyright (c) 2024 Philip Turner. See MIT LICENSE
//===----------------------------------------------------------------------===//

#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE
#define __METAL_SIMDGROUP_MATRIX_STORAGE

// The layout of threads within a SIMD matrix.
//
//  0  0  1  1  8  8  9  9
//  2  2  3  3 10 10 11 11
//  4  4  5  5 12 12 13 13
//  6  6  7  7 14 14 15 15
// 16 16 17 17 24 24 25 25
// 18 18 19 19 26 26 27 27
// 20 20 21 21 28 28 29 29
// 22 22 23 23 30 30 31 31
//
// This is Morton order, a method for coalescing data accesses. It is used
// in a variety of contexts, from ray tracing acceleration structures, to
// nodal-point Laplacians, to sorting large lattices of atoms.
//
// Source: https://patents.google.com/patent/US11256518B2
METAL_FUNC static ushort2 morton_order(ushort thread_index_in_simdgroup) {
  ushort lane_id = thread_index_in_simdgroup;
  ushort quad_id = lane_id / 4;

  constexpr ushort QUADRANT_SPAN_M = 4;
  constexpr ushort THREADS_PER_QUADRANT = 8;
  ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M;
  ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2);
  ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant;

  ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4
  ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2
  ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant;

  return ushort2(N_in_simd, M_in_simd);
}

#pragma METAL internals : enable
namespace metal
{
  template <typename T>
  struct simdgroup_matrix_storage {
    typedef vec<T, 64> storage_type;

    storage_type t;

    METAL_FUNC thread vec<T, 2>* thread_elements() thread {
      return reinterpret_cast<thread vec<T, 2>*>(&t);
    }

    METAL_FUNC simdgroup_matrix_storage() thread = default;

    METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread {
      *(this->thread_elements()) = thread_elements;
    }

    METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) {
      if (transpose_matrix) {
        return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
      } else {
        return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
      }
    }

    METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
      if (transpose_matrix) {
        return src + matrix_origin.x * elements_per_row + matrix_origin.y;
      } else {
        return src + matrix_origin.y * elements_per_row + matrix_origin.x;
      }
    }

    template <typename U>
    METAL_FUNC void load(const device U *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
      if (transpose_matrix) {
        uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
        uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
        U memoryForm0 = src[address0];
        U memoryForm1 = src[address1];
        ((thread T*)thread_elements())[0] = T(memoryForm0);
        ((thread T*)thread_elements())[1] = T(memoryForm1);
      } else if (elements_per_row % 2 != 0) {
        uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
        uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1);
        U memoryForm0 = src[address0];
        U memoryForm1 = src[address1];
        ((thread T*)thread_elements())[0] = T(memoryForm0);
        ((thread T*)thread_elements())[1] = T(memoryForm1);
      } else {
        auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
        vec<U, 2> memoryForm = *(const device vec<U, 2>*)(src + combinedAddress);
        *(thread_elements()) = vec<T, 2>(memoryForm);
      }
    }

    template <typename U>
    METAL_FUNC void load(const threadgroup U *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
      if (transpose_matrix) {
        ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
        ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
        U memoryForm0 = src[address0];
        U memoryForm1 = src[address1];
        ((thread T*)thread_elements())[0] = T(memoryForm0);
        ((thread T*)thread_elements())[1] = T(memoryForm1);
      } else if (elements_per_row % 2 != 0) {
        ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
        ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1);
        U memoryForm0 = src[address0];
        U memoryForm1 = src[address1];
        ((thread T*)thread_elements())[0] = T(memoryForm0);
        ((thread T*)thread_elements())[1] = T(memoryForm1);
      } else {
        auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
        vec<U, 2> memoryForm = *(const threadgroup vec<U, 2>*)(src + combinedAddress);
        *(thread_elements()) = vec<T, 2>(memoryForm);
      }
    }

    template <typename U>
    METAL_FUNC void store(device U *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
      if (transpose_matrix) {
        uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
        uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
        T registerForm0 = ((thread T*)thread_elements())[0];
        T registerForm1 = ((thread T*)thread_elements())[1];
        dst[address0] = U(registerForm0);
        dst[address1] = U(registerForm1);
      } else if (elements_per_row % 2 != 0) {
        uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
        uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1);
        T registerForm0 = ((thread T*)thread_elements())[0];
        T registerForm1 = ((thread T*)thread_elements())[1];
        dst[address0] = U(registerForm0);
        dst[address1] = U(registerForm1);
      } else {
        auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
        vec<T, 2> registerForm = *(thread_elements());
        *(device vec<U, 2>*)(dst + combinedAddress) = vec<U, 2>(registerForm);
      }
    }

    template <typename U>
    METAL_FUNC void store(threadgroup U *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
      if (transpose_matrix) {
        ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
        ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
        T registerForm0 = ((thread T*)thread_elements())[0];
        T registerForm1 = ((thread T*)thread_elements())[1];
        dst[address0] = U(registerForm0);
        dst[address1] = U(registerForm1);
      } else if (elements_per_row % 2 != 0) {
        ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
        ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1);
        T registerForm0 = ((thread T*)thread_elements())[0];
        T registerForm1 = ((thread T*)thread_elements())[1];
        dst[address0] = U(registerForm0);
        dst[address1] = U(registerForm1);
      } else {
        auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
        vec<T, 2> registerForm = *(thread_elements());
        *(threadgroup vec<U, 2>*)(dst + combinedAddress) = vec<U, 2>(registerForm);
      }
    }


    template <typename U, typename V>
    METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) {
      if (!accumulate) {
        *(thread_elements()) = vec<T, 2>(0);
      }
      t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type());
    }
  };
} // namespace metal
#pragma METAL internals : disable

#endif // __METAL_SIMDGROUP_MATRIX_STORAGE



using namespace metal;


struct MFAParams {
  uint R;
  uint C;
  uint Hq;
  uint H_Hk_ratio;
  float dot_product_scale;
  uint causal;
  uint Q_batch_stride;
  uint K_batch_stride;
  uint V_batch_stride;
  uint O_batch_stride;
  uint dO_batch_stride;
  uint dQ_batch_stride;
  uint dK_batch_stride;
  uint dV_batch_stride;
};
#define R (params.R)
#define C (params.C)
#define Hq (params.Hq)
#define H_Hk_ratio (params.H_Hk_ratio)
#define dot_product_scale (params.dot_product_scale)
#define causal (params.causal)
#define dot_product_scale_derivative (dot_product_scale * 0.6931471805599453f)
#define Q_batch_stride (params.Q_batch_stride)
#define K_batch_stride (params.K_batch_stride)
#define V_batch_stride (params.V_batch_stride)
#define O_batch_stride (params.O_batch_stride)
#define dO_batch_stride (params.dO_batch_stride)
#define dV_batch_stride (params.dV_batch_stride)
#define dK_batch_stride (params.dK_batch_stride)



    
    // Declare the function.
    kernel void attention(

  device float* Q [[buffer(0)]],
  device float* K [[buffer(1)]],
  device float* V [[buffer(2)]],
  device float* O [[buffer(3)]],
  device float* L [[buffer(4)]],
  device float* D [[buffer(5)]],
  device float* dO [[buffer(6)]],
  device float* dV [[buffer(7)]],
  device float* dK [[buffer(8)]],
  constant MFAParams& params [[buffer(10)]],



      uint3 gid [[threadgroup_position_in_grid]],
      ushort sidx [[simdgroup_index_in_threadgroup]],
      ushort lane_id [[thread_index_in_simdgroup]]
    ) {
      threadgroup uchar threadgroup_block[4096];
      ushort2 morton_offset = morton_order(lane_id);
      gid = { gid.x % ((C + 16 - 1) / 16), (gid.x / ((C + 16 - 1) / 16)) % Hq, gid.x / (Hq * ((C + 16 - 1) / 16))};
      uint parallelization_group_offset = gid.x;
      parallelization_group_offset *= 16;

      // Return early if the entire SIMD is out of bounds.
      if (parallelization_group_offset >= C) {
        return;
      }


    Q = Q + gid.z * Q_batch_stride + gid.y * 64 * R;


    K = K + gid.z * K_batch_stride + gid.y / H_Hk_ratio * 64 * C;


    V = V + gid.z * V_batch_stride + gid.y / H_Hk_ratio * 64 * C;


    O = O + gid.z * O_batch_stride + gid.y * 64 * R;


    L = L + (gid.z * Hq + gid.y) * R;


    D = D + (gid.z * Hq + gid.y) * R;


    dO = dO + gid.z * dO_batch_stride + gid.y * 64 * R;


    dV = dV + gid.z * dV_batch_stride + gid.y / H_Hk_ratio * 64 * C;


    dK = dK + gid.z * dK_batch_stride + gid.y / H_Hk_ratio * 64 * C;







  // Outer loop over the traversal dimension.
  for (uint r = 0; r < R; r += 64) {
    // S^T = K * Q^T
    


    simdgroup_matrix_storage<float> S_sram[64 / 8];





    #pragma clang loop unroll(full)
    for (ushort c = 0; c < 64; c += 8) {
      auto S = S_sram + c / 8;
      *S = simdgroup_matrix_storage<float>(0);
    }





    #pragma clang loop unroll(disable)
    for (
      ushort d_outer = 0;
      d_outer < 64;
      d_outer += 16
    ) {
      

    if ((
        (R % 64 == 0) ||
        (r + 64 <= R)
      ) && (
        (64 % 8 == 0) ||
        (d_outer + 16 <= 64)
      )) {
      


    simdgroup_matrix_storage<float> K_sram[16 / 8];





      

    threadgroup_barrier(mem_flags::mem_threadgroup);
    if (sidx == 0) {
      uint2 K_offset(d_outer, parallelization_group_offset);
      auto src = simdgroup_matrix_storage<float>
      ::apply_offset(
        K, 64,
        K_offset, false);
      auto dst = (threadgroup float*)(threadgroup_block);

      ushort D_src_dimension = min(
        ushort(16),
        ushort(64 - d_outer));
      ushort D_dst_dimension = 16;
      ushort R_dimension = min(
        uint(16),
        uint(C - parallelization_group_offset));
      ushort2 tile_src(D_src_dimension, R_dimension);
      ushort2 tile_dst(D_dst_dimension, R_dimension);

      simdgroup_event event;
      event.async_copy<16, 32>(
        dst, tile_dst,
        src, 64, tile_src, lane_id, false);
      simdgroup_event::wait(1, &event);
    }



      

      ushort2 K_block_offset(
        morton_offset.x, 
        morton_offset.y + sidx * 8);
      auto K_src = (threadgroup float*)(threadgroup_block);
      K_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        K_src, 16,
        K_block_offset, false);
      threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 K_origin(d, 0);
        K_sram[d / 8].load(
          K_src, 16,
          K_origin, false);
      }





      uint2 Q_src_offset(
        morton_offset.y + d_outer,
        morton_offset.x + r);
      auto Q_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        Q, 64,
        Q_src_offset, false);





      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < 64; c += 8) {
      // Load the RHS from memory.
      ushort2 Q_origin(c, d);
      simdgroup_matrix_storage<float> Q;
      Q.load(
        Q_src, 64,
        Q_origin, true);

      // Issue one SIMD matmul instruction.
      S_sram[c / 8].multiply(
        K_sram[(0 + d) / 8],
        Q, true);
    }



      }




    } else {
      


    simdgroup_matrix_storage<float> K_sram[16 / 8];





      

    threadgroup_barrier(mem_flags::mem_threadgroup);
    if (sidx == 0) {
      uint2 K_offset(d_outer, parallelization_group_offset);
      auto src = simdgroup_matrix_storage<float>
      ::apply_offset(
        K, 64,
        K_offset, false);
      auto dst = (threadgroup float*)(threadgroup_block);

      ushort D_src_dimension = min(
        ushort(16),
        ushort(64 - d_outer));
      ushort D_dst_dimension = 16;
      ushort R_dimension = min(
        uint(16),
        uint(C - parallelization_group_offset));
      ushort2 tile_src(D_src_dimension, R_dimension);
      ushort2 tile_dst(D_dst_dimension, R_dimension);

      simdgroup_event event;
      event.async_copy<16, 32>(
        dst, tile_dst,
        src, 64, tile_src, lane_id, false);
      simdgroup_event::wait(1, &event);
    }



      

      ushort2 K_block_offset(
        morton_offset.x, 
        morton_offset.y + sidx * 8);
      auto K_src = (threadgroup float*)(threadgroup_block);
      K_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        K_src, 16,
        K_block_offset, false);
      threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 K_origin(d, 0);
        K_sram[d / 8].load(
          K_src, 16,
          K_origin, false);
      }





      threadgroup_barrier(mem_flags::mem_threadgroup);
      if (sidx == 0) {
        uint2 Q_offset(d_outer, r);
        auto src = simdgroup_matrix_storage<float>
        ::apply_offset(
          Q, 64,
          Q_offset, false);
        auto dst = (threadgroup float*)(threadgroup_block);
 
        ushort D_src_dimension = min(
          ushort(16),
          ushort(64 - d_outer));
        ushort D_dst_dimension = 16;
        ushort C_src_dimension = min(
          uint(64),
          uint(R - r));
        ushort C_dst_dimension = max(
          ushort((((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8),
          ushort(C_src_dimension));
        ushort2 tile_src(D_src_dimension, C_src_dimension);
        ushort2 tile_dst(D_dst_dimension, C_dst_dimension);

        simdgroup_event event;
        event.async_copy<16, 32>(
          dst, tile_dst,
          src, 64, tile_src, lane_id, false);
        simdgroup_event::wait(1, &event);
      }

      

      ushort2 Q_block_offset(
        morton_offset.x,
        morton_offset.y);
      auto Q_src = (threadgroup float*)(threadgroup_block);
      Q_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        Q_src, 16,
        Q_block_offset, true);
      threadgroup_barrier(mem_flags::mem_threadgroup);








      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c += 8) {
      // Load the RHS from memory.
      ushort2 Q_origin(c, d);
      simdgroup_matrix_storage<float> Q;
      Q.load(
        Q_src, 16,
        Q_origin, true);

      // Issue one SIMD matmul instruction.
      S_sram[c / 8].multiply(
        K_sram[(0 + d) / 8],
        Q, true);
    }



        if (r + 64
            < R) {
          

    #pragma clang loop unroll(full)
    for (ushort c = (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c < 64; c += 8) {
      // Load the RHS from memory.
      ushort2 Q_origin(c, d);
      simdgroup_matrix_storage<float> Q;
      Q.load(
        Q_src, 16,
        Q_origin, true);

      // Issue one SIMD matmul instruction.
      S_sram[c / 8].multiply(
        K_sram[(0 + d) / 8],
        Q, true);
    }



        }
      }




    }



    }





    if (false) {
      ushort d_outer = 64;
      

    if ((
        (R % 64 == 0) ||
        (r + 64 <= R)
      ) && (
        (64 % 8 == 0) ||
        (d_outer + 16 <= 64)
      )) {
      


    simdgroup_matrix_storage<float> K_sram[16 / 8];





      

    threadgroup_barrier(mem_flags::mem_threadgroup);
    if (sidx == 0) {
      uint2 K_offset(d_outer, parallelization_group_offset);
      auto src = simdgroup_matrix_storage<float>
      ::apply_offset(
        K, 64,
        K_offset, false);
      auto dst = (threadgroup float*)(threadgroup_block);

      ushort D_src_dimension = min(
        ushort(16),
        ushort(64 - d_outer));
      ushort D_dst_dimension = 16;
      ushort R_dimension = min(
        uint(16),
        uint(C - parallelization_group_offset));
      ushort2 tile_src(D_src_dimension, R_dimension);
      ushort2 tile_dst(D_dst_dimension, R_dimension);

      simdgroup_event event;
      event.async_copy<16, 32>(
        dst, tile_dst,
        src, 64, tile_src, lane_id, false);
      simdgroup_event::wait(1, &event);
    }



      

      ushort2 K_block_offset(
        morton_offset.x, 
        morton_offset.y + sidx * 8);
      auto K_src = (threadgroup float*)(threadgroup_block);
      K_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        K_src, 16,
        K_block_offset, false);
      threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 K_origin(d, 0);
        K_sram[d / 8].load(
          K_src, 16,
          K_origin, false);
      }





      uint2 Q_src_offset(
        morton_offset.y + d_outer,
        morton_offset.x + r);
      auto Q_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        Q, 64,
        Q_src_offset, false);





      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < 64; c += 8) {
      // Load the RHS from memory.
      ushort2 Q_origin(c, d);
      simdgroup_matrix_storage<float> Q;
      Q.load(
        Q_src, 64,
        Q_origin, true);

      // Issue one SIMD matmul instruction.
      S_sram[c / 8].multiply(
        K_sram[(0 + d) / 8],
        Q, true);
    }



      }




    } else {
      


    simdgroup_matrix_storage<float> K_sram[16 / 8];





      

    threadgroup_barrier(mem_flags::mem_threadgroup);
    if (sidx == 0) {
      uint2 K_offset(d_outer, parallelization_group_offset);
      auto src = simdgroup_matrix_storage<float>
      ::apply_offset(
        K, 64,
        K_offset, false);
      auto dst = (threadgroup float*)(threadgroup_block);

      ushort D_src_dimension = min(
        ushort(16),
        ushort(64 - d_outer));
      ushort D_dst_dimension = 16;
      ushort R_dimension = min(
        uint(16),
        uint(C - parallelization_group_offset));
      ushort2 tile_src(D_src_dimension, R_dimension);
      ushort2 tile_dst(D_dst_dimension, R_dimension);

      simdgroup_event event;
      event.async_copy<16, 32>(
        dst, tile_dst,
        src, 64, tile_src, lane_id, false);
      simdgroup_event::wait(1, &event);
    }



      

      ushort2 K_block_offset(
        morton_offset.x, 
        morton_offset.y + sidx * 8);
      auto K_src = (threadgroup float*)(threadgroup_block);
      K_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        K_src, 16,
        K_block_offset, false);
      threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 K_origin(d, 0);
        K_sram[d / 8].load(
          K_src, 16,
          K_origin, false);
      }





      threadgroup_barrier(mem_flags::mem_threadgroup);
      if (sidx == 0) {
        uint2 Q_offset(d_outer, r);
        auto src = simdgroup_matrix_storage<float>
        ::apply_offset(
          Q, 64,
          Q_offset, false);
        auto dst = (threadgroup float*)(threadgroup_block);
 
        ushort D_src_dimension = min(
          ushort(16),
          ushort(64 - d_outer));
        ushort D_dst_dimension = 16;
        ushort C_src_dimension = min(
          uint(64),
          uint(R - r));
        ushort C_dst_dimension = max(
          ushort((((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8),
          ushort(C_src_dimension));
        ushort2 tile_src(D_src_dimension, C_src_dimension);
        ushort2 tile_dst(D_dst_dimension, C_dst_dimension);

        simdgroup_event event;
        event.async_copy<16, 32>(
          dst, tile_dst,
          src, 64, tile_src, lane_id, false);
        simdgroup_event::wait(1, &event);
      }

      

      ushort2 Q_block_offset(
        morton_offset.x,
        morton_offset.y);
      auto Q_src = (threadgroup float*)(threadgroup_block);
      Q_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        Q_src, 16,
        Q_block_offset, true);
      threadgroup_barrier(mem_flags::mem_threadgroup);








      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c += 8) {
      // Load the RHS from memory.
      ushort2 Q_origin(c, d);
      simdgroup_matrix_storage<float> Q;
      Q.load(
        Q_src, 16,
        Q_origin, true);

      // Issue one SIMD matmul instruction.
      S_sram[c / 8].multiply(
        K_sram[(0 + d) / 8],
        Q, true);
    }



        if (r + 64
            < R) {
          

    #pragma clang loop unroll(full)
    for (ushort c = (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c < 64; c += 8) {
      // Load the RHS from memory.
      ushort2 Q_origin(c, d);
      simdgroup_matrix_storage<float> Q;
      Q.load(
        Q_src, 16,
        Q_origin, true);

      // Issue one SIMD matmul instruction.
      S_sram[c / 8].multiply(
        K_sram[(0 + d) / 8],
        Q, true);
    }



        }
      }




    }



    }





    // P^T = exp(S^T - L)
    

    

      simdgroup_matrix_storage<float> P_sram[64 / 8];



    if (true && (
        (R % 64 == 0) ||
        (r + 64 <= R)
      )) {
      

      auto L_src = L;
      L_src += r + morton_offset.x;



      

      #pragma clang loop unroll(full)
      for (ushort c = 0; c < 64; c += 8) {
        ushort2 L_origin(c, 0);
        simdgroup_matrix_storage<float> L;
        L.load(
          L_src, 1,
          L_origin, false);
        auto L_elements = *(L.thread_elements());

        

      auto S = *(S_sram[c / 8].thread_elements());
      auto P = vec<float, 2>(
        fast::exp2(float2(S) * dot_product_scale - float2(L_elements)));
      *(P_sram[c / 8].thread_elements()) = P;



      }



    } else {
      

    threadgroup_barrier(mem_flags::mem_threadgroup);
    if (sidx == 0) {
      auto L_src = L + r;
      auto L_dst =
      (threadgroup float*)(threadgroup_block);

      ushort R_src_dimension = min(
        uint(64),
        uint(R - r));
      ushort R_dst_dimension = max(
        ushort((((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8),
        ushort(R_src_dimension));

      // Issue an async copy.
      simdgroup_event event;
      event.async_copy<32>(
        L_dst, R_dst_dimension,
        L_src, R_src_dimension, lane_id);
      simdgroup_event::wait(1, &event);
    }



      

      auto L_src =
      (threadgroup float*)(threadgroup_block);
      L_src += morton_offset.x;
      threadgroup_barrier(mem_flags::mem_threadgroup);



      

      #pragma clang loop unroll(full)
      for (ushort c = 0; c < 64; c += 8) {
        ushort2 L_origin(c, 0);
        simdgroup_matrix_storage<float> L;
        L.load(
          L_src, 1,
          L_origin, false);
        auto L_elements = *(L.thread_elements());

        

      auto S = *(S_sram[c / 8].thread_elements());
      auto P = vec<float, 2>(
        fast::exp2(float2(S) * dot_product_scale - float2(L_elements)));
      *(P_sram[c / 8].thread_elements()) = P;



      }



    }




    // dV += P^T * dO
    

    
    #pragma clang loop unroll(disable)
    for (
      ushort d_outer = 0;
      d_outer < 64;
      d_outer += 16
    ) {
      

    if ((
          (R % 64 == 0) ||
          (r + 64 <= R)
        ) && (
          (64 % 8 == 0) ||
          (d_outer + 16 <= 64)
        )) {
      

    

    simdgroup_matrix_storage<float> dV_sram[16 / 8];



    if (r == 0) {
      
    
    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      auto dV = dV_sram + (0 + d) / 8;
      *dV = simdgroup_matrix_storage<float>(0);
    }



    } else {
      

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dV_offset(d_outer, parallelization_group_offset);
       auto src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV, 64,
         dV_offset, false);
       auto dst = (threadgroup float*)(threadgroup_block);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, tile,
         src, 64, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }



      

       ushort2 dV_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dV_src = (threadgroup float*)(threadgroup_block);
       dV_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV_src, 16,
         dV_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);



      
      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dV_origin(d, 0);
        dV_sram[d / 8].load(
          dV_src, 16, 
          dV_origin, false);
      }
      


      
    }
    

      uint2 dO_src_offset(
        morton_offset.x + d_outer,
        morton_offset.y + r);
      auto dO_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        dO, 64,
        dO_src_offset, false);



    

        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < 64; c += 8) {
      

    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      // Load the RHS from memory.
      ushort2 dO_origin(d, c);
      simdgroup_matrix_storage<float> dO;
      dO.load(
        dO_src, 64,
        dO_origin, false);

      // Issue one SIMD matmul instruction.
      dV_sram[(0 + d) / 8].multiply(
        P_sram[c / 8], dO, /*accumulate=*/true);
    }



    }



        if (
          (R % 64 == 0) &&
          (r + 64 == R)
        ) {
           
        }



    

      

       ushort2 dV_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dV_src = (threadgroup float*)(threadgroup_block);
       dV_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV_src, 16,
         dV_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dV_origin(d, 0);
        dV_sram[d / 8].store(
          dV_src, 16,
          dV_origin, false);
      }

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dV_offset(d_outer, parallelization_group_offset);
       auto src = (threadgroup float*)(threadgroup_block);
       auto dst = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV, 64,
         dV_offset, false);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, 64, tile,
         src, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }
     








    } else {
      

    

    simdgroup_matrix_storage<float> dV_sram[16 / 8];



    if (r == 0) {
      
    
    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      auto dV = dV_sram + (0 + d) / 8;
      *dV = simdgroup_matrix_storage<float>(0);
    }



    } else {
      

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dV_offset(d_outer, parallelization_group_offset);
       auto src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV, 64,
         dV_offset, false);
       auto dst = (threadgroup float*)(threadgroup_block);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, tile,
         src, 64, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }



      

       ushort2 dV_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dV_src = (threadgroup float*)(threadgroup_block);
       dV_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV_src, 16,
         dV_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);



      
      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dV_origin(d, 0);
        dV_sram[d / 8].load(
          dV_src, 16, 
          dV_origin, false);
      }
      


      
    }
    
      
      threadgroup_barrier(mem_flags::mem_threadgroup);
      if (sidx == 0) {
        uint2 dO_offset(d_outer, r);
        auto src = simdgroup_matrix_storage<float>
        ::apply_offset(
          dO, 64,
          dO_offset, false);
        auto dst = (threadgroup float*)(threadgroup_block);
        
        ushort D_dimension = min(
          ushort(16),
          ushort(64 - d_outer));
        ushort C_src_dimension = min(
          uint(64),
          uint(R - r));
        ushort C_dst_dimension = max(
          ushort((((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8),
          ushort(C_src_dimension));
        ushort2 tile_src(D_dimension, C_src_dimension);
        ushort2 tile_dst(D_dimension, C_dst_dimension);
        
        simdgroup_event event;
        event.async_copy<16, 32>(
          dst, tile_dst,
          src, 64, tile_src, lane_id, false);
        simdgroup_event::wait(1, &event);
      }

      

      ushort2 dO_block_offset(
        morton_offset.x,
        morton_offset.y);
      auto dO_src = (threadgroup float*)(threadgroup_block);
      dO_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        dO_src, 16,
        dO_block_offset, false);
      threadgroup_barrier(mem_flags::mem_threadgroup);



      
 

    

        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c += 8) {
      

    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      // Load the RHS from memory.
      ushort2 dO_origin(d, c);
      simdgroup_matrix_storage<float> dO;
      dO.load(
        dO_src, 16,
        dO_origin, false);

      // Issue one SIMD matmul instruction.
      dV_sram[(0 + d) / 8].multiply(
        P_sram[c / 8], dO, /*accumulate=*/true);
    }



    }



        if (r + 64
            < R) {
          

    #pragma clang loop unroll(full)
    for (ushort c = (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c < 64; c += 8) {
      

    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      // Load the RHS from memory.
      ushort2 dO_origin(d, c);
      simdgroup_matrix_storage<float> dO;
      dO.load(
        dO_src, 16,
        dO_origin, false);

      // Issue one SIMD matmul instruction.
      dV_sram[(0 + d) / 8].multiply(
        P_sram[c / 8], dO, /*accumulate=*/true);
    }



    }



        } else {
          
        }



    

      

       ushort2 dV_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dV_src = (threadgroup float*)(threadgroup_block);
       dV_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV_src, 16,
         dV_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dV_origin(d, 0);
        dV_sram[d / 8].store(
          dV_src, 16,
          dV_origin, false);
      }

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dV_offset(d_outer, parallelization_group_offset);
       auto src = (threadgroup float*)(threadgroup_block);
       auto dst = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV, 64,
         dV_offset, false);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, 64, tile,
         src, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }
     








    }



    }
    




    if (false) {
      ushort d_outer = 64;
      

    if ((
          (R % 64 == 0) ||
          (r + 64 <= R)
        ) && (
          (64 % 8 == 0) ||
          (d_outer + 16 <= 64)
        )) {
      

    

    simdgroup_matrix_storage<float> dV_sram[16 / 8];



    if (r == 0) {
      
    
    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      auto dV = dV_sram + (0 + d) / 8;
      *dV = simdgroup_matrix_storage<float>(0);
    }



    } else {
      

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dV_offset(d_outer, parallelization_group_offset);
       auto src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV, 64,
         dV_offset, false);
       auto dst = (threadgroup float*)(threadgroup_block);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, tile,
         src, 64, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }



      

       ushort2 dV_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dV_src = (threadgroup float*)(threadgroup_block);
       dV_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV_src, 16,
         dV_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);



      
      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dV_origin(d, 0);
        dV_sram[d / 8].load(
          dV_src, 16, 
          dV_origin, false);
      }
      


      
    }
    

      uint2 dO_src_offset(
        morton_offset.x + d_outer,
        morton_offset.y + r);
      auto dO_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        dO, 64,
        dO_src_offset, false);



    

        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < 64; c += 8) {
      

    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      // Load the RHS from memory.
      ushort2 dO_origin(d, c);
      simdgroup_matrix_storage<float> dO;
      dO.load(
        dO_src, 64,
        dO_origin, false);

      // Issue one SIMD matmul instruction.
      dV_sram[(0 + d) / 8].multiply(
        P_sram[c / 8], dO, /*accumulate=*/true);
    }



    }



        if (
          (R % 64 == 0) &&
          (r + 64 == R)
        ) {
           
        }



    

      

       ushort2 dV_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dV_src = (threadgroup float*)(threadgroup_block);
       dV_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV_src, 16,
         dV_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dV_origin(d, 0);
        dV_sram[d / 8].store(
          dV_src, 16,
          dV_origin, false);
      }

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dV_offset(d_outer, parallelization_group_offset);
       auto src = (threadgroup float*)(threadgroup_block);
       auto dst = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV, 64,
         dV_offset, false);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, 64, tile,
         src, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }
     








    } else {
      

    

    simdgroup_matrix_storage<float> dV_sram[16 / 8];



    if (r == 0) {
      
    
    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      auto dV = dV_sram + (0 + d) / 8;
      *dV = simdgroup_matrix_storage<float>(0);
    }



    } else {
      

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dV_offset(d_outer, parallelization_group_offset);
       auto src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV, 64,
         dV_offset, false);
       auto dst = (threadgroup float*)(threadgroup_block);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, tile,
         src, 64, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }



      

       ushort2 dV_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dV_src = (threadgroup float*)(threadgroup_block);
       dV_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV_src, 16,
         dV_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);



      
      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dV_origin(d, 0);
        dV_sram[d / 8].load(
          dV_src, 16, 
          dV_origin, false);
      }
      


      
    }
    
      
      threadgroup_barrier(mem_flags::mem_threadgroup);
      if (sidx == 0) {
        uint2 dO_offset(d_outer, r);
        auto src = simdgroup_matrix_storage<float>
        ::apply_offset(
          dO, 64,
          dO_offset, false);
        auto dst = (threadgroup float*)(threadgroup_block);
        
        ushort D_dimension = min(
          ushort(16),
          ushort(64 - d_outer));
        ushort C_src_dimension = min(
          uint(64),
          uint(R - r));
        ushort C_dst_dimension = max(
          ushort((((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8),
          ushort(C_src_dimension));
        ushort2 tile_src(D_dimension, C_src_dimension);
        ushort2 tile_dst(D_dimension, C_dst_dimension);
        
        simdgroup_event event;
        event.async_copy<16, 32>(
          dst, tile_dst,
          src, 64, tile_src, lane_id, false);
        simdgroup_event::wait(1, &event);
      }

      

      ushort2 dO_block_offset(
        morton_offset.x,
        morton_offset.y);
      auto dO_src = (threadgroup float*)(threadgroup_block);
      dO_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        dO_src, 16,
        dO_block_offset, false);
      threadgroup_barrier(mem_flags::mem_threadgroup);



      
 

    

        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c += 8) {
      

    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      // Load the RHS from memory.
      ushort2 dO_origin(d, c);
      simdgroup_matrix_storage<float> dO;
      dO.load(
        dO_src, 16,
        dO_origin, false);

      // Issue one SIMD matmul instruction.
      dV_sram[(0 + d) / 8].multiply(
        P_sram[c / 8], dO, /*accumulate=*/true);
    }



    }



        if (r + 64
            < R) {
          

    #pragma clang loop unroll(full)
    for (ushort c = (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c < 64; c += 8) {
      

    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      // Load the RHS from memory.
      ushort2 dO_origin(d, c);
      simdgroup_matrix_storage<float> dO;
      dO.load(
        dO_src, 16,
        dO_origin, false);

      // Issue one SIMD matmul instruction.
      dV_sram[(0 + d) / 8].multiply(
        P_sram[c / 8], dO, /*accumulate=*/true);
    }



    }



        } else {
          
        }



    

      

       ushort2 dV_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dV_src = (threadgroup float*)(threadgroup_block);
       dV_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV_src, 16,
         dV_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dV_origin(d, 0);
        dV_sram[d / 8].store(
          dV_src, 16,
          dV_origin, false);
      }

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dV_offset(d_outer, parallelization_group_offset);
       auto src = (threadgroup float*)(threadgroup_block);
       auto dst = simdgroup_matrix_storage<float>
       ::apply_offset(
         dV, 64,
         dV_offset, false);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, 64, tile,
         src, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }
     








    }



    }





    // dP^T = V * dO^T
    


    simdgroup_matrix_storage<float> dP_sram[64 / 8];





    #pragma clang loop unroll(full)
    for (ushort c = 0; c < 64; c += 8) {
      auto dP = dP_sram + c / 8;
      *dP = simdgroup_matrix_storage<float>(0);
    }





    #pragma clang loop unroll(disable)
    for (
      ushort d_outer = 0;
      d_outer < 64;
      d_outer += 16
    ) {
      

    if ((
        (R % 64 == 0) ||
        (r + 64 <= R)
      ) && (
        (64 % 8 == 0) ||
        (d_outer + 16 <= 64)
      )) {
      


    simdgroup_matrix_storage<float> V_sram[16 / 8];





      

    threadgroup_barrier(mem_flags::mem_threadgroup);
    if (sidx == 0) {
      uint2 V_offset(d_outer, parallelization_group_offset);
      auto src = simdgroup_matrix_storage<float>
      ::apply_offset(
        V, 64,
        V_offset, false);
      auto dst = (threadgroup float*)(threadgroup_block);

      ushort D_src_dimension = min(
        ushort(16),
        ushort(64 - d_outer));
      ushort D_dst_dimension = 16;
      ushort R_dimension = min(
        uint(16),
        uint(C - parallelization_group_offset));
      ushort2 tile_src(D_src_dimension, R_dimension);
      ushort2 tile_dst(D_dst_dimension, R_dimension);

      simdgroup_event event;
      event.async_copy<16, 32>(
        dst, tile_dst,
        src, 64, tile_src, lane_id, false);
      simdgroup_event::wait(1, &event);
    }



      

      ushort2 V_block_offset(
        morton_offset.x, 
        morton_offset.y + sidx * 8);
      auto V_src = (threadgroup float*)(threadgroup_block);
      V_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        V_src, 16,
        V_block_offset, false);
      threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 V_origin(d, 0);
        V_sram[d / 8].load(
          V_src, 16,
          V_origin, false);
      }





      uint2 dO_src_offset(
        morton_offset.y + d_outer,
        morton_offset.x + r);
      auto dO_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        dO, 64,
        dO_src_offset, false);





      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < 64; c += 8) {
      // Load the RHS from memory.
      ushort2 dO_origin(c, d);
      simdgroup_matrix_storage<float> dO;
      dO.load(
        dO_src, 64,
        dO_origin, true);

      // Issue one SIMD matmul instruction.
      dP_sram[c / 8].multiply(
        V_sram[(0 + d) / 8],
        dO, true);
    }



      }




    } else {
      


    simdgroup_matrix_storage<float> V_sram[16 / 8];





      

    threadgroup_barrier(mem_flags::mem_threadgroup);
    if (sidx == 0) {
      uint2 V_offset(d_outer, parallelization_group_offset);
      auto src = simdgroup_matrix_storage<float>
      ::apply_offset(
        V, 64,
        V_offset, false);
      auto dst = (threadgroup float*)(threadgroup_block);

      ushort D_src_dimension = min(
        ushort(16),
        ushort(64 - d_outer));
      ushort D_dst_dimension = 16;
      ushort R_dimension = min(
        uint(16),
        uint(C - parallelization_group_offset));
      ushort2 tile_src(D_src_dimension, R_dimension);
      ushort2 tile_dst(D_dst_dimension, R_dimension);

      simdgroup_event event;
      event.async_copy<16, 32>(
        dst, tile_dst,
        src, 64, tile_src, lane_id, false);
      simdgroup_event::wait(1, &event);
    }



      

      ushort2 V_block_offset(
        morton_offset.x, 
        morton_offset.y + sidx * 8);
      auto V_src = (threadgroup float*)(threadgroup_block);
      V_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        V_src, 16,
        V_block_offset, false);
      threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 V_origin(d, 0);
        V_sram[d / 8].load(
          V_src, 16,
          V_origin, false);
      }





      threadgroup_barrier(mem_flags::mem_threadgroup);
      if (sidx == 0) {
        uint2 dO_offset(d_outer, r);
        auto src = simdgroup_matrix_storage<float>
        ::apply_offset(
          dO, 64,
          dO_offset, false);
        auto dst = (threadgroup float*)(threadgroup_block);
 
        ushort D_src_dimension = min(
          ushort(16),
          ushort(64 - d_outer));
        ushort D_dst_dimension = 16;
        ushort C_src_dimension = min(
          uint(64),
          uint(R - r));
        ushort C_dst_dimension = max(
          ushort((((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8),
          ushort(C_src_dimension));
        ushort2 tile_src(D_src_dimension, C_src_dimension);
        ushort2 tile_dst(D_dst_dimension, C_dst_dimension);

        simdgroup_event event;
        event.async_copy<16, 32>(
          dst, tile_dst,
          src, 64, tile_src, lane_id, false);
        simdgroup_event::wait(1, &event);
      }

      

      ushort2 dO_block_offset(
        morton_offset.x,
        morton_offset.y);
      auto dO_src = (threadgroup float*)(threadgroup_block);
      dO_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        dO_src, 16,
        dO_block_offset, true);
      threadgroup_barrier(mem_flags::mem_threadgroup);








      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c += 8) {
      // Load the RHS from memory.
      ushort2 dO_origin(c, d);
      simdgroup_matrix_storage<float> dO;
      dO.load(
        dO_src, 16,
        dO_origin, true);

      // Issue one SIMD matmul instruction.
      dP_sram[c / 8].multiply(
        V_sram[(0 + d) / 8],
        dO, true);
    }



        if (r + 64
            < R) {
          

    #pragma clang loop unroll(full)
    for (ushort c = (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c < 64; c += 8) {
      // Load the RHS from memory.
      ushort2 dO_origin(c, d);
      simdgroup_matrix_storage<float> dO;
      dO.load(
        dO_src, 16,
        dO_origin, true);

      // Issue one SIMD matmul instruction.
      dP_sram[c / 8].multiply(
        V_sram[(0 + d) / 8],
        dO, true);
    }



        }
      }




    }



    }





    if (false) {
      ushort d_outer = 64;
      

    if ((
        (R % 64 == 0) ||
        (r + 64 <= R)
      ) && (
        (64 % 8 == 0) ||
        (d_outer + 16 <= 64)
      )) {
      


    simdgroup_matrix_storage<float> V_sram[16 / 8];





      

    threadgroup_barrier(mem_flags::mem_threadgroup);
    if (sidx == 0) {
      uint2 V_offset(d_outer, parallelization_group_offset);
      auto src = simdgroup_matrix_storage<float>
      ::apply_offset(
        V, 64,
        V_offset, false);
      auto dst = (threadgroup float*)(threadgroup_block);

      ushort D_src_dimension = min(
        ushort(16),
        ushort(64 - d_outer));
      ushort D_dst_dimension = 16;
      ushort R_dimension = min(
        uint(16),
        uint(C - parallelization_group_offset));
      ushort2 tile_src(D_src_dimension, R_dimension);
      ushort2 tile_dst(D_dst_dimension, R_dimension);

      simdgroup_event event;
      event.async_copy<16, 32>(
        dst, tile_dst,
        src, 64, tile_src, lane_id, false);
      simdgroup_event::wait(1, &event);
    }



      

      ushort2 V_block_offset(
        morton_offset.x, 
        morton_offset.y + sidx * 8);
      auto V_src = (threadgroup float*)(threadgroup_block);
      V_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        V_src, 16,
        V_block_offset, false);
      threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 V_origin(d, 0);
        V_sram[d / 8].load(
          V_src, 16,
          V_origin, false);
      }





      uint2 dO_src_offset(
        morton_offset.y + d_outer,
        morton_offset.x + r);
      auto dO_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        dO, 64,
        dO_src_offset, false);





      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < 64; c += 8) {
      // Load the RHS from memory.
      ushort2 dO_origin(c, d);
      simdgroup_matrix_storage<float> dO;
      dO.load(
        dO_src, 64,
        dO_origin, true);

      // Issue one SIMD matmul instruction.
      dP_sram[c / 8].multiply(
        V_sram[(0 + d) / 8],
        dO, true);
    }



      }




    } else {
      


    simdgroup_matrix_storage<float> V_sram[16 / 8];





      

    threadgroup_barrier(mem_flags::mem_threadgroup);
    if (sidx == 0) {
      uint2 V_offset(d_outer, parallelization_group_offset);
      auto src = simdgroup_matrix_storage<float>
      ::apply_offset(
        V, 64,
        V_offset, false);
      auto dst = (threadgroup float*)(threadgroup_block);

      ushort D_src_dimension = min(
        ushort(16),
        ushort(64 - d_outer));
      ushort D_dst_dimension = 16;
      ushort R_dimension = min(
        uint(16),
        uint(C - parallelization_group_offset));
      ushort2 tile_src(D_src_dimension, R_dimension);
      ushort2 tile_dst(D_dst_dimension, R_dimension);

      simdgroup_event event;
      event.async_copy<16, 32>(
        dst, tile_dst,
        src, 64, tile_src, lane_id, false);
      simdgroup_event::wait(1, &event);
    }



      

      ushort2 V_block_offset(
        morton_offset.x, 
        morton_offset.y + sidx * 8);
      auto V_src = (threadgroup float*)(threadgroup_block);
      V_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        V_src, 16,
        V_block_offset, false);
      threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 V_origin(d, 0);
        V_sram[d / 8].load(
          V_src, 16,
          V_origin, false);
      }





      threadgroup_barrier(mem_flags::mem_threadgroup);
      if (sidx == 0) {
        uint2 dO_offset(d_outer, r);
        auto src = simdgroup_matrix_storage<float>
        ::apply_offset(
          dO, 64,
          dO_offset, false);
        auto dst = (threadgroup float*)(threadgroup_block);
 
        ushort D_src_dimension = min(
          ushort(16),
          ushort(64 - d_outer));
        ushort D_dst_dimension = 16;
        ushort C_src_dimension = min(
          uint(64),
          uint(R - r));
        ushort C_dst_dimension = max(
          ushort((((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8),
          ushort(C_src_dimension));
        ushort2 tile_src(D_src_dimension, C_src_dimension);
        ushort2 tile_dst(D_dst_dimension, C_dst_dimension);

        simdgroup_event event;
        event.async_copy<16, 32>(
          dst, tile_dst,
          src, 64, tile_src, lane_id, false);
        simdgroup_event::wait(1, &event);
      }

      

      ushort2 dO_block_offset(
        morton_offset.x,
        morton_offset.y);
      auto dO_src = (threadgroup float*)(threadgroup_block);
      dO_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        dO_src, 16,
        dO_block_offset, true);
      threadgroup_barrier(mem_flags::mem_threadgroup);








      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c += 8) {
      // Load the RHS from memory.
      ushort2 dO_origin(c, d);
      simdgroup_matrix_storage<float> dO;
      dO.load(
        dO_src, 16,
        dO_origin, true);

      // Issue one SIMD matmul instruction.
      dP_sram[c / 8].multiply(
        V_sram[(0 + d) / 8],
        dO, true);
    }



        if (r + 64
            < R) {
          

    #pragma clang loop unroll(full)
    for (ushort c = (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c < 64; c += 8) {
      // Load the RHS from memory.
      ushort2 dO_origin(c, d);
      simdgroup_matrix_storage<float> dO;
      dO.load(
        dO_src, 16,
        dO_origin, true);

      // Issue one SIMD matmul instruction.
      dP_sram[c / 8].multiply(
        V_sram[(0 + d) / 8],
        dO, true);
    }



        }
      }




    }



    }





    // dS^T = P^T * (dP^T - D) * scaleFactor
    

    

      simdgroup_matrix_storage<float> dS_sram[64 / 8];



    if (true && (
        (R % 64 == 0) ||
        (r + 64 <= R)
      )) {
      

      auto D_src = D;
      D_src += r + morton_offset.x;



      

      #pragma clang loop unroll(full)
      for (ushort c = 0; c < 64; c += 8) {
        ushort2 D_origin(c, 0);
        simdgroup_matrix_storage<float> D;
        D.load(
          D_src, 1,
          D_origin, false);
        auto D_elements = *(D.thread_elements());

        

      auto P = *(P_sram[c / 8].thread_elements());
      auto dP = *(dP_sram[c / 8].thread_elements());
      auto dS = vec<float, 2>(
        float2(P) * (float2(dP) * dot_product_scale_derivative - float2(D_elements)));
      *(dS_sram[c / 8].thread_elements()) = dS;



      }



    } else {
      

    threadgroup_barrier(mem_flags::mem_threadgroup);
    if (sidx == 0) {
      auto D_src = D + r;
      auto D_dst =
      (threadgroup float*)(threadgroup_block);

      ushort R_src_dimension = min(
        uint(64),
        uint(R - r));
      ushort R_dst_dimension = max(
        ushort((((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8),
        ushort(R_src_dimension));

      // Issue an async copy.
      simdgroup_event event;
      event.async_copy<32>(
        D_dst, R_dst_dimension,
        D_src, R_src_dimension, lane_id);
      simdgroup_event::wait(1, &event);
    }



      

      auto D_src =
      (threadgroup float*)(threadgroup_block);
      D_src += morton_offset.x;
      threadgroup_barrier(mem_flags::mem_threadgroup);



      

      #pragma clang loop unroll(full)
      for (ushort c = 0; c < 64; c += 8) {
        ushort2 D_origin(c, 0);
        simdgroup_matrix_storage<float> D;
        D.load(
          D_src, 1,
          D_origin, false);
        auto D_elements = *(D.thread_elements());

        

      auto P = *(P_sram[c / 8].thread_elements());
      auto dP = *(dP_sram[c / 8].thread_elements());
      auto dS = vec<float, 2>(
        float2(P) * (float2(dP) * dot_product_scale_derivative - float2(D_elements)));
      *(dS_sram[c / 8].thread_elements()) = dS;



      }



    }




    // dK += dS^T * Q
    

    
    #pragma clang loop unroll(disable)
    for (
      ushort d_outer = 0;
      d_outer < 64;
      d_outer += 16
    ) {
      

    if ((
          (R % 64 == 0) ||
          (r + 64 <= R)
        ) && (
          (64 % 8 == 0) ||
          (d_outer + 16 <= 64)
        )) {
      

    

    simdgroup_matrix_storage<float> dK_sram[16 / 8];



    if (r == 0) {
      
    
    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      auto dK = dK_sram + (0 + d) / 8;
      *dK = simdgroup_matrix_storage<float>(0);
    }



    } else {
      

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dK_offset(d_outer, parallelization_group_offset);
       auto src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK, 64,
         dK_offset, false);
       auto dst = (threadgroup float*)(threadgroup_block);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, tile,
         src, 64, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }



      

       ushort2 dK_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dK_src = (threadgroup float*)(threadgroup_block);
       dK_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK_src, 16,
         dK_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);



      
      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dK_origin(d, 0);
        dK_sram[d / 8].load(
          dK_src, 16, 
          dK_origin, false);
      }
      


      
    }
    

      uint2 Q_src_offset(
        morton_offset.x + d_outer,
        morton_offset.y + r);
      auto Q_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        Q, 64,
        Q_src_offset, false);



    

        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < 64; c += 8) {
      

    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      // Load the RHS from memory.
      ushort2 Q_origin(d, c);
      simdgroup_matrix_storage<float> Q;
      Q.load(
        Q_src, 64,
        Q_origin, false);

      // Issue one SIMD matmul instruction.
      dK_sram[(0 + d) / 8].multiply(
        dS_sram[c / 8], Q, /*accumulate=*/true);
    }



    }



        if (
          (R % 64 == 0) &&
          (r + 64 == R)
        ) {
           
        }



    

      

       ushort2 dK_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dK_src = (threadgroup float*)(threadgroup_block);
       dK_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK_src, 16,
         dK_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dK_origin(d, 0);
        dK_sram[d / 8].store(
          dK_src, 16,
          dK_origin, false);
      }

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dK_offset(d_outer, parallelization_group_offset);
       auto src = (threadgroup float*)(threadgroup_block);
       auto dst = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK, 64,
         dK_offset, false);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, 64, tile,
         src, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }
     








    } else {
      

    

    simdgroup_matrix_storage<float> dK_sram[16 / 8];



    if (r == 0) {
      
    
    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      auto dK = dK_sram + (0 + d) / 8;
      *dK = simdgroup_matrix_storage<float>(0);
    }



    } else {
      

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dK_offset(d_outer, parallelization_group_offset);
       auto src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK, 64,
         dK_offset, false);
       auto dst = (threadgroup float*)(threadgroup_block);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, tile,
         src, 64, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }



      

       ushort2 dK_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dK_src = (threadgroup float*)(threadgroup_block);
       dK_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK_src, 16,
         dK_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);



      
      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dK_origin(d, 0);
        dK_sram[d / 8].load(
          dK_src, 16, 
          dK_origin, false);
      }
      


      
    }
    
      
      threadgroup_barrier(mem_flags::mem_threadgroup);
      if (sidx == 0) {
        uint2 Q_offset(d_outer, r);
        auto src = simdgroup_matrix_storage<float>
        ::apply_offset(
          Q, 64,
          Q_offset, false);
        auto dst = (threadgroup float*)(threadgroup_block);
        
        ushort D_dimension = min(
          ushort(16),
          ushort(64 - d_outer));
        ushort C_src_dimension = min(
          uint(64),
          uint(R - r));
        ushort C_dst_dimension = max(
          ushort((((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8),
          ushort(C_src_dimension));
        ushort2 tile_src(D_dimension, C_src_dimension);
        ushort2 tile_dst(D_dimension, C_dst_dimension);
        
        simdgroup_event event;
        event.async_copy<16, 32>(
          dst, tile_dst,
          src, 64, tile_src, lane_id, false);
        simdgroup_event::wait(1, &event);
      }

      

      ushort2 Q_block_offset(
        morton_offset.x,
        morton_offset.y);
      auto Q_src = (threadgroup float*)(threadgroup_block);
      Q_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        Q_src, 16,
        Q_block_offset, false);
      threadgroup_barrier(mem_flags::mem_threadgroup);



      
 

    

        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c += 8) {
      

    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      // Load the RHS from memory.
      ushort2 Q_origin(d, c);
      simdgroup_matrix_storage<float> Q;
      Q.load(
        Q_src, 16,
        Q_origin, false);

      // Issue one SIMD matmul instruction.
      dK_sram[(0 + d) / 8].multiply(
        dS_sram[c / 8], Q, /*accumulate=*/true);
    }



    }



        if (r + 64
            < R) {
          

    #pragma clang loop unroll(full)
    for (ushort c = (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c < 64; c += 8) {
      

    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      // Load the RHS from memory.
      ushort2 Q_origin(d, c);
      simdgroup_matrix_storage<float> Q;
      Q.load(
        Q_src, 16,
        Q_origin, false);

      // Issue one SIMD matmul instruction.
      dK_sram[(0 + d) / 8].multiply(
        dS_sram[c / 8], Q, /*accumulate=*/true);
    }



    }



        } else {
          
        }



    

      

       ushort2 dK_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dK_src = (threadgroup float*)(threadgroup_block);
       dK_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK_src, 16,
         dK_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dK_origin(d, 0);
        dK_sram[d / 8].store(
          dK_src, 16,
          dK_origin, false);
      }

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dK_offset(d_outer, parallelization_group_offset);
       auto src = (threadgroup float*)(threadgroup_block);
       auto dst = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK, 64,
         dK_offset, false);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, 64, tile,
         src, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }
     








    }



    }
    




    if (false) {
      ushort d_outer = 64;
      

    if ((
          (R % 64 == 0) ||
          (r + 64 <= R)
        ) && (
          (64 % 8 == 0) ||
          (d_outer + 16 <= 64)
        )) {
      

    

    simdgroup_matrix_storage<float> dK_sram[16 / 8];



    if (r == 0) {
      
    
    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      auto dK = dK_sram + (0 + d) / 8;
      *dK = simdgroup_matrix_storage<float>(0);
    }



    } else {
      

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dK_offset(d_outer, parallelization_group_offset);
       auto src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK, 64,
         dK_offset, false);
       auto dst = (threadgroup float*)(threadgroup_block);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, tile,
         src, 64, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }



      

       ushort2 dK_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dK_src = (threadgroup float*)(threadgroup_block);
       dK_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK_src, 16,
         dK_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);



      
      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dK_origin(d, 0);
        dK_sram[d / 8].load(
          dK_src, 16, 
          dK_origin, false);
      }
      


      
    }
    

      uint2 Q_src_offset(
        morton_offset.x + d_outer,
        morton_offset.y + r);
      auto Q_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        Q, 64,
        Q_src_offset, false);



    

        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < 64; c += 8) {
      

    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      // Load the RHS from memory.
      ushort2 Q_origin(d, c);
      simdgroup_matrix_storage<float> Q;
      Q.load(
        Q_src, 64,
        Q_origin, false);

      // Issue one SIMD matmul instruction.
      dK_sram[(0 + d) / 8].multiply(
        dS_sram[c / 8], Q, /*accumulate=*/true);
    }



    }



        if (
          (R % 64 == 0) &&
          (r + 64 == R)
        ) {
           
        }



    

      

       ushort2 dK_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dK_src = (threadgroup float*)(threadgroup_block);
       dK_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK_src, 16,
         dK_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dK_origin(d, 0);
        dK_sram[d / 8].store(
          dK_src, 16,
          dK_origin, false);
      }

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dK_offset(d_outer, parallelization_group_offset);
       auto src = (threadgroup float*)(threadgroup_block);
       auto dst = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK, 64,
         dK_offset, false);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, 64, tile,
         src, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }
     








    } else {
      

    

    simdgroup_matrix_storage<float> dK_sram[16 / 8];



    if (r == 0) {
      
    
    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      auto dK = dK_sram + (0 + d) / 8;
      *dK = simdgroup_matrix_storage<float>(0);
    }



    } else {
      

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dK_offset(d_outer, parallelization_group_offset);
       auto src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK, 64,
         dK_offset, false);
       auto dst = (threadgroup float*)(threadgroup_block);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, tile,
         src, 64, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }



      

       ushort2 dK_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dK_src = (threadgroup float*)(threadgroup_block);
       dK_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK_src, 16,
         dK_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);



      
      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dK_origin(d, 0);
        dK_sram[d / 8].load(
          dK_src, 16, 
          dK_origin, false);
      }
      


      
    }
    
      
      threadgroup_barrier(mem_flags::mem_threadgroup);
      if (sidx == 0) {
        uint2 Q_offset(d_outer, r);
        auto src = simdgroup_matrix_storage<float>
        ::apply_offset(
          Q, 64,
          Q_offset, false);
        auto dst = (threadgroup float*)(threadgroup_block);
        
        ushort D_dimension = min(
          ushort(16),
          ushort(64 - d_outer));
        ushort C_src_dimension = min(
          uint(64),
          uint(R - r));
        ushort C_dst_dimension = max(
          ushort((((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8),
          ushort(C_src_dimension));
        ushort2 tile_src(D_dimension, C_src_dimension);
        ushort2 tile_dst(D_dimension, C_dst_dimension);
        
        simdgroup_event event;
        event.async_copy<16, 32>(
          dst, tile_dst,
          src, 64, tile_src, lane_id, false);
        simdgroup_event::wait(1, &event);
      }

      

      ushort2 Q_block_offset(
        morton_offset.x,
        morton_offset.y);
      auto Q_src = (threadgroup float*)(threadgroup_block);
      Q_src = simdgroup_matrix_storage<float>
      ::apply_offset(
        Q_src, 16,
        Q_block_offset, false);
      threadgroup_barrier(mem_flags::mem_threadgroup);



      
 

    

        

    #pragma clang loop unroll(full)
    for (ushort c = 0; c < (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c += 8) {
      

    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      // Load the RHS from memory.
      ushort2 Q_origin(d, c);
      simdgroup_matrix_storage<float> Q;
      Q.load(
        Q_src, 16,
        Q_origin, false);

      // Issue one SIMD matmul instruction.
      dK_sram[(0 + d) / 8].multiply(
        dS_sram[c / 8], Q, /*accumulate=*/true);
    }



    }



        if (r + 64
            < R) {
          

    #pragma clang loop unroll(full)
    for (ushort c = (((R % 64 == 0) ? 64 : R % 64) + 7) / 8 * 8; c < 64; c += 8) {
      

    #pragma clang loop unroll(full)
    for (ushort d = 0; d < 16; d += 8) {
      // Load the RHS from memory.
      ushort2 Q_origin(d, c);
      simdgroup_matrix_storage<float> Q;
      Q.load(
        Q_src, 16,
        Q_origin, false);

      // Issue one SIMD matmul instruction.
      dK_sram[(0 + d) / 8].multiply(
        dS_sram[c / 8], Q, /*accumulate=*/true);
    }



    }



        } else {
          
        }



    

      

       ushort2 dK_block_offset(
         morton_offset.x,
         morton_offset.y + sidx * 8);
       auto dK_src = (threadgroup float*)(threadgroup_block);
       dK_src = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK_src, 16,
         dK_block_offset, false);
       threadgroup_barrier(mem_flags::mem_threadgroup);




      #pragma clang loop unroll(full)
      for (ushort d = 0; d < 16; d += 8) {
        ushort2 dK_origin(d, 0);
        dK_sram[d / 8].store(
          dK_src, 16,
          dK_origin, false);
      }

      

     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (sidx == 0) {
       uint2 dK_offset(d_outer, parallelization_group_offset);
       auto src = (threadgroup float*)(threadgroup_block);
       auto dst = simdgroup_matrix_storage<float>
       ::apply_offset(
         dK, 64,
         dK_offset, false);
       
       ushort D_dimension = min(
         ushort(16),
         ushort(64 - d_outer));
       ushort R_dimension = min(
         uint(16),
         uint(C - parallelization_group_offset));
       ushort2 tile(D_dimension, R_dimension);
       
       simdgroup_event event;
       event.async_copy<16, 32>(
         dst, 64, tile,
         src, tile, lane_id, false);
       simdgroup_event::wait(1, &event);
     }
     








    }



    }




  }





}


=== END MFA Shader ===
