# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from std.math import ceildiv, isclose
from std.random import random_float64

from std.gpu.host import DeviceContext
from std.gpu.host.info import A100
from layout import Coord, Idx, TileTensor, row_major
from linalg.bmm import _batched_matmul_gpu
from linalg.matmul.gpu import _matmul_gpu, matmul_kernel_naive, multistage_gemm
from linalg.utils_gpu import MatmulConfig, MatmulKernels, select_config
from std.testing import assert_almost_equal

from std.utils import IndexList


def run_matmul_naive(ctx: DeviceContext, M: Int, N: Int, K: Int) raises:
    print("== run_matmul naive kernel")

    var a_host = alloc[BFloat16](M * K)
    var b_host = alloc[BFloat16](K * N)
    var c_host = alloc[BFloat16](M * N)
    var a_host_n = alloc[Float32](M * K)
    var b_host_n = alloc[Float32](K * N)
    var c_host_n = alloc[Float32](M * N)

    var rand_min = -1.0
    var rand_max = 1.0

    for i in range(M * K):
        var val = random_float64(rand_min, rand_max).cast[DType.float32]()
        a_host[i] = val.cast[DType.bfloat16]()
        a_host_n[i] = a_host[i].cast[DType.float32]()

    for i in range(K * N):
        var val = random_float64(rand_min, rand_max).cast[DType.float32]()
        b_host[i] = val.cast[DType.bfloat16]()
        b_host_n[i] = b_host[i].cast[DType.float32]()
