# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

"""Variable-length selective scan kernels for Mamba SSM architecture."""

from std.gpu import (
    block_dim,
    block_idx,
    thread_idx,
)
from layout import TensorLayout, TileTensor
from std.utils.index import IndexList
from std.algorithm import sync_parallelize
from std.gpu.host import DeviceContext
import std.math
from std.math import exp2
from state_space.causal_conv1d import silu
from state_space.selective_scan import softplus

# LOG2E constant for converting exp to exp2 (faster on GPU)
comptime LOG2E = 1.4426950408889634
comptime MAX_DSTATE = 256  # Larger for Mamba-2 models

# Stride types for passing tensor strides to kernels
comptime Strides1D = IndexList[1]
comptime Strides2D = IndexList[2]
comptime Strides3D = IndexList[3]
comptime Strides4D = IndexList[4]


def varlen_selective_state_update_gpu[
    kernel_dtype: DType,
    DSTATE: Int,
    state_LT: TensorLayout,
    x_LT: TensorLayout,
    dt_LT: TensorLayout,
    A_LT: TensorLayout,
    B_LT: TensorLayout,
    C_LT: TensorLayout,
    D_LT: TensorLayout,
