Skip to content

Add Tutorial of QAT with X86InductorQuantizer #2717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
6 changes: 6 additions & 0 deletions prototype_source/prototype_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ Prototype features are not available as part of binary distributions like PyPI o
:link: ../prototype/pt2e_quant_qat.html
:tags: Quantization

.. customcarditem::
:header: PyTorch 2 Export Quantization with X86 Backend through Inductor
:card_description: Learn how to use PT2 Export Quantization with X86 Backend through Inductor.
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: ../prototype/pt2e_quant_x86_inductor.html
:tags: Quantization

.. Sparsity

Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor
========================================================================================
PyTorch 2 Export Quantization with X86 Backend through Inductor
==================================================================

**Author**: `Leslie Fang <https://github.com/leslie-fang-intel>`_, `Weiwen Xia <https://github.com/Xia-Weiwen>`_, `Jiong Gong <https://github.com/jgong5>`_, `Jerry Zhang <https://github.com/jerryzh168>`_

Prerequisites
^^^^^^^^^^^^^^^
---------------

- `PyTorch 2 Export Post Training Quantization <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_
- `PyTorch 2 Export Quantization-Aware Training <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_
- `TorchInductor and torch.compile concepts in PyTorch <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
- `Inductor C++ Wrapper concepts <https://pytorch.org/tutorials/prototype/inductor_cpp_wrapper_tutorial.html>`_

Introduction
^^^^^^^^^^^^^^
--------------

This tutorial introduces the steps for utilizing the PyTorch 2 Export Quantization flow to generate a quantized model customized
for the x86 inductor backend and explains how to lower the quantized model into the inductor.

The new quantization 2 flow uses the PT2 Export to capture the model into a graph and perform quantization transformations on top of the ATen graph. This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX.
The pytorch 2 export quantization flow uses the torch.export to capture the model into a graph and perform quantization transformations on top of the ATen graph.
This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX.
TorchInductor is the new compiler backend that compiles the FX Graphs generated by TorchDynamo into optimized C++/Triton kernels.

This flow of quantization 2 with Inductor mainly includes three steps:

- Step 1: Capture the FX Graph from the eager Model based on the `torch export mechanism <https://pytorch.org/docs/main/export.html>`_.
- Step 2: Apply the Quantization flow based on the captured FX Graph, including defining the backend-specific quantizer, generating the prepared model with observers,
performing the prepared model's calibration, and converting the prepared model into the quantized model.
performing the prepared model's calibration or quantization-aware training, and converting the prepared model into the quantized model.
- Step 3: Lower the quantized model into inductor with the API ``torch.compile``.

The high-level architecture of this flow could look like this:
Expand Down Expand Up @@ -61,10 +63,14 @@ and outstanding out-of-box performance with the compiler backend. Especially on
further boost the models' performance by leveraging the
`advanced-matrix-extensions <https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/overview.html>`_ feature.

Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_.
Post Training Quantization
----------------------------

Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_
for post training quantization.

1. Capture FX Graph
---------------------
^^^^^^^^^^^^^^^^^^^^^

We will start by performing the necessary imports, capturing the FX Graph from the eager module.

Expand Down Expand Up @@ -111,7 +117,7 @@ We will start by performing the necessary imports, capturing the FX Graph from t
Next, we will have the FX Module to be quantized.

2. Apply Quantization
----------------------------
^^^^^^^^^^^^^^^^^^^^^^^

After we capture the FX Module to be quantized, we will import the Backend Quantizer for X86 CPU and configure how to
quantize the model.
Expand Down Expand Up @@ -160,7 +166,7 @@ After these steps, we finished running the quantization flow and we will get the


3. Lower into Inductor
------------------------
^^^^^^^^^^^^^^^^^^^^^^^^

After we get the quantized model, we will further lower it to the inductor backend. The default Inductor wrapper
generates Python code to invoke both generated kernels and external kernels. Additionally, Inductor supports
Expand Down Expand Up @@ -222,8 +228,74 @@ With PyTorch 2.1 release, all CNN models from TorchBench test suite have been me
to `this document <https://dev-discuss.pytorch.org/t/torchinductor-update-6-cpu-backend-performance-update-and-new-features-in-pytorch-2-1/1514#int8-inference-with-post-training-static-quantization-3>`_
for detail benchmark number.

4. Conclusion
---------------
Quantization Aware Training
-----------------------------

The PyTorch 2 Export Quantization-Aware Training (QAT) is now supported on X86 CPU using X86InductorQuantizer,
followed by the subsequent lowering of the quantized model into Inductor.
For a more in-depth understanding of PT2 Export Quantization-Aware Training,
we recommend referring to the dedicated `PyTorch 2 Export Quantization-Aware Training <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_.

The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:

.. code:: python

import torch
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
prepare_qat_pt2e,
convert_pt2e,
)
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(1024, 1000)

def forward(self, x):
return self.linear(x)

example_inputs = (torch.randn(1, 1024),)
m = M()

# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result shoud mostly stay the same
exported_model = capture_pre_autograd_graph(m, example_inputs)
# we get a model with aten ops

# Step 2. quantization-aware training
# Use Backend Quantizer for X86 CPU
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config(is_qat=True))
prepared_model = prepare_qat_pt2e(exported_model, quantizer)

# train omitted

converted_model = convert_pt2e(prepared_model)
# we have a model with aten ops doing integer computations when possible

# move the quantized model to eval mode, equivalent to `m.eval()`
torch.ao.quantization.move_exported_model_to_eval(converted_model)

# Lower the model into Inductor
with torch.no_grad():
optimized_model = torch.compile(converted_model)
_ = optimized_model(*example_inputs)

Please note that the Inductor ``freeze`` feature is not enabled by default.
To use this feature, you need to run example code with ``TORCHINDUCTOR_FREEZING=1``.

For example:

::

TORCHINDUCTOR_FREEZING=1 python example_x86inductorquantizer_qat.py

Conclusion
------------

With this tutorial, we introduce how to use Inductor with X86 CPU in PyTorch 2 Quantization. Users can learn about
how to use ``X86InductorQuantizer`` to quantize a model and lower it into the inductor with X86 CPU devices.