Metadata-Version: 2.1
Name: mlstm_kernels
Version: 0.0.0
Summary: Fast and efficient mLSTM kernels for the parallel, recurrent and chunkwise form.
Author-email: Maximilian Beck <beck@ml.jku.at>, Korbinian Poeppel <poeppel@ml.jku.at>, Phillip Lippe <phillip.lippe@gmail.com>, Sebastian Boeck <sebastian.boeck@nx-ai.com>
License: NXAI XLSTM 7B COMMUNITY LICENSE AGREEMENT
        
        Preamble
        We are proud to present the xLSTM 7B model, demonstrating the strength of next-generation RNN-based large language models, delivering high-quality performance and fast inference speeds. While xLSTM 7B is freely available for open research and development, we believe that organizations significantly benefiting from our technology should contribute back. Our goal is to support research, small and medium-sized enterprises (SMEs), and open innovation, particularly in Europe, while ensuring that large enterprises who incorporate xLSTM 7B into commercial products or services fairly compensate the creators for their research and development efforts.
        Linz, December 12, 2024
        
        “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the xLSTM Materials set forth herein.
        
        “Documentation” means the specifications, manuals and documentation accompanying NXAI xLSTM 7B distributed by NXAI at https://github.com/NX-AI/.
        
        “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
        
        “NXAI xLSTM 7B” means the foundational large language models and software and
        algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by NXAI.
        
        “xLSTM Materials” means, collectively, NXAI's proprietary NXAI xLSTM 7B and Documentation (and any portion thereof) made available under this Agreement.
        
        “NXAI” or “we” means NXAI GmbH, Linz, Austria.
        
        By using or distributing any portion or element of the xLSTM Materials, you agree to be bound by this Agreement.
        
        1. License Rights and Redistribution.
        
            a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under NXAI's intellectual property embodied in the xLSTM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the xLSTM Materials.
            b. Redistribution and Use.
                i. If you distribute or make available the xLSTM Materials (or any derivative works thereof), or a product or service that uses any of them, including another AI model, you shall (A) provide a copy of this Agreement with any such xLSTM Materials; and (B) prominently display “Built with NXAI xLSTM” on a related website, user interface, blogpost, about page, or product documentation. If you use the xLSTM Materials to create, train, fine tune, or otherwise improve an AI model, which is distributed or made available, you shall also include “xLSTM” at the beginning of any such AI model name.
                ii. If you receive xLSTM Materials, or any derivative works thereof, from a Licensee as part of an integrated end user product, then Section 2 of this Agreement will not apply to you.
                iii. You must retain in all copies of the xLSTM Materials that you distribute the following attribution notice within a “Notice” text file distributed as a part of such copies: “NXAI xLSTM 7B is licensed under the NXAI xLSTM 7B Community License, Copyright © NXAI GmbH, All Rights Reserved.”
        
        2. Additional Commercial Terms. If (a) the Licensee, on a consolidated basis (including parent, subsidiaries, and affiliates), exceeds the annual revenue of one hundred million Euros (€100,000,000) or more, and (b) the Licensee incorporates xLSTM Material, in whole or in part, into a Commercial Product or Service, then the Licensee must obtain a commercial license from NXAI, which NXAI may grant to you in its sole discretion, and you are not authorized to exercise any of the rights under this Agreement unless or until NXAI otherwise expressly grants you such rights
        
        3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE XLSTM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND NXAI DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE XLSTM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE XLSTM MATERIALS AND ANY OUTPUT AND RESULTS.
        
        4. Limitation of Liability. IN NO EVENT WILL NXAI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF NXAI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
        
        5. Intellectual Property.
        
        a. No trademark licenses are granted under this Agreement, and in connection with the xLSTM Materials, neither NXAI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the xLSTM Materials or as set forth in this Section 5(a). NXAI hereby grants you a license to use “NXAI xLSTM” (the “Mark”) solely as required to comply with the last sentence of Section 1.b.i. All goodwill arising out of your use of the Mark will inure to the benefit of NXAI.
        
        b. Subject to NXAI's ownership of xLSTM Materials and derivatives made by or for NXAI, with respect to any derivative works and modifications of the xLSTM Materials that are made by you, as between you and NXAI, you are and will be the owner of such derivative works and modifications.
        
        c. If you institute litigation or other proceedings against NXAI or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the xLSTM Materials or NXAI xLSTM 7B outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless NXAI from and against any claim by any third party arising out of or related to your use or distribution of the xLSTM Materials.
        
        6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the xLSTM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. NXAI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the xLSTM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
        
        7. Governing Law and Jurisdiction. This Agreement shall be governed by and construed in accordance with the laws of the Republic of Austria, without regard to its conflict of laws principles. The courts located in Linz, Austria shall have exclusive jurisdiction over any disputes arising out of or in connection with this Agreement.
        
Project-URL: Homepage, https://github.com/NX-AI/mlstm_kernels
Project-URL: Issues, https://github.com/NX-AI/mlstm_kernels/issues
Keywords: mLSTM,xLSTM,LSTM,Transformer,Machine Learning,Deep Learning,State Space Models
Classifier: Programming Language :: Python :: 3
Classifier: Operating System :: OS Independent
Requires-Python: >=3.11
Description-Content-Type: text/markdown
License-File: LICENSE
License-File: AUTHORS

# mLSTM Kernels

This library provides fast and efficient mLSTM kernels for the parallel, recurrent and chunkwise form. We provide PyTorch and JAX wrappers for our kernels.

Paper coming soon! Stay tuned 📺🎧⏳✨

## Kernel Overview

