|
| 1 | +PyTorch 2 Export Quantization-Aware Training (QAT) with X86 Backend through Inductor |
| 2 | +======================================================================================== |
| 3 | + |
| 4 | +**Author**: `Leslie Fang <https://github.com/leslie-fang-intel>`_, `Jiong Gong <https://github.com/jgong5>`_ |
| 5 | + |
| 6 | +Prerequisites |
| 7 | +^^^^^^^^^^^^^^^ |
| 8 | + |
| 9 | +- `PyTorch 2 Export Quantization-Aware Training tutorial <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_ |
| 10 | +- `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor tutorial <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_x86_inductor.html>`_ |
| 11 | +- `TorchInductor and torch.compile concepts in PyTorch <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_ |
| 12 | + |
| 13 | + |
| 14 | +This tutorial demonstrates the process of performing PT2 export quantization-aware training (QAT) on X86 CPU |
| 15 | +with X86InductorQuantizer, and subsequently lowering the quantized model into Inductor. |
| 16 | +For more comprehensive details about PyTorch 2 Export Quantization-Aware Training in general, please refer to the |
| 17 | +dedicated tutorial on `PyTorch 2 Export Quantization-Aware Training <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_. |
| 18 | +For a deeper understanding of X86InductorQuantizer, please consult the tutorial of |
| 19 | +`PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_x86_inductor.html>`_. |
| 20 | + |
| 21 | +The PyTorch 2 Export QAT flow looks like the following—it is similar |
| 22 | +to the post training quantization (PTQ) flow for the most part: |
| 23 | + |
| 24 | +.. code:: python |
| 25 | +
|
| 26 | + import torch |
| 27 | + from torch._export import capture_pre_autograd_graph |
| 28 | + from torch.ao.quantization.quantize_pt2e import ( |
| 29 | + prepare_qat_pt2e, |
| 30 | + convert_pt2e, |
| 31 | + ) |
| 32 | + import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq |
| 33 | + from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer |
| 34 | +
|
| 35 | + class M(torch.nn.Module): |
| 36 | + def __init__(self): |
| 37 | + super().__init__() |
| 38 | + self.linear = torch.nn.Linear(1024, 1000) |
| 39 | +
|
| 40 | + def forward(self, x): |
| 41 | + return self.linear(x) |
| 42 | +
|
| 43 | +
|
| 44 | + example_inputs = (torch.randn(1, 1024),) |
| 45 | + m = M() |
| 46 | +
|
| 47 | + # Step 1. program capture |
| 48 | + # NOTE: this API will be updated to torch.export API in the future, but the captured |
| 49 | + # result shoud mostly stay the same |
| 50 | + exported_model = capture_pre_autograd_graph(m, example_inputs) |
| 51 | + # we get a model with aten ops |
| 52 | +
|
| 53 | + # Step 2. quantization-aware training |
| 54 | + # Use Backend Quantizer for X86 CPU |
| 55 | + quantizer = X86InductorQuantizer() |
| 56 | + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config(is_qat=True)) |
| 57 | + prepared_model = prepare_qat_pt2e(exported_model, quantizer) |
| 58 | +
|
| 59 | + # train omitted |
| 60 | +
|
| 61 | + converted_model = convert_pt2e(prepared_model) |
| 62 | + # we have a model with aten ops doing integer computations when possible |
| 63 | +
|
| 64 | + # move the quantized model to eval mode, equivalent to `m.eval()` |
| 65 | + torch.ao.quantization.move_exported_model_to_eval(converted_model) |
| 66 | +
|
| 67 | + # Lower the model into Inductor |
| 68 | + with torch.no_grad(): |
| 69 | + optimized_model = torch.compile(converted_model) |
| 70 | + _ = optimized_model(*example_inputs) |
| 71 | +
|
| 72 | +Please note that since the Inductor ``freeze`` feature does not turn on by default yet, need to run example code with ``TORCHINDUCTOR_FREEZING=1``. |
| 73 | + |
| 74 | +For example: |
| 75 | + |
| 76 | +:: |
| 77 | + |
| 78 | + TORCHINDUCTOR_FREEZING=1 python example_x86inductorquantizer_qat.py |
0 commit comments