# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from std.testing import TestSuite

comptime simd_width = 8


def strsv[
    size: Int
](
    L_ptr_in: UnsafePointer[Float32, _],
    x_ptr_in: UnsafePointer[mut=True, Float32, _],
):
    # assuming size is a multiple of simd_width
    var x_ptr = x_ptr_in
    var L_ptr = L_ptr_in
    var n: Int = size
    var x_solved_storage = InlineArray[Float32, simd_width * simd_width](
        uninitialized=True
    )
    var x_solved = x_solved_storage.unsafe_ptr().mut_cast[True]()

    while True:
        for j in range(simd_width):
            var x_j = x_ptr[j]
            for i in range(j + 1, simd_width):
                x_ptr[i] = x_j.fma(-L_ptr[i + j * size], x_ptr[i])

        n -= simd_width
        if n <= 0:
            return

        # Save the solution of the triangular tile in stack, while
        # packing them as simd vectors.
        var x_vec: SIMD[DType.float32, simd_width]
        for i in range(simd_width):
            # Broadcast one solution value to a simd vector.
            x_vec = x_ptr[i]
            x_solved.store(i * simd_width, x_vec)
