Skip to content

Commit fac78f0

Browse files
authored
Merge branch 'main' into main
2 parents 9ceb5c4 + f2b930e commit fac78f0

9 files changed

+80
-124
lines changed
Loading
Loading
Loading
Loading

intermediate_source/compiled_autograd_tutorial.rst

Lines changed: 10 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -97,62 +97,10 @@ Run the script with the ``TORCH_LOGS`` environment variables:
9797
Rerun the snippet above, the compiled autograd graph should now be logged to ``stderr``. Certain graph nodes will have names that are prefixed by ``aot0_``,
9898
these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0, for example, ``aot0_view_2`` corresponds to ``view_2`` of the AOT backward graph with id=0.
9999

100+
In the image below, the red box encapsulates the AOT backward graph that is captured by ``torch.compile`` without Compiled Autograd.
100101

101-
.. code:: python
102102

103-
stderr_output = """
104-
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
105-
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
106-
===== Compiled autograd graph =====
107-
<eval_with_key>.4 class CompiledAutograd(torch.nn.Module):
108-
def forward(self, inputs, sizes, scalars, hooks):
109-
# No stacktrace found for following nodes
110-
aot0_tangents_1: "f32[][]cpu" = inputs[0]
111-
aot0_primals_3: "f32[10][1]cpu" = inputs[1]
112-
getitem_2: "f32[10][1]cpu" = inputs[2]
113-
getitem_3: "f32[10, 10][10, 1]cpu" = inputs[3]; inputs = None
114-
115-
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1)
116-
aot0_expand: "f32[10][0]cpu" = torch.ops.aten.expand.default(aot0_tangents_1, [10]); aot0_tangents_1 = None
117-
aot0_view_2: "f32[1, 10][0, 0]cpu" = torch.ops.aten.view.default(aot0_expand, [1, 10]); aot0_expand = None
118-
aot0_permute_2: "f32[10, 1][0, 0]cpu" = torch.ops.aten.permute.default(aot0_view_2, [1, 0])
119-
aot0_select: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 0)
120-
aot0_view: "f32[1, 10][10, 1]cpu" = torch.ops.aten.view.default(aot0_primals_3, [1, 10]); aot0_primals_3 = None
121-
aot0_mul_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select, aot0_view); aot0_select = None
122-
aot0_select_1: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 1)
123-
aot0_mul_4: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_1, aot0_view); aot0_select_1 = None
124-
aot0_select_2: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 2)
125-
aot0_mul_5: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_2, aot0_view); aot0_select_2 = None
126-
aot0_select_3: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 3)
127-
aot0_mul_6: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_3, aot0_view); aot0_select_3 = None
128-
aot0_select_4: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 4)
129-
aot0_mul_7: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_4, aot0_view); aot0_select_4 = None
130-
aot0_select_5: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 5)
131-
aot0_mul_8: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_5, aot0_view); aot0_select_5 = None
132-
aot0_select_6: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 6)
133-
aot0_mul_9: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_6, aot0_view); aot0_select_6 = None
134-
aot0_select_7: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 7)
135-
aot0_mul_10: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_7, aot0_view); aot0_select_7 = None
136-
aot0_select_8: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 8)
137-
aot0_mul_11: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_8, aot0_view); aot0_select_8 = None
138-
aot0_select_9: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 9); aot0_permute_2 = None
139-
aot0_mul_12: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_9, aot0_view); aot0_select_9 = aot0_view = None
140-
aot0_cat: "f32[10, 10][10, 1]cpu" = torch.ops.aten.cat.default([aot0_mul_3, aot0_mul_4, aot0_mul_5, aot0_mul_6, aot0_mul_7, aot0_mul_8, aot0_mul_9, aot0_mul_10, aot0_mul_11, aot0_mul_12]); aot0_mul_3 = aot0_mul_4 = aot0_mul_5 = aot0_mul_6 = aot0_mul_7 = aot0_mul_8 = aot0_mul_9 = aot0_mul_10 = aot0_mul_11 = aot0_mul_12 = None
141-
aot0_permute_3: "f32[10, 10][1, 10]cpu" = torch.ops.aten.permute.default(aot0_cat, [1, 0]); aot0_cat = None
142-
aot0_sum_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.sum.dim_IntList(aot0_view_2, [0], True); aot0_view_2 = None
143-
aot0_view_3: "f32[10][1]cpu" = torch.ops.aten.view.default(aot0_sum_3, [10]); aot0_sum_3 = None
144-
145-
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 2)
146-
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_2, aot0_view_3); getitem_2 = aot0_view_3 = accumulate_grad_ = None
147-
148-
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1)
149-
aot0_permute_4: "f32[10, 10][10, 1]cpu" = torch.ops.aten.permute.default(aot0_permute_3, [1, 0]); aot0_permute_3 = None
150-
151-
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 3)
152-
accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, aot0_permute_4); getitem_3 = aot0_permute_4 = accumulate_grad__1 = None
153-
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
154-
return []
155-
"""
103+
.. image:: ../_static/img/compiled_autograd/entire_verbose_log.png
156104

