Source code for cutlass.op.gemm

#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

"""
    Ease-of-use interface for constructing, compiling, and running GEMMs.

    The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run
    GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
    Under the hood, the interface will select sensible default parameters for the many template
    parameters for CUTLASS GEMMs.

    Note: optimal performance is not to be expected from this interface. To achieve optimal
    performance, one should specify and tune each configuration parameter.

    The simplest example of using this interface is the following:

    .. highlight:: python
    .. code-block:: python

        # A, B, C, and D are torch/numpy/cupy tensor objects
        plan = cutlass.op.Gemm(A, B, C, D)
        plan.run()


    One can also use the interface by specifying data types of operands at construction
    and using different tensor objects with these data types at runtime:

    .. highlight:: python
    .. code-block:: python

        # The following is shorthand for:
        #        cutlass.op.Gemm(element_A=torch.float32, element_B=torch.float32,
        #                        element_C=torch.float32, element_D=torch.float32,
        #                        element_accumulator=torch.float32,
        #                        layout=cutlass.LayoutType.RowMajor)
        plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)

        A0 = torch.rand((128, 256), device='cuda')
        B0 = torch.rand((256, 64), device='cuda')
        C0 = torch.zeros((128, 64), device='cuda')
        D0 = torch.zeros((128, 64), device.'cuda')
        plan.run(A0, B0, C0, D0)

        A = torch.rand((32, 128), device='cuda')
        B = torch.rand((128, 256), device='cuda')
        C = torch.zeros((32, 256), device='cuda')
        D = torch.zeros((32, 256), device.'cuda')
        plan.run(A1, B1, C1, D1)

    The interface additionally enables one to decouple the compilation of the underlying CUTLASS
    kernel from its execution:

    .. highlight:: python
    .. code-block:: python

        plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
        plan.compile()

        # Do other work...

        plan.run(A0, B0, C0, D0)

        # Do other work...

        plan.run(A1, B1, C1, D1)

    Elementwise activation functions are easily fused to the GEMM via the interface:

    .. highlight:: python
    .. code-block:: python

        plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
        plan.activation = cutlass.epilogue.relu

    Operations can also be run asynchronously:

    .. highlight:: python
    .. code-block:: python

        plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
        args = plan.run()

        # Do other work...

        args.sync()
"""

import cutlass_bindings

import cutlass
from cutlass import epilogue, swizzle
from cutlass.backend import compiler
from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal
from cutlass.backend.library import TensorDescription, TileDescription
from cutlass.op.op import OperationBase
from cutlass.utils import check, datatypes


