# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from std.gpu.host import DeviceContext
from layout import Idx, TileTensor, row_major
from layout._fillers import random
from nn.conv.conv import conv_cudnn, conv_gpu
from std.testing import assert_almost_equal

from std.utils.index import IndexList


# input: NHWC
# filer: RSCF
def test_conv_cudnn[
    input_dim: IndexList[4],
    filter_dim: IndexList[4],
    output_dim: IndexList[4],
    input_type: DType,
    filter_type: DType,
    output_type: DType,
    stride_dim: IndexList[2],
    dilation_dim: IndexList[2],
    pad_dim: IndexList[
        4
    ],  # Format: [pad_h_before, pad_h_after, pad_w_before, pad_w_after]
    num_groups: Int = 1,
](ctx: DeviceContext) raises:
    print(
        "== test_cudnn_conv_gpu: dtype_in=",
        input_type,
        " dtype_filter=",
        filter_type,
        " dtype_out=",
        output_type,
        " num_groups=",
        num_groups,
    )

    # Extract dimensions
