# flashmask_attention — GPU cuda error 700
# 根因: Q/K/V 某个维度为 0 (num_heads=0 / head_dim=0 / batch=0 / seq_len=0)

# case 1: query num_heads=0
nn.functional.flashmask_attention(query=Tensor.float16((1, 2048, 0, 96)), key=Tensor.float16((1, 2048, 8, 96)), value=Tensor.float16((1, 2048, 8, 96)), startend_row_indices=Tensor.int32((1, 1, 2048, 1)), causal=True)

# case 2: key num_heads=0
nn.functional.flashmask_attention(query=Tensor.float16((1, 2048, 8, 96)), key=Tensor.float16((1, 2048, 0, 96)), value=Tensor.float16((1, 2048, 8, 96)), startend_row_indices=Tensor.int32((1, 1, 2048, 1)), causal=True)

# case 3: key head_dim=0
nn.functional.flashmask_attention(query=Tensor.float16((1, 2048, 8, 96)), key=Tensor.float16((1, 2048, 8, 0)), value=Tensor.float16((1, 2048, 8, 96)), startend_row_indices=Tensor.int32((1, 1, 2048, 1)), causal=True)

# case 4: value batch=0
nn.functional.flashmask_attention(query=Tensor.float16((1, 2048, 8, 96)), key=Tensor.float16((1, 2048, 8, 96)), value=Tensor.float16((0, 2048, 8, 96)), startend_row_indices=Tensor.int32((1, 1, 2048, 1)), causal=True)

# case 5: value seq_len=0
nn.functional.flashmask_attention(query=Tensor.float16((1, 2048, 8, 96)), key=Tensor.float16((1, 2048, 8, 96)), value=Tensor.float16((1, 0, 8, 96)), startend_row_indices=Tensor.int32((1, 1, 2048, 1)), causal=True)

# case 6: value num_heads=0
nn.functional.flashmask_attention(query=Tensor.float16((1, 2048, 8, 96)), key=Tensor.float16((1, 2048, 8, 96)), value=Tensor.float16((1, 2048, 0, 96)), startend_row_indices=Tensor.int32((1, 1, 2048, 1)), causal=True)

# case 7: value head_dim=0
nn.functional.flashmask_attention(query=Tensor.float16((1, 2048, 8, 96)), key=Tensor.float16((1, 2048, 8, 96)), value=Tensor.float16((1, 2048, 8, 0)), startend_row_indices=Tensor.int32((1, 1, 2048, 1)), causal=True)
