From c77e969437f71f0e35f86f5e69318c47be81296e Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Tue, 19 Dec 2023 12:03:15 +0800 Subject: [PATCH 1/8] Add QAT Tutorial --- .../pt2e_quant_qat_x86_inductor.rst | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 prototype_source/pt2e_quant_qat_x86_inductor.rst diff --git a/prototype_source/pt2e_quant_qat_x86_inductor.rst b/prototype_source/pt2e_quant_qat_x86_inductor.rst new file mode 100644 index 00000000000..d929560dfed --- /dev/null +++ b/prototype_source/pt2e_quant_qat_x86_inductor.rst @@ -0,0 +1,78 @@ +PyTorch 2 Export Quantization-Aware Training (QAT) with X86 Backend through Inductor +======================================================================================== + +**Author**: `Leslie Fang `_, `Jiong Gong `_ + +Prerequisites +^^^^^^^^^^^^^^^ + +- `PyTorch 2 Export Quantization-Aware Training tutorial `_ +- `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor tutorial `_ +- `TorchInductor and torch.compile concepts in PyTorch `_ + + +This tutorial demonstrates the process of performing PT2 export quantization-aware training (QAT) on X86 CPU +with X86InductorQuantizer, and subsequently lowering the quantized model into Inductor. +For more comprehensive details about PyTorch 2 Export Quantization-Aware Training in general, please refer to the +dedicated tutorial on `PyTorch 2 Export Quantization-Aware Training `_. +For a deeper understanding of X86InductorQuantizer, please consult the tutorial of +`PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor `_. + +The PyTorch 2 Export QAT flow looks like the following—it is similar +to the post training quantization (PTQ) flow for the most part: + +.. code:: python + + import torch + from torch._export import capture_pre_autograd_graph + from torch.ao.quantization.quantize_pt2e import ( + prepare_qat_pt2e, + convert_pt2e, + ) + import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq + from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(1024, 1000) + + def forward(self, x): + return self.linear(x) + + + example_inputs = (torch.randn(1, 1024),) + m = M() + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result shoud mostly stay the same + exported_model = capture_pre_autograd_graph(m, example_inputs) + # we get a model with aten ops + + # Step 2. quantization-aware training + # Use Backend Quantizer for X86 CPU + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config(is_qat=True)) + prepared_model = prepare_qat_pt2e(exported_model, quantizer) + + # train omitted + + converted_model = convert_pt2e(prepared_model) + # we have a model with aten ops doing integer computations when possible + + # move the quantized model to eval mode, equivalent to `m.eval()` + torch.ao.quantization.move_exported_model_to_eval(converted_model) + + # Lower the model into Inductor + with torch.no_grad(): + optimized_model = torch.compile(converted_model) + _ = optimized_model(*example_inputs) + +Please note that since the Inductor ``freeze`` feature does not turn on by default yet, need to run example code with ``TORCHINDUCTOR_FREEZING=1``. + +For example: + +:: + + TORCHINDUCTOR_FREEZING=1 python example_x86inductorquantizer_qat.py From 15520f2281c107c288d154815519c623704d2051 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 20 Dec 2023 10:39:46 +0800 Subject: [PATCH 2/8] Update prototype_source/pt2e_quant_qat_x86_inductor.rst Co-authored-by: Svetlana Karslioglu --- prototype_source/pt2e_quant_qat_x86_inductor.rst | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/prototype_source/pt2e_quant_qat_x86_inductor.rst b/prototype_source/pt2e_quant_qat_x86_inductor.rst index d929560dfed..4fc39e6f8df 100644 --- a/prototype_source/pt2e_quant_qat_x86_inductor.rst +++ b/prototype_source/pt2e_quant_qat_x86_inductor.rst @@ -11,11 +11,10 @@ Prerequisites - `TorchInductor and torch.compile concepts in PyTorch `_ -This tutorial demonstrates the process of performing PT2 export quantization-aware training (QAT) on X86 CPU -with X86InductorQuantizer, and subsequently lowering the quantized model into Inductor. -For more comprehensive details about PyTorch 2 Export Quantization-Aware Training in general, please refer to the -dedicated tutorial on `PyTorch 2 Export Quantization-Aware Training `_. -For a deeper understanding of X86InductorQuantizer, please consult the tutorial of +This tutorial demonstrates the process of performing PT2 export Quantization-Aware Training (QAT) on X86 CPU +using X86InductorQuantizer and subsequently lowering the quantized model into Inductor. +For a more in-depth understanding of PT2 Export Quantization-Aware Training, we recommend referring to the dedicated `PyTorch 2 Export Quantization-Aware Training `_. +To gain a deeper insight into X86InductorQuantizer, please see the tutorial of `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor `_. The PyTorch 2 Export QAT flow looks like the following—it is similar From 3febc85e7ad615f5cbf779f69a1e8ffae0972f28 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 20 Dec 2023 10:39:55 +0800 Subject: [PATCH 3/8] Update prototype_source/pt2e_quant_qat_x86_inductor.rst Co-authored-by: Svetlana Karslioglu --- prototype_source/pt2e_quant_qat_x86_inductor.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/pt2e_quant_qat_x86_inductor.rst b/prototype_source/pt2e_quant_qat_x86_inductor.rst index 4fc39e6f8df..8e8c76550cb 100644 --- a/prototype_source/pt2e_quant_qat_x86_inductor.rst +++ b/prototype_source/pt2e_quant_qat_x86_inductor.rst @@ -17,7 +17,7 @@ For a more in-depth understanding of PT2 Export Quantization-Aware Training, we To gain a deeper insight into X86InductorQuantizer, please see the tutorial of `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor `_. -The PyTorch 2 Export QAT flow looks like the following—it is similar +The PyTorch 2 Export QAT flow is largely similar to the to the post training quantization (PTQ) flow for the most part: .. code:: python From 0dc0d03b9b1bb4d31778eae22f7854f44bd130be Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 20 Dec 2023 10:40:00 +0800 Subject: [PATCH 4/8] Update prototype_source/pt2e_quant_qat_x86_inductor.rst Co-authored-by: Svetlana Karslioglu --- prototype_source/pt2e_quant_qat_x86_inductor.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/pt2e_quant_qat_x86_inductor.rst b/prototype_source/pt2e_quant_qat_x86_inductor.rst index 8e8c76550cb..14763c4971c 100644 --- a/prototype_source/pt2e_quant_qat_x86_inductor.rst +++ b/prototype_source/pt2e_quant_qat_x86_inductor.rst @@ -18,7 +18,7 @@ To gain a deeper insight into X86InductorQuantizer, please see the tutorial of `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor `_. The PyTorch 2 Export QAT flow is largely similar to the -to the post training quantization (PTQ) flow for the most part: +to the post-training quantization (PTQ) flow for the most part: .. code:: python From 6857c39e1d8ccc8fe46f4f14dcd39578e249f522 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 20 Dec 2023 10:40:06 +0800 Subject: [PATCH 5/8] Update prototype_source/pt2e_quant_qat_x86_inductor.rst Co-authored-by: Svetlana Karslioglu --- prototype_source/pt2e_quant_qat_x86_inductor.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/pt2e_quant_qat_x86_inductor.rst b/prototype_source/pt2e_quant_qat_x86_inductor.rst index 14763c4971c..afa296a845c 100644 --- a/prototype_source/pt2e_quant_qat_x86_inductor.rst +++ b/prototype_source/pt2e_quant_qat_x86_inductor.rst @@ -68,7 +68,7 @@ to the post-training quantization (PTQ) flow for the most part: optimized_model = torch.compile(converted_model) _ = optimized_model(*example_inputs) -Please note that since the Inductor ``freeze`` feature does not turn on by default yet, need to run example code with ``TORCHINDUCTOR_FREEZING=1``. +Please note that the Inductor ``freeze`` feature is not enabled by default. To use this feature, you need to run example code with ``TORCHINDUCTOR_FREEZING=1``. For example: From 2ead9c8f51c924675182e966f6a125c079e3aa7c Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 20 Dec 2023 12:52:51 +0800 Subject: [PATCH 6/8] Merge PTQ/QAT tutorial for x86Inductor --- prototype_source/prototype_index.rst | 6 ++ .../pt2e_quant_qat_x86_inductor.rst | 77 ------------------- ...ductor.rst => pt2e_quant_x86_inductor.rst} | 35 +++++++-- 3 files changed, 35 insertions(+), 83 deletions(-) delete mode 100644 prototype_source/pt2e_quant_qat_x86_inductor.rst rename prototype_source/{pt2e_quant_ptq_x86_inductor.rst => pt2e_quant_x86_inductor.rst} (88%) diff --git a/prototype_source/prototype_index.rst b/prototype_source/prototype_index.rst index 1f303e7d159..8d965194f88 100644 --- a/prototype_source/prototype_index.rst +++ b/prototype_source/prototype_index.rst @@ -89,6 +89,12 @@ Prototype features are not available as part of binary distributions like PyPI o :link: ../prototype/pt2e_quant_qat.html :tags: Quantization +.. customcarditem:: + :header: PyTorch 2 Export Quantization with X86 Backend through Inductor + :card_description: Learn how to use PT2 Export Quantization with X86 Backend through Inductor. + :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: ../prototype/pt2e_quant_x86_inductor.html + :tags: Quantization .. Sparsity diff --git a/prototype_source/pt2e_quant_qat_x86_inductor.rst b/prototype_source/pt2e_quant_qat_x86_inductor.rst deleted file mode 100644 index afa296a845c..00000000000 --- a/prototype_source/pt2e_quant_qat_x86_inductor.rst +++ /dev/null @@ -1,77 +0,0 @@ -PyTorch 2 Export Quantization-Aware Training (QAT) with X86 Backend through Inductor -======================================================================================== - -**Author**: `Leslie Fang `_, `Jiong Gong `_ - -Prerequisites -^^^^^^^^^^^^^^^ - -- `PyTorch 2 Export Quantization-Aware Training tutorial `_ -- `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor tutorial `_ -- `TorchInductor and torch.compile concepts in PyTorch `_ - - -This tutorial demonstrates the process of performing PT2 export Quantization-Aware Training (QAT) on X86 CPU -using X86InductorQuantizer and subsequently lowering the quantized model into Inductor. -For a more in-depth understanding of PT2 Export Quantization-Aware Training, we recommend referring to the dedicated `PyTorch 2 Export Quantization-Aware Training `_. -To gain a deeper insight into X86InductorQuantizer, please see the tutorial of -`PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor `_. - -The PyTorch 2 Export QAT flow is largely similar to the -to the post-training quantization (PTQ) flow for the most part: - -.. code:: python - - import torch - from torch._export import capture_pre_autograd_graph - from torch.ao.quantization.quantize_pt2e import ( - prepare_qat_pt2e, - convert_pt2e, - ) - import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq - from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(1024, 1000) - - def forward(self, x): - return self.linear(x) - - - example_inputs = (torch.randn(1, 1024),) - m = M() - - # Step 1. program capture - # NOTE: this API will be updated to torch.export API in the future, but the captured - # result shoud mostly stay the same - exported_model = capture_pre_autograd_graph(m, example_inputs) - # we get a model with aten ops - - # Step 2. quantization-aware training - # Use Backend Quantizer for X86 CPU - quantizer = X86InductorQuantizer() - quantizer.set_global(xiq.get_default_x86_inductor_quantization_config(is_qat=True)) - prepared_model = prepare_qat_pt2e(exported_model, quantizer) - - # train omitted - - converted_model = convert_pt2e(prepared_model) - # we have a model with aten ops doing integer computations when possible - - # move the quantized model to eval mode, equivalent to `m.eval()` - torch.ao.quantization.move_exported_model_to_eval(converted_model) - - # Lower the model into Inductor - with torch.no_grad(): - optimized_model = torch.compile(converted_model) - _ = optimized_model(*example_inputs) - -Please note that the Inductor ``freeze`` feature is not enabled by default. To use this feature, you need to run example code with ``TORCHINDUCTOR_FREEZING=1``. - -For example: - -:: - - TORCHINDUCTOR_FREEZING=1 python example_x86inductorquantizer_qat.py diff --git a/prototype_source/pt2e_quant_ptq_x86_inductor.rst b/prototype_source/pt2e_quant_x86_inductor.rst similarity index 88% rename from prototype_source/pt2e_quant_ptq_x86_inductor.rst rename to prototype_source/pt2e_quant_x86_inductor.rst index 60bd5ffa5a4..a9380d8d650 100644 --- a/prototype_source/pt2e_quant_ptq_x86_inductor.rst +++ b/prototype_source/pt2e_quant_x86_inductor.rst @@ -1,4 +1,4 @@ -PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor +PyTorch 2 Export Quantization with X86 Backend through Inductor ======================================================================================== **Author**: `Leslie Fang `_, `Weiwen Xia `_, `Jiong Gong `_, `Jerry Zhang `_ @@ -7,6 +7,7 @@ Prerequisites ^^^^^^^^^^^^^^^ - `PyTorch 2 Export Post Training Quantization `_ +- `PyTorch 2 Export Quantization-Aware Training tutorial `_ - `TorchInductor and torch.compile concepts in PyTorch `_ - `Inductor C++ Wrapper concepts `_ @@ -16,14 +17,15 @@ Introduction This tutorial introduces the steps for utilizing the PyTorch 2 Export Quantization flow to generate a quantized model customized for the x86 inductor backend and explains how to lower the quantized model into the inductor. -The new quantization 2 flow uses the PT2 Export to capture the model into a graph and perform quantization transformations on top of the ATen graph. This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX. +The new quantization 2 flow uses the PT2 Export to capture the model into a graph and perform quantization transformations on top of the ATen graph. +This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX. TorchInductor is the new compiler backend that compiles the FX Graphs generated by TorchDynamo into optimized C++/Triton kernels. This flow of quantization 2 with Inductor mainly includes three steps: - Step 1: Capture the FX Graph from the eager Model based on the `torch export mechanism `_. - Step 2: Apply the Quantization flow based on the captured FX Graph, including defining the backend-specific quantizer, generating the prepared model with observers, - performing the prepared model's calibration, and converting the prepared model into the quantized model. + performing the prepared model's calibration or quantization-aware training, and converting the prepared model into the quantized model. - Step 3: Lower the quantized model into inductor with the API ``torch.compile``. The high-level architecture of this flow could look like this: @@ -83,6 +85,8 @@ We will start by performing the necessary imports, capturing the FX Graph from t model = models.__dict__[model_name](pretrained=True) # Set the model to eval mode + # Only apply it for post-training static quantization + # Skip this step for quantization-aware training model = model.eval() # Create the data, using the dummy data here as an example @@ -116,11 +120,20 @@ Next, we will have the FX Module to be quantized. After we capture the FX Module to be quantized, we will import the Backend Quantizer for X86 CPU and configure how to quantize the model. +For post-training static quantization: + :: quantizer = X86InductorQuantizer() quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) +For quantization-aware training: + +:: + + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config(is_qat=True)) + .. note:: The default quantization configuration in ``X86InductorQuantizer`` uses 8-bits for both activations and weights. @@ -128,14 +141,23 @@ quantize the model. `multiplications are 7-bit x 8-bit `_. In other words, potential numeric saturation and accuracy issue may happen when running on CPU without Vector Neural Network Instruction. -After we import the backend-specific Quantizer, we will prepare the model for post-training quantization. -``prepare_pt2e`` folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model. +After we import the backend-specific Quantizer, we will prepare the model for post-training quantization or quantization-aware training. + +For post-training static quantization, ``prepare_pt2e`` folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model. :: prepared_model = prepare_pt2e(exported_model, quantizer) -Now, we will calibrate the ``prepared_model`` after the observers are inserted in the model. +For quantization-aware training: + +:: + + prepared_model = prepare_qat_pt2e(exported_model, quantizer) + + +Now, we will do calibration for post-training static quantization or quantization-aware training. Here is the example code +for post-training static quantization. The example code omits quantization-aware training for simplicity. :: @@ -155,6 +177,7 @@ Finally, we will convert the calibrated Model to a quantized Model. ``convert_pt :: converted_model = convert_pt2e(prepared_model) + torch.ao.quantization.move_exported_model_to_eval(converted_model) After these steps, we finished running the quantization flow and we will get the quantized model. From 2ca1927504f2aa8f96f39fcba566c10373b72f51 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 22 Dec 2023 09:18:45 +0800 Subject: [PATCH 7/8] change the words --- prototype_source/pt2e_quant_x86_inductor.rst | 31 ++++---------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/prototype_source/pt2e_quant_x86_inductor.rst b/prototype_source/pt2e_quant_x86_inductor.rst index a9380d8d650..6dca39d6e12 100644 --- a/prototype_source/pt2e_quant_x86_inductor.rst +++ b/prototype_source/pt2e_quant_x86_inductor.rst @@ -7,7 +7,7 @@ Prerequisites ^^^^^^^^^^^^^^^ - `PyTorch 2 Export Post Training Quantization `_ -- `PyTorch 2 Export Quantization-Aware Training tutorial `_ +- `PyTorch 2 Export Quantization-Aware Training `_ - `TorchInductor and torch.compile concepts in PyTorch `_ - `Inductor C++ Wrapper concepts `_ @@ -17,7 +17,7 @@ Introduction This tutorial introduces the steps for utilizing the PyTorch 2 Export Quantization flow to generate a quantized model customized for the x86 inductor backend and explains how to lower the quantized model into the inductor. -The new quantization 2 flow uses the PT2 Export to capture the model into a graph and perform quantization transformations on top of the ATen graph. +The pytorch 2 export quantization flow uses the torch.export to capture the model into a graph and perform quantization transformations on top of the ATen graph. This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX. TorchInductor is the new compiler backend that compiles the FX Graphs generated by TorchDynamo into optimized C++/Triton kernels. @@ -85,8 +85,6 @@ We will start by performing the necessary imports, capturing the FX Graph from t model = models.__dict__[model_name](pretrained=True) # Set the model to eval mode - # Only apply it for post-training static quantization - # Skip this step for quantization-aware training model = model.eval() # Create the data, using the dummy data here as an example @@ -120,20 +118,11 @@ Next, we will have the FX Module to be quantized. After we capture the FX Module to be quantized, we will import the Backend Quantizer for X86 CPU and configure how to quantize the model. -For post-training static quantization: - :: quantizer = X86InductorQuantizer() quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) -For quantization-aware training: - -:: - - quantizer = X86InductorQuantizer() - quantizer.set_global(xiq.get_default_x86_inductor_quantization_config(is_qat=True)) - .. note:: The default quantization configuration in ``X86InductorQuantizer`` uses 8-bits for both activations and weights. @@ -141,23 +130,14 @@ For quantization-aware training: `multiplications are 7-bit x 8-bit `_. In other words, potential numeric saturation and accuracy issue may happen when running on CPU without Vector Neural Network Instruction. -After we import the backend-specific Quantizer, we will prepare the model for post-training quantization or quantization-aware training. - -For post-training static quantization, ``prepare_pt2e`` folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model. +After we import the backend-specific Quantizer, we will prepare the model for post-training quantization. +``prepare_pt2e`` folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model. :: prepared_model = prepare_pt2e(exported_model, quantizer) -For quantization-aware training: - -:: - - prepared_model = prepare_qat_pt2e(exported_model, quantizer) - - -Now, we will do calibration for post-training static quantization or quantization-aware training. Here is the example code -for post-training static quantization. The example code omits quantization-aware training for simplicity. +Now, we will calibrate the ``prepared_model`` after the observers are inserted in the model. :: @@ -177,7 +157,6 @@ Finally, we will convert the calibrated Model to a quantized Model. ``convert_pt :: converted_model = convert_pt2e(prepared_model) - torch.ao.quantization.move_exported_model_to_eval(converted_model) After these steps, we finished running the quantization flow and we will get the quantized model. From 16aa287293929fe1eb7acce2889e1981aca405c4 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 22 Dec 2023 09:48:52 +0800 Subject: [PATCH 8/8] Add QAT as a standalone section --- prototype_source/pt2e_quant_x86_inductor.rst | 88 ++++++++++++++++++-- 1 file changed, 79 insertions(+), 9 deletions(-) diff --git a/prototype_source/pt2e_quant_x86_inductor.rst b/prototype_source/pt2e_quant_x86_inductor.rst index 6dca39d6e12..80415068cae 100644 --- a/prototype_source/pt2e_quant_x86_inductor.rst +++ b/prototype_source/pt2e_quant_x86_inductor.rst @@ -1,10 +1,10 @@ PyTorch 2 Export Quantization with X86 Backend through Inductor -======================================================================================== +================================================================== **Author**: `Leslie Fang `_, `Weiwen Xia `_, `Jiong Gong `_, `Jerry Zhang `_ Prerequisites -^^^^^^^^^^^^^^^ +--------------- - `PyTorch 2 Export Post Training Quantization `_ - `PyTorch 2 Export Quantization-Aware Training `_ @@ -12,7 +12,7 @@ Prerequisites - `Inductor C++ Wrapper concepts `_ Introduction -^^^^^^^^^^^^^^ +-------------- This tutorial introduces the steps for utilizing the PyTorch 2 Export Quantization flow to generate a quantized model customized for the x86 inductor backend and explains how to lower the quantized model into the inductor. @@ -63,10 +63,14 @@ and outstanding out-of-box performance with the compiler backend. Especially on further boost the models' performance by leveraging the `advanced-matrix-extensions `_ feature. -Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model `_. +Post Training Quantization +---------------------------- + +Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model `_ +for post training quantization. 1. Capture FX Graph ---------------------- +^^^^^^^^^^^^^^^^^^^^^ We will start by performing the necessary imports, capturing the FX Graph from the eager module. @@ -113,7 +117,7 @@ We will start by performing the necessary imports, capturing the FX Graph from t Next, we will have the FX Module to be quantized. 2. Apply Quantization ----------------------------- +^^^^^^^^^^^^^^^^^^^^^^^ After we capture the FX Module to be quantized, we will import the Backend Quantizer for X86 CPU and configure how to quantize the model. @@ -162,7 +166,7 @@ After these steps, we finished running the quantization flow and we will get the 3. Lower into Inductor ------------------------- +^^^^^^^^^^^^^^^^^^^^^^^^ After we get the quantized model, we will further lower it to the inductor backend. The default Inductor wrapper generates Python code to invoke both generated kernels and external kernels. Additionally, Inductor supports @@ -224,8 +228,74 @@ With PyTorch 2.1 release, all CNN models from TorchBench test suite have been me to `this document `_ for detail benchmark number. -4. Conclusion ---------------- +Quantization Aware Training +----------------------------- + +The PyTorch 2 Export Quantization-Aware Training (QAT) is now supported on X86 CPU using X86InductorQuantizer, +followed by the subsequent lowering of the quantized model into Inductor. +For a more in-depth understanding of PT2 Export Quantization-Aware Training, +we recommend referring to the dedicated `PyTorch 2 Export Quantization-Aware Training `_. + +The PyTorch 2 Export QAT flow is largely similar to the PTQ flow: + +.. code:: python + + import torch + from torch._export import capture_pre_autograd_graph + from torch.ao.quantization.quantize_pt2e import ( + prepare_qat_pt2e, + convert_pt2e, + ) + import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq + from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(1024, 1000) + + def forward(self, x): + return self.linear(x) + + example_inputs = (torch.randn(1, 1024),) + m = M() + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result shoud mostly stay the same + exported_model = capture_pre_autograd_graph(m, example_inputs) + # we get a model with aten ops + + # Step 2. quantization-aware training + # Use Backend Quantizer for X86 CPU + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config(is_qat=True)) + prepared_model = prepare_qat_pt2e(exported_model, quantizer) + + # train omitted + + converted_model = convert_pt2e(prepared_model) + # we have a model with aten ops doing integer computations when possible + + # move the quantized model to eval mode, equivalent to `m.eval()` + torch.ao.quantization.move_exported_model_to_eval(converted_model) + + # Lower the model into Inductor + with torch.no_grad(): + optimized_model = torch.compile(converted_model) + _ = optimized_model(*example_inputs) + +Please note that the Inductor ``freeze`` feature is not enabled by default. +To use this feature, you need to run example code with ``TORCHINDUCTOR_FREEZING=1``. + +For example: + +:: + + TORCHINDUCTOR_FREEZING=1 python example_x86inductorquantizer_qat.py + +Conclusion +------------ With this tutorial, we introduce how to use Inductor with X86 CPU in PyTorch 2 Quantization. Users can learn about how to use ``X86InductorQuantizer`` to quantize a model and lower it into the inductor with X86 CPU devices.