Skip to content

Commit bfedc3c

Browse files
fix int8/fp8 constant folding issue (#3543)
1 parent 853cc0b commit bfedc3c

File tree

5 files changed

+114
-20
lines changed

5 files changed

+114
-20
lines changed

.github/workflows/build_wheels_linux_aarch64.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,10 @@ jobs:
257257
export PYTORCH_VERSION="$(${CONDA_RUN} pip show torch | grep ^Version: | sed 's/Version: *//' | sed 's/+.\+//')"
258258
${CONDA_RUN} python setup.py clean
259259
echo "Successfully ran `python setup.py clean`"
260+
if [[ "$BUILD_VERSION" != *"+"${CU_VERSION} ]]; then
261+
BUILD_VERSION="${BUILD_VERSION}+${CU_VERSION}"
262+
fi
263+
echo "BUILD_VERSION=$BUILD_VERSION"
260264
if [[ ${{ inputs.is-jetpack }} == false ]]; then
261265
${CONDA_RUN} python setup.py bdist_wheel
262266
else

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,9 @@ def aten_ops_neg(
597597
)
598598
else:
599599

600-
@dynamo_tensorrt_converter(torch.ops.tensorrt.quantize_op.default)
600+
@dynamo_tensorrt_converter(
601+
torch.ops.tensorrt.quantize_op.default, supports_dynamic_shapes=True
602+
)
601603
def aten_ops_quantize_op(
602604
ctx: ConversionContext,
603605
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,42 +28,72 @@ def quantize(
2828
"""
2929

3030
with unset_fake_temporarily():
31-
if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in (
32-
trt.float32,
33-
trt.float16,
34-
):
35-
raise ValueError(
36-
f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16"
37-
)
31+
if isinstance(input_tensor, (torch.Tensor, TRTTensor)):
32+
if input_tensor.dtype not in (
33+
trt.float32,
34+
trt.float16,
35+
trt.bfloat16,
36+
torch.bfloat16,
37+
torch.float16,
38+
torch.float32,
39+
):
40+
raise ValueError(
41+
f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16"
42+
)
3843
if num_bits != 8 or exponent_bits not in (0, 4):
3944
raise ValueError(
4045
f"quantize converter currently only accept INT8 or FP8 based quantize, got {num_bits=}, {exponent_bits=}"
4146
)
47+
else:
48+
raise ValueError(
49+
f"quantize converter received an input of {type(input_tensor)} type. Supported types: torch.Tensor | TRTTensor"
50+
)
51+
4252
if num_bits == 8 and exponent_bits == 0:
53+
dtype = trt.DataType.INT8
4354
max_bound = 127
4455
elif num_bits == 8 and exponent_bits == 4:
56+
dtype = trt.DataType.FP8
4557
max_bound = 448
4658

4759
amax = to_torch(amax, None)
60+
axis = None
61+
# int8 weight quantization is per-channel quantization(it can have one or multiple amax values)
62+
if dtype == trt.DataType.INT8 and amax.numel() > 1:
63+
# if the amax has more than one element, calculate the axis, otherwise axis value will be ignored
64+
amax_init_shape = amax.shape
65+
amax = amax.squeeze().data
66+
assert (
67+
len(amax.shape) == 1
68+
), f"TensorRT does not support multi-axis quantization. {name=} {amax_init_shape=} {amax.shape=} "
69+
axis = list(amax_init_shape).index(list(amax.shape)[0])
70+
assert (
71+
axis == 0
72+
), f"{name=} {amax=} is per-channel quantization, expected axis to be 0, but got {axis=}"
73+
else:
74+
# int8 activation and fp8 weight/activation quantization is per-tensor quantization, it can only have single amax value
75+
assert (
76+
amax.numel() == 1
77+
), f"{name=} is per-tensor quantization, expected amax is a singular value, but got {amax.shape=}"
4878
scale = torch.divide(amax, max_bound)
79+
scale.masked_fill_(scale == 0, 1.0)
4980
scale = get_trt_tensor(ctx, scale, name + "_scale")
50-
# Add Q node
51-
quantize_layer = ctx.net.add_quantize(input_tensor, scale)
52-
if num_bits == 8 and exponent_bits == 0:
53-
quantize_layer.set_output_type(0, trt.DataType.INT8)
54-
elif num_bits == 8 and exponent_bits == 4:
55-
quantize_layer.set_output_type(0, trt.DataType.FP8)
81+
input_tensor = get_trt_tensor(ctx, input_tensor, name)
5682

83+
# Add Q node
84+
quantize_layer = ctx.net.add_quantize(input_tensor, scale, dtype)
85+
if axis is not None:
86+
quantize_layer.axis = axis
5787
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
5888
q_output = quantize_layer.get_output(0)
5989
# Add DQ node
60-
dequantize_layer = ctx.net.add_dequantize(q_output, scale)
90+
dequantize_layer = ctx.net.add_dequantize(
91+
q_output, scale, output_type=input_tensor.dtype
92+
)
93+
dequantize_layer.to_type = input_tensor.dtype
94+
if axis is not None:
95+
dequantize_layer.axis = axis
6196
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
62-
if num_bits == 8 and exponent_bits == 0:
63-
dequantize_layer.precision = trt.DataType.INT8
64-
elif num_bits == 8 and exponent_bits == 4:
65-
# Set DQ layer precision to FP8
66-
dequantize_layer.precision = trt.DataType.FP8
6797
dq_output = dequantize_layer.get_output(0)
6898

6999
return dq_output

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
101101

102102
# TODO: Update this function when quantization is added
103103
def is_impure(self, node: torch.fx.node.Node) -> bool:
104+
if node.target in (torch.ops.tensorrt.quantize_op.default,):
105+
return True
104106
return False

tests/py/dynamo/models/test_models_export.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,59 @@ def calibrate_loop(model):
302302
)
303303
outputs_trt = trt_model(input_tensor)
304304
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)
305+
306+
307+
@unittest.skipIf(
308+
platform.system() != "Linux"
309+
or not importlib.util.find_spec("modelopt")
310+
or Version(metadata.version("nvidia-modelopt")) < Version("0.17.0"),
311+
"modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux",
312+
)
313+
@pytest.mark.unit
314+
def test_base_int8_dynamic_shape(ir):
315+
import modelopt.torch.quantization as mtq
316+
from modelopt.torch.quantization.utils import export_torch_mode
317+
318+
dtype = torch.bfloat16
319+
320+
class SimpleNetwork(torch.nn.Module):
321+
def __init__(self):
322+
super(SimpleNetwork, self).__init__()
323+
self.conv = torch.nn.Conv2d(3, 3, 3, dtype=dtype)
324+
self.linear = torch.nn.Linear(222, 222, dtype=dtype)
325+
326+
def forward(self, x):
327+
return self.linear(self.conv(x))
328+
329+
def calibrate_loop(model):
330+
"""Simple calibration function for testing."""
331+
model(input_tensor)
332+
333+
BATCH_SIZE = torch.export.Dim("BATCH_SIZE", min=2, max=16)
334+
batch_size = 8
335+
input_tensor = torch.randn(batch_size, 3, 224, 224, dtype=dtype).cuda()
336+
model = SimpleNetwork().eval().cuda()
337+
338+
quant_cfg = mtq.INT8_DEFAULT_CFG
339+
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
340+
341+
# model has INT8 qdq nodes at this point
342+
output_pyt = model(input_tensor)
343+
344+
with torch.no_grad():
345+
with export_torch_mode():
346+
exp_program = torch.export.export(
347+
model, (input_tensor,), strict=False, dynamic_shapes=({0: BATCH_SIZE},)
348+
)
349+
trt_model = torchtrt.dynamo.compile(
350+
exp_program,
351+
inputs=[input_tensor],
352+
enabled_precisions={torch.int8, dtype},
353+
min_block_size=1,
354+
debug=True,
355+
cache_built_engines=False,
356+
reuse_cached_engines=False,
357+
truncate_double=True,
358+
)
359+
outputs_trt = trt_model(input_tensor)
360+
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-2, atol=5e-2)

0 commit comments

Comments
 (0)