Skip to content

quantization fuse in convert_fx leave a wrong dequantize node when fuse multiple-input node #91688

Closed
@FindDefinition

Description

@FindDefinition

🐛 Describe the bug

I need to fuse a residual node to a single layer:

y = some_quantized_tensor
x = some_quantized_tensor
x = bn(conv(x))
x = x + y
x = relu(x)

to

x = conv_with_add_relu(x, y)

When I use following test code:

from collections import OrderedDict
import contextlib
import operator
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.ao.quantization.fx.match_utils import (
    MatchAllNode,
)
from torch.ao.quantization.quantize_fx import (
    fuse_fx,
)
from torch.ao.quantization.backend_config import (
    get_qnnpack_backend_config,
    BackendConfig,
    BackendPatternConfig,
    DTypeConfig,
    ObservationType,
    get_fbgemm_backend_config
)
from torch.ao.quantization import get_default_qconfig_mapping

import torch.ao.quantization.quantize_fx as qfx

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 3, 3)
        self.bn = torch.nn.BatchNorm2d(3)
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(3)
        self.iden = nn.Identity()

    def forward(self, x):
        y = x
        y = self.iden(x)
        x = self.conv(x)
        x = self.bn(x)
        x = torch.add(x, y)
        x = self.relu(x)
        return x


m = M().eval()

def fuse_conv_bn_relu(is_qat, relu, add_pattern):
    _, bn_pattern, _ = add_pattern
    bn, conv = bn_pattern
    return conv

def conv_bn_res_relu_root_node_getter(pattern):
    relu, add_pattern = pattern
    _, bn_pattern, _ = add_pattern
    bn, conv = bn_pattern
    return conv

def conv_bn_res_relu_extra_inputs_getter(pattern):
    """ get inputs pattern for extra inputs, inputs for root node
    are assumed to be copied over from root node to the fused node
    """
    relu, add_pattern = pattern
    _, bn_pattern, extra_input = add_pattern
    bn, conv = bn_pattern
    return [extra_input]
fbgemm_weighted_op_int8_dtype_config = DTypeConfig(
    input_dtype=torch.quint8,
    output_dtype=torch.quint8,
    weight_dtype=torch.qint8,
    bias_dtype=torch.float,
)
# for pytorch <= 1.13
# conv_bn_res_relu_config = BackendPatternConfig((nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
#     .set_fuser_method(fuse_conv_bn_relu) \
#     ._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
#     ._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter)
# for pytorch master
conv_bn_res_relu_config = BackendPatternConfig() \
    ._set_pattern_complex_format((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
    .set_fuser_method(fuse_conv_bn_relu) \
    ._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
    ._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter) \
    .set_dtype_configs(fbgemm_weighted_op_int8_dtype_config)

backend_config = get_fbgemm_backend_config().set_backend_pattern_config(conv_bn_res_relu_config)
# m = fuse_fx(m, backend_config=backend_config)
qmapping = get_default_qconfig_mapping()
prepared_model = qfx.prepare_fx(m, qmapping, (), backend_config=backend_config)
converted_model = qfx.convert_fx(prepared_model, qconfig_mapping=qmapping, backend_config=backend_config)

converted_model.print_readable()

I found that the second input of conv_add in converted node is a dequantized tensor, which cause error in my project:

class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        iden_input_scale_0 = self.iden_input_scale_0
        iden_input_zero_point_0 = self.iden_input_zero_point_0
        quantize_per_tensor = torch.quantize_per_tensor(x, iden_input_scale_0, iden_input_zero_point_0, torch.quint8);  x = iden_input_scale_0 = iden_input_zero_point_0 = None
        
        # File: /home/yy/anaconda3/envs/cpudev/lib/python3.8/site-packages/torch/ao/quantization/fx/tracer.py:103, code: return super().call_module(m, forward, args, kwargs)
        iden = self.iden(quantize_per_tensor)
        
        # No stacktrace found for following nodes
        dequantize_1 = iden.dequantize();  iden = None
        
        # File: /home/yy/anaconda3/envs/cpudev/lib/python3.8/site-packages/torch/ao/quantization/fx/tracer.py:103, code: return super().call_module(m, forward, args, kwargs)
        conv = self.conv(quantize_per_tensor, dequantize_1);  quantize_per_tensor = dequantize_1 = None
        
        # No stacktrace found for following nodes
        dequantize_2 = conv.dequantize();  conv = None
        return dequantize_2

The dequantize_1 in graphmodule code should be a quantized tensor.

Versions

Collecting environment information...
PyTorch version: 2.0.0.dev20221231
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 11.0.0 (https://github.com/llvm/llvm-project.git 0160ad802e899c2922bc9b29564080c22eb0908c)
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.8.15 (default, Nov 24 2022, 15:19:38) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.17
Is CUDA available: False
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080 Laptop GPU
Nvidia driver version: 525.60.11
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.0.dev20221231
[pip3] torchaudio==2.0.0.dev20221231
[pip3] torchvision==0.15.0.dev20221231
[conda] blas 1.0 mkl
[conda] cpuonly 2.0 0 pytorch-nightly
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.23.5 py38h14f4228_0
[conda] numpy-base 1.23.5 py38h31eccc5_0
[conda] pytorch 2.0.0.dev20221231 py3.8_cpu_0 pytorch-nightly
[conda] pytorch-mutex 1.0 cpu pytorch-nightly
[conda] torchaudio 2.0.0.dev20221231 py38_cpu pytorch-nightly
[conda] torchvision 0.15.0.dev20221231 py38_cpu pytorch-nightly

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel

Metadata

Metadata

Labels

oncall: quantizationQuantization support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions