Skip to content

Commit f096067

Browse files
Add QAT as a standalone section
1 parent 8b46ad5 commit f096067

File tree

1 file changed

+79
-9
lines changed

1 file changed

+79
-9
lines changed

prototype_source/pt2e_quant_x86_inductor.rst

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
PyTorch 2 Export Quantization with X86 Backend through Inductor
2-
========================================================================================
2+
==================================================================
33

44
**Author**: `Leslie Fang <https://github.com/leslie-fang-intel>`_, `Weiwen Xia <https://github.com/Xia-Weiwen>`_, `Jiong Gong <https://github.com/jgong5>`_, `Jerry Zhang <https://github.com/jerryzh168>`_
55

66
Prerequisites
7-
^^^^^^^^^^^^^^^
7+
---------------
88

99
- `PyTorch 2 Export Post Training Quantization <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_
1010
- `PyTorch 2 Export Quantization-Aware Training <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_
1111
- `TorchInductor and torch.compile concepts in PyTorch <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
1212

1313
Introduction
14-
^^^^^^^^^^^^^^
14+
--------------
1515

1616
This tutorial introduces the steps for utilizing the PyTorch 2 Export Quantization flow to generate a quantized model customized
1717
for the x86 inductor backend and explains how to lower the quantized model into the inductor.
@@ -62,10 +62,14 @@ and outstanding out-of-box performance with the compiler backend. Especially on
6262
further boost the models' performance by leveraging the
6363
`advanced-matrix-extensions <https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/overview.html>`_ feature.
6464

65-
Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_.
65+
Post Training Quantization
66+
----------------------------
67+
68+
Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_
69+
for post training quantization.
6670

6771
1. Capture FX Graph
68-
---------------------
72+
^^^^^^^^^^^^^^^^^^^^^
6973

7074
We will start by performing the necessary imports, capturing the FX Graph from the eager module.
7175

@@ -112,7 +116,7 @@ We will start by performing the necessary imports, capturing the FX Graph from t
112116
Next, we will have the FX Module to be quantized.
113117

114118
2. Apply Quantization
115-
----------------------------
119+
^^^^^^^^^^^^^^^^^^^^^^^
116120

117121
After we capture the FX Module to be quantized, we will import the Backend Quantizer for X86 CPU and configure how to
118122
quantize the model.
@@ -161,7 +165,7 @@ After these steps, we finished running the quantization flow and we will get the
161165

162166

163167
3. Lower into Inductor
164-
------------------------
168+
^^^^^^^^^^^^^^^^^^^^^^^^
165169

166170
After we get the quantized model, we will further lower it to the inductor backend.
167171

@@ -212,8 +216,74 @@ With PyTorch 2.1 release, all CNN models from TorchBench test suite have been me
212216
to `this document <https://dev-discuss.pytorch.org/t/torchinductor-update-6-cpu-backend-performance-update-and-new-features-in-pytorch-2-1/1514#int8-inference-with-post-training-static-quantization-3>`_
213217
for detail benchmark number.
214218

215-
4. Conclusion
216-
---------------
219+
Quantization Aware Training
220+
-----------------------------
221+
222+
The PyTorch 2 Export Quantization-Aware Training (QAT) is now supported on X86 CPU using X86InductorQuantizer,
223+
followed by the subsequent lowering of the quantized model into Inductor.
224+
For a more in-depth understanding of PT2 Export Quantization-Aware Training,
225+
we recommend referring to the dedicated `PyTorch 2 Export Quantization-Aware Training <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_.
226+
227+
The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:
228+
229+
.. code:: python
230+
231+
import torch
232+
from torch._export import capture_pre_autograd_graph
233+
from torch.ao.quantization.quantize_pt2e import (
234+
prepare_qat_pt2e,
235+
convert_pt2e,
236+
)
237+
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
238+
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
239+
240+
class M(torch.nn.Module):
241+
def __init__(self):
242+
super().__init__()
243+
self.linear = torch.nn.Linear(1024, 1000)
244+
245+
def forward(self, x):
246+
return self.linear(x)
247+
248+
example_inputs = (torch.randn(1, 1024),)
249+
m = M()
250+
251+
# Step 1. program capture
252+
# NOTE: this API will be updated to torch.export API in the future, but the captured
253+
# result shoud mostly stay the same
254+
exported_model = capture_pre_autograd_graph(m, example_inputs)
255+
# we get a model with aten ops
256+
257+
# Step 2. quantization-aware training
258+
# Use Backend Quantizer for X86 CPU
259+
quantizer = X86InductorQuantizer()
260+
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config(is_qat=True))
261+
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
262+
263+
# train omitted
264+
265+
converted_model = convert_pt2e(prepared_model)
266+
# we have a model with aten ops doing integer computations when possible
267+
268+
# move the quantized model to eval mode, equivalent to `m.eval()`
269+
torch.ao.quantization.move_exported_model_to_eval(converted_model)
270+
271+
# Lower the model into Inductor
272+
with torch.no_grad():
273+
optimized_model = torch.compile(converted_model)
274+
_ = optimized_model(*example_inputs)
275+
276+
Please note that the Inductor ``freeze`` feature is not enabled by default.
277+
To use this feature, you need to run example code with ``TORCHINDUCTOR_FREEZING=1``.
278+
279+
For example:
280+
281+
::
282+
283+
TORCHINDUCTOR_FREEZING=1 python example_x86inductorquantizer_qat.py
284+
285+
Conclusion
286+
------------
217287

218288
With this tutorial, we introduce how to use Inductor with X86 CPU in PyTorch 2 Quantization. Users can learn about
219289
how to use ``X86InductorQuantizer`` to quantize a model and lower it into the inductor with X86 CPU devices.

0 commit comments

Comments
 (0)