Skip to content

Commit b4cb387

Browse files
authored
Merge branch 'main' into add_device_mesh_recipe
2 parents b85d479 + eec8d56 commit b4cb387

File tree

5 files changed

+297
-12
lines changed

5 files changed

+297
-12
lines changed

prototype_source/prototype_index.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ Prototype features are not available as part of binary distributions like PyPI o
8989
:link: ../prototype/pt2e_quant_qat.html
9090
:tags: Quantization
9191

92+
.. customcarditem::
93+
:header: PyTorch 2 Export Quantization with X86 Backend through Inductor
94+
:card_description: Learn how to use PT2 Export Quantization with X86 Backend through Inductor.
95+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
96+
:link: ../prototype/pt2e_quant_x86_inductor.html
97+
:tags: Quantization
9298

9399
.. Sparsity
94100

prototype_source/pt2e_quant_ptq_x86_inductor.rst renamed to prototype_source/pt2e_quant_x86_inductor.rst

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
1-
PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor
2-
========================================================================================
1+
PyTorch 2 Export Quantization with X86 Backend through Inductor
2+
==================================================================
33

44
**Author**: `Leslie Fang <https://github.com/leslie-fang-intel>`_, `Weiwen Xia <https://github.com/Xia-Weiwen>`_, `Jiong Gong <https://github.com/jgong5>`_, `Jerry Zhang <https://github.com/jerryzh168>`_
55

66
Prerequisites
7-
^^^^^^^^^^^^^^^
7+
---------------
88

99
- `PyTorch 2 Export Post Training Quantization <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_
10+
- `PyTorch 2 Export Quantization-Aware Training <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_
1011
- `TorchInductor and torch.compile concepts in PyTorch <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
1112
- `Inductor C++ Wrapper concepts <https://pytorch.org/tutorials/prototype/inductor_cpp_wrapper_tutorial.html>`_
1213

1314
Introduction
14-
^^^^^^^^^^^^^^
15+
--------------
1516

1617
This tutorial introduces the steps for utilizing the PyTorch 2 Export Quantization flow to generate a quantized model customized
1718
for the x86 inductor backend and explains how to lower the quantized model into the inductor.
1819

19-
The new quantization 2 flow uses the PT2 Export to capture the model into a graph and perform quantization transformations on top of the ATen graph. This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX.
20+
The pytorch 2 export quantization flow uses the torch.export to capture the model into a graph and perform quantization transformations on top of the ATen graph.
21+
This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX.
2022
TorchInductor is the new compiler backend that compiles the FX Graphs generated by TorchDynamo into optimized C++/Triton kernels.
2123

2224
This flow of quantization 2 with Inductor mainly includes three steps:
2325

2426
- Step 1: Capture the FX Graph from the eager Model based on the `torch export mechanism <https://pytorch.org/docs/main/export.html>`_.
2527
- Step 2: Apply the Quantization flow based on the captured FX Graph, including defining the backend-specific quantizer, generating the prepared model with observers,
26-
performing the prepared model's calibration, and converting the prepared model into the quantized model.
28+
performing the prepared model's calibration or quantization-aware training, and converting the prepared model into the quantized model.
2729
- Step 3: Lower the quantized model into inductor with the API ``torch.compile``.
2830

2931
The high-level architecture of this flow could look like this:
@@ -61,10 +63,14 @@ and outstanding out-of-box performance with the compiler backend. Especially on
6163
further boost the models' performance by leveraging the
6264
`advanced-matrix-extensions <https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/overview.html>`_ feature.
6365

64-
Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_.
66+
Post Training Quantization
67+
----------------------------
68+
69+
Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_
70+
for post training quantization.
6571

6672
1. Capture FX Graph
67-
---------------------
73+
^^^^^^^^^^^^^^^^^^^^^
6874

6975
We will start by performing the necessary imports, capturing the FX Graph from the eager module.
7076

@@ -111,7 +117,7 @@ We will start by performing the necessary imports, capturing the FX Graph from t
111117
Next, we will have the FX Module to be quantized.
112118

