# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from std.sys import get_defined_int

from std.benchmark import (
    Bench,
    Bencher,
    BenchId,
    BenchMetric,
    ThroughputMeasure,
)
from std.builtin._closure import __ownership_keepalive
from std.gpu import global_idx
from std.gpu.host import DeviceContext
from internal_utils import update_bench_config_args
from std.testing import assert_equal


def vec_func(
    in0: UnsafePointer[Float32, ImmutAnyOrigin],
    in1: UnsafePointer[Float32, ImmutAnyOrigin],
    output: UnsafePointer[Float32, MutAnyOrigin],
    len: Int,
):
    var tid = global_idx.x
    if tid >= len:
        return
    output[tid] = in0[tid] + in1[tid]


@no_inline
def bench_vec_add(
    mut b: Bench, *, block_dim: Int, length: Int, context: DeviceContext
) raises:
    comptime dtype = DType.float32
    var in0_host = alloc[Scalar[dtype]](length)
    var in1_host = alloc[Scalar[dtype]](length)
    var out_host = alloc[Scalar[dtype]](length)
