# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from layout import TileTensor, row_major
from nn.gather_scatter import scatter_nd_generator
from std.testing import assert_equal


@always_inline
@parameter
def use_update[
    dtype: DType, width: Int, //
](input_val: SIMD[dtype, width], update_val: SIMD[dtype, width]) -> SIMD[
    dtype, width
]:
    return update_val


def main() raises:
    def test_scatternd() raises:
        print("== test_scatternd")
        # data: 4x4x4 = 64 elements
        var data_ptr = alloc[Float32](64)
        var data_vals: InlineArray[Float32, 64] = [
            Float32(1),
            2,
            3,
            4,
            5,
            6,
            7,
            8,
            8,
            7,
            6,
            5,
            4,
            3,
            2,
            1,
