# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

"""Tile loader for SM100 convolution with hardware im2col TMA.

This module provides TileLoaderTMAIm2col, which uses TMA's im2col addressing
mode to perform implicit GEMM convolution without explicit im2col buffers.
The TMA descriptor encodes convolution geometry and transforms coordinates
on-the-fly during memory loads.
"""

from std.gpu.memory import AddressSpace
from layout.tma_async import SharedMemBarrier, TMATensorTileIm2col
from layout import TensorLayout, TileTensor
from std.utils.index import IndexList


struct TileLoaderTMAIm2col[
    tma_origin: ImmutOrigin,
    dtype: DType,
    tma_rank: Int,
    tile_shape: IndexList[tma_rank],
    desc_shape: IndexList[tma_rank],
    /,
    *,
    cta_group: Int,
](TrivialRegisterPassable):
    """TMA tile loader using hardware im2col for implicit GEMM convolution.

    Uses a TMATensorTileIm2col descriptor (cuTensorMapEncodeIm2col) to perform
    coordinate transformation in TMA hardware. Coordinates are in GEMM space:
    - k_coord: K dimension (C * R * S reduction)
    - m_coord: M dimension (batch * H_out * W_out spatial)

    Parameters:
        tma_origin: Origin of the TMA descriptor pointer.
        dtype: Element data type.
        tma_rank: Rank of the TMA tile/descriptor shapes.
        tile_shape: TMA tile shape as IndexList.
        desc_shape: TMA descriptor shape as IndexList.
