# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
from std.sys import size_of
from std.math import align_up

from std.gpu.primitives.cluster import cluster_mask_base
from std.gpu.host._tensormap import SwizzleMode
from std.gpu.memory import AddressSpace
from std.gpu.host.nvidia.tma import TensorMapSwizzle
from std.gpu import block_id_in_cluster
from std.gpu.compute.arch.mma_nvidia_sm100 import *
from std.gpu.compute.arch.tcgen05 import *
from std.gpu.compute.arch.mma_nvidia_sm100 import MMASmemDescriptorPair
from layout import IntTuple, Layout, TileTensor
from layout.tile_layout import TensorLayout, _types_to_int_tuple
from layout.tensor_core_async import (
    _CM_ROW_BYTES,
    tile_to_descriptor,
    tile_layout_k_major,
    tile_layout_mn_major,
)

from std.utils.index import Index, IndexList, product
from linalg.fp4_utils import SF_MN_GROUP_SIZE, SF_ATOM_M, SF_ATOM_K


def _create_mma_desc_k_major[
    dtype: DType, swizzle_mode: TensorMapSwizzle
](
    ptr: UnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, ...]
) -> MMASmemDescriptor:
    """Creates an MMA descriptor for K-major layout directly from swizzle mode.

    Bypasses the legacy Layout pipeline by computing SBO/LBO from swizzle
    parameters. For K-major: SBO = 8 * swizzle_mode.bytes(),
    LBO = 16 bytes (core matrix row size).
    """
    comptime SBO = 8 * swizzle_mode.bytes()
    comptime LBO = _CM_ROW_BYTES
    return MMASmemDescriptor.create[SBO, LBO, swizzle_mode](ptr)