157105
.. note:: This is the graph on which we will call ``torch.compile``, **NOT** the optimized graph. Compiled Autograd essentially generates some unoptimized Python code to represent the entire C++ autograd execution.
158106

@@ -181,7 +129,7 @@ Or you can use the context manager, which will apply to all autograd calls withi
181129
182130
Compiled Autograd addresses certain limitations of AOTAutograd
183131
--------------------------------------------------------------
184-
1. Graph breaks in the forward pass lead to graph breaks in the backward pass:
132+
1. Graph breaks in the forward pass no longer necessarily lead to graph breaks in the backward pass:
185133

186134
.. code:: python
187135
@@ -216,7 +164,10 @@ Compiled Autograd addresses certain limitations of AOTAutograd
216164
In the first ``torch.compile`` case, we see that 3 backward graphs were produced due to the 2 graph breaks in the compiled function ``fn``.
217165
Whereas in the second ``torch.compile`` with compiled autograd case, we see that a full backward graph was traced despite the graph breaks.
218166

219-
2. Backward hooks are not captured
167+
.. note:: It is still possible for the Dynamo to graph break when tracing backward hooks captured by Compiled Autograd.
168+
169+
170+
2. Backward hooks can now be captured
220171

221172
.. code:: python
222173
@@ -233,19 +184,7 @@ Whereas in the second ``torch.compile`` with compiled autograd case, we see that
233184
234185
There should be a ``call_hook`` node in the graph, which dynamo will later inline into the following:
235186

236-
.. code:: python
237-
238-
stderr_output = """
239-
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
240-
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
241-
===== Compiled autograd graph =====
242-
<eval_with_key>.2 class CompiledAutograd(torch.nn.Module):
243-
def forward(self, inputs, sizes, scalars, hooks):
244-
...
245-
getitem_2 = hooks[0]; hooks = None
246-
call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None
247-
...
248-
"""
187+
.. image:: ../_static/img/compiled_autograd/call_hook_node.png
249188

250189
Common recompilation reasons for Compiled Autograd
251190
--------------------------------------------------
@@ -261,18 +200,7 @@ Common recompilation reasons for Compiled Autograd
261200
262201
In the example above, we call a different operator on each iteration, leading to ``loss`` tracking a different autograd history each time. You should see some recompile messages: **Cache miss due to new autograd node**.
263202

264-
.. code:: python
265-
266-
stderr_output = """
267-
Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
268-
...
269-
Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[]
270-
...
271-
Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[]
272-
...
273-
Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[]
274-
...
275-
"""
203+
.. image:: ../_static/img/compiled_autograd/recompile_due_to_node.png
276204

277205
2. Due to tensors changing shapes:
278206

@@ -286,16 +214,7 @@ In the example above, we call a different operator on each iteration, leading to
286214
287215
In the example above, ``x`` changes shapes, and compiled autograd will mark ``x`` as a dynamic shape tensor after the first change. You should see recompiles messages: **Cache miss due to changed shapes**.
288216

