# weight_only_linear — GPU cuda error 700
# 根因: weight 第一维为 0

# case 1: int8, with bias, 3d input
nn.quant.weight_only_linear(x=Tensor.float16((1, 32, 128)), weight=Tensor.int8((0, 128)), bias=Tensor.float16((288,)), weight_scale=Tensor.float16((288,)), weight_dtype="int8", group_size=-1)

# case 2: int4, with bias, 3d input
nn.quant.weight_only_linear(x=Tensor.float16((1, 32, 64)), weight=Tensor.int8((0, 64)), bias=Tensor.float16((256,)), weight_scale=Tensor.float16((256,)), weight_dtype="int4", group_size=-1)

# case 3: int8, with bias, 3d input (same feat dim)
nn.quant.weight_only_linear(x=Tensor.float16((1, 32, 64)), weight=Tensor.int8((0, 64)), bias=Tensor.float16((256,)), weight_scale=Tensor.float16((256,)), weight_dtype="int8", group_size=-1)

# case 4: int8, no bias, 2d input
nn.quant.weight_only_linear(x=Tensor.float16((100, 320)), weight=Tensor.int8((0, 320)), weight_scale=Tensor.float16((512,)), weight_dtype="int8")

# case 5: int8, with bias, 2d input
nn.quant.weight_only_linear(x=Tensor.float16((100, 512)), weight=Tensor.int8((0, 512)), bias=Tensor.float16((1024,)), weight_scale=Tensor.float16((1024,)), weight_dtype="int8")

# case 6: int8, no bias, 2d input
nn.quant.weight_only_linear(x=Tensor.float16((100, 512)), weight=Tensor.int8((0, 512)), weight_scale=Tensor.float16((512,)), weight_dtype="int8")

# case 7: int8, with bias, 2d input (small)
nn.quant.weight_only_linear(x=Tensor.float16((101, 64)), weight=Tensor.int8((0, 64)), bias=Tensor.float16((192,)), weight_scale=Tensor.float16((192,)), weight_dtype="int8")

# case 8: int8, no bias, 2d input (768)
nn.quant.weight_only_linear(x=Tensor.float16((123, 768)), weight=Tensor.int8((0, 768)), bias=None, weight_scale=Tensor.float16((2304,)), weight_dtype="int8")

# case 9: int8, no bias, 2d input (768)
nn.quant.weight_only_linear(x=Tensor.float16((131, 768)), weight=Tensor.int8((0, 768)), bias=None, weight_scale=Tensor.float16((2304,)), weight_dtype="int8")

# case 10: int8, with bias, 3d input (batch=2)
nn.quant.weight_only_linear(x=Tensor.float16((2, 1, 512)), weight=Tensor.int8((0, 512)), bias=Tensor.float16((1024,)), weight_scale=Tensor.float16((1024,)), weight_dtype="int8")

# case 11: int8, with bias, 3d input (batch=2, small)
nn.quant.weight_only_linear(x=Tensor.float16((2, 1, 64)), weight=Tensor.int8((0, 64)), bias=Tensor.float16((192,)), weight_scale=Tensor.float16((192,)), weight_dtype="int8")