[docs]class Gemm(OperationBase): """ Constructs a ``Gemm`` object. The data types and layouts of operands A, B, and C, along with the data type of output D and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime -- these are not to be changed after a ``Gemm`` has been constructed. The constructor has optional parameters for flexibly setting these parameters. The following constructors are equivalent: .. highlight:: python .. code-block:: python # Use F32 for A, B, C, D, and accumulation. All operands are row major. # Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts # for operands to the same values. Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor) # Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``. Gemm(element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32, element_C=cutlass.DataType.f32, element_D=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor) # Set the data types and elements from existing tensors. Note that one can use different tensors when # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must # have the same data type and layout as those passed in here). # A, B, C, and D are row-major torch.Tensor objects of type torch.float32 Gemm(A=A, B=B, C=C, D=D) # Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is # the same as that for D, at present) Gemm(element=cutlass.DataType.f32, layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor) # Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types # and layouts will inherit those passed in via the generic ``element`` and ``layout`` Gemm(element_A=cutlass.DataType.f32, layout_B=cutlass.LayoutType.RowMajor, element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor) The order of precedence for the setting of the data type and layout for a given operand/output is as follows: 1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor 2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those 3) Otherwise, use the generic values (e.g., ``element``, ``layout``) :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 :type cc: int :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 :type kernel_cc: int :param A: tensor representing data type and layout of operand A :param B: tensor representing data type and layout of operand B :param C: tensor representing data type and layout of operand C :param D: tensor representing data type and layout of operand D :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B :param beta: scalar parameter beta from GEMM operation that scales operand C :param element_accumulator: data type to be used in accumulation of the product of operands A and B :type element_accumulator: cutlass.DataType :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type :type element: cutlass.DataType :param layout: generic layout type to be used for operands A, B, C, and D :type layout: cutlass.LayoutType :param element_A: data type to be used for operand A :type element_A: cutlass.DataType :param element_B: data type to be used for operand B :type element_B: cutlass.DataType :param element_C: data type to be used for operand C :type element_C: cutlass.DataType :param element_D: data type to be used for operand D :type element_D: cutlass.DataType :type layout_A: layout of operand A :param layout_A: cutlass.LayoutType :type layout_B: layout of operand B :param layout_B: cutlass.LayoutType :type layout_C: layout of operand C :param layout_C: cutlass.LayoutType :type layout_D: layout of operand D :param layout_D: cutlass.LayoutType """ def __init__( self, A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0, element_accumulator=None, element=None, layout=None, element_A=None, element_B=None, element_C=None, element_D=None, layout_A=None, layout_B=None, layout_C=None, cc: int = None, kernel_cc: int = None ): super().__init__(cc=cc, kernel_cc=kernel_cc) self.name = "gemm" self.compiled = False elements = [] layouts = [] # Check that at least one of the following is set for each tensor (illustrated assuming tensor A): # ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout`` for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D], [layout_A, layout_B, layout_C, layout_C], [A, B, C, D], ["A", "B", "C", "D"]): if elt is not None and tens is not None: raise Exception(f'Must not specify both element_{name} and tensor {name}') if lay is not None and tens is not None: raise Exception(f'Must not specify both layout_{name} and tensor {name}') if elt is None and tens is None and element is None: raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.') if lay is None and tens is None and layout is None: raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.') elt_to_set = None lay_to_set = None if tens is not None: elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens) else: elt_to_set = elt if elt is not None else element lay_to_set = lay if lay is not None else layout elements.append(datatypes.library_type(elt_to_set)) layouts.append(datatypes.library_layout(lay_to_set)) self._element_a, self._element_b, self._element_c, self._element_d = elements self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts if element_accumulator is None: self._element_accumulator = self._element_c else: self._element_accumulator = datatypes.library_type(element_accumulator) self.A = A self.B = B self.C = C self.D = D self.alpha = alpha self.beta = beta self.epilogue_functor = None self.op_class = None self._reset_operations() self._swizzling_functor = cutlass.swizzle.IdentitySwizzle1 def _reset_operations(self, reset_epilogue: bool = True): # Set the default op class datatype_comb = (self._element_a, self._element_b, self._element_accumulator) layout_comb = (self._layout_a, self._layout_b) self.possible_op_classes = self.options.supporting_opclasses( self._element_a, self._element_b, self._element_accumulator, self._layout_a, self._layout_b) if cutlass.OpcodeClass.TensorOp in self.possible_op_classes: self.opclass = cutlass.OpcodeClass.TensorOp elif cutlass.OpcodeClass.Simt in self.possible_op_classes: self.opclass = cutlass.OpcodeClass.Simt else: raise Exception(f'No kernel configuration found for supported data type and layout ' f'combination {datatype_comb}x{layout_comb}') if reset_epilogue: self._reset_epilogue_functor_activation(epilogue.identity) def _reset_epilogue_functor_activation(self, activation): if self.epilogue_functor is None: if self.op_class == cutlass.OpcodeClass.Simt: elements_per_access = 1 else: elements_per_access = 128 // cutlass.DataTypeSize[self._element_c] else: elements_per_access = self.epilogue_functor.epilogue_vector_length if not self.specified_kernel_cc: if self.current_cc == 90 and activation != epilogue.identity: # CUTLASS 3.0 kernels currently only support identity activation. If one requests a non-identity activation, # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") self._reset_options(80) self._reset_operations(reset_epilogue=False) elif (self.cc == 90 and self.current_cc != 90 and activation == epilogue.identity): # SM80 fallback kernels are currently used. Since an identity activation is requested, # we can switch back to using SM90 kernels. self._reset_options(90) self._reset_operations(reset_epilogue=False) else: if self.current_cc == 90 and activation != epilogue.identity: raise Exception("Epilogues with elementwise fusion are not currently supported " "in the Python interface for 3.x kernels. To use 2.x kernels " "with fused elementwise epilogues, do not set the `kernel_cc` " "parameter when constructing the Gemm object.") self.epilogue_functor = epilogue.get_activation_epilogue( activation, datatypes.binding_type(self._element_c), elements_per_access, datatypes.binding_type(self._element_accumulator), datatypes.binding_type(self._element_accumulator), ) def _reset_epilogue_functor_alignment(self, alignment): if self.epilogue_functor is None or not hasattr(self.epilogue_functor, 'activation_functor'): activation = epilogue.identity else: activation = type(self.epilogue_functor.activation_functor) self.epilogue_functor = epilogue.get_activation_epilogue( activation, datatypes.binding_type(self._element_c), alignment, datatypes.binding_type(self._element_accumulator), datatypes.binding_type(self._element_accumulator), ) @property def activation(self): """ Returns the type of the current activation function used """ return type(self.epilogue_functor.activation_functor) @activation.setter def activation(self, act): """ Sets the type of the activation function to use """ self._reset_epilogue_functor_activation(act) @property def opclass(self) -> cutlass.OpcodeClass: """ Returns the opcode class currently in use by the GEMM :return: opcode class currently in use :rtype: cutlass.OpcodeClass """ return self.op_class @opclass.setter def opclass(self, oc: cutlass.OpcodeClass): """ Sets the opcode class to use in the GEMM. If the opcode class is not supported under the given compute capability and element/layout combinations of the GEMM, an exception is raised. """ if oc in self.possible_op_classes: self.op_class = oc else: raise Exception( f'Unsupported operation class {oc} for CC {self.cc} and data type combination ' f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and ' f'layout combination ({self._layout_a}, {self._layout_b}).') # Changing the op class changes the elements per access in the epilogue. Reset this. if self.op_class == cutlass.OpcodeClass.Simt: elements_per_access = 1 else: elements_per_access = 128 // cutlass.DataTypeSize[self._element_c] if self.epilogue_functor is not None: self._reset_epilogue_functor_alignment(elements_per_access) # Changing the op class also changes the possible operations available. Reset these. self.possible_operations = self.options.operations( self.op_class, self._element_a, self._element_b, self._element_accumulator, self._layout_a, self._layout_b) @property def swizzling_functor(self): """ Returns the type of the swizzling functor currently being used by the GEMM :return: swizzing functor type """ return self._swizzling_functor @swizzling_functor.setter def swizzling_functor(self, swizzling_functor): """ Sets the swizzling functor to the type specified by `swizzling_functor` """ if swizzling_functor == swizzle.ThreadblockSwizzleStreamK: if self.op_class == cutlass.OpcodeClass.Simt: raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp') if self.current_cc == 90: raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90') self._swizzling_functor = swizzling_functor def _valid_tile_description(self, td: TileDescription) -> tuple: """ Checks whether the provided tile description is valid for the given compute capability. At present, this checks the following: - Does the tile description use a number of stages supported by the compute capability in question? - Does the tile size requested fit within shared memory? - Are cluster dimensions outside the valid range requested for a given architecture (e.g., more non-unit cluster dimensions for pre-SM90 architectures)? - Is the kernel schedule being used supported on the architecture in question? :param td: tile description to validate :type td: cutlass.backend.TileDescription :return: tuple in which the first element is a bool indicating that the tile description is valid and the second element is a string providing an optional error message. :rtype: tuple """ # Check stage count based on the CC to which we are compiling (self.cc), rather # than the CC from which we find kernels (self.current_cc) valid, msg = check.valid_stage_count(self.cc, td) if not valid: return (valid, msg) valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape) if not valid: return (valid, msg) valid, msg = check.valid_kernel_schedule(self.current_cc, td.kernel_schedule) return valid, msg
[docs] def tile_descriptions(self) -> list: """ Returns a list of valid tile descriptions for the operations :returns: list of valid tile descriptions for the operations :rtype: list """ return [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
[docs] def construct( self, tile_description: TileDescription = None, alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal: """ Constructs a ``cutlass.backend.GemmUniversalOperation`` based on the input parameters and current kernel specification of the ``Gemm`` object. :param tile_description: tile description specifying shapes and operand types to use in the kernel :type tile_description: cutlass.backend.TileDescription :param alignment_A: alignment of operand A :type alignment_A: int :param alignment_B: alignment of operand B :type alignment_B: int :param alignment_C: alignment of operand C :type alignment_C: int :return: operation that was constructed :rtype: cutlass.backend.GemmOperationUniversal """ alignment_pref_A = min(128 // cutlass.DataTypeSize[self._element_a], max(self.possible_operations.alignments)) alignment_pref_B = min(128 // cutlass.DataTypeSize[self._element_b], max(self.possible_operations.alignments)) alignment_pref_C = min(128 // cutlass.DataTypeSize[self._element_c], max(self.possible_operations.alignments)) alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A) alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B) alignment_C = check.alignment_or_default(alignment_C, alignment_pref_C) self._reset_epilogue_functor_alignment(alignment_C) tensor_A = TensorDescription( datatypes.binding_type(self._element_a), datatypes.binding_layout(self._layout_a), alignment_A ) tensor_B = TensorDescription( datatypes.binding_type(self._element_b), datatypes.binding_layout(self._layout_b), alignment_B ) tensor_C = TensorDescription( datatypes.binding_type(self._element_c), datatypes.binding_layout(self._layout_c), alignment_C ) if tile_description is None: op = self.possible_operations.operations(alignment_A)[0] tile_description = datatypes.td_from_profiler_op(op) else: valid, err_str = self._valid_tile_description(tile_description) if not valid: raise Exception(f"Invalid tile description. {err_str}") self.tile_description = tile_description operation = GemmOperationUniversal( arch=self.current_cc, tile_description=tile_description, A=tensor_A, B=tensor_B, C=tensor_C, epilogue_functor=self.epilogue_functor, swizzling_functor=self._swizzling_functor, ) return operation
[docs] def compile(self, tile_description: TileDescription = None, alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, print_module: bool = False) -> cutlass.backend.GemmOperationUniversal: """ Emits and compiles the kernel currently specified. If ``tile_description`` and any of the ``alignment`` parameters are set, the kernel will be chosen using this tile description and alignments. Otherwise, a default tile description and alignment will be used. :param tile_description: tile description specifying shapes and operand types to use in the kernel :type tile_description: cutlass.backend.TileDescription :param alignment_A: alignment of operand A :type alignment_A: int :param alignment_B: alignment of operand B :type alignment_B: int :param alignment_C: alignment of operand C :type alignment_C: int :param print_module: whether to print the emitted C++ code :type print_module: bool :return: operation that was compiled :rtype: cutlass.backend.GemmOperationUniversal """ self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C) if print_module: print(self.operation.rt_module.emit()) compiler.add_module([self.operation,]) return self.operation
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): """ Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception is raised if it does not. :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in :type tensor: numpy/cupy/torch array/tensor object :param ref_dtype: data type for the tensor that this object was initialized to :param ref_layout: layout for the tensor that this object was initialized to :param name: identifier of the tensor to verify. Used in raising exceptions :type name: str """ dtype, layout = datatypes.get_datatype_and_layout(tensor) if dtype != ref_type or layout != ref_layout: raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) ' f'does not match the expected type and ' f'layout of ({ref_type}, {ref_layout}).') def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name): """ Verifies the following properties: 1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``) 2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``) If either of these properties does not hold, an exception is raised. If these properties hold and ``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned. :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in :type tensor: numpy/cupy/torch array/tensor object :param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in :type ref_tensor: numpy/cupy/torch array/tensor object :param ref_dtype: data type for the tensor that this object was initialized to :param ref_layout: layout for the tensor that this object was initialized to :param name: identifier of the tensor to verify. Used in raising exceptions :type name: str :return: valid tensor object to use :rtype: numpy/cupy/torch array/tensor object """ if tensor is None: if ref_tensor is None: raise Exception(f"Tensor {name} must be set.") return ref_tensor self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name) return tensor def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name): """ Verifies the following properties: 1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``) 2) If ``scalar`` is not ``None``, its datatype must match matches the current version set by the plan (i.e., those in ``ref_dtype``) If either of these properties does not hold, an exception is raised. If these properties hold and ``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned. :param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in :type scalar: numpy/cupy/torch scalar :param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in :type ref_scalar: numpy/cupy/torch scalar :param ref_dtype: data type for the scalar that this object was initialized to :param name: identifier of the scalar to verify. Used in raising exceptions :type name: str :return: valid scalar to use :rtype: numpy/cupy/torch scalar """ if scalar is None: if ref_scalar is None: raise Exception(f"Scalar {name} must be set.") return ref_scalar dtype = datatypes.library_type(scalar.dtype) if dtype != ref_dtype: raise Exception( f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}." ) return scalar
[docs] def run(self, A=None, B=None, C=None, D=None, alpha=None, beta=None, batch_count: int = 1, sync: bool = True, print_module: bool = False) -> GemmArguments: """ Runs the kernel currently specified. If it has not already been, the kernel is emitted and compiled. Tensors holding operands and outputs of the kernel are sourced either from the ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta`` parameters provided in this call, or from those passed in on the construction of this object -- one of the two must be specified. By default, this call returns only once the kernel has completed. To launch the kernel and immediately return, set ``sync=False``. In this case, it is the responsibility of the caller to syncrhonize the results of the kernel before attempting to access outputs by calling ``sync()`` on the arguments returned from this call. :param A: tensor representing data type and layout of operand A :param B: tensor representing data type and layout of operand B :param C: tensor representing data type and layout of operand C :param D: tensor representing data type and layout of operand D :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B :param beta: scalar parameter beta from GEMM operation that scales operand C :param batch_count: number of GEMMs in the batch :type batch_count: int :param sync: whether the call should wait for the kernel to complete before returning :type sync: bool :param print_module: whether to print the emitted C++ code :type print_module: bool :return: arguments passed in to the kernel :rtype: cutlass.backend.GemmArguments """ if batch_count < 1: raise Exception(f"Invalid batch count {batch_count}. Value must be an integer >= 1.") A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D") alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a) alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b) alignment_c = self.possible_operations.find_alignment(C.shape, self._layout_c) self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, alignment_C=alignment_c, print_module=print_module) problem_size = cutlass_bindings.gemm.GemmCoord(A.shape[0], B.shape[1], A.shape[1]) if batch_count == 1: mode = cutlass_bindings.gemm.Mode.Gemm kwargs = {'split_k_slices': 1} else: mode = cutlass_bindings.gemm.Mode.Batched kwargs = {'batch': batch_count} arguments = GemmArguments( operation=self.operation, problem_size=problem_size, A=A, B=B, C=C, D=D, output_op=self.operation.epilogue_type(alpha, beta), gemm_mode=mode, **kwargs ) self.operation.run(arguments) if sync: arguments.sync() return arguments