# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #


# The following inputs and golden values were generated using the upstream code
# [here](https://github.com/meta-llama/llama/blob/8fac8befd776bc03242fe7bc2236cdb41b6c609c/llama/model.py).
def k_cache_input[dtype: DType]() raises -> List[Scalar[dtype]]:
    return rebind[List[Scalar[dtype]]](
        [
            # fmt: off
            Float32(-2.5095443725585938), 0.4880010485649109, 0.7845868468284607, 0.02864718623459339, 0.640755295753479, 0.5832474231719971, 1.0669267177581787, -0.4501533806324005,
            -0.18526747822761536, 0.7527588605880737, 0.4047577977180481, 0.17846599221229553, 0.2649095058441162, 1.2731683254241943, -0.0013108636485412717, -0.30360376834869385,
            -1.457029104232788, -0.10233523696660995, -0.5991530418395996, 0.4770564138889313, 0.7261772155761719, 0.09115186333656311, -0.3890652060508728, 0.5279164910316467,
            -0.012685478664934635, 0.24083632230758667, 0.13253536820411682, 0.7642406225204468, 1.095009684562683, 0.3398909568786621, 0.7199674844741821, 0.41140761971473694,
            1.931160569190979, 1.0118638277053833, -1.4364064931869507, -1.1298598051071167, -0.1360345333814621, 1.6354095935821533, 0.6547407507896423, 0.5760045647621155,
            1.1415079832077026, 0.018564576283097267, -1.8058050870895386, 0.9254348874092102, -0.3753443658351898, 1.0330873727798462, -0.6866509318351746, 0.6368136405944824,
            -0.9726738929748535, 0.9584577679634094, 1.6192004680633545, 1.450609803199768, 0.2694815397262573, -0.21037597954273224, -0.7328027486801147, 0.10429783165454865,
            0.3487516939640045, 0.9675941467285156, -0.46568843722343445, 1.6047972440719604, -2.4801201820373535, -0.4175437390804291, -1.1954537630081177, 0.8123369216918945,
            -1.9005532264709473, 0.22857652604579926, 0.02485940419137478, -0.34595024585723877, 0.2868328094482422, -0.7308424115180969, 0.17482025921344757, -1.0939292907714844,
            -1.6021603345870972, 1.3528969287872314, 1.288827657699585, 0.05229547247290611, -1.5468504428863525, 0.7567060589790344, 0.7755194902420044, 2.0265355110168457,
            0.03581761196255684, 0.12058872729539871, -0.8056637048721313, -0.20757682621479034, -0.9319478273391724, -1.5909662246704102, -1.13597571849823, -0.52259761095047,
            -0.5187733173370361, -1.5012763738632202, -1.9266542196273804, 0.1278512328863144, 1.0229133367538452, -0.5557951331138611, 0.7042727470397949, 0.7098760008811951,
            # fmt: on
        ]
    ).copy()


def q_input[dtype: DType]() raises -> List[Scalar[dtype]]:
    return rebind[List[Scalar[dtype]]](
        [
            # fmt: off
            Float32(1.9269152879714966), 1.4872840642929077, 0.9007171988487244, -2.1055209636688232, 0.6784184575080872, -1.2345448732376099, -0.04306747764348984, -1.6046669483184814,
            -0.7521352767944336, 1.6487230062484741, -0.3924786448478699, -1.4036071300506592, -0.7278813123703003, -0.5594301819801331, -0.7688388824462891, 0.7624453902244568,
            1.6423169374465942, -0.1595974713563919, -0.4973975419998169, 0.439589262008667, -0.7581311464309692, 1.078317642211914, 0.8008005619049072, 1.680620551109314,
            1.27912437915802, 1.2964228391647339, 0.610466480255127, 1.334737777709961, -0.2316243201494217, 0.041759490966796875, -0.2515752911567688, 0.859858512878418,
            -1.3846737146377563, -0.8712361454963684, -0.223365917801857, 1.7173614501953125, 0.3188803195953369, -0.42451897263526917, 0.3057209253311157, -0.7745925188064575,
            -1.5575724840164185, 0.9956361055374146, -0.8797858357429504, -0.6011420488357544, -1.2741512060165405, 2.1227850914001465, -1.234653115272522, -0.4879138767719269,
            -0.9138230085372925, -0.6581372618675232, 0.07802387326955795, 0.5258087515830994, -0.48799172043800354, 1.1913690567016602, -0.8140076398849487, -0.7359927892684937,
            -1.4032478332519531, 0.03600366786122322, -0.06347727030515671, 0.6756148934364319, -0.0978068932890892, 1.8445940017700195, -1.184537410736084, 1.3835493326187134,
            1.4451338052749634, 0.8564125299453735, 2.218075752258301, 0.5231655240058899, 0.34664666652679443, -0.19733144342899323, -1.0545889139175415, 1.2779955863952637,
