Metadata-Version: 2.3
Name: jax-flash-attn2
Version: 0.0.3
Summary: Flash Attention Implementation with Multiple Backend Support and Sharding This module provides a flexible implementation of Flash Attention with support for different backends (GPU, TPU, CPU) and platforms (Triton, Pallas, JAX).
License: Apache-2.0
Keywords: JAX,Deep Learning,Machine Learning,XLA
Author: Erfan Zare Chavoshi
Author-email: erfanzare810@gmail.com
Requires-Python: >=3.10,<3.14
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3.9
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Dist: eformer (==0.0.15)
Requires-Dist: einops (>=0.8.0,<0.9.0)
Requires-Dist: jax (>=0.4.36)
Requires-Dist: jaxlib (>=0.4.36)
Requires-Dist: triton (>=3.2.0,<3.3.0)
Project-URL: Documentation, https://erfanzar.github.io/jax-flash-attn2
Project-URL: Homepage, https://github.com/erfanzar/jax-flash-attn2
Project-URL: Repository, https://github.com/erfanzar/jax-flash-attn2
Description-Content-Type: text/markdown

# JAX-Flash-Attention2

A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).

## Installation

```bash
pip install jax-flash-attn2
```

## Basic Usage

```python
import jax
import jax.numpy as jnp
import jax_flash_attn2 as jfa

# Initialize the FlashAttention module with desired configuration
flash_attention = jfa.FlashAttention(
 jfa.AttentionConfig(
  platform=jfa.Platform.TRITON,  # Options: TRITON, PALLAS, JAX
  backend=jfa.Backend.GPU,       # Options: GPU, TPU, CPU
 )
)

# Create sample inputs
batch_size, num_heads, seq_len, head_dim = 2, 4, 512, 64
query = jax.random.normal(jax.random.PRNGKey(0), (batch_size, num_heads * 4, seq_len, head_dim), "f2")
key = jax.random.normal(jax.random.PRNGKey(1), (batch_size, num_heads, seq_len, head_dim), "f2")
value = jax.random.normal(jax.random.PRNGKey(2), (batch_size, num_heads, seq_len, head_dim), "f2")

# Compute attention
output = flash_attention(
 query=query,
 key=key,
 value=value,
 causal=True  # Enable causal masking for decoder-only models
)

# output shape: (batch_size, num_heads, seq_len, head_dim)
```

## Advanced Usage

### With Attention Mask

```python
# Create an attention mask (1 = attend, 0 = mask)
attention_mask = jnp.ones((batch_size, 1, seq_len, seq_len))  # Allow full attention
# For example, mask the first 100 tokens from attending to the last 100 tokens
attention_mask = attention_mask.at[:, :, :100, -100:].set(0)

output = flash_attention(
 query=query,
 key=key,
 value=value,
 attention_mask=attention_mask,
 causal=False  # Using explicit mask instead of causal
)
```

### With Attention Bias

```python
# Create an attention bias
bias = jnp.zeros((batch_size, 1, seq_len, seq_len))
# Add position-dependent bias
for i in range(seq_len):
 for j in range(seq_len):
  bias = bias.at[:, :, i, j].set(1.0 / (1.0 + abs(i - j)))

output = flash_attention(
 query=query,
 key=key,
 value=value,
 bias=bias
)
```

### With Dropout

```python
output = flash_attention(
 query=query,
 key=key,
 value=value,
 dropout_prob=0.1,
 dropout_seed=42,
 causal=True
)
```

## Flax Modules with JFA2

Here's an example of integrating jax-flash-attn2 within a Transformer model implemented in Flax:

