production_forward vs paper_forward output: mean_abs=0.0015909688081592321, max_abs=0.037109375
production_forward grad[0] vs paper_forward: mean_abs=0.008148351684212685, max_abs=0.296875, mean_rel=0.07191670686006546, max_rel=82.24017333984375, norm_rel=0.019787410274147987, ref_abs_avg=0.4483058452606201, test_abs_avg=0.44830867648124695
production_forward grad[1] vs paper_forward: mean_abs=3.5689146518707275, max_abs=28.0, mean_rel=0.2113567590713501, max_rel=853.3299560546875, norm_rel=0.020402884110808372, ref_abs_avg=154.62057495117188, test_abs_avg=154.6934051513672
production_forward grad[2] vs paper_forward: mean_abs=0.5898212790489197, max_abs=2.25, mean_rel=0.1548423171043396, max_rel=15.33864974975586, norm_rel=0.02379394881427288, ref_abs_avg=25.02594757080078, test_abs_avg=25.029855728149414
production_forward grad[3] vs paper_forward: mean_abs=0.7476010322570801, max_abs=5.5, mean_rel=0.4033028483390808, max_rel=2624.999755859375, norm_rel=0.023227201774716377, ref_abs_avg=32.39134216308594, test_abs_avg=32.388877868652344
production_forward grad[4] vs paper_forward: mean_abs=0.728201150894165, max_abs=4.5, mean_rel=0.2097283899784088, max_rel=1874.9998779296875, norm_rel=0.022768892347812653, ref_abs_avg=32.15387725830078, test_abs_avg=32.15330505371094
production_forward grad[5] vs paper_forward: mean_abs=0.5406272411346436, max_abs=2.43359375, mean_rel=0.2477346658706665, max_rel=48.58466720581055, norm_rel=0.023470386862754822, ref_abs_avg=23.264392852783203, test_abs_avg=23.244359970092773
production_forward grad[6] vs paper_forward: mean_abs=0.6510639190673828, max_abs=4.125, mean_rel=0.3615352511405945, max_rel=2624.999755859375, norm_rel=0.022956449538469315, ref_abs_avg=28.515300750732422, test_abs_avg=28.51598358154297
production_forward grad[7] vs paper_forward: mean_abs=0.6385107040405273, max_abs=3.875, mean_rel=0.21351633965969086, max_rel=2140.625, norm_rel=0.022718623280525208, ref_abs_avg=28.224096298217773, test_abs_avg=28.222366333007812
production_forward grad[8] vs paper_forward: mean_abs=0.45709002017974854, max_abs=2.060546875, mean_rel=0.13297662138938904, max_rel=20.48066520690918, norm_rel=0.02224912866950035, ref_abs_avg=20.56022834777832, test_abs_avg=20.579389572143555
production_forward grad[9] vs paper_forward: mean_abs=0.5902024507522583, max_abs=4.0, mean_rel=0.3360888957977295, max_rel=2375.0, norm_rel=0.02288617379963398, ref_abs_avg=25.93301010131836, test_abs_avg=25.933307647705078
production_forward grad[10] vs paper_forward: mean_abs=0.57527756690979, max_abs=3.46875, mean_rel=0.2056526392698288, max_rel=1656.2498779296875, norm_rel=0.022430511191487312, ref_abs_avg=25.801651000976562, test_abs_avg=25.799903869628906
production_forward grad[11] vs paper_forward: mean_abs=0.44624602794647217, max_abs=1.75, mean_rel=0.1299598515033722, max_rel=20.287446975708008, norm_rel=0.02215731143951416, ref_abs_avg=19.992389678955078, test_abs_avg=20.032703399658203
production_forward grad[12] vs paper_forward: mean_abs=0.5405467748641968, max_abs=3.25, mean_rel=0.2874146103858948, max_rel=1874.9998779296875, norm_rel=0.022545846179127693, ref_abs_avg=24.124713897705078, test_abs_avg=24.124874114990234
production_forward grad[13] vs paper_forward: mean_abs=0.5283933877944946, max_abs=3.125, mean_rel=0.19668562710285187, max_rel=1976.5623779296875, norm_rel=0.022341901436448097, ref_abs_avg=23.791858673095703, test_abs_avg=23.790372848510742
production_forward grad[14] vs paper_forward: mean_abs=0.4023003578186035, max_abs=1.75, mean_rel=0.08061817288398743, max_rel=4.860461711883545, norm_rel=0.020989591255784035, ref_abs_avg=19.16363525390625, test_abs_avg=19.15840721130371
production_forward grad[15] vs paper_forward: mean_abs=0.5082268118858337, max_abs=3.1875, mean_rel=0.3006656765937805, max_rel=2062.5, norm_rel=0.02246699668467045, ref_abs_avg=22.763225555419922, test_abs_avg=22.764339447021484
production_forward grad[16] vs paper_forward: mean_abs=0.49460893869400024, max_abs=3.25, mean_rel=0.21930190920829773, max_rel=2781.249755859375, norm_rel=0.02221773937344551, ref_abs_avg=22.398113250732422, test_abs_avg=22.397003173828125
production_forward grad[17] vs paper_forward: mean_abs=0.39930903911590576, max_abs=1.375, mean_rel=0.20159776508808136, max_rel=35.49483871459961, norm_rel=0.02353324368596077, ref_abs_avg=16.994983673095703, test_abs_avg=16.991619110107422
production_forward grad[18] vs paper_forward: mean_abs=0.4802919626235962, max_abs=3.0, mean_rel=0.2569091320037842, max_rel=1687.4998779296875, norm_rel=0.02232186868786812, ref_abs_avg=21.628864288330078, test_abs_avg=21.630016326904297
production_forward grad[19] vs paper_forward: mean_abs=0.46895653009414673, max_abs=3.0, mean_rel=0.20424258708953857, max_rel=1632.8123779296875, norm_rel=0.022070012986660004, ref_abs_avg=21.351871490478516, test_abs_avg=21.349647521972656
production_forward grad[20] vs paper_forward: mean_abs=0.35968124866485596, max_abs=1.625, mean_rel=0.20471055805683136, max_rel=67.39720916748047, norm_rel=0.022027965635061264, ref_abs_avg=16.492218017578125, test_abs_avg=16.514856338500977
production_forward grad[21] vs paper_forward: mean_abs=0.4542326331138611, max_abs=2.9375, mean_rel=0.2606085538864136, max_rel=1687.4998779296875, norm_rel=0.02222181297838688, ref_abs_avg=20.542743682861328, test_abs_avg=20.541393280029297
production_forward grad[22] vs paper_forward: mean_abs=0.4448806047439575, max_abs=2.75, mean_rel=0.1723000854253769, max_rel=1203.125, norm_rel=0.02195911295711994, ref_abs_avg=20.37023162841797, test_abs_avg=20.3709659576416
production_forward grad[23] vs paper_forward: mean_abs=0.3711378276348114, max_abs=1.5, mean_rel=0.1287778913974762, max_rel=21.703516006469727, norm_rel=0.02277565747499466, ref_abs_avg=16.08954620361328, test_abs_avg=16.035297393798828
production_forward grad[24] vs paper_forward: mean_abs=0.43281465768814087, max_abs=2.75, mean_rel=0.27190035581588745, max_rel=1906.2498779296875, norm_rel=0.022114479914307594, ref_abs_avg=19.670291900634766, test_abs_avg=19.67083168029785
production_forward grad[25] vs paper_forward: mean_abs=0.42111504077911377, max_abs=2.5, mean_rel=0.1787712275981903, max_rel=1499.9998779296875, norm_rel=0.02184338867664337, ref_abs_avg=19.37619400024414, test_abs_avg=19.374610900878906
production_forward grad[26] vs paper_forward: mean_abs=0.41144195199012756, max_abs=1.484375, mean_rel=0.16254611313343048, max_rel=19.959196090698242, norm_rel=0.024621307849884033, ref_abs_avg=16.925765991210938, test_abs_avg=16.91206169128418
production_forward grad[27] vs paper_forward: mean_abs=0.4981827437877655, max_abs=3.3125, mean_rel=0.30906975269317627, max_rel=3312.499755859375, norm_rel=0.024059277027845383, ref_abs_avg=20.825206756591797, test_abs_avg=20.825448989868164
production_forward grad[28] vs paper_forward: mean_abs=0.4854714870452881, max_abs=3.1875, mean_rel=0.1910092532634735, max_rel=1257.8125, norm_rel=0.02363063022494316, ref_abs_avg=20.63259506225586, test_abs_avg=20.63321304321289
production_forward grad[29] vs paper_forward: mean_abs=0.36438316106796265, max_abs=1.3125, mean_rel=0.1401636004447937, max_rel=8.721378326416016, norm_rel=0.024020345881581306, ref_abs_avg=15.474939346313477, test_abs_avg=15.494500160217285
production_forward grad[30] vs paper_forward: mean_abs=0.46023738384246826, max_abs=3.0, mean_rel=0.29477471113204956, max_rel=1499.9998779296875, norm_rel=0.024430107325315475, ref_abs_avg=18.909706115722656, test_abs_avg=18.90925407409668
production_forward grad[31] vs paper_forward: mean_abs=0.451096773147583, max_abs=3.0, mean_rel=0.2217235416173935, max_rel=1398.4373779296875, norm_rel=0.02409554459154606, ref_abs_avg=18.800636291503906, test_abs_avg=18.79907989501953
production_forward grad[32] vs paper_forward: mean_abs=0.3668745458126068, max_abs=1.5, mean_rel=0.12749116122722626, max_rel=13.808879852294922, norm_rel=0.027472764253616333, ref_abs_avg=13.538631439208984, test_abs_avg=13.511730194091797
production_forward grad[33] vs paper_forward: mean_abs=0.4240168035030365, max_abs=2.75, mean_rel=0.27592530846595764, max_rel=1796.8748779296875, norm_rel=0.024077044799923897, ref_abs_avg=17.677690505981445, test_abs_avg=17.67664337158203
production_forward grad[34] vs paper_forward: mean_abs=0.4143846035003662, max_abs=2.75, mean_rel=0.16893932223320007, max_rel=765.6249389648438, norm_rel=0.0238331388682127, ref_abs_avg=17.446697235107422, test_abs_avg=17.448902130126953
production_forward grad[35] vs paper_forward: mean_abs=0.3267608880996704, max_abs=1.3125, mean_rel=0.18473778665065765, max_rel=22.529098510742188, norm_rel=0.02419128082692623, ref_abs_avg=13.292194366455078, test_abs_avg=13.322484016418457
production_forward grad[36] vs paper_forward: mean_abs=0.3971181809902191, max_abs=2.5, mean_rel=0.24194394052028656, max_rel=1437.4998779296875, norm_rel=0.023917190730571747, ref_abs_avg=16.651264190673828, test_abs_avg=16.650447845458984
production_forward grad[37] vs paper_forward: mean_abs=0.39044612646102905, max_abs=2.6640625, mean_rel=0.1921447515487671, max_rel=1250.0, norm_rel=0.02374613657593727, ref_abs_avg=16.4951114654541, test_abs_avg=16.493083953857422
production_forward grad[38] vs paper_forward: mean_abs=0.31398487091064453, max_abs=1.234375, mean_rel=0.12364263087511063, max_rel=13.9302978515625, norm_rel=0.024217911064624786, ref_abs_avg=13.183891296386719, test_abs_avg=13.150862693786621
production_forward grad[39] vs paper_forward: mean_abs=0.37832361459732056, max_abs=2.671875, mean_rel=0.2620084285736084, max_rel=1281.25, norm_rel=0.023654552176594734, ref_abs_avg=16.048559188842773, test_abs_avg=16.04684829711914
production_forward grad[40] vs paper_forward: mean_abs=0.3692372143268585, max_abs=2.75, mean_rel=0.19061708450317383, max_rel=1046.875, norm_rel=0.023714719340205193, ref_abs_avg=15.611117362976074, test_abs_avg=15.61070442199707
production_forward grad[41] vs paper_forward: mean_abs=0.29868173599243164, max_abs=1.0625, mean_rel=0.11060123145580292, max_rel=12.736675262451172, norm_rel=0.023662660270929337, ref_abs_avg=12.5090913772583, test_abs_avg=12.503185272216797
production_forward grad[42] vs paper_forward: mean_abs=0.35728245973587036, max_abs=2.25, mean_rel=0.24720554053783417, max_rel=1125.0, norm_rel=0.023499753326177597, ref_abs_avg=15.252750396728516, test_abs_avg=15.251564025878906
production_forward grad[43] vs paper_forward: mean_abs=0.3496260643005371, max_abs=2.125, mean_rel=0.19490176439285278, max_rel=957.0311889648438, norm_rel=0.023254912346601486, ref_abs_avg=15.073389053344727, test_abs_avg=15.072786331176758
production_forward grad[44] vs paper_forward: mean_abs=0.27404510974884033, max_abs=1.25, mean_rel=0.18929952383041382, max_rel=19.225034713745117, norm_rel=0.023769523948431015, ref_abs_avg=11.71425724029541, test_abs_avg=11.731437683105469
production_forward grad[45] vs paper_forward: mean_abs=0.3381339907646179, max_abs=2.1875, mean_rel=0.23270508646965027, max_rel=1312.4998779296875, norm_rel=0.023293444886803627, ref_abs_avg=14.573936462402344, test_abs_avg=14.573966026306152
production_forward grad[46] vs paper_forward: mean_abs=0.33158472180366516, max_abs=2.125, mean_rel=0.15034709870815277, max_rel=718.7499389648438, norm_rel=0.022817373275756836, ref_abs_avg=14.576032638549805, test_abs_avg=14.575080871582031
production_forward grad[47] vs paper_forward: mean_abs=0.26073092222213745, max_abs=1.125, mean_rel=0.10674865543842316, max_rel=4.84916877746582, norm_rel=0.022230908274650574, ref_abs_avg=11.541099548339844, test_abs_avg=11.527880668640137
production_forward grad[48] vs paper_forward: mean_abs=0.3244280517101288, max_abs=2.0625, mean_rel=0.21469107270240784, max_rel=1015.6249389648438, norm_rel=0.02299935556948185, ref_abs_avg=14.142468452453613, test_abs_avg=14.142011642456055
production_forward grad[49] vs paper_forward: mean_abs=0.3178289830684662, max_abs=2.0, mean_rel=0.16265767812728882, max_rel=1078.125, norm_rel=0.02282097190618515, ref_abs_avg=13.964330673217773, test_abs_avg=13.964764595031738
production_forward grad[50] vs paper_forward: mean_abs=0.28514564037323, max_abs=1.1953125, mean_rel=0.09831731766462326, max_rel=6.065698146820068, norm_rel=0.023774882778525352, ref_abs_avg=12.328948974609375, test_abs_avg=12.359790802001953
production_forward grad[51] vs paper_forward: mean_abs=0.35961735248565674, max_abs=2.25, mean_rel=0.25456392765045166, max_rel=1437.4998779296875, norm_rel=0.02450413443148136, ref_abs_avg=14.723672866821289, test_abs_avg=14.723365783691406
production_forward grad[52] vs paper_forward: mean_abs=0.3544885516166687, max_abs=2.375, mean_rel=0.19723734259605408, max_rel=1140.625, norm_rel=0.02421320416033268, ref_abs_avg=14.683416366577148, test_abs_avg=14.684968948364258
production_forward grad[53] vs paper_forward: mean_abs=0.2677767276763916, max_abs=1.25, mean_rel=0.0915558710694313, max_rel=12.10095500946045, norm_rel=0.022357672452926636, ref_abs_avg=12.4222412109375, test_abs_avg=12.410839080810547
production_forward grad[54] vs paper_forward: mean_abs=0.33882784843444824, max_abs=2.28125, mean_rel=0.22980639338493347, max_rel=1218.75, norm_rel=0.02430274337530136, ref_abs_avg=13.97381591796875, test_abs_avg=13.973095893859863
production_forward grad[55] vs paper_forward: mean_abs=0.3326941728591919, max_abs=2.375, mean_rel=0.17475172877311707, max_rel=992.1874389648438, norm_rel=0.024161798879504204, ref_abs_avg=13.775853157043457, test_abs_avg=13.774063110351562
production_forward grad[56] vs paper_forward: mean_abs=0.2715741991996765, max_abs=1.125, mean_rel=0.35244113206863403, max_rel=114.36035919189453, norm_rel=0.02571619302034378, ref_abs_avg=10.35863971710205, test_abs_avg=10.358856201171875
production_forward grad[57] vs paper_forward: mean_abs=0.3115929067134857, max_abs=2.25, mean_rel=0.2098907083272934, max_rel=1398.4373779296875, norm_rel=0.024011926725506783, ref_abs_avg=13.018994331359863, test_abs_avg=13.01828670501709
production_forward grad[58] vs paper_forward: mean_abs=0.3083970546722412, max_abs=2.25, mean_rel=0.16161879897117615, max_rel=1312.4998779296875, norm_rel=0.023684460669755936, ref_abs_avg=13.048135757446289, test_abs_avg=13.047735214233398
production_forward grad[59] vs paper_forward: mean_abs=0.23486453294754028, max_abs=1.0, mean_rel=0.2652820348739624, max_rel=79.45822143554688, norm_rel=0.02347489632666111, ref_abs_avg=10.244096755981445, test_abs_avg=10.239326477050781
production_forward grad[60] vs paper_forward: mean_abs=0.2926396131515503, max_abs=2.0, mean_rel=0.22195960581302643, max_rel=1312.4998779296875, norm_rel=0.02358296699821949, ref_abs_avg=12.435601234436035, test_abs_avg=12.434669494628906
production_forward grad[61] vs paper_forward: mean_abs=0.28892982006073, max_abs=1.875, mean_rel=0.17660272121429443, max_rel=1156.25, norm_rel=0.023625481873750687, ref_abs_avg=12.24772834777832, test_abs_avg=12.247979164123535
production_forward grad[62] vs paper_forward: mean_abs=0.233007550239563, max_abs=1.03125, mean_rel=0.07134227454662323, max_rel=5.401581287384033, norm_rel=0.02202901430428028, ref_abs_avg=10.811806678771973, test_abs_avg=10.81933307647705
production_forward grad[63] vs paper_forward: mean_abs=0.27649229764938354, max_abs=1.9375, mean_rel=0.21526473760604858, max_rel=1250.0, norm_rel=0.023007238283753395, ref_abs_avg=12.02518367767334, test_abs_avg=12.024755477905273
production_forward grad[64] vs paper_forward: mean_abs=0.27042174339294434, max_abs=2.5, mean_rel=0.17606981098651886, max_rel=843.7499389648438, norm_rel=0.022845257073640823, ref_abs_avg=11.863580703735352, test_abs_avg=11.860772132873535
production_forward grad[65] vs paper_forward: mean_abs=0.20214855670928955, max_abs=0.84375, mean_rel=0.10889343917369843, max_rel=15.399094581604004, norm_rel=0.02115677483379841, ref_abs_avg=9.521017074584961, test_abs_avg=9.51398754119873
production_forward grad[66] vs paper_forward: mean_abs=0.2624596953392029, max_abs=1.8125, mean_rel=0.21145588159561157, max_rel=843.7499389648438, norm_rel=0.022746438160538673, ref_abs_avg=11.54885482788086, test_abs_avg=11.548436164855957
production_forward grad[67] vs paper_forward: mean_abs=0.25806155800819397, max_abs=2.0, mean_rel=0.1643695831298828, max_rel=749.9999389648438, norm_rel=0.022284872829914093, ref_abs_avg=11.56436538696289, test_abs_avg=11.564817428588867
production_forward grad[68] vs paper_forward: mean_abs=0.20068371295928955, max_abs=1.0, mean_rel=0.2432432472705841, max_rel=47.52890396118164, norm_rel=0.021562591195106506, ref_abs_avg=9.237298965454102, test_abs_avg=9.240493774414062
production_forward grad[69] vs paper_forward: mean_abs=0.250171422958374, max_abs=1.75, mean_rel=0.1942710131406784, max_rel=1062.5, norm_rel=0.022146165370941162, ref_abs_avg=11.291823387145996, test_abs_avg=11.291874885559082
production_forward grad[70] vs paper_forward: mean_abs=0.24436752498149872, max_abs=1.8125, mean_rel=0.15870419144630432, max_rel=699.2186889648438, norm_rel=0.021884413436055183, ref_abs_avg=11.157407760620117, test_abs_avg=11.160604476928711
production_forward grad[71] vs paper_forward: mean_abs=0.1847977638244629, max_abs=0.875, mean_rel=0.0690447986125946, max_rel=4.7681450843811035, norm_rel=0.02127794362604618, ref_abs_avg=8.999303817749023, test_abs_avg=8.986897468566895
production_forward grad[72] vs paper_forward: mean_abs=0.23882758617401123, max_abs=1.9375, mean_rel=0.18314872682094574, max_rel=968.7499389648438, norm_rel=0.021820783615112305, ref_abs_avg=10.954387664794922, test_abs_avg=10.953927993774414
production_forward grad[73] vs paper_forward: mean_abs=0.23161327838897705, max_abs=1.875, mean_rel=0.1582535207271576, max_rel=734.3749389648438, norm_rel=0.02190561406314373, ref_abs_avg=10.604012489318848, test_abs_avg=10.604293823242188
production_forward grad[74] vs paper_forward: mean_abs=0.22736108303070068, max_abs=0.9375, mean_rel=0.2462705671787262, max_rel=61.3666877746582, norm_rel=0.02324555814266205, ref_abs_avg=9.898711204528809, test_abs_avg=9.888178825378418
production_forward grad[75] vs paper_forward: mean_abs=0.2653729319572449, max_abs=2.125, mean_rel=0.2266218662261963, max_rel=1398.4373779296875, norm_rel=0.02279689721763134, ref_abs_avg=11.635418891906738, test_abs_avg=11.635598182678223
production_forward grad[76] vs paper_forward: mean_abs=0.2591458559036255, max_abs=2.125, mean_rel=0.15401917695999146, max_rel=1171.875, norm_rel=0.022790784016251564, ref_abs_avg=11.367654800415039, test_abs_avg=11.366369247436523
production_forward grad[77] vs paper_forward: mean_abs=0.20824408531188965, max_abs=0.6875, mean_rel=0.08602194488048553, max_rel=3.1659905910491943, norm_rel=0.022192703559994698, ref_abs_avg=9.294315338134766, test_abs_avg=9.290700912475586
production_forward grad[78] vs paper_forward: mean_abs=0.2442096471786499, max_abs=2.0, mean_rel=0.18323254585266113, max_rel=851.5624389648438, norm_rel=0.02264542691409588, ref_abs_avg=10.799137115478516, test_abs_avg=10.79883098602295
production_forward grad[79] vs paper_forward: mean_abs=0.2427668273448944, max_abs=2.0, mean_rel=0.16851770877838135, max_rel=773.4374389648438, norm_rel=0.022554604336619377, ref_abs_avg=10.784515380859375, test_abs_avg=10.784950256347656
production_forward grad[80] vs paper_forward: mean_abs=0.20513224601745605, max_abs=0.75, mean_rel=0.07180675864219666, max_rel=2.635840892791748, norm_rel=0.022752469405531883, ref_abs_avg=9.140271186828613, test_abs_avg=9.147586822509766
production_forward grad[81] vs paper_forward: mean_abs=0.23243765532970428, max_abs=2.25, mean_rel=0.18436582386493683, max_rel=874.9999389648438, norm_rel=0.022143373265862465, ref_abs_avg=10.504345893859863, test_abs_avg=10.50464153289795
production_forward grad[82] vs paper_forward: mean_abs=0.22689133882522583, max_abs=1.75, mean_rel=0.1542598009109497, max_rel=843.7499389648438, norm_rel=0.022128058597445488, ref_abs_avg=10.30517578125, test_abs_avg=10.30815601348877
production_forward grad[83] vs paper_forward: mean_abs=0.18149924278259277, max_abs=0.84375, mean_rel=0.10535785555839539, max_rel=6.275761127471924, norm_rel=0.02194127067923546, ref_abs_avg=8.273307800292969, test_abs_avg=8.269660949707031
production_forward grad[84] vs paper_forward: mean_abs=0.21010860800743103, max_abs=1.75, mean_rel=0.17496149241924286, max_rel=812.4999389648438, norm_rel=0.021580884233117104, ref_abs_avg=9.778024673461914, test_abs_avg=9.778023719787598
production_forward grad[85] vs paper_forward: mean_abs=0.20928502082824707, max_abs=1.9375, mean_rel=0.13949403166770935, max_rel=664.0624389648438, norm_rel=0.021028006449341774, ref_abs_avg=9.985664367675781, test_abs_avg=9.981639862060547
production_forward grad[86] vs paper_forward: mean_abs=0.1724340319633484, max_abs=0.78125, mean_rel=0.07224014401435852, max_rel=2.522777795791626, norm_rel=0.020845942199230194, ref_abs_avg=8.282792091369629, test_abs_avg=8.273399353027344
production_forward grad[87] vs paper_forward: mean_abs=0.20402896404266357, max_abs=1.96875, mean_rel=0.17440493404865265, max_rel=749.9999389648438, norm_rel=0.02095373347401619, ref_abs_avg=9.784822463989258, test_abs_avg=9.784955978393555
production_forward grad[88] vs paper_forward: mean_abs=0.19388791918754578, max_abs=1.625, mean_rel=0.14899751543998718, max_rel=773.4374389648438, norm_rel=0.020516443997621536, ref_abs_avg=9.48837947845459, test_abs_avg=9.48553466796875
production_forward grad[89] vs paper_forward: mean_abs=0.15633082389831543, max_abs=0.63671875, mean_rel=0.07911847531795502, max_rel=7.269948482513428, norm_rel=0.021048393100500107, ref_abs_avg=7.524419784545898, test_abs_avg=7.534184455871582
production_forward grad[90] vs paper_forward: mean_abs=0.1895187795162201, max_abs=2.0, mean_rel=0.16729584336280823, max_rel=625.0, norm_rel=0.020464692264795303, ref_abs_avg=9.305545806884766, test_abs_avg=9.306333541870117
production_forward grad[91] vs paper_forward: mean_abs=0.18897303938865662, max_abs=2.3125, mean_rel=0.13934631645679474, max_rel=1046.875, norm_rel=0.020806219428777695, ref_abs_avg=9.18683910369873, test_abs_avg=9.185504913330078
production_forward grad[92] vs paper_forward: mean_abs=0.1520010530948639, max_abs=0.53125, mean_rel=0.718239426612854, max_rel=273.5296630859375, norm_rel=0.019237320870161057, ref_abs_avg=7.884075164794922, test_abs_avg=7.8904829025268555
production_forward grad[93] vs paper_forward: mean_abs=0.18095272779464722, max_abs=2.0, mean_rel=0.1590661108493805, max_rel=781.2499389648438, norm_rel=0.019939236342906952, ref_abs_avg=9.168222427368164, test_abs_avg=9.16820240020752
production_forward grad[94] vs paper_forward: mean_abs=0.17392084002494812, max_abs=2.0625, mean_rel=0.13798888027668, max_rel=648.4374389648438, norm_rel=0.02014978788793087, ref_abs_avg=8.767361640930176, test_abs_avg=8.768778800964355
production_forward grad[95] vs paper_forward: mean_abs=0.143407940864563, max_abs=0.53125, mean_rel=0.09281758964061737, max_rel=10.561132431030273, norm_rel=0.019541168585419655, ref_abs_avg=7.295104026794434, test_abs_avg=7.299098968505859
production_forward grad[96] vs paper_forward: mean_abs=0.17028909921646118, max_abs=1.6875, mean_rel=0.15974709391593933, max_rel=937.4999389648438, norm_rel=0.019331052899360657, ref_abs_avg=8.941450119018555, test_abs_avg=8.942192077636719
production_forward grad[97] vs paper_forward: mean_abs=0.160007506608963, max_abs=1.875, mean_rel=0.12342413514852524, max_rel=507.8124694824219, norm_rel=0.018361635506153107, ref_abs_avg=8.836846351623535, test_abs_avg=8.830492973327637

