# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from std.sys import size_of

from std.gpu import barrier
from std.gpu.host import DeviceContext
from std.gpu.host.nvidia.tma import TensorMapSwizzle, TMADescriptor
from std.gpu import block_idx, thread_idx
from std.gpu.sync import syncwarp
from layout import Layout, LayoutTensor
from layout._fillers import arange
from layout._utils import ManagedLayoutTensor
from layout.layout_tensor import copy_sram_to_dram
from layout.swizzle import make_swizzle
from layout.tma_async import (
    create_tensor_tile,
    SharedMemBarrier,
    TMATensorTile,
    TMATensorTileArray,
    _idx_product,
)
from std.memory import stack_allocation
from std.testing import assert_equal

from std.utils.index import Index, IndexList


@__llvm_arg_metadata(template_tma_tensormap, `nvvm.grid_constant`)
def test_tma_replace_global_addr_in_gmem_descriptor_kernel[
    dtype: DType,
    num_of_tensormaps: Int,
    src_layout: Layout,
    dst_layout: Layout,
    tile_rank: Int,
    cta_tile_shape: IndexList[tile_rank],
    desc_shape: IndexList[tile_rank],
    thread_layout: Layout,
](
    dst: LayoutTensor[dtype, dst_layout, MutAnyOrigin],