```python
import typing as tp
from functools import partial

import chex
import flax.nnx as nn
import jax
import jax.numpy as jnp

import jax_flash_attn2 as jfa


class JFAttention2(nn.Module):
 def __init__(
  self,
  hidden_size: int,
  head_dim: int,
  num_attention_heads: int,
  num_key_value_heads: int,
  dtype: jnp.dtype = jnp.float32,
  param_dtype: jnp.dtype = jnp.float32,
  precision: jax.lax.PrecisionLike = None,
  *,
  rngs: nn.Rngs = None,
 ):
  if rngs is None:
   rngs = nn.Rngs(0)
  self.dtype = dtype
  self.param_dtype = param_dtype
  self.precision = precision
  self.rngs = rngs

  self.hidden_size = hidden_size
  self.head_dim = head_dim
  self.num_attention_heads = num_attention_heads
  self.num_key_value_heads = num_key_value_heads

  self.num_key_value_groups = num_attention_heads // num_key_value_heads

  if self.num_key_value_groups == 1:
   assert num_attention_heads == num_key_value_heads

  linear_class = partial(
   nn.Linear,
   dtype=dtype,
   param_dtype=param_dtype,
   use_bias=False,
   kernel_init=jax.nn.initializers.normal(0.02),
   precision=precision,
   rngs=rngs,
  )
  self.q_proj = linear_class(hidden_size, num_attention_heads * self.head_dim)
  self.k_proj = linear_class(hidden_size, num_key_value_heads * self.head_dim)
  self.v_proj = linear_class(hidden_size, num_key_value_heads * self.head_dim)
  self.o_proj = linear_class(num_attention_heads * self.head_dim, hidden_size)

  config = jfa.AttentionConfig(platform=jfa.Platform.TRITON, backend=jfa.Backend.GPU)

  self.jfa2 = jfa.FlashAttention(config)

 def __call__(
  self,
  hidden_states: chex.Array,
  attention_mask: chex.Array,
  causal: bool = True,
 ) -> tp.Tuple[chex.Array, chex.Array]:
  batch_size, sequence_length = hidden_states.shape[:2]
  query_states, key_states, value_states = (
   self.q_proj(hidden_states),
   self.k_proj(hidden_states),
   self.v_proj(hidden_states),
  )
  qshape = (
   batch_size,
   sequence_length,
   self.num_attention_heads,
   self.head_dim,
  )
  kv_shape = (
   batch_size,
   sequence_length,
   self.num_key_value_heads,
   self.head_dim,
  )
  query_states = query_states.reshape(qshape)
  key_states = key_states.reshape(kv_shape)
  value_states = value_states.reshape(kv_shape)
  attn_output = self.jfa2.forward(
   query_states.astype(jnp.bfloat16),
   key_states.astype(jnp.bfloat16),
   value_states.astype(jnp.bfloat16),
   jnp.where(attention_mask, 0, jnp.finfo(query_states).min).astype(jnp.bfloat16),
   causal=causal,
  )
  attn_output = jnp.reshape(attn_output, (batch_size, sequence_length, -1))
  attn_output = self.o_proj(attn_output)
  return attn_output
```

## Platform-Specific Examples

### Using JAX Backend

```python
jax_flash_attn = jfa.FlashAttention(
 jfa.AttentionConfig(
  platform=jfa.Platform.JAX,
  backend=jfa.Backend.CPU,  # Works on any hardware
 )
)

output = jax_flash_attn(query, key, value)
```

### Using Pallas for TPU

```python
tpu_flash_attn = jfa.FlashAttention(
 jfa.AttentionConfig(
  platform=jfa.Platform.PALLAS,
  backend=jfa.Backend.TPU,
 )
)

output = tpu_flash_attn(query, key, value)
```

## Integration with JAX Transformations

```python
@jax.jit
def attention_forward(q, k, v, mask=None):
 return flash_attention(
  query=q,
  key=k,
  value=v,
  attention_mask=mask,
  causal=True
 )

# JIT-compiled function
fast_attention = attention_forward(query, key, value)

# With gradient computation
def loss_fn(q, k, v):
 attn_output = flash_attention(q, k, v, causal=True)
 return jnp.mean(attn_output)

grads = jax.grad(loss_fn)(query, key, value)
```

## Limitations

- Triton platform is only available on NVIDIA GPUs.
- Some platform-backend combinations are not supported (see table above).
- Custom attention masks are not yet supported (use bias instead).

## Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

## Citation

If you use this implementation in your research, please cite:

```bibtex
@software{jax_flash_attn2,
    title = {JAX Flash Attention 2.0},
    year = {2024},
    url = {https://github.com/erfanzar/jax-flash-attn2}
}
```

### refrence citations

```bibtex
@inproceedings{dao2022flashattention,
  title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
  author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2022}
}
@inproceedings{dao2023flashattention2,
  title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
  author={Dao, Tri},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2024}
}
```

## Acknowledgments And Refrences

1. All of kernels are copied from [`EasyDeL`](https://github.com/erfanzar/Easydel)

