Metadata-Version: 2.4
Name: RaanA
Version: 0.1.0
Summary: RaanA quantization algorithm
Author-email: Yongyi Yang <yongyi@umich.edu>
License: MIT
Project-URL: Homepage, https://github.com/FFTYYY/RaanA
Project-URL: Repository, https://github.com/FFTYYY/RaanA
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch
Requires-Dist: transformers>=4.46.2
Requires-Dist: datasets>=3.1.0
Requires-Dist: tqdm>=4.67.0
Requires-Dist: scipy>=1.14.1
Requires-Dist: scikit-learn>=1.5.2
Requires-Dist: pybind11>=2.6.0
Dynamic: license-file

# RaanA: A Fast, Flexible, and Data-Efficient Post-Training Quantization Algorithm

This repo contains the implementation of the paper "RaanA: A Fast, Flexible, and Data-Efficient Post-Training Quantization Algorithm".

## Installation

1. to install from pypi: `pip install raana`
2. to install from source: 
    ```sh
    pip install build 
    git clone https://github.com/FFTYYY/RaanA
    cd RaanA
    python -m build 
    ```
    the generated `.whl` files will be in `dist/`.

## Quick Start

```python
from transformers import AutoTokenizer, LlamaForCausalLM
from raana import quantize, zeroshot_calibration, trick_centralize, trick_norm_row

# initialize your model
model     = LlamaForCausalLM.from_pretrained(...)
tokenizer = AutoTokenizer.from_pretrained(...)

# quantization
quantized_model = quantize(
    model,                                              # the model to quantize
    b_candidates    = list(range(1,9)),                 # allowed bit-width
    calibrate_data  = zeroshot_calibration(tokenizer),  # use zero-shot calibration
    avg_bits        = 3.3,                              # average number of bits
)["model"]

# evaluate your model
evaluete(quantized_model)
```

### A Complete Example
Too run example quantization for llama2 on wikitext2 (and reproduce the result reported in the paper):
```
pip install raana
git clone https://github.com/FFTYYY/RaanA
cd RaanA/examples
python wikitext2.py --model=meta-llama/Llama-2-7b-hf --avgbits=3.3
```
See `examples/wikitext2.py` for a complete example usage.

## Detailed Usage

The entry point of `raana` is `ranna.quantize`.

```python
from torch.nn   import Module
from torch      import Tensor
from typing     import Callable

from raana.task_adaptor     import TaskAdaptor
from raana.rotations        import RandomRotation, default_rotation
from raana.tricks           import Trick
from raana.select_layers    import default_linear_selector
from raana.quantized_linear import default_weightbias_extractor, default_matmul
from raana.tricks           import trick_centralize, trick_norm_col

quantize(
    model               : Module,
    b_candidates        : list[float],
    calibrate_data      : TaskAdaptor,
    avg_bits            : float,
    linear_selector     : Callable[[Module], bool]          = default_linear_selector,
    rotation_maker      : Callable[[], RandomRotation]      = default_rotation,
    trick_makers        : list[Callable[[], Trick]]         = [trick_centralize, trick_norm_col],
    weightbias_extractor: Callable[[Module], tuple[Tensor, Tensor | None]] = default_weightbias_extractor,
    matmul              : Callable[[Tensor, Tensor, Tensor, int], Tensor]  = default_matmul,
)
```

### Required Arguments

**`model: torch.nn.Module`**
- The pytorch model to be quantized.
  
**`b_candidates: list[float]`**
- Candidate number of bits allowed for each layer.
- Can optionally float numbers in 0~1. If so, less-than-one-bit quantization will be enabled.
- Example: `[0.5, 0.75, 1, 2, 3, 4]`.

**`calibrate_data: raana.task_adaptor.TaskAdaptor`**
- The calibration data used for quantization.
- For language modeling tasks, can use `raana.task_adaptor.LMAdaptor( data: list[str], tokenizer: PreTrainedTokenizer)`
- For zero-shot calibration in language modeling, use `raana.zeroshot_calibration(tokenizer)`.
- For non-language modeling tasks, can write your own `TaskAdaptor` class. 

