Skip to content

Commit c77e969

Browse files
Add QAT Tutorial
1 parent be1e924 commit c77e969

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
PyTorch 2 Export Quantization-Aware Training (QAT) with X86 Backend through Inductor
2+
========================================================================================
3+
4+
**Author**: `Leslie Fang <https://github.com/leslie-fang-intel>`_, `Jiong Gong <https://github.com/jgong5>`_
5+
6+
Prerequisites
7+
^^^^^^^^^^^^^^^
8+
9+
- `PyTorch 2 Export Quantization-Aware Training tutorial <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_
10+
- `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor tutorial <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_x86_inductor.html>`_
11+
- `TorchInductor and torch.compile concepts in PyTorch <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
12+
13+
14+
This tutorial demonstrates the process of performing PT2 export quantization-aware training (QAT) on X86 CPU
15+
with X86InductorQuantizer, and subsequently lowering the quantized model into Inductor.
16+
For more comprehensive details about PyTorch 2 Export Quantization-Aware Training in general, please refer to the
17+
dedicated tutorial on `PyTorch 2 Export Quantization-Aware Training <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_.
18+
For a deeper understanding of X86InductorQuantizer, please consult the tutorial of
19+
`PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_x86_inductor.html>`_.
20+
21+
The PyTorch 2 Export QAT flow looks like the following—it is similar
22+
to the post training quantization (PTQ) flow for the most part:
23+
24+
.. code:: python
25+
26+
import torch
27+
from torch._export import capture_pre_autograd_graph
28+
from torch.ao.quantization.quantize_pt2e import (
29+
prepare_qat_pt2e,
30+
convert_pt2e,
31+
)
32+
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
33+
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
34+
35+
class M(torch.nn.Module):
36+
def __init__(self):
37+
super().__init__()
38+
self.linear = torch.nn.Linear(1024, 1000)
39+
40+
def forward(self, x):
41+
return self.linear(x)
42+
43+
44+
example_inputs = (torch.randn(1, 1024),)
45+
m = M()
46+
47+
# Step 1. program capture
48+
# NOTE: this API will be updated to torch.export API in the future, but the captured
49+
# result shoud mostly stay the same
50+
exported_model = capture_pre_autograd_graph(m, example_inputs)
51+
# we get a model with aten ops
52+
53+
# Step 2. quantization-aware training
54+
# Use Backend Quantizer for X86 CPU
55+
quantizer = X86InductorQuantizer()
56+
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config(is_qat=True))
57+
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
58+
59+
# train omitted
60+
61+
converted_model = convert_pt2e(prepared_model)
62+
# we have a model with aten ops doing integer computations when possible
63+
64+
# move the quantized model to eval mode, equivalent to `m.eval()`
65+
torch.ao.quantization.move_exported_model_to_eval(converted_model)
66+
67+
# Lower the model into Inductor
68+
with torch.no_grad():
69+
optimized_model = torch.compile(converted_model)
70+
_ = optimized_model(*example_inputs)
71+
72+
Please note that since the Inductor ``freeze`` feature does not turn on by default yet, need to run example code with ``TORCHINDUCTOR_FREEZING=1``.
73+
74+
For example:
75+
76+
::
77+
78+
TORCHINDUCTOR_FREEZING=1 python example_x86inductorquantizer_qat.py

0 commit comments

Comments
 (0)