Metadata-Version: 2.3
Name: quantizeutils
Version: 0.1.1
Summary: quantization utility modules to bridge torch fx and PT2E quantized models, as well as ONNX and others, inspired by methods in mmdeploy, without the outdated dependencies and some features not found in it.
License: GPL-3.0-or-later
Author: Elisa Aleman
Author-email: elisa.claire.aleman.carreon@gmail.com
Requires-Python: >=3.12
Classifier: License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Requires-Dist: ai-edge-torch (==0.4.0)
Requires-Dist: onnx (>=1.18.0,<2.0.0)
Requires-Dist: onnxruntime-gpu (>=1.22.0,<2.0.0)
Requires-Dist: tensorflow (==2.19.1)
Requires-Dist: torch (==2.6.0)
Requires-Dist: torchvision (==0.21.0)
Description-Content-Type: text/markdown

# quantizeutils


Quantization utility modules I used on my [About Quantization](https://github.com/elisa-aleman/ai_python_dev_reference/blob/main/docs/ai_development/About-Quantization.md) guide. 

## Installation

```sh
# @ shell

pip install quantizeutils

# or

poetry add quantizeutils
```

## Usage

### Pre and Post Process FX traced models before QAT

- **`quantizeutils.fx.utils.pre_procecss.propagate_split_share_qparams_pre_process()`**
    - torch.fx.trace() produces weirdly shared quantization parameters when torch.split() is present in the graph. This function fixes that.

```python
import torch
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.observer import MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver
from torch.ao.quantization.fx.tracer import QuantizationTracer
from torch.fx import GraphModule
from torch.ao.quantization.fx import prepare
from torch.ao.quantization.backend_config import get_native_backend_config

from quantizeutils.fx.utils.pre_process import propagate_split_share_qparams_pre_process

class ExampleModel(torch.nn.Module):
    '''
    Expects mnist input of shape (batch,1,28,28)
    '''
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(1,30,1,1)
        self.spconv1 = torch.nn.Conv2d(15,5,1,1)
        self.spconv2 = torch.nn.Conv2d(15,5,1,1)
        # self.pool = torch.nn.AdaptiveMaxPool2d((1,1))
        # not supported once quantized, so replace with manual MaxPool2d
        self.pool = torch.nn.MaxPool2d(28,28)
        self.fc = torch.nn.Linear(10, 10)
    def forward(self,x):
        x = self.conv(x)
        y,z = torch.split(x,2)
        y = self.spconv1(y)
        z = self.spconv2(z)
        x = torch.cat([y,z], dim=1)
        x = self.pool(x)
        x = torch.squeeze(x, dim=2)
        x = torch.squeeze(x, dim=2)
        x = self.fc(x)
        return x


model = ExampleModel()

# define the qconfig_mapping
qconfig_mapping = QConfigMapping().set_global(
        QConfig(
            activation=MovingAverageMinMaxObserver.with_args(
                dtype=torch.quint8,
                qscheme=torch.per_tensor_affine,
                ),
            weight=MovingAveragePerChannelMinMaxObserver.with_args(
                dtype=torch.qint8,
                qscheme=torch.per_channel_symmetric,
                ),
            )
        )

# FX trace the model
tracer = QuantizationTracer(skipped_module_names=[], skipped_module_classes=[])
graph = tracer.trace(model)
traced_fx = GraphModule(tracer.root, graph, 'ExampleModel')
print(traced_fx)
'''
ExampleModel(
  (conv): Conv2d(1, 30, kernel_size=(1, 1), stride=(1, 1))
  (spconv1): Conv2d(15, 5, kernel_size=(1, 1), stride=(1, 1))
  (spconv2): Conv2d(15, 5, kernel_size=(1, 1), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=28, stride=28, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=10, out_features=10, bias=True)
)



def forward(self, x):
    conv = self.conv(x);  x = None
    split = torch.functional.split(conv, 2, dim = 0);  conv = None
    getitem = split[0]
    getitem_1 = split[1];  split = None
    spconv1 = self.spconv1(getitem);  getitem = None
    spconv2 = self.spconv2(getitem_1);  getitem_1 = None
    cat = torch.cat([spconv1, spconv2], dim = 1);  spconv1 = spconv2 = None
    pool = self.pool(cat);  cat = None
    squeeze = torch.squeeze(pool, dim = 2);  pool = None
    squeeze_1 = torch.squeeze(squeeze, dim = 2);  squeeze = None
    fc = self.fc(squeeze_1);  squeeze_1 = None
    return fc

# To see more debug info, please use `graph_module.print_readable()`
'''

# FX prepare the quantization nodes
example_inputs = (torch.randn(1,1,28,28),)
backend_config = get_native_backend_config()
prepared_fx = prepare(
    traced_fx,
    qconfig_mapping=qconfig_mapping,
    node_name_to_scope=tracer.node_name_to_scope,
    is_qat=True, # convenient even if not QAT
    example_inputs=example_inputs,
    backend_config=backend_config,
    )
print(prepared_fx)
'''
GraphModule(
  (activation_post_process_0): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (conv): Conv2d(
    1, 30, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_1): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_2): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_4): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (spconv1): Conv2d(
    15, 5, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_3): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (spconv2): Conv2d(
    15, 5, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_5): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_6): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (pool): MaxPool2d(kernel_size=28, stride=28, padding=0, dilation=1, ceil_mode=False)
  (activation_post_process_7): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_8): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_9): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (fc): Linear(
    in_features=10, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_10): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
)



def forward(self, x):
    activation_post_process_0 = self.activation_post_process_0(x);  x = None
    conv = self.conv(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_1 = self.activation_post_process_1(conv);  conv = None
    split = torch.functional.split(activation_post_process_1, 2, dim = 0);  activation_post_process_1 = None
    getitem = split[0]
    activation_post_process_2 = self.activation_post_process_2(getitem);  getitem = None
    getitem_1 = split[1];  split = None
    activation_post_process_4 = self.activation_post_process_4(getitem_1);  getitem_1 = None
    spconv1 = self.spconv1(activation_post_process_2);  activation_post_process_2 = None
    activation_post_process_3 = self.activation_post_process_3(spconv1);  spconv1 = None
    spconv2 = self.spconv2(activation_post_process_4);  activation_post_process_4 = None
    activation_post_process_5 = self.activation_post_process_5(spconv2);  spconv2 = None
    cat = torch.cat([activation_post_process_3, activation_post_process_5], dim = 1);  activation_post_process_3 = activation_post_process_5 = None
    activation_post_process_6 = self.activation_post_process_6(cat);  cat = None
    pool = self.pool(activation_post_process_6);  activation_post_process_6 = None
    activation_post_process_7 = self.activation_post_process_7(pool);  pool = None
    squeeze = torch.squeeze(activation_post_process_7, dim = 2);  activation_post_process_7 = None
    activation_post_process_8 = self.activation_post_process_8(squeeze);  squeeze = None
    squeeze_1 = torch.squeeze(activation_post_process_8, dim = 2);  activation_post_process_8 = None
    activation_post_process_9 = self.activation_post_process_9(squeeze_1);  squeeze_1 = None
    fc = self.fc(activation_post_process_9);  activation_post_process_9 = None
    activation_post_process_10 = self.activation_post_process_10(fc);  fc = None
    return activation_post_process_10

# To see more debug info, please use `graph_module.print_readable()`
'''

propagate_split_share_qparams_pre_process(prepared_fx, backend_config)
print(prepared_fx)
'''
GraphModule(
  (activation_post_process_0): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (conv): Conv2d(
    1, 30, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_1): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (spconv1): Conv2d(
    15, 5, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_3): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (spconv2): Conv2d(
    15, 5, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_5): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_6): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (pool): MaxPool2d(kernel_size=28, stride=28, padding=0, dilation=1, ceil_mode=False)
  (activation_post_process_7): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_8): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (activation_post_process_9): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  (fc): Linear(
    in_features=10, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process_10): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
)



def forward(self, x):
    activation_post_process_0 = self.activation_post_process_0(x);  x = None
    conv = self.conv(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_1 = self.activation_post_process_1(conv);  conv = None
    split = torch.functional.split(activation_post_process_1, 2, dim = 0);  activation_post_process_1 = None
    getitem = split[0]
    activation_post_process_2 = self.activation_post_process_1(getitem);  getitem = None
    getitem_1 = split[1];  split = None
    activation_post_process_4 = self.activation_post_process_1(getitem_1);  getitem_1 = None
    spconv1 = self.spconv1(activation_post_process_2);  activation_post_process_2 = None
    activation_post_process_3 = self.activation_post_process_3(spconv1);  spconv1 = None
    spconv2 = self.spconv2(activation_post_process_4);  activation_post_process_4 = None
    activation_post_process_5 = self.activation_post_process_5(spconv2);  spconv2 = None
    cat = torch.cat([activation_post_process_3, activation_post_process_5], dim = 1);  activation_post_process_3 = activation_post_process_5 = None
    activation_post_process_6 = self.activation_post_process_6(cat);  cat = None
    pool = self.pool(activation_post_process_6);  activation_post_process_6 = None
    activation_post_process_7 = self.activation_post_process_7(pool);  pool = None
    squeeze = torch.squeeze(activation_post_process_7, dim = 2);  activation_post_process_7 = None
    activation_post_process_8 = self.activation_post_process_7(squeeze);  squeeze = None
    squeeze_1 = torch.squeeze(activation_post_process_8, dim = 2);  activation_post_process_8 = None
    activation_post_process_9 = self.activation_post_process_7(squeeze_1);  squeeze_1 = None
    fc = self.fc(activation_post_process_9);  activation_post_process_9 = None
    activation_post_process_10 = self.activation_post_process_10(fc);  fc = None
    return activation_post_process_10

# To see more debug info, please use `graph_module.print_readable()`
'''
```

- **`quantizeutils.fx.utils.pre_procecss.relu_clamp_backend_config_unshare_observers()`**
    - ReLU and torch.clamp use shared observers in the torch native backend config (default). This expands the quantization min and max unnecessarily keeping, for example, min values below 0 on ReLU nodes and wasting quantization
    scaling space that is not needed. This function fixes that if applied before FX tracing.


```python

from typing import Any
import torch
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.observer import MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver
from torch.ao.quantization.fx.tracer import QuantizationTracer
from torch.fx import GraphModule
from torch.ao.quantization.fx import prepare
from torch.ao.quantization.backend_config import get_native_backend_config

from quantizeutils.fx.utils.pre_process import relu_clamp_backend_config_unshare_observers

class ExampleModel(torch.nn.Module):
    '''
    Expects mnist input of shape (batch,1,28,28)
    '''
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1,32,1,1)
        self.conv2 = torch.nn.Conv2d(32,10,1,1)
        # self.pool = torch.nn.AdaptiveMaxPool2d((1,1))
        # not supported once quantized, so replace with manual MaxPool2d
        self.pool = torch.nn.MaxPool2d(28,28)
        self.fc = torch.nn.Linear(10, 10)
    def forward(self,x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = self.pool(x)
        x = torch.squeeze(x, dim=2)
        x = torch.squeeze(x, dim=2)
        x = self.fc(x)
        return x

model = ExampleModel()

# define the qconfig_mapping
qconfig_mapping = QConfigMapping().set_global(
        QConfig(
            activation=MovingAverageMinMaxObserver.with_args(
                dtype=torch.quint8,
                qscheme=torch.per_tensor_affine,
                ),
            weight=MovingAveragePerChannelMinMaxObserver.with_args(
                dtype=torch.qint8,
                qscheme=torch.per_channel_symmetric,
                ),
            )
        )

# FX trace the model
tracer = QuantizationTracer(skipped_module_names=[], skipped_module_classes=[])
graph = tracer.trace(model)
traced_fx = GraphModule(tracer.root, graph, 'ExampleModel')
print(traced_fx)
'''
ExampleModel(
  (conv1): Conv2d(1, 32, kernel_size=(1, 1), stride=(1, 1))
  (conv2): Conv2d(32, 10, kernel_size=(1, 1), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=28, stride=28, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=10, out_features=10, bias=True)
)



def forward(self, x):
    conv1 = self.conv1(x);  x = None
    relu = torch.nn.functional.relu(conv1, inplace = False);  conv1 = None
    conv2 = self.conv2(relu);  relu = None
    relu_1 = torch.nn.functional.relu(conv2, inplace = False);  conv2 = None
    pool = self.pool(relu_1);  relu_1 = None
    squeeze = torch.squeeze(pool, dim = 2);  pool = None
    squeeze_1 = torch.squeeze(squeeze, dim = 2);  squeeze = None
    fc = self.fc(squeeze_1);  squeeze_1 = None
    return fc

# To see more debug info, please use `graph_module.print_readable()`
'''

# Don't fuse just to prove what happens to nodes that can't fuse with ReLU
# FX prepare the quantization nodes
example_inputs = (torch.randn(1,1,28,28),)
backend_config = get_native_backend_config()
backend_config = relu_clamp_backend_config_unshare_observers(backend_config)

prepared_fx = prepare(
    traced_fx,
    qconfig_mapping=qconfig_mapping,
    node_name_to_scope=tracer.node_name_to_scope,
    is_qat=True, # convenient even if not QAT
    example_inputs=example_inputs,
    backend_config=backend_config,
    )
# pass some input to see the observed scales
prepared_fx(example_inputs[0])
print(prepared_fx)
'''
GraphModule(
  (activation_post_process_0): MovingAverageMinMaxObserver(min_val=-2.873756170272827, max_val=3.512624740600586)
  (conv1): Conv2d(
    1, 32, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.2244,  0.1366,  0.5535,  0.4781,  0.9987, -0.8585,  0.2149,  0.2896,
              -0.3051, -0.2861,  0.4546,  0.6375, -0.9563, -0.2443,  0.9397,  0.4525,
              -0.8703,  0.0118,  0.7989,  0.4656,  0.8642, -0.8372, -0.6900, -0.2179,
              -0.9575,  0.1994,  0.9602,  0.8782,  0.1776, -0.9443, -0.2989,  0.3896]), max_val=tensor([-0.2244,  0.1366,  0.5535,  0.4781,  0.9987, -0.8585,  0.2149,  0.2896,
              -0.3051, -0.2861,  0.4546,  0.6375, -0.9563, -0.2443,  0.9397,  0.4525,
              -0.8703,  0.0118,  0.7989,  0.4656,  0.8642, -0.8372, -0.6900, -0.2179,
              -0.9575,  0.1994,  0.9602,  0.8782,  0.1776, -0.9443, -0.2989,  0.3896])
    )
  )
  (activation_post_process_1): MovingAverageMinMaxObserver(min_val=-3.461930274963379, max_val=3.918088436126709)
  (activation_post_process_2): MovingAverageMinMaxObserver(min_val=0.0, max_val=3.918088436126709)
  (conv2): Conv2d(
    32, 10, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.1602, -0.1757, -0.1724, -0.1654, -0.1460, -0.1729, -0.1722, -0.1635,
              -0.1721, -0.1758]), max_val=tensor([0.1669, 0.1637, 0.1765, 0.1631, 0.1765, 0.1700, 0.1761, 0.1572, 0.1726,
              0.1767])
    )
  )
  (activation_post_process_3): MovingAverageMinMaxObserver(min_val=-2.4537484645843506, max_val=1.3463588953018188)
  (activation_post_process_4): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (pool): MaxPool2d(kernel_size=28, stride=28, padding=0, dilation=1, ceil_mode=False)
  (activation_post_process_5): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (activation_post_process_6): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (activation_post_process_7): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (fc): Linear(
    in_features=10, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.2888, -0.2640, -0.3119, -0.3112, -0.2534, -0.3156, -0.2953, -0.3095,
              -0.2894, -0.2766]), max_val=tensor([0.2896, 0.2684, 0.2644, 0.2437, 0.3056, 0.3141, 0.3016, 0.2660, 0.2929,
              0.3111])
    )
  )
  (activation_post_process_8): MovingAverageMinMaxObserver(min_val=-0.6559169888496399, max_val=1.2133853435516357)
)



def forward(self, x):
    activation_post_process_0 = self.activation_post_process_0(x);  x = None
    conv1 = self.conv1(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_1 = self.activation_post_process_1(conv1);  conv1 = None
    relu = torch.nn.functional.relu(activation_post_process_1, inplace = False);  activation_post_process_1 = None
    activation_post_process_2 = self.activation_post_process_2(relu);  relu = None
    conv2 = self.conv2(activation_post_process_2);  activation_post_process_2 = None
    activation_post_process_3 = self.activation_post_process_3(conv2);  conv2 = None
    relu_1 = torch.nn.functional.relu(activation_post_process_3, inplace = False);  activation_post_process_3 = None
    activation_post_process_4 = self.activation_post_process_4(relu_1);  relu_1 = None
    pool = self.pool(activation_post_process_4);  activation_post_process_4 = None
    activation_post_process_5 = self.activation_post_process_5(pool);  pool = None
    squeeze = torch.squeeze(activation_post_process_5, dim = 2);  activation_post_process_5 = None
    activation_post_process_6 = self.activation_post_process_6(squeeze);  squeeze = None
    squeeze_1 = torch.squeeze(activation_post_process_6, dim = 2);  activation_post_process_6 = None
    activation_post_process_7 = self.activation_post_process_7(squeeze_1);  squeeze_1 = None
    fc = self.fc(activation_post_process_7);  activation_post_process_7 = None
    activation_post_process_8 = self.activation_post_process_8(fc);  fc = None
    return activation_post_process_8

# To see more debug info, please use `graph_module.print_readable()`
'''
```


- **`quantizeutils.fx.utils.post_process.fuse_qat_bn_post_process()`**
    - Prepares QAT unfused nodes (for example batch normalization) before exporting to ONNX
    
```python
from quantizeutils.fx.utils.post_process import fuse_qat_bn_post_process

print(prepared_model)
'''
ExampleModel(
  (quant): QuantStub(
    (activation_post_process): MovingAverageMinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (dequant): DeQuantStub()
  (conv1): ConvBnReLU2d(
    1, 32, kernel_size=(1, 1), stride=(1, 1)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([ 0.9805,  0.9878, -0.9778, -0.9910, -0.9607,  1.0020,  0.9854, -0.9932,
               0.9920, -1.0192,  0.9573,  0.9959,  1.0661, -0.9587, -0.4509,  0.8841,
               0.9185,  0.9696, -0.9722, -1.0099,  1.0207, -1.0131,  0.9228,  0.9731,
              -0.8032, -0.9803, -0.9691,  1.0209, -0.9520,  1.0132,  1.0179, -1.0628],
             device='cuda:0'), max_val=tensor([ 0.9805,  0.9878, -0.9778, -0.9910, -0.9607,  1.0020,  0.9854, -0.9932,
               0.9920, -1.0192,  0.9573,  0.9959,  1.0661, -0.9587, -0.4509,  0.8841,
               0.9185,  0.9696, -0.9722, -1.0099,  1.0207, -1.0131,  0.9228,  0.9731,
              -0.8032, -0.9803, -0.9691,  1.0209, -0.9520,  1.0132,  1.0179, -1.0628],
             device='cuda:0')
    )
    (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=3.486288547515869)
  )
  (bn1): Identity()
  (act1): Identity()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): ConvBnReLU2d(
    32, 64, kernel_size=(1, 1), stride=(1, 1)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-2.1960, -0.2812, -1.7402, -0.8100, -3.3399, -1.0657, -0.1177, -0.6993,
              -0.2671, -4.9234, -0.5415, -3.2926, -0.1593, -0.3427, -0.4231, -3.2403,
              -0.7386, -2.6828, -5.1171, -3.6965, -0.2139, -9.6442, -0.3108, -2.1643,
              -0.7892, -2.2819, -2.0034, -0.4834, -0.3995, -4.6650, -1.5611, -1.3696,
              -3.9522, -0.3022, -1.5632, -0.4557, -5.8931, -0.4400, -1.2626, -0.5098,
              -3.3187, -0.3899, -0.4554, -2.8338, -0.4487, -0.2008, -1.1349, -5.3991,
              -4.4046, -0.4110, -1.2552, -0.5631, -0.3380, -2.7315, -2.2920, -2.1396,
              -0.4084, -6.2974, -1.1824, -0.1679, -2.1181, -0.8331, -1.1392, -5.4736],
             device='cuda:0'), max_val=tensor([2.1027, 0.2706, 2.1971, 0.7181, 2.7690, 0.9632, 0.1437, 0.6662, 0.2961,
              3.6605, 0.5448, 2.8305, 0.1597, 0.3670, 0.4135, 4.4315, 0.7591, 3.2193,
              4.7957, 4.5160, 0.2183, 8.5344, 0.3707, 2.4349, 1.2237, 2.4649, 2.5588,
              0.4111, 0.5481, 5.3878, 1.8818, 0.9792, 3.3811, 0.3097, 0.9306, 0.5589,
              5.2913, 0.3896, 0.8146, 0.8410, 2.5406, 0.2423, 0.4662, 3.2510, 0.4216,
              0.2331, 0.7295, 5.5472, 3.2039, 0.3767, 1.2058, 0.4847, 0.2246, 2.3556,
              2.5913, 1.7738, 0.4303, 6.8638, 0.9290, 0.2386, 2.2515, 1.1236, 0.7875,
              4.1441], device='cuda:0')
    )
    (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=7.873152732849121)
  )
  (bn2): Identity()
  (act2): Identity()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fl1): Flatten(start_dim=1, end_dim=-1)
  (fc1): LinearReLU(
    in_features=3136, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.1110, -0.1682, -0.1074, -0.0716, -0.1544, -0.1686, -0.1512, -0.2297,
              -0.1243, -0.1656], device='cuda:0'), max_val=tensor([0.1568, 0.1229, 0.1314, 0.0929, 0.1936, 0.1594, 0.1684, 0.1468, 0.1751,
              0.1768], device='cuda:0')
    )
    (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=52.048946380615234)
  )
  (fc1act): Identity()
  (fc2): Linear(
    in_features=10, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.4732, -0.7787, -0.3168, -0.1518, -0.6355, -0.3758, -0.5703, -0.3268,
              -0.3168, -0.5046], device='cuda:0'), max_val=tensor([0.3203, 0.4310, 0.3799, 0.3980, 0.3096, 0.2491, 0.4260, 0.2836, 0.3210,
              0.3214], device='cuda:0')
    )
    (activation_post_process): MovingAverageMinMaxObserver(min_val=-21.06275177001953, max_val=15.43133544921875)
  )
)
'''

qconfig = QConfig(
    activation=MovingAverageMinMaxObserver.with_args(
        dtype=torch.quint8,
        qscheme=torch.per_tensor_affine,
        ),
    weight=MovingAveragePerChannelMinMaxObserver.with_args(
        dtype=torch.qint8,
        qscheme=torch.per_channel_symmetric,
        ),
    )
device='cuda:0'

fuse_qat_bn_post_process(
    prepared_model,
    qconfig,
    device,
    update_weight_with_fakequant=False,
    keep_w_fake_quant=True)
print(prepared_model)
'''
ExampleModel(
  (quant): QuantStub(
    (activation_post_process): MovingAverageMinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (dequant): DeQuantStub()
  (conv1): ConvReLU2d(
    1, 32, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([ 0.9773,  0.9846, -0.9746, -0.9878, -0.9576,  0.9987,  0.9822, -0.9899,
               0.9888, -1.0159,  0.9542,  0.9927,  1.0627, -0.9556, -0.4495,  0.8815,
               0.9156,  0.9664, -0.9691, -1.0067,  1.0174, -1.0098,  0.9199,  0.9699,
              -0.8006, -0.9771, -0.9660,  1.0176, -0.9489,  1.0099,  1.0146, -1.0593],
             device='cuda:0'), max_val=tensor([ 0.9773,  0.9846, -0.9746, -0.9878, -0.9576,  0.9987,  0.9822, -0.9899,
               0.9888, -1.0159,  0.9542,  0.9927,  1.0627, -0.9556, -0.4495,  0.8815,
               0.9156,  0.9664, -0.9691, -1.0067,  1.0174, -1.0098,  0.9199,  0.9699,
              -0.8006, -0.9771, -0.9660,  1.0176, -0.9489,  1.0099,  1.0146, -1.0593],
             device='cuda:0')
    )
  )
  (bn1): Identity()
  (act1): Identity()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): ConvReLU2d(
    32, 64, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-2.2004, -0.2815, -1.7259, -0.8108, -3.3225, -1.0663, -0.1178, -0.6998,
              -0.2673, -4.9116, -0.5416, -3.3015, -0.1594, -0.3430, -0.4228, -3.2384,
              -0.7388, -2.6793, -5.0133, -3.6836, -0.2141, -9.6123, -0.3111, -2.1644,
              -0.7847, -2.2629, -1.9944, -0.4836, -0.3984, -4.6541, -1.5561, -1.3597,
              -3.9317, -0.3023, -1.5552, -0.4559, -5.8850, -0.4401, -1.2523, -0.5099,
              -3.3166, -0.3899, -0.4554, -2.8402, -0.4490, -0.2010, -1.1246, -5.3832,
              -4.3736, -0.4115, -1.2560, -0.5640, -0.3383, -2.7325, -2.2894, -2.1422,
              -0.4086, -6.1985, -1.1782, -0.1680, -2.1212, -0.8303, -1.1334, -5.4791],
             device='cuda:0'), max_val=tensor([2.1069, 0.2708, 2.1790, 0.7187, 2.7546, 0.9638, 0.1438, 0.6667, 0.2964,
              3.6516, 0.5449, 2.8382, 0.1598, 0.3673, 0.4133, 4.4289, 0.7594, 3.2151,
              4.6985, 4.5002, 0.2185, 8.5062, 0.3711, 2.4350, 1.2167, 2.4443, 2.5474,
              0.4113, 0.5465, 5.3751, 1.8757, 0.9722, 3.3635, 0.3099, 0.9258, 0.5592,
              5.2839, 0.3897, 0.8079, 0.8413, 2.5389, 0.2423, 0.4662, 3.2584, 0.4219,
              0.2332, 0.7228, 5.5310, 3.1814, 0.3771, 1.2065, 0.4855, 0.2248, 2.3564,
              2.5883, 1.7760, 0.4306, 6.7560, 0.9257, 0.2388, 2.2547, 1.1199, 0.7835,
              4.1483], device='cuda:0')
    )
  )
  (bn2): Identity()
  (act2): Identity()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fl1): Flatten(start_dim=1, end_dim=-1)
  (fc1): LinearReLU(
    in_features=3136, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.1110, -0.1682, -0.1074, -0.0716, -0.1544, -0.1686, -0.1512, -0.2297,
              -0.1243, -0.1656], device='cuda:0'), max_val=tensor([0.1568, 0.1229, 0.1314, 0.0929, 0.1936, 0.1594, 0.1684, 0.1468, 0.1751,
              0.1768], device='cuda:0')
    )
  )
  (fc1act): Identity()
  (fc2): Linear(
    in_features=10, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.4732, -0.7787, -0.3168, -0.1518, -0.6355, -0.3758, -0.5703, -0.3268,
              -0.3168, -0.5046], device='cuda:0'), max_val=tensor([0.3203, 0.4310, 0.3799, 0.3980, 0.3096, 0.2491, 0.4260, 0.2836, 0.3210,
              0.3214], device='cuda:0')
    )
  )
)
'''
```


- **`quantizeutils.fx.utils.post_process.merge_relu_clamp_to_qparams_post_process`**
    - Some modules like Conv+ReLU will fuse automatically in the native backend but remain unfused if exported to ONNX or other backends. This function merges the ReLU and `torch.clamp` node activations to the previous node as part of their q_min and q_max, instead of relying on a secondary node.

```python
from quantizeutils.fx.utils.post_process import merge_relu_clamp_to_qparams_post_process

merge_relu_clamp_to_qparams_post_process(prepared_fx)
print(prepared_fx)
'''
GraphModule(
  (activation_post_process_0): MovingAverageMinMaxObserver(min_val=-2.873756170272827, max_val=3.512624740600586)
  (conv1): Conv2d(
    1, 32, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.2244,  0.1366,  0.5535,  0.4781,  0.9987, -0.8585,  0.2149,  0.2896,
              -0.3051, -0.2861,  0.4546,  0.6375, -0.9563, -0.2443,  0.9397,  0.4525,
              -0.8703,  0.0118,  0.7989,  0.4656,  0.8642, -0.8372, -0.6900, -0.2179,
              -0.9575,  0.1994,  0.9602,  0.8782,  0.1776, -0.9443, -0.2989,  0.3896]), max_val=tensor([-0.2244,  0.1366,  0.5535,  0.4781,  0.9987, -0.8585,  0.2149,  0.2896,
              -0.3051, -0.2861,  0.4546,  0.6375, -0.9563, -0.2443,  0.9397,  0.4525,
              -0.8703,  0.0118,  0.7989,  0.4656,  0.8642, -0.8372, -0.6900, -0.2179,
              -0.9575,  0.1994,  0.9602,  0.8782,  0.1776, -0.9443, -0.2989,  0.3896])
    )
  )
  (activation_post_process_2): MovingAverageMinMaxObserver(min_val=0.0, max_val=3.918088436126709)
  (conv2): Conv2d(
    32, 10, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.1602, -0.1757, -0.1724, -0.1654, -0.1460, -0.1729, -0.1722, -0.1635,
              -0.1721, -0.1758]), max_val=tensor([0.1669, 0.1637, 0.1765, 0.1631, 0.1765, 0.1700, 0.1761, 0.1572, 0.1726,
              0.1767])
    )
  )
  (activation_post_process_4): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (pool): MaxPool2d(kernel_size=28, stride=28, padding=0, dilation=1, ceil_mode=False)
  (activation_post_process_5): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (activation_post_process_6): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (activation_post_process_7): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.3463588953018188)
  (fc): Linear(
    in_features=10, out_features=10, bias=True
    (weight_fake_quant): MovingAveragePerChannelMinMaxObserver(
      min_val=tensor([-0.2888, -0.2640, -0.3119, -0.3112, -0.2534, -0.3156, -0.2953, -0.3095,
              -0.2894, -0.2766]), max_val=tensor([0.2896, 0.2684, 0.2644, 0.2437, 0.3056, 0.3141, 0.3016, 0.2660, 0.2929,
              0.3111])
    )
  )
  (activation_post_process_8): MovingAverageMinMaxObserver(min_val=-0.6559169888496399, max_val=1.2133853435516357)
)



def forward(self, x):
    activation_post_process_0 = self.activation_post_process_0(x);  x = None
    conv1 = self.conv1(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_2 = self.activation_post_process_2(conv1);  conv1 = None
    conv2 = self.conv2(activation_post_process_2);  activation_post_process_2 = None
    activation_post_process_4 = self.activation_post_process_4(conv2);  conv2 = None
    pool = self.pool(activation_post_process_4);  activation_post_process_4 = None
    activation_post_process_5 = self.activation_post_process_5(pool);  pool = None
    squeeze = torch.squeeze(activation_post_process_5, dim = 2);  activation_post_process_5 = None
    activation_post_process_6 = self.activation_post_process_6(squeeze);  squeeze = None
    squeeze_1 = torch.squeeze(activation_post_process_6, dim = 2);  activation_post_process_6 = None
    activation_post_process_7 = self.activation_post_process_7(squeeze_1);  squeeze_1 = None
    fc = self.fc(activation_post_process_7);  activation_post_process_7 = None
    activation_post_process_8 = self.activation_post_process_8(fc);  fc = None
    return activation_post_process_8

# To see more debug info, please use `graph_module.print_readable()`
'''

```


### FX Backend for AIEdgeTorch export

[AIEdgeTorch](https://github.com/google-ai-edge/ai-edge-torch) is a powerful (but still volatile) tool to convert torch models to tensorflow through PT2E. Since some models are currently only quantized with FX graphs, I thought to write an FX backend configuration to potentially convert FX models to ai_edge_torch exportable models. More on my [About Quantization](https://github.com/elisa-aleman/ai_python_dev_reference/blob/main/docs/ai_development/About-Quantization.md) guide. 

`quantizeutils.fx.backend_config.ai_edge_backend`