289-
.. code:: python
290-
291-
stderr_output = """
292-
...
293-
Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic
294-
Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
295-
Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
296-
Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
297-
...
298-
"""
217+
.. image:: ../_static/img/compiled_autograd/recompile_due_to_dynamic.png
299218

300219
Conclusion
301220
----------

prototype_source/pt2e_quant_ptq.rst

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ The PyTorch 2 export quantization API looks like this:
5151
.. code:: python
5252
5353
import torch
54-
from torch._export import capture_pre_autograd_graph
5554
class M(torch.nn.Module):
5655
def __init__(self):
5756
super().__init__()
@@ -65,9 +64,9 @@ The PyTorch 2 export quantization API looks like this:
6564
m = M().eval()
6665
6766
# Step 1. program capture
68-
# NOTE: this API will be updated to torch.export API in the future, but the captured
69-
# result shoud mostly stay the same
70-
m = capture_pre_autograd_graph(m, *example_inputs)
67+
# This is available for pytorch 2.5+, for more details on lower pytorch versions
68+
# please check `Export the model with torch.export` section
69+
m = torch.export.export_for_training(m, example_inputs).module()
7170
# we get a model with aten ops
7271
7372
@@ -77,7 +76,7 @@ The PyTorch 2 export quantization API looks like this:
7776
convert_pt2e,
7877
)
7978
80-
from torch.ao.quantization.quantizer import (
79+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
8180
XNNPACKQuantizer,
8281
get_symmetric_quantization_config,
8382
)
@@ -280,10 +279,7 @@ and rename it to ``data/resnet18_pretrained_float.pth``.
280279
return model
281280
282281
def print_size_of_model(model):
283-
if isinstance(model, torch.jit.RecursiveScriptModule):
284-
torch.jit.save(model, "temp.p")
285-
else:
286-
torch.jit.save(torch.jit.script(model), "temp.p")
282+
torch.save(model.state_dict(), "temp.p")
287283
print("Size (MB):", os.path.getsize("temp.p")/1e6)
288284
os.remove("temp.p")
289285
@@ -351,18 +347,28 @@ Here is how you can use ``torch.export`` to export the model:
351347

352348
.. code-block:: python
353349
354-
from torch._export import capture_pre_autograd_graph
355-
356350
example_inputs = (torch.rand(2, 3, 224, 224),)
357-
exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
351+
# for pytorch 2.5+
352+
exported_model = torch.export.export_for_training(model_to_quantize, example_inputs).module()
353+
354+
# for pytorch 2.4 and before
355+
# from torch._export import capture_pre_autograd_graph
356+
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
357+
358358
# or capture with dynamic dimensions
359+
# for pytorch 2.5+
360+
dynamic_shapes = tuple(
361+
{0: torch.export.Dim("dim")} if i == 0 else None
362+
for i in range(len(example_inputs))
363+
)
364+
exported_model = torch.export.export_for_training(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()
365+
366+
# for pytorch 2.4 and before
367+
# dynamic_shape API may vary as well
359368
# from torch._export import dynamic_dim
360369
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs, constraints=[dynamic_dim(example_inputs[0], 0)])
361370
362371
363-
``capture_pre_autograd_graph`` is a short term API, it will be updated to use the offical ``torch.export`` API when that is ready.
364-
365-
366372
Import the Backend Specific Quantizer and Configure how to Quantize the Model
367373
-----------------------------------------------------------------------------
368374

@@ -454,7 +460,7 @@ we offer in the long term might change based on feedback from PyTorch users.
454460
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
455461
return out_i8
456462
457-
* Reference Quantized Model Representation (available in the nightly build)
463+
* Reference Quantized Model Representation
458464

