# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from layout import *
from layout.int_tuple import product
from layout.layout_tensor import *
from std.testing import assert_equal


def test_vectorize_2() raises:
    var tensor = LayoutTensor[
        DType.float32,
        Layout(IntTuple(IntTuple(16, 32), 4), IntTuple(IntTuple(32, 1), 512)),
        MutAnyOrigin,
    ].stack_allocation[stack_alignment=16]()

    var n = product(tensor.layout.shape)
    for i in range(n):
        tensor.ptr[i] = Float32(i)

    var frag = tensor._vectorize_2[origin_of(), IntTuple(IntTuple(1, 4), 1)]()
    var crd = RuntimeTuple[IntTuple(2)]()
    var val = frag[crd]
    assert_equal(val[0], 64)
    assert_equal(val[1], 65)
    assert_equal(val[2], 66)
    assert_equal(val[3], 67)

    var frag_linear = tensor._vectorize_2[4]()
    var val_linear = frag_linear[crd]
    assert_equal(val_linear[0], 64)
    assert_equal(val_linear[1], 65)
    assert_equal(val_linear[2], 66)
    assert_equal(val_linear[3], 67)

    var three_dim_tensor = LayoutTensor[
        DType.float32,
        Layout(IntTuple(16, 32, 4), IntTuple(32, 1, 512)),
        MutAnyOrigin,
    ].stack_allocation[stack_alignment=16]()
