Description
🐛 Describe the bug
I need to fuse a residual node to a single layer:
y = some_quantized_tensor
x = some_quantized_tensor
x = bn(conv(x))
x = x + y
x = relu(x)
to
x = conv_with_add_relu(x, y)
When I use following test code:
from collections import OrderedDict
import contextlib
import operator
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.ao.quantization.fx.match_utils import (
MatchAllNode,
)
from torch.ao.quantization.quantize_fx import (
fuse_fx,
)
from torch.ao.quantization.backend_config import (
get_qnnpack_backend_config,
BackendConfig,
BackendPatternConfig,
DTypeConfig,
ObservationType,
get_fbgemm_backend_config
)
from torch.ao.quantization import get_default_qconfig_mapping
import torch.ao.quantization.quantize_fx as qfx
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(3)
self.iden = nn.Identity()
def forward(self, x):
y = x
y = self.iden(x)
x = self.conv(x)
x = self.bn(x)
x = torch.add(x, y)
x = self.relu(x)
return x
m = M().eval()
def fuse_conv_bn_relu(is_qat, relu, add_pattern):
_, bn_pattern, _ = add_pattern
bn, conv = bn_pattern
return conv
def conv_bn_res_relu_root_node_getter(pattern):
relu, add_pattern = pattern
_, bn_pattern, _ = add_pattern
bn, conv = bn_pattern
return conv
def conv_bn_res_relu_extra_inputs_getter(pattern):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
relu, add_pattern = pattern
_, bn_pattern, extra_input = add_pattern
bn, conv = bn_pattern
return [extra_input]
fbgemm_weighted_op_int8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.qint8,
bias_dtype=torch.float,
)
# for pytorch <= 1.13
# conv_bn_res_relu_config = BackendPatternConfig((nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
# .set_fuser_method(fuse_conv_bn_relu) \
# ._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
# ._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter)
# for pytorch master
conv_bn_res_relu_config = BackendPatternConfig() \
._set_pattern_complex_format((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
.set_fuser_method(fuse_conv_bn_relu) \
._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter) \
.set_dtype_configs(fbgemm_weighted_op_int8_dtype_config)
backend_config = get_fbgemm_backend_config().set_backend_pattern_config(conv_bn_res_relu_config)
# m = fuse_fx(m, backend_config=backend_config)
qmapping = get_default_qconfig_mapping()
prepared_model = qfx.prepare_fx(m, qmapping, (), backend_config=backend_config)
converted_model = qfx.convert_fx(prepared_model, qconfig_mapping=qmapping, backend_config=backend_config)
converted_model.print_readable()
I found that the second input of conv_add in converted node is a dequantized tensor, which cause error in my project:
class GraphModule(torch.nn.Module):
def forward(self, x):
# No stacktrace found for following nodes
iden_input_scale_0 = self.iden_input_scale_0
iden_input_zero_point_0 = self.iden_input_zero_point_0
quantize_per_tensor = torch.quantize_per_tensor(x, iden_input_scale_0, iden_input_zero_point_0, torch.quint8); x = iden_input_scale_0 = iden_input_zero_point_0 = None
# File: /home/yy/anaconda3/envs/cpudev/lib/python3.8/site-packages/torch/ao/quantization/fx/tracer.py:103, code: return super().call_module(m, forward, args, kwargs)
iden = self.iden(quantize_per_tensor)
# No stacktrace found for following nodes
dequantize_1 = iden.dequantize(); iden = None
# File: /home/yy/anaconda3/envs/cpudev/lib/python3.8/site-packages/torch/ao/quantization/fx/tracer.py:103, code: return super().call_module(m, forward, args, kwargs)
conv = self.conv(quantize_per_tensor, dequantize_1); quantize_per_tensor = dequantize_1 = None
# No stacktrace found for following nodes
dequantize_2 = conv.dequantize(); conv = None
return dequantize_2
The dequantize_1
in graphmodule code should be a quantized tensor.
Versions
Collecting environment information...
PyTorch version: 2.0.0.dev20221231
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 11.0.0 (https://github.com/llvm/llvm-project.git 0160ad802e899c2922bc9b29564080c22eb0908c)
CMake version: Could not collect
Libc version: glibc-2.31
Python version: 3.8.15 (default, Nov 24 2022, 15:19:38) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.17
Is CUDA available: False
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080 Laptop GPU
Nvidia driver version: 525.60.11
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.0.dev20221231
[pip3] torchaudio==2.0.0.dev20221231
[pip3] torchvision==0.15.0.dev20221231
[conda] blas 1.0 mkl
[conda] cpuonly 2.0 0 pytorch-nightly
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.23.5 py38h14f4228_0
[conda] numpy-base 1.23.5 py38h31eccc5_0
[conda] pytorch 2.0.0.dev20221231 py3.8_cpu_0 pytorch-nightly
[conda] pytorch-mutex 1.0 cpu pytorch-nightly
[conda] torchaudio 2.0.0.dev20221231 py38_cpu pytorch-nightly
[conda] torchvision 0.15.0.dev20221231 py38_cpu pytorch-nightly
cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel