Metadata-Version: 2.3
Name: ejkernel
Version: 0.0.80.3
Summary: Accelerate, Optimize performance with streamlined training and serving options with JAX.
Keywords: Deep Learning,Machine Learning,JAX,CUDA,XLA,Triton,Pallas
Author: Erfan Zare Chavoshi
Author-email: Erfan Zare Chavoshi <Erfanzare810@gmail.com>
License: Apache-2.0
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Requires-Dist: beartype>=0.22.2
Requires-Dist: chex>=0.1.91
Requires-Dist: einops>=0.8.1
Requires-Dist: jax~=0.10.0
Requires-Dist: jaxlib~=0.10.0
Requires-Dist: jaxtyping>=0.3.2
Requires-Dist: pydantic>=2.11.10
Requires-Dist: tqdm>=4.67.1
Requires-Dist: jax[cuda13]~=0.10.0 ; extra == 'cuda'
Requires-Dist: jax-cuda13-plugin[with-cuda]~=0.10.0 ; extra == 'cuda'
Requires-Dist: jax-cuda13-pjrt~=0.10.0 ; extra == 'cuda'
Requires-Dist: triton==3.6.0 ; extra == 'cuda'
Requires-Dist: nvidia-cutlass-dsl[cu13]==4.5.2 ; extra == 'cuda'
Requires-Dist: nvidia-cutlass-dsl-libs-base==4.5.2 ; extra == 'cuda'
Requires-Dist: nvidia-cutlass-dsl-libs-cu13==4.5.2 ; extra == 'cuda'
Requires-Dist: jax-tvm-ffi==0.1.2 ; extra == 'cuda'
Requires-Dist: apache-tvm-ffi==0.1.8.post2 ; extra == 'cuda'
Requires-Dist: eformer ; extra == 'dev'
Requires-Dist: xprof>=2.20.6 ; extra == 'profile'
Requires-Dist: tb-nightly>=2.21.0a20250820 ; extra == 'profile'
Requires-Dist: xprof-nightly>=2.21.6a20250820 ; extra == 'profile'
Requires-Dist: tilelang==0.1.9 ; extra == 'tilelang'
Requires-Dist: apache-tvm-ffi==0.1.8.post2 ; extra == 'tilelang'
Requires-Dist: nvidia-cudnn-cu12==9.12.0.46 ; extra == 'tilelang'
Requires-Dist: jax[tpu]~=0.10.0 ; extra == 'tpu'
Requires-Dist: triton==3.6.0 ; extra == 'triton'
Requires-Python: >=3.11, <3.14
Project-URL: Documentation, https://ejkernel.readthedocs.io/en/latest/
Project-URL: Homepage, https://github.com/erfanzar/ejkernel
Project-URL: Repository, https://github.com/erfanzar/ejkernel
Provides-Extra: cuda
Provides-Extra: dev
Provides-Extra: profile
Provides-Extra: tilelang
Provides-Extra: tpu
Provides-Extra: triton
Description-Content-Type: text/markdown

# ejKernel: High-Performance JAX Kernels for Deep Learning

> _"The best optimization is the one you don't have to think about."_

[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/)
[![JAX](https://img.shields.io/badge/JAX-0.9.0+-orange.svg)](https://github.com/google/jax)
[![Documentation](https://img.shields.io/badge/docs-readthedocs-green.svg)](https://ejkernel.readthedocs.io/en/latest/)

ejKernel is a production-grade kernel library for JAX that provides highly optimized implementations of deep learning operations with automatic multi-backend support. The library features a sophisticated configuration management system with autotuning, comprehensive type safety, and seamless execution across GPUs, TPUs, and CPUs.

> [!NOTE]
> eJkernel contains **no AI-generated code**. All kernels, modules, and core logic are manually designed and implemented by human developers.
> AI tooling (Opus 4.5) is used **exclusively for documentation**, which may therefore contain minor inaccuracies. There is no “vibe coding” or automated code generation anywhere in the codebase.

## Table of Contents

- [Key Features](#key-features)
- [Installation](#installation)
- [Quick Start](#quick-start)
- [Architecture Overview](#architecture-overview)
- [Supported Operations](#supported-operations)
- [Advanced Usage](#advanced-usage)
- [Development](#development)
- [Testing](#testing)
- [Contributing](#contributing)
- [Citation](#citation)
- [License](#license)

## Key Features

### Intelligent Kernel Management

- **7-Tier Configuration System**: Override → Overlay → Memory Cache → Persistent Cache → Autotune → Heuristics → Error
- **Automatic Platform Detection**: Seamlessly selects optimal implementation based on hardware
- **Priority-Based Registry**: Multi-backend support with intelligent fallback mechanisms
- **Device Fingerprinting**: Hardware-specific configuration caching for optimal performance

### State-of-the-Art Operations

- **30+ Deep Learning Operations**: Flash Attention v2, Flash MLA, Ring Attention, Page Attention, Block Sparse, GLA, Lightning, Gated Delta Rule, Quantized MatMul, State Space Models (Mamba), RWKV (v4/v6/v7), and more
- **Memory Efficiency**: Custom VJP implementations with O(N) memory complexity for attention
- **Distributed Support**: Full shard_map integration for model and data parallelism
- **Mixed Precision**: Comprehensive dtype support with automatic gradient conversion

### Production-Ready Infrastructure

- **Type Safety**: Full jaxtyping annotations with runtime validation via beartype
- **Comprehensive Testing**: Cross-backend validation, performance benchmarks, integration tests
- **Atomic Persistence**: Thread-safe configuration storage with automatic optimization
- **Profiling Integration**: Built-in support for JAX profiling and performance monitoring

## Installation

### Basic Installation

```bash
pip install ejkernel
```

### Platform-Specific Installation

```bash
# GPU Support (CUDA)
pip install ejkernel[cuda]

# TPU Support
pip install ejkernel[tpu]

# Development Installation
git clone https://github.com/erfanzar/ejkernel.git
cd ejkernel
pip install -e ".[dev]"
```

### Dependencies

- Python 3.11-3.13
- JAX >= 0.9.0
- Triton == 3.6.0 (for GPU)
- nvidia-cutlass-dsl >= 4.4.0 (optional, for CuTe DSL kernels)
- jax-tvm-ffi == 0.1.2 (optional, for CuTe TVM-FFI primitive path)
- jaxtyping >= 0.3.2
- beartype >= 0.22.2
- pydantic >= 2.11.10

## Quick Start

### Simple API with Automatic Optimization

```python
import jax.numpy as jnp
from ejkernel.modules import flash_attention

# Basic usage - automatic configuration selection
output = flash_attention(
    query, key, value,
    causal=True,
    dropout_prob=0.1
)

# With advanced features
output = flash_attention(
    query, key, value,
    causal=True,
    sliding_window=128,        # Local attention window
    logits_soft_cap=30.0,      # Gemma-2 style soft capping
    attention_mask=mask,       # Custom attention pattern
)
```

### Custom Configuration

```python
from ejkernel.modules import FlashAttentionConfig
from ejkernel.ops.utils.datacarrier import FwdParams, BwdParams

# Create optimized configuration
config = FlashAttentionConfig(
    fwd_params=FwdParams(
        q_blocksize=256,
        kv_blocksize=256,
        num_warps=8,
        num_stages=2
    ),
    bwd_params=BwdParams(
        q_blocksize=128,
        kv_blocksize=128,
        num_warps=4
    ),
    platform="triton",  # Force specific backend
    backend="gpu"
)

output = flash_attention(query, key, value, cfg=config)
```

### Direct Kernel Registry Access

```python
from ejkernel import kernel_registry, Platform, Backend

# Get specific implementation
kernel = kernel_registry.get(
    algorithm="flash_attention",
    platform=Platform.TRITON,
    backend=Backend.GPU
)

# Direct execution
output = kernel(query, key, value, causal=True)
```

### Distributed Execution

```python
import jax
from jax.sharding import Mesh, PartitionSpec as P
from ejkernel.modules import flash_attention

# Setup mesh for distributed execution
devices = jax.devices()
mesh = Mesh(devices, axis_names=("data", "model"))

# Run distributed attention
output = flash_attention(
    query, key, value,
    causal=True,
    mesh=mesh,
    in_specs=(P("data", None), P("data", None), P("data", None)),
    out_specs=P("data", None)
)
```

## Architecture Overview

### System Design

ejKernel employs a sophisticated layered architecture that separates concerns while maintaining high performance:

```md
┌─────────────────────────────────────────────────────┐
│ Public API (modules/) │
│ Simple functions with sensible defaults │
├─────────────────────────────────────────────────────┤
│ Operations Layer (ops/) │
│ Configuration management, autotuning, caching │
├─────────────────────────────────────────────────────┤
│ Kernel Registry (kernels/) │
│ Platform routing, signature validation │
├─────────────────────────────────────────────────────┤
│ Backend Implementations (kernels/\_\*) │
│ Triton, CuTe, Pallas, XLA, CUDA kernels │
└─────────────────────────────────────────────────────┘
```

### Core Components

#### Kernel Registry

The registry provides automatic platform-specific kernel selection:

```python
@kernel_registry.register("my_operation", Platform.TRITON, Backend.GPU, priority=100)
def my_operation_gpu(x, y):
    # GPU-optimized implementation
    pass

@kernel_registry.register("my_operation", Platform.XLA, Backend.ANY, priority=50)
def my_operation_fallback(x, y):
    # Universal fallback
    pass

# Automatic selection based on available hardware
impl = kernel_registry.get("my_operation")
```

#### Configuration Management

Multi-tier configuration system with intelligent fallback:

```python
class ConfigSelectorChain:
    """
    Selection hierarchy:
    1. Override - Explicit user configuration
    2. Overlay - Temporary context overrides
    3. Memory Cache - In-memory lookup
    4. Persistent Cache - Disk-based storage
    5. Autotune - Performance benchmarking
    6. Heuristics - Intelligent defaults
    7. Error - Clear failure message
    """
```

#### Custom VJP System

All performance-critical kernels implement memory-efficient gradients:

```python
@jax.custom_vjp
def kernel_with_custom_grad(inputs):
    return forward(inputs)

def kernel_fwd(inputs):
    output, residuals = forward_with_residuals(inputs)
    return output, residuals

def kernel_bwd(residuals, grad_output):
    return efficient_backward(residuals, grad_output)

kernel_with_custom_grad.defvjp(kernel_fwd, kernel_bwd)
```

## Supported Operations

### Attention Mechanisms

| Algorithm                        | Description                      | Memory | Key Features                                                            |
| -------------------------------- | -------------------------------- | ------ | ----------------------------------------------------------------------- |
| **Flash Attention v2**           | Memory-efficient exact attention | O(N)   | Causal masking, dropout, sliding windows, soft capping                  |
| **Ring Attention**               | Distributed sequence parallelism | O(N/P) | Ultra-long sequences, communication overlap, XLA single-device fallback |
| **Page Attention**               | KV-cache optimized inference     | O(N)   | Block-wise memory, continuous batching                                  |
| **Block Sparse Attention**       | Configurable sparse patterns     | O(N√N) | Local+global, custom patterns                                           |
| **GLA**                          | Gated Linear Attention           | O(N)   | Linear complexity, gated updates                                        |
| **Lightning Attention**          | Layer-dependent decay            | O(N)   | Exponential moving average                                              |
| **MLA**                          | Multi-head Latent Attention      | O(N)   | Compressed KV representation                                            |
| **Ragged Page Attention v2**     | Variable-length paged attention  | O(N)   | Ragged sequences with page caching                                      |
| **Ragged Page Attention v3**     | Enhanced ragged page attention   | O(N)   | Attention sinks support, improved handling                              |
| **Ragged Decode Attention**      | Variable-length decoding         | O(N)   | Efficient batched inference                                             |
| **Gated Delta Rule (GDR)**       | Gated delta-rule recurrence      | O(N)   | Chunked + recurrent + single-step, custom VJP, Qwen3Next               |
| **Ragged GDR**                   | Packed continuous-batching GDR   | O(N)   | Variable-length sequences, Pallas TPU decode (3.6x speedup)            |
| **Kernel Delta Attention**       | Delta-rule linear attention      | O(N)   | Linear complexity, delta updates, decay control                         |
| **Unified Attention**            | vLLM-style paged attention       | O(N)   | Segmented 3D decode kernel                                              |
| **Prefill Page Attention**       | Page attention prefill phase     | O(N)   | Separate prefill handling                                               |
| **Decode Attention**             | Single-token decode attention    | O(N)   | Optimized single-step decoding                                          |
| **Chunked Prefill Paged Decode** | Combined prefill + decode        | O(N)   | Chunked prefill with paged KV cache decode                              |
| **Flash MLA**                    | Multi-head Latent Attention      | O(N)   | Low-rank KV compression, memory-efficient inference                     |
| **Scaled Dot-Product Attention** | Standard attention               | O(N²)  | Basic reference implementation                                          |

### Recurrent Linear Attention (RWKV)

| Operation      | Description                           | Key Features                                       |
| -------------- | ------------------------------------- | -------------------------------------------------- |
| **RWKV-4**     | Time-mix recurrence                   | Numerically stable (α,β,ε) state, O(N) memory      |
| **RWKV-6**     | Multi-head linear attention           | Variable-length packing, reverse mode, O(N) memory |
| **RWKV-7**     | DPLR (Diagonal + Low-Rank) recurrence | (a,b) parameterization, state-space inspired       |
| **RWKV-7 Mul** | Multiplicative RWKV-7 variant         | (kk,a) reparameterization for optimized kernels    |

### Other Operations

| Operation             | Description                                                 | Use Case                  |
| --------------------- | ----------------------------------------------------------- | ------------------------- |
| **Grouped MatMul**    | Efficient batched matrix operations                         | Expert models, MoE        |
| **Grouped MatMul v2** | Enhanced with shard_map support                             | Distributed expert models |
| **Mean Pooling**      | Variable-length sequence aggregation                        | Sentence embeddings       |
| **Recurrent**         | Optimized RNN/LSTM/GRU operations                           | Sequential modeling       |
| **Native Sparse**     | Block-sparse matrix computations                            | Sparse attention patterns |
| **Quantized MatMul**  | Multi-mode quantized matmul (affine, NF4, MXFP4/8, NVFP4/8) | Low-bit inference         |

### State Space Models

| Operation          | Description      | Key Features                                                               |
| ------------------ | ---------------- | -------------------------------------------------------------------------- |
| **State Space v1** | Mamba1-style SSM | 2D A matrix, separate dt_proj, custom VJP for memory efficiency            |
| **State Space v2** | Mamba2-style SSM | Per-head scalar A, n_groups for parameter grouping, optional gated RMSNorm |

### Platform Support Matrix

| Operation                    | Triton (GPU) | CUTE (GPU) | CUDA (GPU) | Pallas (TPU) | XLA (Universal) |
| ---------------------------- | ------------ | ---------- | ---------- | ------------ | --------------- |
| Flash Attention v2           | ✅           | ✅         | ✅         | ✅           | ✅              |
| Flash MLA                    | ✅           | -          | -          | -            | ✅              |
| Ring Attention               | ✅           | -          | -          | ✅           | ✅              |
| Page Attention               | ✅           | -          | -          | ✅           | ✅              |
| Block Sparse Attention       | ✅           | -          | ✅         | ✅           | ✅              |
| Decode Attention             | ✅           | -          | -          | -            | ✅              |
| Chunked Prefill Paged Decode | ✅           | ✅         | -          | -            | ✅              |
| Ragged Page Attention v2     | ✅           | -          | -          | ✅           | ✅              |
| Ragged Page Attention v3     | ✅           | -          | ✅         | ✅           | ✅              |
| Ragged Decode Attention      | ✅           | -          | -          | ✅           | ✅              |
| GLA                          | ✅           | -          | -          | -            | ✅              |
| Lightning Attention          | ✅           | -          | -          | -            | ✅              |
| Recurrent                    | ✅           | -          | -          | -            | ✅              |
| Mean Pooling                 | ✅           | -          | -          | -            | ✅              |
| Grouped MatMul               | -            | -          | -          | ✅           | ✅              |
| Grouped MatMul v2            | -            | -          | -          | ✅           | -               |
| Native Sparse Attention      | ✅           | -          | -          | -            | ✅              |
| Quantized MatMul             | ✅           | ✅         | ✅         | ✅           | ✅              |
| Gated Delta Rule             | -            | -          | -          | ✅           | ✅              |
| Ragged Gated Delta Rule     | -            | -          | -          | ✅           | ✅              |
| Kernel Delta Attention       | -            | -          | -          | -            | ✅              |
| Unified Attention            | ✅           | ✅         | ✅         | -            | ✅              |
| Prefill Page Attention       | -            | -          | -          | ✅           | ✅              |
| Scaled Dot-Product Attention | -            | -          | -          | -            | ✅              |
| State Space v1               | -            | -          | -          | -            | ✅              |
| State Space v2               | -            | -          | -          | -            | ✅              |
| RWKV-4                       | ✅           | -          | -          | -            | ✅              |
| RWKV-6                       | ✅           | -          | -          | -            | ✅              |
| RWKV-7                       | ✅           | -          | -          | -            | ✅              |
| RWKV-7 Mul                   | ✅           | -          | -          | -            | ✅              |

✅ = Production ready | - = Not available

\* CuTe backend uses TVM-FFI primitive path with fused kernels. \* Quantized MatMul on TPU uses hybrid dispatch (packed Pallas / predecode / XLA fallback). \* Distributed matmul ops (`all_gather_matmul`, `reduce_scatter_matmul`) intentionally do not perform runtime fallback between distributed backends; choose `platform`/`cfg.platform` explicitly.

## Advanced Usage

### Page Attention for KV-Cache Inference

```python
from ejkernel.modules import page_attention, PageAttentionConfig

# Configure paged attention for inference
config = PageAttentionConfig(
    platform="auto",
    backend="gpu"
)

output = page_attention(
    query=q,
    key_cache=k_cache,
    value_cache=v_cache,
    block_table=block_table,
    cache_seqlens=cache_seqlens,
    cfg=config
)
```

### Ragged Page Attention for Variable-Length Batches

```python
from ejkernel.modules import ragged_page_attention_v3, RaggedPageAttentionv3Config

# For variable-length sequences with attention sinks
config = RaggedPageAttentionv3Config(
    platform="pallas",
    backend="tpu"
)

output = ragged_page_attention_v3(
    query=q,
    key_pages=k_pages,
    value_pages=v_pages,
    lengths=seq_lengths,
    page_indices=page_indices,
    cfg=config
)
```

### Performance Optimization

```python
# Force autotuning for optimal configuration
import os
os.environ["EJKERNEL_AUTOTUNE_POLICY"] = "autotune"
os.environ["EJKERNEL_LOG_AUTOTUNE"] = "1"

# Enable profiling
os.environ["EJKERNEL_OPS_STAMP"] = "json"  # Detailed metadata
os.environ["EJKERNEL_OPS_RECORD"] = "1"    # Record invocations
```

### Custom Kernel Development

```python
from ejkernel.ops.core import Kernel
from ejkernel.modules.operations.configs import BaseOperationConfig
from dataclasses import dataclass

@dataclass
class MyConfig(BaseOperationConfig):
    param1: int = 128
    param2: float = 0.1

class MyKernel(Kernel[MyConfig, Array]):
    def __init__(self):
        super().__init__(op_id="my_kernel")

    def run(self, x, cfg: MyConfig):
        impl = kernel_registry.get("my_kernel", cfg.platform)
        return impl(x, param1=cfg.param1, param2=cfg.param2)

    def heuristic_cfg(self, inv):
        # Return default configuration
        return MyConfig(param1=256)

    def candidate_cfgs(self, inv):
        # Return autotuning candidates
        return [MyConfig(param1=p) for p in [64, 128, 256]]
```

### Integration with Flax Models

```python
import flax.linen as nn
from ejkernel.modules import flash_attention

class TransformerBlock(nn.Module):
    num_heads: int = 8
    head_dim: int = 64

    @nn.compact
    def __call__(self, x, mask=None):
        # Project to Q, K, V
        q = nn.Dense(self.num_heads * self.head_dim)(x)
        k = nn.Dense(self.num_heads * self.head_dim)(x)
        v = nn.Dense(self.num_heads * self.head_dim)(x)

        # Reshape for attention
        shape = (x.shape[0], x.shape[1], self.num_heads, self.head_dim)
        q, k, v = map(lambda t: t.reshape(shape), (q, k, v))

        # Apply ejKernel Flash Attention
        attn_output = flash_attention(
            q, k, v,
            causal=True,
            attention_mask=mask
        )

        # Project output
        return nn.Dense(x.shape[-1])(attn_output.reshape(x.shape))
```

## Development

### Setting Up Development Environment

```bash
# Clone repository
git clone https://github.com/erfanzar/ejkernel.git
cd ejkernel

# Create virtual environment
python -m venv .venv
source .venv/bin/activate  # On Windows: .venv\Scripts\activate

# Install in development mode
pip install -e ".[dev]"

# Install pre-commit hooks
pre-commit install
```

### Code Style

The project uses:

- **black** for code formatting (line length: 121)
- **ruff** for linting
- **mypy/pyright** for type checking
- **pre-commit** for automated checks

### Adding New Kernels

1. **Implement the kernel** in appropriate backend directory:

```python
# ejkernel/kernels/_triton/my_kernel/_interface.py
@kernel_registry.register("my_kernel", Platform.TRITON, Backend.GPU)
def my_kernel_triton(x, config):
    # Implementation
    pass
```

1. **Create module wrapper**:

```python
# ejkernel/modules/operations/my_kernel.py
class MyKernel(Kernel[MyKernelConfig, Array]):
    # Module implementation
    pass
```

1. **Add tests**:

```python
# test/kernels/_triton/test_my_kernel.py
class TestMyKernel(unittest.TestCase):
    # Test implementation
    pass
```

1. **Update documentation**

## Testing

### Running Tests

```bash
# Run all tests
pytest test/

# Platform-specific tests
pytest test/kernels/_xla/          # XLA implementations
pytest test/kernels/_triton/       # Triton implementations
pytest test/kernels/_pallas/       # Pallas implementations

# Specific test patterns
pytest -k "flash_attention"
pytest --verbose --failfast

# Module operations tests
pytest test/modules/operations
```

### Test Categories

- **Unit Tests**: Individual component testing
- **Integration Tests**: End-to-end workflows
- **Comparison Tests**: Cross-backend consistency
- **Performance Tests**: Regression detection

## Benchmarking

Run benchmarks to compare performance across backends:

```bash
# General attention benchmarks
python benchmarks/benchmark_attention.py

# Flash attention benchmarks
python benchmarks/benchmark_flash_attention.py

# Ragged page attention benchmarks
python benchmarks/benchmark_ragged_page_attention_v3.py
```

## Contributing

We welcome contributions!

### Priority Areas

- TPU/Pallas implementations for existing algorithms
- CUDA native kernels for maximum performance
- New attention mechanisms from recent papers
- Performance optimizations and kernel fusion
- Documentation and examples

### Contribution Process

1. Fork the repository
1. Create a feature branch
1. Implement your changes with tests
1. Ensure all tests pass
1. Submit a pull request

## Documentation

Comprehensive documentation available at [ejkernel.readthedocs.io](https://ejkernel.readthedocs.io/en/latest/)

- **[API Reference](https://ejkernel.readthedocs.io/en/latest/api/)**: Complete API documentation
- **[Tutorials](https://ejkernel.readthedocs.io/en/latest/tutorials/)**: Step-by-step guides
- **[Architecture](https://ejkernel.readthedocs.io/en/latest/architecture/)**: Design documentation
- **[Benchmarks](https://ejkernel.readthedocs.io/en/latest/benchmarks/)**: Performance analysis

## Citation

If you use ejKernel in your research, please cite:

```bibtex
@software{ejkernel2025,
  author = {Erfan Zare Chavoshi},
  title = {ejKernel: High-Performance JAX Kernels for Deep Learning},
  year = {2025},
  url = {https://github.com/erfanzar/ejkernel},
  note = {Production-grade kernel library with multi-backend support}
}
```

## License

ejKernel is licensed under the Apache License 2.0. See [LICENSE](LICENSE) for details.

## Acknowledgments

ejKernel builds upon excellent work from:

- [JAX](https://github.com/google/jax) - Composable transformations of Python+NumPy programs
- [Triton](https://github.com/openai/triton) - GPU kernel programming language
- [Pallas](https://github.com/google/jax/tree/main/jax/experimental/pallas) - JAX kernel language
- [Flash Attention](https://github.com/Dao-AILab/flash-attention) - Memory-efficient attention
- [EasyDeL](https://github.com/erfanzar/EasyDeL) - Parent framework for JAX deep learning

## Community

- **GitHub Issues**: [Bug reports and feature requests](https://github.com/erfanzar/ejkernel/issues)
- **Discussions**: [Community forum](https://github.com/erfanzar/ejkernel/discussions)
- **Email**: <Erfanzare810@gmail.com>

---

**ejKernel** - Production-grade kernels for JAX deep learning