import math
import random
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl

DEVICE = "cuda"
DTYPE = torch.bfloat16

L = 32
BLOCK_SIZE = 8
NUM_BLOCKS = math.ceil(L / BLOCK_SIZE) + 1

B, T, D = 32, 1024, 512
BT = B * T

EPS = torch.finfo(torch.float32).eps

autotune_configs = [
    triton.Config({}, num_warps=num_warps, num_stages=num_stages)
    for num_warps in [1, 2, 4, 8, 16]
    for num_stages in [1, 2, 3, 4]
]


# @triton.autotune(
#     configs=autotune_configs,
#     key=["NUM_SOURCE_BLOCKS", "HIDDEN_DIM", "NUM_QUERIES_PER_BLOCK", "PADDED_SRC"],
# )
@triton.jit
def phase_1_batched_interblock_attention_kernel(
    block_representations_ptr,
    pseudo_queries_ptr,
    softmax_normalized_output_ptr,
    lse_ptr,
    eps,
    NUM_SOURCE_BLOCKS: tl.constexpr,
    BT: tl.constexpr,
    HIDDEN_DIM: tl.constexpr,
    NUM_QUERIES_PER_BLOCK: tl.constexpr,
    PADDED_SRC: tl.constexpr,
):
    batch_seq_idx = tl.program_id(0)

    source_block_range = tl.arange(0, PADDED_SRC)[:, None]
    hidden_dim_range = tl.arange(0, HIDDEN_DIM)[None, :]
    valid_block_mask_2d = source_block_range < NUM_SOURCE_BLOCKS

    valid_block_mask_1d = tl.arange(0, PADDED_SRC) < NUM_SOURCE_BLOCKS

    source_block_values = tl.load(
        block_representations_ptr
        + source_block_range * (BT * HIDDEN_DIM)
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range,
        mask=valid_block_mask_2d,
        other=0.0,
    ).to(tl.float32)

    squared_norm_sum = tl.sum(source_block_values * source_block_values, axis=1)
    inverse_rms_norm = tl.rsqrt(squared_norm_sum / float(HIDDEN_DIM) + eps)

    hidden_dim_range_1d = tl.arange(0, HIDDEN_DIM)

    for layer_offset in tl.static_range(NUM_QUERIES_PER_BLOCK):
        pseudo_query_vector = tl.load(
            pseudo_queries_ptr + layer_offset * HIDDEN_DIM + hidden_dim_range,
            eviction_policy="evict_last",
        ).to(tl.float32)

        attention_logits = (
            tl.sum(source_block_values * pseudo_query_vector, axis=1) * inverse_rms_norm
        )
        attention_logits = tl.where(
            valid_block_mask_1d, attention_logits, float("-inf")
        )

        max_attention_logit = tl.max(attention_logits)
        exp_attention_logits = tl.exp(attention_logits - max_attention_logit)
        exp_sum = tl.sum(exp_attention_logits)

        unnormalized_output = tl.sum(
            exp_attention_logits[:, None] * source_block_values, axis=0
        )
        normalized_output = (unnormalized_output / exp_sum).to(tl.bfloat16)

        tl.store(
            softmax_normalized_output_ptr
            + layer_offset * BT * HIDDEN_DIM
            + batch_seq_idx * HIDDEN_DIM
            + hidden_dim_range_1d,
            normalized_output,
        )
        tl.store(
            lse_ptr + layer_offset * BT + batch_seq_idx,
            max_attention_logit + tl.log(exp_sum),
        )