113119
2. Apply Quantization
114-
----------------------------
120+
^^^^^^^^^^^^^^^^^^^^^^^
115121

116122
After we capture the FX Module to be quantized, we will import the Backend Quantizer for X86 CPU and configure how to
117123
quantize the model.
@@ -160,7 +166,7 @@ After these steps, we finished running the quantization flow and we will get the
160166

161167

162168
3. Lower into Inductor
163-
------------------------
169+
^^^^^^^^^^^^^^^^^^^^^^^^
164170

165171
After we get the quantized model, we will further lower it to the inductor backend. The default Inductor wrapper
166172
generates Python code to invoke both generated kernels and external kernels. Additionally, Inductor supports
@@ -222,8 +228,74 @@ With PyTorch 2.1 release, all CNN models from TorchBench test suite have been me
222228
to `this document <https://dev-discuss.pytorch.org/t/torchinductor-update-6-cpu-backend-performance-update-and-new-features-in-pytorch-2-1/1514#int8-inference-with-post-training-static-quantization-3>`_
223229
for detail benchmark number.
224230

225-
4. Conclusion
226-
---------------
231+
Quantization Aware Training
232+
-----------------------------
233+
234+
The PyTorch 2 Export Quantization-Aware Training (QAT) is now supported on X86 CPU using X86InductorQuantizer,
235+
followed by the subsequent lowering of the quantized model into Inductor.
236+
For a more in-depth understanding of PT2 Export Quantization-Aware Training,
237+
we recommend referring to the dedicated `PyTorch 2 Export Quantization-Aware Training <https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html>`_.
238+
239+
The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:
240+
241+
.. code:: python
242+
243+
import torch
244+
from torch._export import capture_pre_autograd_graph
245+
from torch.ao.quantization.quantize_pt2e import (
246+
prepare_qat_pt2e,
247+
convert_pt2e,
248+
)
249+
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
250+
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
251+
252+
class M(torch.nn.Module):
253+
def __init__(self):
254+
super().__init__()
255+
self.linear = torch.nn.Linear(1024, 1000)
256+
257+
def forward(self, x):
258+
return self.linear(x)
259+
260+
example_inputs = (torch.randn(1, 1024),)
261+
m = M()
262+
263+
# Step 1. program capture
264+
# NOTE: this API will be updated to torch.export API in the future, but the captured
265+
# result shoud mostly stay the same
266+
exported_model = capture_pre_autograd_graph(m, example_inputs)
267+
# we get a model with aten ops
268+
269+
# Step 2. quantization-aware training
270+
# Use Backend Quantizer for X86 CPU
271+
quantizer = X86InductorQuantizer()
272+
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config(is_qat=True))
273+
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
274+
275+
# train omitted
276+
277+
converted_model = convert_pt2e(prepared_model)
278+
# we have a model with aten ops doing integer computations when possible
279+
280+
# move the quantized model to eval mode, equivalent to `m.eval()`
281+
torch.ao.quantization.move_exported_model_to_eval(converted_model)
282+
283+
# Lower the model into Inductor
284+
with torch.no_grad():
285+
optimized_model = torch.compile(converted_model)
286+
_ = optimized_model(*example_inputs)
287+
288+
Please note that the Inductor ``freeze`` feature is not enabled by default.
289+
To use this feature, you need to run example code with ``TORCHINDUCTOR_FREEZING=1``.
290+
291+
For example:
292+
293+
::
294+
295+
TORCHINDUCTOR_FREEZING=1 python example_x86inductorquantizer_qat.py
296+
297+
Conclusion
298+
------------
227299

