Skip to content

Commit 16aa287

Browse files
Add QAT as a standalone section
1 parent 2ca1927 commit 16aa287

File tree

1 file changed

+79
-9
lines changed

1 file changed

+79
-9
lines changed

prototype_source/pt2e_quant_x86_inductor.rst

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
PyTorch 2 Export Quantization with X86 Backend through Inductor
2-
========================================================================================
2+
==================================================================
33

44
**Author**: `Leslie Fang <https://github.com/leslie-fang-intel>`_, `Weiwen Xia <https://github.com/Xia-Weiwen>`_, `Jiong Gong <https://github.com/jgong5>`_, `Jerry Zhang <https://github.com/jerryzh168>`_
55

66
Prerequisites
7-
^^^^^^^^^^^^^^^
7+
---------------
88

99
- `PyTorch 2 Export Post Training Quantization <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_
1010
- `PyTorch 2 Export Quantization-Aware Training <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_
1111
- `TorchInductor and torch.compile concepts in PyTorch <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
1212
- `Inductor C++ Wrapper concepts <https://pytorch.org/tutorials/prototype/inductor_cpp_wrapper_tutorial.html>`_
1313

1414
Introduction
15-
^^^^^^^^^^^^^^
15+
--------------
1616

1717
This tutorial introduces the steps for utilizing the PyTorch 2 Export Quantization flow to generate a quantized model customized
1818
for the x86 inductor backend and explains how to lower the quantized model into the inductor.
@@ -63,10 +63,14 @@ and outstanding out-of-box performance with the compiler backend. Especially on
6363
further boost the models' performance by leveraging the
6464
`advanced-matrix-extensions <https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/overview.html>`_ feature.
6565

66-
Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_.
66+
Post Training Quantization
67+
----------------------------
68+
69+
Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_
70+
for post training quantization.
6771

6872
1. Capture FX Graph
69-
---------------------
73+
^^^^^^^^^^^^^^^^^^^^^
7074

7175
We will start by performing the necessary imports, capturing the FX Graph from the eager module.
7276

@@ -113,7 +117,7 @@ We will start by performing the necessary imports, capturing the FX Graph from t
113117
Next, we will have the FX Module to be quantized.
114118

115119
2. Apply Quantization
116-
----------------------------
120+
^^^^^^^^^^^^^^^^^^^^^^^
117121

118122
After we capture the FX Module to be quantized, we will import the Backend Quantizer for X86 CPU and configure how to
119123
quantize the model.
@@ -162,7 +166,7 @@ After these steps, we finished running the quantization flow and we will get the
162166

163167

164168
3. Lower into Inductor
165-
------------------------
169+
^^^^^^^^^^^^^^^^^^^^^^^^
166170

167171
After we get the quantized model, we will further lower it to the inductor backend. The default Inductor wrapper
168172
generates Python code to invoke both generated kernels and external kernels. Additionally, Inductor supports
@@ -224,8 +228,74 @@ With PyTorch 2.1 release, all CNN models from TorchBench test suite have been me
224228
to `this document <https://dev-discuss.pytorch.org/t/torchinductor-update-6-cpu-backend-performance-update-and-new-features-in-pytorch-2-1/1514#int8-inference-with-post-training-static-quantization-3>`_
225229
for detail benchmark number.
226230

227-
4. Conclusion
228-
---------------
231+
Quantization Aware Training
232+
-----------------------------
233+
234+
The PyTorch 2 Export Quantization-Aware Training (QAT) is now supported on X86 CPU using X86InductorQuantizer,
235+
followed by the subsequent lowering of the quantized model into Inductor.
236+
For a more in-depth understanding of PT2 Export Quantization-Aware Training,
237+
we recommend referring to the dedicated `PyTorch 2 Export Quantization-Aware Training <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_.
238+
239+
The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:
240+
241+
.. code:: python
242+
243+
import torch
244+
from torch._export import capture_pre_autograd_graph
245+
from torch.ao.quantization.quantize_pt2e import (
246+
prepare_qat_pt2e,
247+
convert_pt2e,
248+
)
249+
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
250+
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
251+
252+
class M(torch.nn.Module):
253+
def __init__(self):
254+
super().__init__()
255+
self.linear = torch.nn.Linear(1024, 1000)
256+
257+
def forward(self, x):
258+
return self.linear(x)
259+
260+
example_inputs = (torch.randn(1, 1024),)
261+
m = M()
262+
263+
# Step 1. program capture
264+
# NOTE: this API will be updated to torch.export API in the future, but the captured
265+
# result shoud mostly stay the same
266+
exported_model = capture_pre_autograd_graph(m, example_inputs)
267+
# we get a model with aten ops
268+
269+
# Step 2. quantization-aware training
270+
# Use Backend Quantizer for X86 CPU
271+
quantizer = X86InductorQuantizer()
272+
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config(is_qat=True))
273+
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
274+
275+
# train omitted
276+
277+
converted_model = convert_pt2e(prepared_model)
278+
# we have a model with aten ops doing integer computations when possible
279+
280+
# move the quantized model to eval mode, equivalent to `m.eval()`
281+
torch.ao.quantization.move_exported_model_to_eval(converted_model)
282+
283+
# Lower the model into Inductor
284+
with torch.no_grad():
285+
optimized_model = torch.compile(converted_model)
286+
_ = optimized_model(*example_inputs)
287+
288+
Please note that the Inductor ``freeze`` feature is not enabled by default.
289+
To use this feature, you need to run example code with ``TORCHINDUCTOR_FREEZING=1``.
290+
291+
For example:
292+
293+
::
294+
295+
TORCHINDUCTOR_FREEZING=1 python example_x86inductorquantizer_qat.py
296+
297+
Conclusion
298+
------------
229299

230300
With this tutorial, we introduce how to use Inductor with X86 CPU in PyTorch 2 Quantization. Users can learn about
231301
how to use ``X86InductorQuantizer`` to quantize a model and lower it into the inductor with X86 CPU devices.

0 commit comments

Comments
 (0)