def phase_1_batched_interblock_attention(
    block_representations,
    pseudo_queries,
    softmax_outputs,
    lses,
    eps=None,
):
    NUM_QUERIES = pseudo_queries.shape[0]
    NUM_SOURCE_BLOCKS = block_representations.shape[0]

    if eps is None:
        eps = EPS

    phase_1_batched_interblock_attention_kernel[(BT,)](
        block_representations,
        pseudo_queries,
        softmax_outputs,
        lses,
        eps,
        NUM_SOURCE_BLOCKS,
        BT,
        D,
        NUM_QUERIES,
        triton.next_power_of_2(NUM_SOURCE_BLOCKS),
    )


# @triton.autotune(
#     configs=autotune_configs,
#     key=["HIDDEN_DIM"],
#     restore_value=[
#         "interblock_normalized_output_ptr",
#     ],
# )
@triton.jit
def phase_2_online_softmax_merge_intrablock_kernel(
    intrablock_partial_sum_ptr,
    pseudo_query_ptr,
    interblock_normalized_output_ptr,
    interblock_lse_ptr,
    eps,
    HIDDEN_DIM: tl.constexpr,
):
    batch_seq_idx = tl.program_id(0)
    hidden_dim_range = tl.arange(0, HIDDEN_DIM)

    intrablock_partial_sum = tl.load(
        intrablock_partial_sum_ptr + batch_seq_idx * HIDDEN_DIM + hidden_dim_range
    ).to(tl.float32)
    pseudo_query_vector = tl.load(
        pseudo_query_ptr + hidden_dim_range, eviction_policy="evict_last"
    ).to(tl.float32)

    interblock_lse = tl.load(interblock_lse_ptr + batch_seq_idx)
    interblock_normalized_output = tl.load(
        interblock_normalized_output_ptr + batch_seq_idx * HIDDEN_DIM + hidden_dim_range
    ).to(tl.float32)

    squared_norm_sum = tl.sum(intrablock_partial_sum * intrablock_partial_sum)
    inverse_rms_norm = tl.rsqrt(squared_norm_sum / float(HIDDEN_DIM) + eps)

    intrablock_logit = (
        tl.sum(intrablock_partial_sum * pseudo_query_vector) * inverse_rms_norm
    )
    merged_max = tl.maximum(interblock_lse, intrablock_logit)
    interblock_weight = tl.exp(interblock_lse - merged_max)
    intrablock_weight = tl.exp(intrablock_logit - merged_max)
    exp_sum = interblock_weight + intrablock_weight
    merged_output = (
        interblock_weight * interblock_normalized_output
        + intrablock_weight * intrablock_partial_sum
    ) / exp_sum

    tl.store(
        interblock_normalized_output_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range,
        merged_output.to(tl.bfloat16),
    )