This library contains three different types of kernels:

- `parallel`: Parallel kernels that process a sequence in parallel (like Attention).
- `chunkwise`: Chunkwise kernels, that process chunks of the sequence in parallel.
- `recurrent`: Recurrent step kernels for inference.

## Benchmark

Runtime comparison of mLSTM chunkwise kernel (triton) [`triton_limit_chunk`] and (triton XL) [`triton_xl_chunk`] against other baselines:

![xLSTM Figure](./res/plot_sequence_length_consttok_nh8_hd512_line.svg)

**Left**: Forward pass
**Right**: Forward and backward pass



## Usage PyTorch

### Available Kernels

You can view all available kernels for the mLSTM by calling

```python
from mlstm_kernels.torch import (
    get_available_mlstm_kernels,
    get_available_mlstm_sequence_kernels,
    get_available_mlstm_step_kernels,
)

print(get_available_mlstm_kernels())
print(get_available_mlstm_sequence_kernels())
print(get_available_mlstm_step_kernels())
```

and then use one of

```python
from mlstm_kernels.torch import (
    get_mlstm_kernel,
    get_mlstm_sequence_kernel,
    get_mlstm_step_kernel,
)
```
to access the specific kernel function.

### Direct Import

You can directly import the specific kernel for example the chunkwise `triton_limit_chunk` kernel via:

```python
from mlstm_kernels.torch.chunkwise import mlstm_chunkwise__limit_chunk
```

### Backend Module

For PyTorch we provide a backend module for an easy integration into existing architectures.

```python
from mlstm_kernels.torch.backend_module import mLSTMBackendConfig, mLSTMBackend
```

### Training Kernel Interface

This is the interface used for the chunkwise and parallel kernels.

```python
def mlstm_interface(
    q: torch.Tensor, # (B, NH, S, DHQK)
    k: torch.Tensor, # (B, NH, S, DHQK)
    v: torch.Tensor, # (B, NH, S, DHHV)
    i: torch.Tensor, # (B, NH, S)
    f: torch.Tensor, # (B, NH, S)
    c_initial: torch.Tensor = None, # (B, NH, DHQK, DHHV)
    n_initial: torch.Tensor = None, # (B, NH, DHQK)
    m_initial: torch.Tensor = None, # (B, NH, 1)
    return_last_states: bool = False,
    eps: float = 1e-6,
    autocast_kernel_dtype: torch.dtype = torch.bfloat16,
    chunk_size: int = 64,
    **kwargs,
) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
    # (B, NH, S, DHHV) | ((B, NH, S, DHHV), ((B, NH, DHQK, DHHV), (B, NH, DHQK), (B, NH)))
    """
    Returns:
        torch.Tensor: matH outputs (no n and m values, no last states)
        tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: matH, (matC_last, vecN_last, scaM_last)
    """
    pass

```

### Step Kernel interface

This is the interface for the mlstm step kernels.

```python
def mlstm_step_interface(
    q: torch.Tensor,  # (B, NH, DHQK)
    k: torch.Tensor,  # (B, NH, DHQK)
    v: torch.Tensor,  # (B, NH, DHHV)
    i: torch.Tensor,  # (B, NH, 1)
    f: torch.Tensor,  # (B, NH, 1)
    c: torch.Tensor,  # (B, NH, DHQK, DHHV)
    n: torch.Tensor,  # (B, NH, DHQK)
    m: torch.Tensor,  # (B, NH, 1)
    eps: float = 1e-6,
    **kwargs,
) -> tuple[
    torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]
]:  # vecH, (matC_state_new (B, NH, DHQK, DHHV), vecN_state_new (B, NH, DHQK), vecM_state_new (B, NH, 1))
    pass
```

## Usage JAX

The JAX module `mlstm_kernels.jax` mirrors the PyTorch module `mlstm_kernels.torch` and can be used in the same way.

We will also provide a backend module for Flax soon.

## Running the unit tests

The unit tests cross-check the different kernel implementations on numerical deviations for different dtypes.
You can run all of them with the following command:

```bash
pytest -s tests/torch
# make sure you are in a JAX GPU environment
pytest -s tests/jax
```

The `-s` disables the log capturing so you see the results directly on the command line.
Each test will log the outputs to a new folder with the timestamp as name in the `test_outputs/` directory.

Example:
Each test starts with the line
`Test chunkwise-triton_xl_chunk target=triton_chunkwise_xl_chunk vs. baseline=native_parallel_stablef_custbw with S=256, B=1, NH=2, DHQK=64, DHHV=128, DTYPE=torch.float32`.

This test tests the chunkwise triton kernel `triton_chunkwise_xl_chunk` against the `native_parallel_stablef_custbw` baseline and runs the `triton_chunkwise_xl_chunk` in dtype float32. It will compare the errors against the baseline in the same dtype (i.e. float32 here) and in float64 if specified.

## Citation

Our paper is currently under preparation. We will announce it soon.
In the meantime if you use this codebase, or otherwise find our work valuable, please use this citations:

```
@article{beck:25unlocking,
      title={Unlocking the Power of Recurrence for Efficient xLSTM Kernels},
      author={Maximilian Beck and Korbinian Pöppel and Sepp Hochreiter},
      booktitle = {Under preparation},
      year={2025},
}
@software{beck:24mlstmkernels,
  title  = {mLSTM Kernels: A Library for Efficient mLSTM Kernels},
  author = {Maximilian Beck and Korbinian Pöppel and Phillip Lippe},
  url    = {https://github.com/NXAI/mlstm_kernels},
  month  = dec,
  year   = {2024}
}
```
