# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #


from layout import TileTensor, row_major
from linalg.accumulate import _Accumulator, _simd_load_maybe_partial
from std.testing import *


# TODO: rewrite c-layout comments according to the new struct.
def test_maybe_partial_load() raises:
    comptime simd_size = 4
    comptime size = simd_size + 1

    var a = InlineArray[Float32, size](uninitialized=True)
    for i in range(size):
        a[i] = 1.0

    var vec = _simd_load_maybe_partial[simd_size, False](a.unsafe_ptr(), 0)
    assert_equal(vec, SIMD[DType.float32, simd_size](1.0))

    vec = _simd_load_maybe_partial[simd_size, True](
        a.unsafe_ptr(), simd_size, 1
    )
    assert_equal(vec, SIMD[DType.float32, simd_size](1.0, 0.0, 0.0, 0.0))


def test_accumulate[
    simd_size: Int = 4, num_rows: Int = 2, num_cols: Int = 2, length: Int = 2
]() raises:
    comptime type = DType.float32

    # A: [[ 0.0, 0.0 ],
    #     [ 1.0, 1.0 ],
    #     [ 2.0, 2.0 ],
    #     [ 3.0, 3.0 ]]
    var a = InlineArray[Scalar[type], 2 * num_rows * length](uninitialized=True)
    for i in range(2 * num_rows):
        var a_ptr = a.unsafe_ptr() + i * length
        a_ptr[0] = Scalar[type](i)
