# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

"""Test TMA -> SMEM -> TMEM -> MMA TS .ws pipeline for P = Q x K^T.

Loads a 64x512 BF16 Q matrix into SMEM via TMA (k-major, SWIZZLE_128B),
then copies it to TMEM via tcgen05_cp.  Loads a 64x512 BF16 K matrix
into a separate SMEM region via TMA (same layout).

The MMA computation uses tcgen05.mma.ws.cta_group::1.kind::f16 in TS
mode (A from TMEM, B from SMEM).  The K [64,512] SMEM data is
reinterpreted as [128,256] via the dual GEMM fold trick.  The MMA output
is [64,128] in TMEM: columns 0-63 = partial Q_even * K_even^T, columns
64-127 = partial Q_odd * K_odd^T.  Summing these two halves gives the
full 64x64 Q x K^T result, which is verified against a GPU naive matmul
reference (matmul_kernel_naive with transpose_b=True).

The fact that P = Q x K^T passes verification proves that both Q and K
were loaded correctly (Q through TMA -> SMEM -> TMEM, K through TMA ->
SMEM), so no separate readback verification is needed for either.
"""

from std.math import ceildiv
from std.memory import UnsafePointer, alloc
from std.random import rand, randn, seed
from std.sys import size_of

from std.gpu import barrier, thread_idx, warp_id as get_warp_id
from std.gpu.host import DeviceBuffer, DeviceContext, FuncAttribute
from std.gpu.host.nvidia.tma import (
    TensorMapSwizzle,
    prefetch_tma_descriptor,
)
from std.gpu.memory import (
    AddressSpace,
    external_memory,
)
from std.gpu.compute.arch.mma_nvidia_sm100 import (
    MMASmemDescriptor,
    UMMAInsDescriptor,
