# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from layout import Coord, TileTensor
from layout.coord import DynamicCoord
from layout.tile_layout import Layout

from std.utils.index import IndexList


# Reshape assumes inputs are contiguous. It should always be fused last and
# a non-contiguous tensor cannot be fused *into* this as input.
@always_inline
def reshape[
    dtype: DType,
    //,
    output_rank: Int,
](
    input: TileTensor[dtype, ...],
    new_shape: IndexList[output_rank],
) -> TileTensor[
    dtype,
    Layout[
        shape_types=DynamicCoord[DType.int64, output_rank].element_types,
        stride_types=DynamicCoord[DType.int64, output_rank].element_types,
    ],
    input.origin,
    address_space=input.address_space,
]:
    var stride_tuple = type_of(new_shape)()
    var stride: Int = 1

    # Create contiguous strides.
    comptime for i in reversed(range(output_rank)):
        # Start from the back so we can accumulate the strides.
        stride_tuple[i] = stride
        stride *= new_shape[i]

    # Return the a view with the new shape.
    return TileTensor(