def phase_2_online_softmax_merge_intrablock(
    intrablock_partial_sum,
    pseudo_query,
    interblock_normalized_output,
    interblock_lse,
    eps=None,
):
    if eps is None:
        eps = EPS

    phase_2_online_softmax_merge_intrablock_kernel[(BT,)](
        intrablock_partial_sum,
        pseudo_query,
        interblock_normalized_output,
        interblock_lse,
        eps,
        D,
    )


# @triton.autotune(
#     configs=autotune_configs,
#     key=["NUM_SOURCE_BLOCKS", "HIDDEN_DIM", "NUM_QUERIES_PER_BLOCK", "PADDED_SRC"],
#     restore_value=[
#         "grad_block_representations_accumulator_ptr",
#     ],
# )
@triton.jit
def phase_1_batched_interblock_attention_backward_kernel(
    block_representations_ptr,
    pseudo_queries_ptr,
    lse_ptr,
    grad_softmax_normalized_output_ptr,
    grad_lse_ptr,
    grad_block_representations_accumulator_ptr,
    grad_pseudo_queries_partial_ptr,
    eps,
    NUM_SOURCE_BLOCKS: tl.constexpr,
    BT: tl.constexpr,
    HIDDEN_DIM: tl.constexpr,
    NUM_QUERIES_PER_BLOCK: tl.constexpr,
    PADDED_SRC: tl.constexpr,
    HAS_GRAD_LSE: tl.constexpr,
):
    batch_seq_idx = tl.program_id(0)

    source_block_range = tl.arange(0, PADDED_SRC)[:, None]
    source_block_range_1d = tl.arange(0, PADDED_SRC)

    hidden_dim_range = tl.arange(0, HIDDEN_DIM)[None, :]
    hidden_dim_range_1d = tl.arange(0, HIDDEN_DIM)

    valid_block_mask_2d = source_block_range < NUM_SOURCE_BLOCKS
    valid_block_mask_1d = source_block_range_1d < NUM_SOURCE_BLOCKS

    source_block_values = tl.load(
        block_representations_ptr
        + source_block_range * (BT * HIDDEN_DIM)
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range,
        mask=valid_block_mask_2d,
        other=0.0,
    ).to(tl.float32)

    squared_norm_sum = tl.sum(source_block_values * source_block_values, axis=1)
    inverse_rms_norm = tl.rsqrt(squared_norm_sum / float(HIDDEN_DIM) + eps)
    inverse_rms_norm_cubed = inverse_rms_norm * inverse_rms_norm * inverse_rms_norm

    for layer_offset in tl.static_range(NUM_QUERIES_PER_BLOCK):
        pseudo_query_vector = tl.load(
            pseudo_queries_ptr + layer_offset * HIDDEN_DIM + hidden_dim_range,
            eviction_policy="evict_last",
        ).to(tl.float32)

        grad_attention_output = tl.load(
            grad_softmax_normalized_output_ptr
            + layer_offset * BT * HIDDEN_DIM
            + batch_seq_idx * HIDDEN_DIM
            + hidden_dim_range_1d,
        ).to(tl.float32)

        if HAS_GRAD_LSE:
            grad_logsumexp = tl.load(
                grad_lse_ptr + layer_offset * BT + batch_seq_idx
            ).to(tl.float32)
        else:
            grad_logsumexp = 0.0

        forward_logsumexp = tl.load(lse_ptr + layer_offset * BT + batch_seq_idx).to(
            tl.float32
        )

        pseudo_query_source_dot = tl.sum(
            source_block_values * pseudo_query_vector,
            axis=1,
        )

        attention_logits = pseudo_query_source_dot * inverse_rms_norm
        attention_logits = tl.where(
            valid_block_mask_1d,
            attention_logits,
            float("-inf"),
        )

        softmax_probabilities = tl.exp(attention_logits - forward_logsumexp)

        grad_output_dot_source_values = tl.sum(
            source_block_values * grad_attention_output[None, :],
            axis=1,
        )

        grad_output_dot_expected_value = tl.sum(
            softmax_probabilities * grad_output_dot_source_values,
            axis=0,
        )

        grad_attention_logits = softmax_probabilities * (
            grad_logsumexp
            + grad_output_dot_source_values
            - grad_output_dot_expected_value
        )

        grad_source_from_value_path = (
            softmax_probabilities[:, None] * grad_attention_output[None, :]
        )

        grad_source_from_logit_path = grad_attention_logits[:, None] * (
            inverse_rms_norm[:, None] * pseudo_query_vector
            - pseudo_query_source_dot[:, None]
            * inverse_rms_norm_cubed[:, None]
            * source_block_values
            / float(HIDDEN_DIM)
        )

        grad_source_block_values = (
            grad_source_from_value_path + grad_source_from_logit_path
        )

        grad_source_block_values = tl.where(
            valid_block_mask_2d,
            grad_source_block_values,
            0.0,
        )

        grad_pseudo_query = tl.sum(
            grad_attention_logits[:, None]
            * inverse_rms_norm[:, None]
            * source_block_values,
            axis=0,
        )

        tl.atomic_add(
            grad_block_representations_accumulator_ptr
            + source_block_range * (BT * HIDDEN_DIM)
            + batch_seq_idx * HIDDEN_DIM
            + hidden_dim_range,
            grad_source_block_values,
            mask=valid_block_mask_2d,
            sem="relaxed",
        )

        tl.store(
            grad_pseudo_queries_partial_ptr
            + layer_offset * BT * HIDDEN_DIM
            + batch_seq_idx * HIDDEN_DIM
            + hidden_dim_range_1d,
            grad_pseudo_query,
        )