228300
With this tutorial, we introduce how to use Inductor with X86 CPU in PyTorch 2 Quantization. Users can learn about
229301
how to use ``X86InductorQuantizer`` to quantize a model and lower it into the inductor with X86 CPU devices.
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
(beta) Compiling the optimizer with torch.compile
2+
==========================================================================================
3+
4+
**Author:** `Michael Lazos <https://github.com/mlazos>`_
5+
6+
The optimizer is a key algorithm for training any deep learning model.
7+
Since it is responsible for updating every model parameter, it can often
8+
become the bottleneck in training performance for large models. In this recipe,
9+
we will apply ``torch.compile`` to the optimizer to observe the GPU performance
10+
improvement.
11+
12+
.. note::
13+
14+
This tutorial requires PyTorch 2.2.0 or later.
15+
16+
Model Setup
17+
~~~~~~~~~~~~~~~~~~~~~
18+
For this example, we'll use a simple sequence of linear layers.
19+
Since we are only benchmarking the optimizer, the choice of model doesn't matter
20+
because optimizer performance is a function of the number of parameters.
21+
22+
Depending on what machine you are using, your exact results may vary.
23+
24+
.. code-block:: python
25+
26+
import torch
27+
28+
model = torch.nn.Sequential(
29+
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
30+
)
31+
input = torch.rand(1024, device="cuda")
32+
output = model(input)
33+
output.sum().backward()
34+
35+
Setting up and running the optimizer benchmark
36+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37+
In this example, we'll use the Adam optimizer
38+
and create a helper function to wrap the step()
39+
in ``torch.compile()``.
40+
41+
.. note::
42+
43+
``torch.compile`` is only supported on cuda devices with compute capability >= 7.0
44+
45+
.. code-block:: python
46+
47+
# exit cleanly if we are on a device that doesn't support torch.compile
48+
if torch.cuda.get_device_capability() < (7, 0):
49+
print("Exiting because torch.compile is not supported on this device.")
50+
import sys
51+
sys.exit(0)
52+
53+
54+
opt = torch.optim.Adam(model.parameters(), lr=0.01)
55+
56+
57+
@torch.compile(fullgraph=False)
58+
def fn():
59+
opt.step()
60+
61+
62+
# Let's define a helpful benchmarking function:
63+
import torch.utils.benchmark as benchmark
64+
65+
66+
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
67+
t0 = benchmark.Timer(
68+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
69+
)
70+
return t0.blocked_autorange().mean * 1e6
71+
72+
73+
# Warmup runs to compile the function
74+
for _ in range(5):
75+
fn()
76+
77+
eager_runtime = benchmark_torch_function_in_microseconds(opt.step)
78+
compiled_runtime = benchmark_torch_function_in_microseconds(fn)
79+
80+
assert eager_runtime > compiled_runtime
81+
82+
print(f"eager runtime: {eager_runtime}us")
83+
print(f"compiled runtime: {compiled_runtime}us")
84+
85+
Sample Results:
86+
87+
* Eager runtime: 747.2437149845064us
88+
* Compiled runtime: 392.07384741178us

recipes_source/recipes_index.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,14 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
144144
:link: ../recipes/recipes/module_load_state_dict_tips.html
145145
:tags: Basics
146146

147+
.. customcarditem::
148+
:header: (beta) Using TORCH_LOGS to observe torch.compile
149+
:card_description: Learn how to use the torch logging APIs to observe the compilation process.
150+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
151+
:link: ../recipes/torch_logs.html
152+
:tags: Basics
153+
154+
147155
.. Interpretability
148156
149157
.. customcarditem::
@@ -276,6 +284,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
276284
:link: ../recipes/amx.html
277285
:tags: Model-Optimization
278286

287+
.. (beta) Compiling the Optimizer with torch.compile
288+
289+
.. customcarditem::
290+
:header: (beta) Compiling the Optimizer with torch.compile
291+
:card_description: Speed up the optimizer using torch.compile
292+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
293+
:link: ../recipes/compiling_optimizer.html
294+
:tags: Model-Optimization
295+
279296
.. Intel(R) Extension for PyTorch*
280297
281298
.. customcarditem::
@@ -360,6 +377,7 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
360377