459465
We will have a special representation for selected ops, for example, quantized linear. Other ops are represented as ``dq -> float32_op -> q`` and ``q/dq`` are decomposed into more primitive operators.
460466
You can get this representation by using ``convert_pt2e(..., use_reference_representation=True)``.
@@ -485,8 +491,6 @@ Now we can compare the size and model accuracy with baseline model.
485491
.. code-block:: python
486492
487493
# Baseline model size and accuracy
488-
scripted_float_model_file = "resnet18_scripted.pth"
489-
490494
print("Size of baseline model")
491495
print_size_of_model(float_model)
492496
@@ -495,6 +499,8 @@ Now we can compare the size and model accuracy with baseline model.
495499
496500
# Quantized model size and accuracy
497501
print("Size of model after quantization")
502+
# export again to remove unused weights
503+
quantized_model = torch.export.export_for_training(quantized_model, example_inputs).module()
498504
print_size_of_model(quantized_model)
499505
500506
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Quantization in PyTorch 2.0 Export Tutorial
2+
===========================================
3+
4+
This tutorial has been moved.
5+
6+
Redirecting in 3 seconds...
7+
8+
.. raw:: html
9+
10+
<meta http-equiv="Refresh" content="3; url='https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html'" />

prototype_source/pt2e_quant_qat.rst

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ to the post training quantization (PTQ) flow for the most part:
1818
prepare_qat_pt2e,
1919
convert_pt2e,
2020
)
21-
from torch.ao.quantization.quantizer import (
21+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
2222
XNNPACKQuantizer,
2323
get_symmetric_quantization_config,
2424
)
@@ -36,9 +36,9 @@ to the post training quantization (PTQ) flow for the most part:
3636
m = M()
3737
3838
# Step 1. program capture
39-
# NOTE: this API will be updated to torch.export API in the future, but the captured
40-
# result shoud mostly stay the same
41-
m = capture_pre_autograd_graph(m, *example_inputs)
39+
# This is available for pytorch 2.5+, for more details on lower pytorch versions
40+
# please check `Export the model with torch.export` section
41+
m = torch.export.export_for_training(m, example_inputs).module()
4242
# we get a model with aten ops
4343
4444
# Step 2. quantization-aware training
@@ -272,24 +272,35 @@ Here is how you can use ``torch.export`` to export the model:
272272
from torch._export import capture_pre_autograd_graph
273273
274274
example_inputs = (torch.rand(2, 3, 224, 224),)
275-
exported_model = capture_pre_autograd_graph(float_model, example_inputs)
275+
# for pytorch 2.5+
276+
exported_model = torch.export.export_for_training(float_model, example_inputs).module()
277+
# for pytorch 2.4 and before
278+
# from torch._export import capture_pre_autograd_graph
279+
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
276280
277281
278282
.. code:: python
279283
280284
# or, to capture with dynamic dimensions:
281-
from torch._export import dynamic_dim
282285
283-
example_inputs = (torch.rand(2, 3, 224, 224),)
284-
exported_model = capture_pre_autograd_graph(
285-
float_model,
286-
example_inputs,
287-
constraints=[dynamic_dim(example_inputs[0], 0)],
286+
# for pytorch 2.5+
287+
dynamic_shapes = tuple(
288+
{0: torch.export.Dim("dim")} if i == 0 else None
289+
for i in range(len(example_inputs))
288290
)
289-
.. note::
290-
291-
``capture_pre_autograd_graph`` is a short term API, it will be updated to use the offical ``torch.export`` API when that is ready.
292-
291+
exported_model = torch.export.export_for_training(float_model, example_inputs, dynamic_shapes=dynamic_shapes).module()
292+
293+
# for pytorch 2.4 and before
294+
# dynamic_shape API may vary as well
295+
# from torch._export import dynamic_dim
296+
297+
# example_inputs = (torch.rand(2, 3, 224, 224),)
298+
# exported_model = capture_pre_autograd_graph(
299+
# float_model,
300+
# example_inputs,
301+
# constraints=[dynamic_dim(example_inputs[0], 0)],
302+
# )
303+
293304
294305
Import the Backend Specific Quantizer and Configure how to Quantize the Model
295306
-----------------------------------------------------------------------------
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Quantization in PyTorch 2.0 Export Tutorial
2+
===========================================
3+
4+
This tutorial has been moved.
5+
6+
Redirecting in 3 seconds...
7+
8+
.. raw:: html
9+
10+
<meta http-equiv="Refresh" content="3; url='https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html'" />

0 commit comments

Comments
 (0)