reduce_configs = [
    triton.Config(
        {
            "BLOCK_BATCH_SEQ": block_batch_seq,
            "BLOCK_HIDDEN": block_hidden,
        },
        num_warps=num_warps,
        num_stages=1,
    )
    for block_batch_seq in [64, 128, 256]
    for block_hidden in [16, 32]
    for num_warps in [4, 8]
]


@triton.autotune(
    configs=reduce_configs,
    key=["NUM_BATCH_SEQ", "HIDDEN_DIM", "NUM_QUERIES_PER_BLOCK"],
    restore_value=[
        "grad_pseudo_queries_accumulator_ptr",
    ],
)
@triton.jit
def phase_1_reduce_grad_pseudo_queries_kernel(
    grad_pseudo_queries_partial_ptr,
    grad_pseudo_queries_accumulator_ptr,
    NUM_BATCH_SEQ: tl.constexpr,
    HIDDEN_DIM: tl.constexpr,
    NUM_QUERIES_PER_BLOCK: tl.constexpr,
    BLOCK_BATCH_SEQ: tl.constexpr,
    BLOCK_HIDDEN: tl.constexpr,
):
    batch_seq_block_idx = tl.program_id(0)
    query_idx = tl.program_id(1)
    hidden_block_idx = tl.program_id(2)

    batch_seq_offsets = batch_seq_block_idx * BLOCK_BATCH_SEQ + tl.arange(
        0, BLOCK_BATCH_SEQ
    )

    hidden_offsets = hidden_block_idx * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN)

    grad_tile = tl.load(
        grad_pseudo_queries_partial_ptr
        + query_idx * NUM_BATCH_SEQ * HIDDEN_DIM
        + batch_seq_offsets[:, None] * HIDDEN_DIM
        + hidden_offsets[None, :],
        mask=(
            (batch_seq_offsets[:, None] < NUM_BATCH_SEQ)
            & (hidden_offsets[None, :] < HIDDEN_DIM)
            & (query_idx < NUM_QUERIES_PER_BLOCK)
        ),
        other=0.0,
    ).to(tl.float32)

    grad_reduced = tl.sum(grad_tile, axis=0)

    tl.atomic_add(
        grad_pseudo_queries_accumulator_ptr + query_idx * HIDDEN_DIM + hidden_offsets,
        grad_reduced,
        mask=((hidden_offsets < HIDDEN_DIM) & (query_idx < NUM_QUERIES_PER_BLOCK)),
        sem="relaxed",
    )


def phase_1_batched_interblock_attention_backward(
    block_representations,
    pseudo_queries,
    lses,
    grad_softmax_outputs,
    grad_lses,
    grad_block_representations,
    grad_pseudo_queries,
    grad_pseudo_queries_partial,
    eps=None,
):
    NUM_QUERIES = pseudo_queries.shape[0]
    NUM_SOURCE_BLOCKS = block_representations.shape[0]

    if eps is None:
        eps = EPS

    has_grad_lses = grad_lses is not None
    if grad_lses is None:
        grad_lses = lses

    phase_1_batched_interblock_attention_backward_kernel[(BT,)](
        block_representations,
        pseudo_queries,
        lses,
        grad_softmax_outputs,
        grad_lses,
        grad_block_representations,
        grad_pseudo_queries_partial,
        eps,
        NUM_SOURCE_BLOCKS,
        BT,
        D,
        NUM_QUERIES,
        triton.next_power_of_2(NUM_SOURCE_BLOCKS),
        has_grad_lses,
    )

    phase_1_reduce_grad_pseudo_queries_kernel[
        lambda META: (
            triton.cdiv(BT, META["BLOCK_BATCH_SEQ"]),
            NUM_QUERIES,
            triton.cdiv(D, META["BLOCK_HIDDEN"]),
        )
    ](
        grad_pseudo_queries_partial,
        grad_pseudo_queries,
        BT,
        D,
        NUM_QUERIES,
    )