**`avg_bits: float`**

- Target average number of bits per quantized linear layer. The quantizer will search for the optimal bit allocation under this constraint.

### Optional Arguments

**`linear_selector: Callable[[torch.nn.Module], bool]`**
- A function to choose which sub-modules to quantize.
- There are different types of linear modules in different model implementations (e.g. some models use `nn.Linear` while others use `nn.Conv1d`), so we allow the user to use this function to specify which linear modules are to quantize.
- Default: selcte all `torch.nn.Linear` layers.


**`rotation_maker: Callable[[], raana.rotations.RandomRotation]`**
- A function to construct a random rotation.
- This parameter leaves flexibility for users to specify their own random rotation implementation.
- The default implementation is randomized Hadamard Transformation, as described in the paper. 
- The Hadamard Transformation used in the default parameter is simply a matrix multiplication with the Hadamard matrix generated by `scipy.linear.hadamard`. In order to minimize the dependency of `raana`, we don't use any GPU fast Hadamard kernels in the default implementation. The users are encouraged to install fast Hadamdard kernels themselves and pass them to the quantizer through this parameter.
- We encourage users to install the fast Hadamard implementation from [DAO-AILab](https://github.com/Dao-AILab/fast-hadamard-transform) and pass it to raana:
    ```python
    from torch import Tensor
    from fast_hadamard_transform import hadamard_transform
    from raana.rotations import PiecewiseHadamard

    def hadamard(X: Tensor):
        # normalize it by sqrt(d) to make it an orthornormal operator.
        return hadamard_transform(X) / (X.size(-1) ** 0.5) 

    quantize(
        ..., 
        rotation_maker = lambda: PiecewiseHadamard( hadamard = hadamard )
    )
    ```
- Default: randomized Hadamard transformation. Uses `scipy.linalg.hadamard` as the implementation of Hadamard Transformation.

**`trick_makers: list[Callable[[], raana.tricks.Trick]]`**
- List of functions to construct tricks. See the paper for the definition of "trick" here.
- Currently implemented four tricks: `trick_centralize`, `trick_pca`, `trick_norm_row`, `trick_norm_col`.
- Default: `[trick_centralize, trick_norm_col]`.

**`weightbias_extractor: Callable[[nn.Module], tuple[Tensor, Tensor | None]]`**
- A function to extract weight and bias matrices from a linear module and transform them into the standard size.
- The returned value of this function should be extracted `weight` and `bias` of the provided layer. `weight` should be a tensor of size `(d_in, d_out)`, and `bias` should be `None` or a tensor of size `(d_out, )`.
- Default: `lambda layer: (layer.weight.t().data, layer.bias.data)`

**`matmul: Callable[[Tensor, Tensor, Tensor, int], Tensor]`**
- A function to perform low-precision matrix multiplication.
- Since there are no official implementation for low-precision uint-float matrix multiplication implemnetation and we want to minimize the dependency of `raana`, we leave the implementation of matrix multiplication to users.
- The input parameters are `X, qW, rescale, B`. `X` is a float tensor, `qW` is a `B`-bit uint tensor and `rescale` is a float rescale tensor. This return value of this function should be equal to `(X@qW - ((2**B-1)/2.*X.sum(dim=-1)).view(-1,1)) * rescale.view(1,-1)`. 
- Default: transform everything to float32 and do standard matrix multiplication. Below is the default implementation.
    ```python
    def default_matmul(X: tc.Tensor, qW: tc.Tensor, rescale: tc.Tensor, B: int):
        dtype = X.dtype
        X       = X.to(tc.float32)
        rescale = rescale.to(tc.float32).view(1, -1)
        q_bias  = (float(2 ** B - 1) / 2. * X.sum(dim = -1)).view(-1, 1)
        Z = (X @ qW.to(tc.float32)) * rescale
        Z = Z - q_bias * rescale
        return Z.to(dtype)
    ```

### Returns
```python
{
    "model" : torch.nn.Module,  # quantized model
    "bits"  : list[int],        # allocated bitwidth per layer
    "losses": list[float]       # calibration loss per calibration data
}
```