361378
/recipes/recipes/loading_data_recipe
362379
/recipes/recipes/defining_a_neural_network
380+
/recipes/torch_logs
363381
/recipes/recipes/what_is_state_dict
364382
/recipes/recipes/saving_and_loading_models_for_inference
365383
/recipes/recipes/saving_and_loading_a_general_checkpoint
@@ -375,6 +393,7 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
375393
/recipes/recipes/amp_recipe
376394
/recipes/recipes/tuning_guide
377395
/recipes/recipes/intel_extension_for_pytorch
396+
/recipes/compiling_optimizer
378397
/recipes/torch_compile_backend_ipex
379398
/recipes/torchscript_inference
380399
/recipes/deployment_with_flask

recipes_source/torch_logs.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
(beta) Using TORCH_LOGS python API with torch.compile
3+
==========================================================================================
4+
**Author:** `Michael Lazos <https://github.com/mlazos>`_
5+
"""
6+
7+
import logging
8+
9+
######################################################################
10+
#
11+
# This tutorial introduces the ``TORCH_LOGS`` environment variable, as well ass the Python API, and
12+
# demonstrates how to apply it to observe the phases of ``torch.compile``.
13+
#
14+
# .. note::
15+
#
16+
# This tutorial requires PyTorch 2.2.0 or later.
17+
#
18+
#
19+
20+
21+
######################################################################
22+
# Setup
23+
# ~~~~~~~~~~~~~~~~~~~~~
24+
# In this example, we'll set up a simple Python function which performs an elementwise
25+
# add and observe the compilation process with ``TORCH_LOGS`` Python API.
26+
#
27+
# .. note::
28+
#
29+
# There is also an environment variable ``TORCH_LOGS``, which can be used to
30+
# change logging settings at the command line. The equivalent environment
31+
# variable setting is shown for each example.
32+
33+
import torch
34+
35+
# exit cleanly if we are on a device that doesn't support torch.compile
36+
if torch.cuda.get_device_capability() < (7, 0):
37+
print("Exiting because torch.compile is not supported on this device.")
38+
import sys
39+
40+
sys.exit(0)
41+
42+
43+
@torch.compile()
44+
def fn(x, y):
45+
z = x + y
46+
return z + 2
47+
48+
49+
inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))
50+
51+
52+
# print separator and reset dynamo
53+
# between each example
54+
def separator(name):
55+
print(f"==================={name}=========================")
56+
torch._dynamo.reset()
57+
58+
59+
separator("Dynamo Tracing")
60+
# View dynamo tracing
61+
# TORCH_LOGS="+dynamo"
62+
torch._logging.set_logs(dynamo=logging.DEBUG)
63+
fn(*inputs)
64+
65+
separator("Traced Graph")
66+
# View traced graph
67+
# TORCH_LOGS="graph"
68+
torch._logging.set_logs(graph=True)
69+
fn(*inputs)
70+
71+
separator("Fusion Decisions")
72+
# View fusion decisions
73+
# TORCH_LOGS="fusion"
74+
torch._logging.set_logs(fusion=True)
75+
fn(*inputs)
76+
77+
separator("Output Code")
78+
# View output code generated by inductor
79+
# TORCH_LOGS="output_code"
80+
torch._logging.set_logs(output_code=True)
81+
fn(*inputs)
82+
83+
separator("")
84+
85+
######################################################################
86+
# Conclusion
87+
# ~~~~~~~~~~
88+
#
89+
# In this tutorial we introduced the TORCH_LOGS environment variable and python API
90+
# by experimenting with a small number of the available logging options.
91+
# To view descriptions of all available options, run any python script
92+
# which imports torch and set TORCH_LOGS to "help".
93+
#
94+
# Alternatively, you can view the `torch._logging documentation`_ to see
95+
# descriptions of all available logging options.
96+
#
97+
# For more information on torch.compile, see the `torch.compile tutorial`_.
98+
#
99+
# .. _torch._logging documentation: https://pytorch.org/docs/main/logging.html
100+
# .. _torch.compile tutorial: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html

0 commit comments

Comments
 (0)