# @triton.autotune(
#     configs=autotune_configs,
#     key=["HIDDEN_DIM"],
#     restore_value=[
#         "grad_intrablock_partial_sum_accumulator_ptr",
#     ],
# )
@triton.jit
def phase_2_online_softmax_merge_intrablock_backward_kernel(
    intrablock_partial_sum_ptr,
    pseudo_query_ptr,
    phase1_interblock_normalized_output_ptr,
    phase1_interblock_logsumexp_ptr,
    grad_merged_attention_output_ptr,
    grad_intrablock_partial_sum_accumulator_ptr,
    grad_pseudo_query_partial_ptr,
    grad_phase1_interblock_normalized_output_ptr,
    grad_phase1_interblock_logsumexp_ptr,
    eps,
    HIDDEN_DIM: tl.constexpr,
):
    batch_seq_idx = tl.program_id(0)
    hidden_dim_range = tl.arange(0, HIDDEN_DIM)

    intrablock_partial_sum = tl.load(
        intrablock_partial_sum_ptr + batch_seq_idx * HIDDEN_DIM + hidden_dim_range
    ).to(tl.float32)

    pseudo_query = tl.load(
        pseudo_query_ptr + hidden_dim_range,
        eviction_policy="evict_last",
    ).to(tl.float32)

    phase1_interblock_normalized_output = tl.load(
        phase1_interblock_normalized_output_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range
    ).to(tl.float32)

    phase1_interblock_logsumexp = tl.load(
        phase1_interblock_logsumexp_ptr + batch_seq_idx
    ).to(tl.float32)

    grad_merged_attention_output = tl.load(
        grad_merged_attention_output_ptr + batch_seq_idx * HIDDEN_DIM + hidden_dim_range
    ).to(tl.float32)

    intrablock_partial_sum_squared_norm = tl.sum(
        intrablock_partial_sum * intrablock_partial_sum
    )
    intrablock_inverse_rms_norm = tl.rsqrt(
        intrablock_partial_sum_squared_norm / float(HIDDEN_DIM) + eps
    )

    pseudo_query_intrablock_dot = tl.sum(intrablock_partial_sum * pseudo_query)
    phase2_intrablock_logit = pseudo_query_intrablock_dot * intrablock_inverse_rms_norm

    online_softmax_shift = tl.maximum(
        phase1_interblock_logsumexp,
        phase2_intrablock_logit,
    )
    phase1_partition_weight = tl.exp(phase1_interblock_logsumexp - online_softmax_shift)
    phase2_partition_weight = tl.exp(phase2_intrablock_logit - online_softmax_shift)
    merged_partition_weight_sum = phase1_partition_weight + phase2_partition_weight

    phase1_merge_probability = phase1_partition_weight / merged_partition_weight_sum
    phase2_merge_probability = phase2_partition_weight / merged_partition_weight_sum

    grad_phase1_interblock_normalized_output = (
        phase1_merge_probability * grad_merged_attention_output
    )
    grad_intrablock_partial_sum_from_value_path = (
        phase2_merge_probability * grad_merged_attention_output
    )

    grad_output_dot_interblock_minus_intrablock = tl.sum(
        grad_merged_attention_output
        * (phase1_interblock_normalized_output - intrablock_partial_sum)
    )

    merge_probability_product = phase1_merge_probability * phase2_merge_probability

    grad_phase1_interblock_logsumexp = (
        merge_probability_product * grad_output_dot_interblock_minus_intrablock
    )

    grad_phase2_intrablock_logit = (
        -merge_probability_product * grad_output_dot_interblock_minus_intrablock
    )

    intrablock_inverse_rms_norm_cubed = (
        intrablock_inverse_rms_norm
        * intrablock_inverse_rms_norm
        * intrablock_inverse_rms_norm
    )

    grad_intrablock_partial_sum_from_logit_path = grad_phase2_intrablock_logit * (
        intrablock_inverse_rms_norm * pseudo_query
        - pseudo_query_intrablock_dot
        * intrablock_inverse_rms_norm_cubed
        * intrablock_partial_sum
        / float(HIDDEN_DIM)
    )

    grad_pseudo_query = (
        grad_phase2_intrablock_logit
        * intrablock_inverse_rms_norm
        * intrablock_partial_sum
    )
    grad_intrablock_partial_sum = (
        grad_intrablock_partial_sum_from_value_path
        + grad_intrablock_partial_sum_from_logit_path
    )

    grad_intrablock_ptr = (
        grad_intrablock_partial_sum_accumulator_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range
    )

    tl.store(
        grad_intrablock_ptr,
        tl.load(grad_intrablock_ptr).to(tl.float32) + grad_intrablock_partial_sum,
    )

    tl.store(
        grad_pseudo_query_partial_ptr + batch_seq_idx * HIDDEN_DIM + hidden_dim_range,
        grad_pseudo_query,
    )

    tl.store(
        grad_phase1_interblock_normalized_output_ptr
        + batch_seq_idx * HIDDEN_DIM
        + hidden_dim_range,
        grad_phase1_interblock_normalized_output,
    )

    tl.store(
        grad_phase1_interblock_logsumexp_ptr + batch_seq_idx,
        grad_phase1_interblock_logsumexp,
    )


@triton.autotune(
    configs=reduce_configs,
    key=["NUM_BATCH_SEQ", "HIDDEN_DIM"],
    restore_value=[
        "grad_pseudo_query_accumulator_ptr",
    ],
)
@triton.jit
def phase_2_reduce_grad_pseudo_query_kernel(
    grad_pseudo_query_partial_ptr,
    grad_pseudo_query_accumulator_ptr,
    NUM_BATCH_SEQ: tl.constexpr,
    HIDDEN_DIM: tl.constexpr,
    BLOCK_BATCH_SEQ: tl.constexpr,
    BLOCK_HIDDEN: tl.constexpr,
):
    batch_seq_block_idx = tl.program_id(0)
    hidden_block_idx = tl.program_id(1)

    batch_seq_offsets = batch_seq_block_idx * BLOCK_BATCH_SEQ + tl.arange(
        0, BLOCK_BATCH_SEQ
    )
    hidden_offsets = hidden_block_idx * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN)

    grad_tile = tl.load(
        grad_pseudo_query_partial_ptr
        + batch_seq_offsets[:, None] * HIDDEN_DIM
        + hidden_offsets[None, :],
        mask=(
            (batch_seq_offsets[:, None] < NUM_BATCH_SEQ)
            & (hidden_offsets[None, :] < HIDDEN_DIM)
        ),
        other=0.0,
    ).to(tl.float32)

    grad_reduced = tl.sum(grad_tile, axis=0)

    tl.atomic_add(
        grad_pseudo_query_accumulator_ptr + hidden_offsets,
        grad_reduced,
        mask=hidden_offsets < HIDDEN_DIM,
        sem="relaxed",
    )


def phase_2_online_softmax_merge_intrablock_backward(
    intrablock_partial_sum,
    pseudo_query,
    phase1_interblock_normalized_output,
    phase1_interblock_logsumexp,
    grad_merged_attention_output,
    grad_intrablock_partial_sum,
    grad_pseudo_query,
    grad_phase1_interblock_normalized_output,
    grad_phase1_interblock_logsumexp,
    grad_pseudo_query_partial,
    eps=None,
):
    if eps is None:
        eps = EPS

    phase_2_online_softmax_merge_intrablock_backward_kernel[(BT,)](
        intrablock_partial_sum,
        pseudo_query,
        phase1_interblock_normalized_output,
        phase1_interblock_logsumexp,
        grad_merged_attention_output,
        grad_intrablock_partial_sum,
        grad_pseudo_query_partial,
        grad_phase1_interblock_normalized_output,
        grad_phase1_interblock_logsumexp,
        eps,
        D,
    )

    phase_2_reduce_grad_pseudo_query_kernel[
        lambda META: (
            triton.cdiv(BT, META["BLOCK_BATCH_SEQ"]),
            triton.cdiv(D, META["BLOCK_HIDDEN"]),
        )
    ](
        grad_pseudo_query_partial,
        grad_pseudo_query,
        BT,
        D,
    )


class BlockwiseAttentionFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, pseudo_queries, layers, eps, *flat_layer_params):
        block_representations = torch.empty(
            NUM_BLOCKS,
            B,
            T,
            D,
            device=DEVICE,
            dtype=inputs.dtype,
        )
        block_representations[0].copy_(inputs)

        block_attn_out_scratch = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            D,
            device=DEVICE,
            dtype=torch.bfloat16,
        )

        block_lse_scratch = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            device=DEVICE,
            dtype=torch.float32,
        )

        for block_start in range(0, L, BLOCK_SIZE):
            curr_block_idx = block_start // BLOCK_SIZE + 1
            num_queries = min(BLOCK_SIZE, L - block_start)

            block_attn_out = block_attn_out_scratch[:num_queries]
            block_lse = block_lse_scratch[:num_queries]

            phase_1_batched_interblock_attention(
                block_representations[:curr_block_idx],
                pseudo_queries[block_start : block_start + num_queries],
                block_attn_out,
                block_lse,
                eps=eps,
            )

            curr_block = block_representations[curr_block_idx]

            for query_offset in range(num_queries):
                i = block_start + query_offset

                if query_offset != 0:
                    phase_2_online_softmax_merge_intrablock(
                        curr_block,
                        pseudo_queries[i],
                        block_attn_out[query_offset],
                        block_lse[query_offset],
                        eps=eps,
                    )

                update = layers[i](block_attn_out[query_offset])

                if query_offset == 0:
                    curr_block.copy_(update)
                else:
                    curr_block.add_(update)

        final_out = torch.empty(
            B,
            T,
            D,
            device=DEVICE,
            dtype=inputs.dtype,
        )

        final_lse_scratch = torch.empty(
            1,
            B,
            T,
            device=DEVICE,
            dtype=torch.float32,
        )

        phase_1_batched_interblock_attention(
            block_representations,
            pseudo_queries[-1:],
            final_out.unsqueeze(0),
            final_lse_scratch,
            eps=eps,
        )

        ctx.save_for_backward(
            block_representations,
            pseudo_queries,
        )
        ctx.layers = layers
        ctx.eps = eps
        ctx.num_layer_params = len(flat_layer_params)

        return final_out

    @staticmethod
    def backward(ctx, *grad_outputs):
        grad_output = grad_outputs[0]
        if grad_output is None:
            return (None, None, None, None, *([None] * ctx.num_layer_params))

        block_representations, pseudo_queries = ctx.saved_tensors
        layers = ctx.layers
        eps = ctx.eps

        device = block_representations.device
        block_dtype = block_representations.dtype
        attn_dtype = torch.bfloat16

        grad_output = grad_output.contiguous()

        layer_param_groups = [tuple(layer.parameters()) for layer in layers]
        flat_layer_params = [p for group in layer_param_groups for p in group]

        param_offsets = []
        offset = 0
        for group in layer_param_groups:
            param_offsets.append(offset)
            offset += len(group)

        grad_flat_layer_params = [
            torch.zeros_like(p, dtype=torch.float32) if p.requires_grad else None
            for p in flat_layer_params
        ]

        grad_block_representations = torch.zeros_like(
            block_representations,
            dtype=torch.float32,
        )
        grad_pseudo_queries = torch.zeros_like(
            pseudo_queries,
            dtype=torch.float32,
        )

        grad_pseudo_queries_partial = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            D,
            device=device,
            dtype=torch.float32,
        )

        grad_phase2_pseudo_query_partial = torch.empty(
            B,
            T,
            D,
            device=device,
            dtype=torch.float32,
        )

        def run_layer_backward(layer_idx, layer_input_buf, grad_update_f32):
            params_i = layer_param_groups[layer_idx]
            active_param_indices = [
                j for j, p in enumerate(params_i) if p.requires_grad
            ]
            active_params = [params_i[j] for j in active_param_indices]

            with torch.enable_grad():
                layer_input = layer_input_buf.detach().requires_grad_(True)
                update = layers[layer_idx](layer_input)

                grad_results = torch.autograd.grad(
                    outputs=update,
                    inputs=(layer_input, *active_params),
                    grad_outputs=grad_update_f32.to(dtype=update.dtype),
                    retain_graph=False,
                    create_graph=False,
                    allow_unused=False,
                )

            grad_layer_input = grad_results[0]
            if grad_layer_input is None:
                grad_layer_input_f32 = torch.zeros_like(
                    layer_input_buf,
                    dtype=torch.float32,
                )
            else:
                grad_layer_input_f32 = grad_layer_input.to(torch.float32).contiguous()

            base = param_offsets[layer_idx]
            for local_idx, param_grad in zip(active_param_indices, grad_results[1:]):
                if param_grad is not None:
                    grad_flat_layer_params[base + local_idx].add_(
                        param_grad.to(torch.float32)
                    )

            return grad_layer_input_f32

        final_out_recomputed = torch.empty(
            1,
            B,
            T,
            D,
            device=device,
            dtype=attn_dtype,
        )
        final_lse = torch.empty(
            1,
            B,
            T,
            device=device,
            dtype=torch.float32,
        )

        with torch.no_grad():
            phase_1_batched_interblock_attention(
                block_representations,
                pseudo_queries[-1:],
                final_out_recomputed,
                final_lse,
                eps=eps,
            )

        phase_1_batched_interblock_attention_backward(
            block_representations,
            pseudo_queries[-1:],
            final_lse,
            grad_output.unsqueeze(0),
            None,
            grad_block_representations,
            grad_pseudo_queries[-1:],
            grad_pseudo_queries_partial[:1],
            eps=eps,
        )

        block_phase1_out_scratch = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            D,
            device=device,
            dtype=attn_dtype,
        )
        block_lse_scratch = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            device=device,
            dtype=torch.float32,
        )

        grad_block_phase1_out_scratch = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            D,
            device=device,
            dtype=torch.float32,
        )
        grad_block_lse_scratch = torch.empty(
            BLOCK_SIZE,
            B,
            T,
            device=device,
            dtype=torch.float32,
        )

        intrablock_partial_before_scratch = torch.empty(
            max(BLOCK_SIZE - 1, 1),
            B,
            T,
            D,
            device=device,
            dtype=block_dtype,
        )

        partial_recompute = torch.empty(
            B,
            T,
            D,
            device=device,
            dtype=block_dtype,
        )

        layer_input_tmp = torch.empty(
            B,
            T,
            D,
            device=device,
            dtype=attn_dtype,
        )

        grad_curr_partial = torch.empty(
            B,
            T,
            D,
            device=device,
            dtype=torch.float32,
        )
        grad_prev_partial = torch.empty_like(grad_curr_partial)

        last_block_start = ((L - 1) // BLOCK_SIZE) * BLOCK_SIZE

        for block_start in range(last_block_start, -1, -BLOCK_SIZE):
            curr_block_idx = block_start // BLOCK_SIZE + 1
            num_queries = min(BLOCK_SIZE, L - block_start)

            phase1_out = block_phase1_out_scratch[:num_queries]
            phase1_lse = block_lse_scratch[:num_queries]

            grad_phase1_out = grad_block_phase1_out_scratch[:num_queries]
            grad_phase1_lse = grad_block_lse_scratch[:num_queries]

            grad_phase1_out.zero_()
            grad_phase1_lse.zero_()

            with torch.no_grad():
                phase_1_batched_interblock_attention(
                    block_representations[:curr_block_idx],
                    pseudo_queries[block_start : block_start + num_queries],
                    phase1_out,
                    phase1_lse,
                    eps=eps,
                )

                for query_offset in range(num_queries):
                    layer_idx = block_start + query_offset

                    layer_input_tmp.copy_(phase1_out[query_offset])

                    if query_offset != 0:
                        intrablock_partial_before_scratch[query_offset - 1].copy_(
                            partial_recompute
                        )

                        phase_2_online_softmax_merge_intrablock(
                            intrablock_partial_before_scratch[query_offset - 1],
                            pseudo_queries[layer_idx],
                            layer_input_tmp,
                            phase1_lse[query_offset],
                            eps=eps,
                        )

                    update = layers[layer_idx](layer_input_tmp)

                    if query_offset == 0:
                        partial_recompute.copy_(update)
                    else:
                        partial_recompute.add_(update)

            grad_curr_partial.copy_(grad_block_representations[curr_block_idx])

            for query_offset in range(num_queries - 1, -1, -1):
                layer_idx = block_start + query_offset

                with torch.no_grad():
                    layer_input_tmp.copy_(phase1_out[query_offset])

                    if query_offset != 0:
                        phase_2_online_softmax_merge_intrablock(
                            intrablock_partial_before_scratch[query_offset - 1],
                            pseudo_queries[layer_idx],
                            layer_input_tmp,
                            phase1_lse[query_offset],
                            eps=eps,
                        )

                grad_layer_input = run_layer_backward(
                    layer_idx,
                    layer_input_tmp,
                    grad_curr_partial,
                )

                if query_offset == 0:
                    grad_phase1_out[query_offset].copy_(grad_layer_input)
                else:
                    grad_prev_partial.copy_(grad_curr_partial)

                    phase_2_online_softmax_merge_intrablock_backward(
                        intrablock_partial_before_scratch[query_offset - 1],
                        pseudo_queries[layer_idx],
                        phase1_out[query_offset],
                        phase1_lse[query_offset],
                        grad_layer_input,
                        grad_prev_partial,
                        grad_pseudo_queries[layer_idx],
                        grad_phase1_out[query_offset],
                        grad_phase1_lse[query_offset],
                        grad_phase2_pseudo_query_partial,
                        eps=eps,
                    )

                    grad_curr_partial, grad_prev_partial = (
                        grad_prev_partial,
                        grad_curr_partial,
                    )

            phase_1_batched_interblock_attention_backward(
                block_representations[:curr_block_idx],
                pseudo_queries[block_start : block_start + num_queries],
                phase1_lse,
                grad_phase1_out,
                grad_phase1_lse,
                grad_block_representations[:curr_block_idx],
                grad_pseudo_queries[block_start : block_start + num_queries],
                grad_pseudo_queries_partial[:num_queries],
                eps=eps,
            )

        grad_inputs = (
            grad_block_representations[0].to(block_dtype)
            if ctx.needs_input_grad[0]
            else None
        )

        grad_pseudo_queries_out = (
            grad_pseudo_queries.to(pseudo_queries.dtype)
            if ctx.needs_input_grad[1]
            else None
        )

        grad_flat_layer_params_out = []
        for j, (param, grad_param) in enumerate(
            zip(flat_layer_params, grad_flat_layer_params)
        ):
            needs_grad = ctx.needs_input_grad[4 + j]
            if not needs_grad or grad_param is None:
                grad_flat_layer_params_out.append(None)
            else:
                grad_flat_layer_params_out.append(grad_param.to(param.dtype))

        return (
            grad_inputs,
            grad_pseudo_queries_out,
            None,
            None,
            *grad_flat_layer_params_out,
        )


def production_forward(inputs, pseudo_queries, layers, eps=None):
    if eps is None:
        eps = EPS

    flat_layer_params = tuple(p for layer in layers for p in layer.parameters())

    return BlockwiseAttentionFunction.apply(
        inputs,
        pseudo_queries,
        layers,
        eps,
        *flat_layer_params,
    )


# @torch.compile(mode="max-autotune-no-cudagraphs")
def naive_attention_residual(pseudo_query, values):
    keys = F.rms_norm(values, (values.shape[-1],), eps=EPS)

    logits = torch.einsum("d, n b t d -> n b t", pseudo_query, keys)
    logits = logits - logits.max(dim=0, keepdim=True).values

    return torch.einsum(
        "n b t, n b t d -> b t d",
        logits.softmax(0),
        values,
    ).to(DTYPE)


def paper_forward(inputs, pseudo_queries, layers):
    inputs = inputs.to(torch.float32)
    pseudo_queries = pseudo_queries.to(torch.float32)

    blocks = [inputs]

    for i in range(len(layers)):
        outputs = naive_attention_residual(
            pseudo_queries[i],
            torch.stack(blocks, dim=0),
        )

        update = layers[i](outputs)

        if i % BLOCK_SIZE == 0:
            blocks.append(update)
        else:
            blocks[-1] = blocks[-1] + update

    return naive_attention_residual(
        pseudo_queries[-1],
        torch.stack(blocks, dim=0),
    )


# @torch.compile(mode="max-autotune-no-cudagraphs")
def phase_1_fn(query, value):
    query = query.to(torch.float32)
    value = value.to(torch.float32)

    D_ = value.shape[-1]

    squared_norm_sum = (value * value).sum(dim=-1)
    inverse_rms_norm = torch.rsqrt(squared_norm_sum / float(D_) + EPS)
    raw_dot = torch.einsum("nbtd,sd->nbts", value, query)
    logits = raw_dot * inverse_rms_norm.unsqueeze(-1)

    max_logits = logits.amax(dim=0)
    exp_weights = torch.exp(logits - max_logits.unsqueeze(0))
    exp_sum = exp_weights.sum(dim=0)

    weighted_sum = (exp_weights.unsqueeze(-1) * value.unsqueeze(3)).sum(dim=0)
    normalized = (weighted_sum / exp_sum[..., None]).permute(2, 0, 1, 3).contiguous()

    lse = (max_logits + torch.log(exp_sum)).permute(2, 0, 1).contiguous()

    h = normalized[0]
    return lse, normalized.to(torch.bfloat16), h


# @torch.compile(mode="max-autotune-no-cudagraphs")
def phase_2_fn(current_block_values, query_vector, prev_lse, prev_normalized):
    query_vector_f32 = query_vector.to(torch.float32)
    prev_normalized_f32 = prev_normalized.to(torch.float32)

    current_block_values_f32 = current_block_values.to(torch.float32)

    squared_norm_sum = (current_block_values_f32 * current_block_values_f32).sum(dim=-1)

    inverse_rms_norm = torch.rsqrt(
        squared_norm_sum / current_block_values_f32.shape[-1] + EPS
    )

    current_logit = (current_block_values_f32 @ query_vector_f32) * inverse_rms_norm

    merged_max = torch.maximum(prev_lse, current_logit)
    interblock_weight = torch.exp(prev_lse - merged_max)
    intrablock_weight = torch.exp(current_logit - merged_max)

    out = (
        interblock_weight[..., None] * prev_normalized_f32
        + intrablock_weight[..., None] * current_block_values_f32
    ) / (interblock_weight + intrablock_weight)[..., None]

    return out.to(torch.bfloat16)


def torch_compile_phases_forward(inputs, query_w, layers):
    blocks = [inputs]

    for i in range(len(layers)):
        offset = i % BLOCK_SIZE

        if offset == 0:
            values = torch.stack(blocks, dim=0)

            lse, normalized, h = phase_1_fn(query_w[i : i + BLOCK_SIZE], values)
            blocks.append(layers[i](h.to(inputs.dtype)))
        else:
            h = phase_2_fn(
                blocks[-1],
                query_w[i],
                lse[offset],
                normalized[offset],
            )

            blocks[-1] = blocks[-1] + layers[i](h.to(inputs.dtype))

    _, _, h = phase_1_fn(query_w[-1:], torch.stack(blocks, dim=0))
    return h.to(inputs.dtype)


class SwiGLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm = nn.RMSNorm(D, device=DEVICE, dtype=DTYPE, eps=EPS)
        self.linear1 = nn.Linear(D, D * 2, bias=False, device=DEVICE, dtype=DTYPE)
        self.linear2 = nn.Linear(D, D, bias=False, device=DEVICE, dtype=DTYPE)

    def forward(self, x):
        h1, gate = self.linear1(self.norm(x)).chunk(2, dim=-1)
        return self.linear2(F.silu(gate) * h1)


class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


def grad_targets(inputs, pseudo_queries, layers):
    params = tuple(p for layer in layers for p in layer.parameters() if p.requires_grad)
    return (inputs, pseudo_queries, *params)


def bench_fwd_bwd(fn, inputs, pseudo_queries, layers, grad_out, warmup=3, runs=10):
    targets = grad_targets(inputs, pseudo_queries, layers)

    for _ in range(warmup):
        out = fn(inputs, pseudo_queries, layers)
        torch.autograd.grad(
            outputs=out,
            inputs=targets,
            grad_outputs=grad_out,
            retain_graph=False,
            create_graph=False,
            allow_unused=False,
        )

    torch.cuda.synchronize()
    t0 = time.perf_counter()

    for _ in range(runs):
        out = fn(inputs, pseudo_queries, layers)
        torch.autograd.grad(
            outputs=out,
            inputs=targets,
            grad_outputs=grad_out,
            retain_graph=False,
            create_graph=False,
            allow_unused=True,
        )

    torch.cuda.synchronize()

    return (time.perf_counter() - t0) / runs * 1000


def collect_grads(fn, inputs, pseudo_queries, layers, grad_out):
    targets = grad_targets(inputs, pseudo_queries, layers)

    out = fn(inputs, pseudo_queries, layers)

    grads = torch.autograd.grad(
        outputs=out,
        inputs=targets,
        grad_outputs=grad_out,
        retain_graph=False,
        create_graph=False,
        allow_unused=False,
    )

    grads = [grad.detach().to(torch.float32) for grad in grads]
    return out.detach(), grads


def compare_grads(
    ref_name, ref_fn, test_name, test_fn, inputs, pseudo_queries, layers, grad_out
):
    ref_out, ref_grads = collect_grads(ref_fn, inputs, pseudo_queries, layers, grad_out)
    test_out, test_grads = collect_grads(
        test_fn, inputs, pseudo_queries, layers, grad_out
    )

    out_abs = (ref_out.to(torch.float32) - test_out.to(torch.float32)).abs()
    print(
        f"{test_name} vs {ref_name} output: "
        f"mean_abs={out_abs.mean()}, max_abs={out_abs.max()}"
    )

    for idx, (rg, tg) in enumerate(zip(ref_grads, test_grads)):
        if rg is None or tg is None:
            print(
                f"{test_name} grad[{idx}] vs {ref_name}: "
                f"None mismatch: ref_is_none={rg is None}, test_is_none={tg is None}"
            )
            continue

        diff = (rg - tg).abs()
        rel = diff / (rg.abs() + 1e-3)

        norm_rel = (rg - tg).norm() / (rg.norm() + 1e-12)

        rg_abs_avg = rg.abs().mean()
        tg_abs_avg = tg.abs().mean()

        print(
            f"{test_name} grad[{idx}] vs {ref_name}: "
            f"mean_abs={diff.mean()}, max_abs={diff.max()}, "
            f"mean_rel={rel.mean()}, max_rel={rel.max()}, "
            f"norm_rel={norm_rel}, "
            f"ref_abs_avg={rg_abs_avg}, test_abs_avg={tg_abs_avg}"
        )


def bench_backward_only(
    fn, inputs, pseudo_queries, layers, grad_out, warmup=3, runs=10
):
    targets = grad_targets(inputs, pseudo_queries, layers)

    for _ in range(warmup):
        out = fn(inputs, pseudo_queries, layers)
        torch.cuda.synchronize()

        torch.autograd.grad(
            outputs=out,
            inputs=targets,
            grad_outputs=grad_out,
            retain_graph=False,
            create_graph=False,
            allow_unused=False,
        )
        torch.cuda.synchronize()

    total = 0.0

    for _ in range(runs):
        out = fn(inputs, pseudo_queries, layers)
        torch.cuda.synchronize()

        t0 = time.perf_counter()
        torch.autograd.grad(
            outputs=out,
            inputs=targets,
            grad_outputs=grad_out,
            retain_graph=False,
            create_graph=False,
            allow_unused=True,
        )
        torch.cuda.synchronize()

        total += time.perf_counter() - t0

    return total / runs * 1000


def print_bench_group(title, args):
    print(title)
    for name, func in funcs_to_bench:
        fwd_bwd = bench_fwd_bwd(func, *args, grad_out)
        bwd = bench_backward_only(func, *args, grad_out)
        print(f"{name} fwd+bwd:  {fwd_bwd:.3f} ms")
        print(f"{name} bwd-only: {bwd:.3f} ms")
    print()


for i in range(1):
    inputs = torch.randn(
        B,
        T,
        D,
        device=DEVICE,
        dtype=DTYPE,
        requires_grad=True,
    )

    layers_swiglu = [SwiGLU() for _ in range(L)]
    layers_identity = [Identity() for _ in range(L)]

    pseudo_queries_zeros = torch.zeros(
        L + 1,
        D,
        device=DEVICE,
        dtype=DTYPE,
        requires_grad=True,
    )

    pseudo_queries_randn = torch.randn(
        L + 1,
        D,
        device=DEVICE,
        dtype=DTYPE,
        requires_grad=True,
    ) / math.sqrt(D)

    grad_out = torch.randn(
        B,
        T,
        D,
        device=DEVICE,
        dtype=DTYPE,
    )

    args_identity = (inputs, pseudo_queries_randn, layers_swiglu)

    args_swiglu_zeros = (inputs, pseudo_queries_zeros, layers_swiglu)
    args_swiglu_randn = (inputs, pseudo_queries_randn, layers_swiglu)

    funcs_to_bench = [
        ("torch_compile_phases_forward", torch_compile_phases_forward),
        ("production_forward", production_forward),
        ("paper_forward", paper_forward),
    ]

    random.shuffle(funcs_to_bench)

    # print_bench_group("identity / randn queries", args_identity)
    # print_bench_group("swiglu / zero queries", args_swiglu_zeros)
    # print_bench_group("swiglu / randn queries", args_swiglu_randn)

    compare_grads(
        "paper_forward",
        paper_forward,
        "production_forward",
        production_forward,
        *args_identity,
        grad_out,
    )

    # compare_grads(
    #     "paper_forward",
    #     paper_forward,
    #     "torch_compile_phases_forward",
    #     torch_compile_phases_forward,
    #     *args_identity,
    #     grad_out,
    # )
