From 5086d075d97d3e273b71f59f2c8af6fe738b59f3 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 10 Nov 2023 15:03:27 +0800 Subject: [PATCH 1/3] Add the example code for int8-mixed-bf16 quantization in X86Inductor Quantizer --- .../pt2e_quant_ptq_x86_inductor.rst | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/prototype_source/pt2e_quant_ptq_x86_inductor.rst b/prototype_source/pt2e_quant_ptq_x86_inductor.rst index 1a6e152c996..9da90c355f2 100644 --- a/prototype_source/pt2e_quant_ptq_x86_inductor.rst +++ b/prototype_source/pt2e_quant_ptq_x86_inductor.rst @@ -165,11 +165,25 @@ After we get the quantized model, we will further lower it to the inductor backe :: - optimized_model = torch.compile(converted_model) + with torch.no_grad(): + optimized_model = torch.compile(converted_model) + + # Running some benchmark + optimized_model(*example_inputs) + +In a more advanced scenario, int8-mixed-bf16 quantization comes into play. In this instance, +a Convolution or GEMM operator produces BFloat16 output data type instead of Float32 in the absence +of a subsequent quantization node. Subsequently, the BFloat16 tensor seamlessly propagates through +subsequent pointwise operators, effectively minimizing memory usage and potentially enhancing performance. + +:: - # Running some benchmark - optimized_model(*example_inputs) + with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True), torch.no_grad(): + # Turn on Autocast to use int8-mixed-bf16 quantization + optimized_model = torch.compile(converted_model) + # Running some benchmark + optimized_model(*example_inputs) Put all these codes together, we will have the toy example code. Please note that since the Inductor ``freeze`` feature does not turn on by default yet, run your example code with ``TORCHINDUCTOR_FREEZING=1``. From c7398d2255deec5ec465d038627411d4e37bb0db Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Sat, 11 Nov 2023 09:25:44 +0800 Subject: [PATCH 2/3] Specify the quantized operators like input output running precision --- prototype_source/pt2e_quant_ptq_x86_inductor.rst | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/prototype_source/pt2e_quant_ptq_x86_inductor.rst b/prototype_source/pt2e_quant_ptq_x86_inductor.rst index 9da90c355f2..dbc83e12ee3 100644 --- a/prototype_source/pt2e_quant_ptq_x86_inductor.rst +++ b/prototype_source/pt2e_quant_ptq_x86_inductor.rst @@ -179,7 +179,17 @@ subsequent pointwise operators, effectively minimizing memory usage and potentia :: with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True), torch.no_grad(): - # Turn on Autocast to use int8-mixed-bf16 quantization + # Turn on Autocast to use int8-mixed-bf16 quantization. After lowering into Inductor CPP Backend, + # For operators such as QConvolution and QLinear: + # * The input data type is consistently defined as int8, attributable to the presence of a pair + of quantization and dequantization nodes inserted at the input. + # * The computation precision remains at int8. + # * The output data type may vary, being either int8 or BFloat16, contingent on the presence + # of a pair of quantization and dequantization nodes at the output. + # For non-quantizable pointwise operators, the data type will be inherited from the previous node, + # potentially resulting in a data type of BFloat16 in this scenario. + # For quantizable pointwise operators such as QMaxpool2D, it continues to operate with the int8 + # data type for both input and output. optimized_model = torch.compile(converted_model) # Running some benchmark From e7e252593bdf30d65a52b7dcc9a2bcb732d35794 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Tue, 14 Nov 2023 13:43:10 +0800 Subject: [PATCH 3/3] highlight the usage same as regular BF16 Autocast --- prototype_source/pt2e_quant_ptq_x86_inductor.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/prototype_source/pt2e_quant_ptq_x86_inductor.rst b/prototype_source/pt2e_quant_ptq_x86_inductor.rst index dbc83e12ee3..f2cabe88949 100644 --- a/prototype_source/pt2e_quant_ptq_x86_inductor.rst +++ b/prototype_source/pt2e_quant_ptq_x86_inductor.rst @@ -175,6 +175,8 @@ In a more advanced scenario, int8-mixed-bf16 quantization comes into play. In th a Convolution or GEMM operator produces BFloat16 output data type instead of Float32 in the absence of a subsequent quantization node. Subsequently, the BFloat16 tensor seamlessly propagates through subsequent pointwise operators, effectively minimizing memory usage and potentially enhancing performance. +The utilization of this feature mirrors that of regular BFloat16 Autocast, as simple as wrapping the +script within the BFloat16 Autocast context. ::