# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Test ping-pong kernel across different shapes.

Documents known limitations:
- FP8: Only works when M % 256 == 0 (BM tile alignment)
- BF16: Works with all M values
"""

from layout import TileTensor, row_major
from std.gpu.host import DeviceContext
import linalg.matmul.vendor.blas as vendor_blas
from linalg.matmul.gpu.amd.amd_ping_pong_matmul import (
    structured_ping_pong_matmul as ping_pong_matmul,
)
from std.testing import assert_true
from std.random import random_float64


def test_shape[
    in_dtype: DType, M: Int, N: Int, K: Int, enable_swizzle: Bool = True
](ctx: DeviceContext) raises:
    """Test a single shape."""
    var device_a = ctx.enqueue_create_buffer[in_dtype](M * K)
    var device_b = ctx.enqueue_create_buffer[in_dtype](N * K)
    var device_c = ctx.enqueue_create_buffer[DType.float32](M * N)
    var device_c_ref = ctx.enqueue_create_buffer[DType.float32](M * N)

    # Use random data to expose precision and swizzle bugs.
    # Small range [-0.5, 0.5] keeps values representable in low-precision formats.
    with device_a.map_to_host() as host_a, device_b.map_to_host() as host_b:
        for i in range(M * K):
            host_a[i] = random_float64(-0.5, 0.5).cast[in_dtype]()
        for i in range(K * N):
            host_b[i] = random_float64(-0.5, 0.5).cast[in_dtype]()

    var a_tt = TileTensor(device_a, row_major[M, K]())
    var b_tt = TileTensor(device_b, row_major[N, K]())
    var c_tt = TileTensor(device_c, row_major[M, N]())
