# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""MLA FP8 index kernel for computing attention scores with paged KV cache."""

from std.sys import size_of
from std.math import ceildiv

from layout import (
    Idx,
    TensorLayout,
    TileTensor,
    row_major,
)

from std.gpu import block_idx, thread_idx
from std.gpu.host import DeviceContext, FuncAttribute

from kv_cache.types import KVCollectionT

from nn.index_fp8 import fp8_index_kernel, IndexSmemStorage
from nn.attention.mha_mask import MHAMask, MaskName
from nn.attention.mha_operand import KVCacheMHAOperand, KVCacheScalesMHAOperand
from nn.attention.mha_utils import dispatch_mask
from nn.topk import topk_gpu

from std.utils.index import Index


# ===----------------------------------------------------------------------=== #
# Mask application kernel
# ===----------------------------------------------------------------------=== #


@__name(t"mla_apply_mask", mangle=True)
def apply_mask_kernel[
    mask_t: MHAMask,
    ScoresLayoutType: TensorLayout,
    scores_origin: MutOrigin,
    VLLayoutType: TensorLayout,
    vl_origin: ImmutOrigin,
