import triton
import torch


def get_mma_config(M, N, K, GROUP_M, num_warps, num_stages):
    return triton.Config(
        {
            "BLOCK_SIZE_M": M,
            "BLOCK_SIZE_N": N,
            "BLOCK_SIZE_K": K,
            "GROUP_SIZE_M": GROUP_M,
        },
        num_stages=num_stages,
        num_warps=num_warps,
    )


def get_autotune_configs():
    return [
        get_mma_config(64, 32, 16, 8, 4, 4),
        get_mma_config(64, 32, 32, 8, 4, 4),
        get_mma_config(128, 32, 32, 8, 4, 4),
        get_mma_config(128, 64, 16, 4, 4, 4),
        get_mma_config(128, 64, 32, 8, 4, 4),
        get_mma_config(128, 64, 64, 8, 4, 4),
        get_mma_config(128, 64, 64, 8, 8, 4),
        get_mma_config(128, 128, 32, 8, 4, 2),
        get_mma_config(128, 128, 32, 8, 4, 4),
        get_mma_config(128, 128, 32, 8, 4, 6),
        get_mma_config(128, 128, 64, 8, 4, 4),
        get_mma_config(128, 128, 64, 8, 4, 6),
        get_mma_config(128, 128, 64, 8, 8, 4),
    ]


def get_autotune_conv2d_bwd_configs():
    return [
        # BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, num_warps, num_stages
        # GEMM_M=C_OUT, GEMM_N=C_IN*FILTER_H*FILTER_W, GEMM_K=BATCH_SIZE*H_OUT*W_OUT
        get_mma_config(16, 16, 32, 4, 4, 4),
        get_mma_config(16, 16, 32, 4, 4, 6),
        get_mma_config(16, 16, 32, 4, 4, 8),
        get_mma_config(16, 16, 64, 4, 4, 4),
        get_mma_config(16, 16, 64, 4, 4, 6),
        get_mma_config(16, 16, 64, 4, 4, 8),
        get_mma_config(16, 16, 64, 8, 4, 4),
        get_mma_config(16, 16, 128, 4, 4, 4),
        get_mma_config(16, 16, 128, 8, 4, 4),
        get_mma_config(16, 32, 64, 8, 4, 4),
        get_mma_config(16, 32, 128, 4, 4, 4),
        get_mma_config(16, 32, 128, 8, 4, 4),
        get_mma_config(32, 32, 64, 4, 4, 4),
        get_mma_config(32, 32, 64, 8, 4, 4),
        get_mma_config(32, 32, 64, 8, 4, 6),
        get_mma_config(32, 32, 64, 8, 4, 8),
        get_mma_config(32, 32, 128, 8, 4, 4),
        get_mma_config(32, 64, 64, 8, 4, 4),
    ]


def enable_cudnn_optimizations():
    if torch.backends.cudnn.is_available():
        torch.backends.cudnn.benchmark = True
        print("cuDNN benchmark enabled.")
    else:
        print("cuDNN is not available.")


def enable_torch_optimizations(
    allow_tf32=True,
    high_precision=True,
    fp16_reduced_precision=True,
):
    """
    Enables various optimizations in PyTorch for matmul operations.
    """
    if torch.backends.cudnn.is_available():
        torch.backends.cudnn.benchmark = True
        # torch.backends.cudnn.deterministic =True
        print("cuDNN benchmark enabled.")
    if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
        if allow_tf32:
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
            print("TF32 enabled for matmul.")
        else:
            torch.backends.cuda.matmul.allow_tf32 = False
            torch.backends.cudnn.allow_tf32 = False
            print("TF32 disabled for matmul.")

        if high_precision:
            torch.set_float32_matmul_precision("high")
            print("High precision matmul enabled.")

    if fp16_reduced_precision:
        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
        print("Reduced precision reductions enabled for fp16/bf16.")


def disable_torch_optimizations():
    """
    Enables various optimizations in PyTorch for matmul operations.
    """
    if torch.backends.cudnn.is_available():
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        print("cuDNN benchmark disabled.")
    if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
        print("TF32 disabled for matmul.")

        torch.set_float32_matmul_precision("high")
        print("High precision matmul enabled.")

    torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
    torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
    print("Reduced precision reductions disabled for fp16/bf16.")
