# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from std.random import random_ui64

from std.gpu.host import DeviceContext, DeviceBuffer
from layout import Idx, TileTensor, coord_to_index_list, row_major
from nn.index_tensor import _index_tensor_impl
from std.testing import assert_equal, assert_true

from std.utils import IndexList


def execute_index_tensor_test[
    data_type: DType,
    //,
    batch_dims: Int,
](
    data_device: TileTensor[data_type, address_space=AddressSpace.GENERIC, ...],
    indices_device: TileTensor[address_space=AddressSpace.GENERIC, ...],
    expected_output_device: TileTensor[
        data_type, address_space=AddressSpace.GENERIC, ...
    ],
    expected_output_device_buffer: DeviceBuffer[data_type],
    ctx: DeviceContext,
) raises:
    # execute the kernel
    var actual_output_device = ctx.enqueue_create_buffer[
        expected_output_device.dtype
    ](expected_output_device.num_elements())
    var actual_output_tensor = TileTensor(
        actual_output_device,
        row_major(
            expected_output_device.layout.shape_coord().make_dynamic[
                DType.int64
            ]()
        ),
    )
    # Convert all tensors to dynamic layouts before calling the kernel
    _index_tensor_impl[batch_dims, target="gpu"](
