# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Test scatter+broadcast kernel.

Uses the example from KERN-2435: DP=4, TP=2, 8 GPUs distributing row_offsets.

  row_offsets = [0, 5, 12, 20, 28, 35, 40, 48, 56]
  Sequence lengths: [5, 7, 8, 8, 7, 5, 8, 8]

  Split by Replica (2 sequences each, reindexed from 0):
    Replica A (seq 0-1): [0, 5, 12]
    Replica B (seq 2-3): [0, 8, 16]
    Replica C (seq 4-5): [0, 7, 12]
    Replica D (seq 6-7): [0, 8, 16]

  Distribution (DP=4, TP=2):
    Replica A [0,5,12]  -> GPU 0, GPU 1
    Replica B [0,8,16]  -> GPU 2, GPU 3
    Replica C [0,7,12]  -> GPU 4, GPU 5
    Replica D [0,8,16]  -> GPU 6, GPU 7
"""

from layout import Idx, TileTensor, row_major
from std.collections import InlineArray
from std.math import ceildiv
from std.sys import size_of
from std.gpu.host import DeviceBuffer, DeviceContext
from std.testing import assert_true

from comm import Signal, MAX_GPUS
from comm.scatter import scatter
from comm.sync import enable_p2p

comptime dtype = DType.uint32


def _test_pull[
    ngpus: Int,
    dp_size: Int,
](expected: List[List[Scalar[dtype]]]) raises:
