Skip to content

Commit d862a95

Browse files
xmfansvekars
andauthored
Add compiled autograd tutorial (#3026)
* Add compiled autograd tutorial --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 19fffda commit d862a95

File tree

2 files changed

+310
-0
lines changed

2 files changed

+310
-0
lines changed

index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,13 @@ Welcome to PyTorch Tutorials
439439
:link: advanced/python_custom_ops.html
440440
:tags: Extending-PyTorch,Frontend-APIs,C++,CUDA
441441

442+
.. customcarditem::
443+
:header: Compiled Autograd: Capturing a larger backward graph for ``torch.compile``
444+
:card_description: Learn how to use compiled autograd to capture a larger backward graph.
445+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
446+
:link: intermediate/compiled_autograd_tutorial
447+
:tags: Model-Optimization,CUDA
448+
442449
.. customcarditem::
443450
:header: Custom C++ and CUDA Operators
444451
:card_description: How to extend PyTorch with custom C++ and CUDA operators.
@@ -1132,6 +1139,7 @@ Additional Resources
11321139
intermediate/nvfuser_intro_tutorial
11331140
intermediate/ax_multiobjective_nas_tutorial
11341141
intermediate/torch_compile_tutorial
1142+
intermediate/compiled_autograd_tutorial
11351143
intermediate/inductor_debug_cpu
11361144
intermediate/scaled_dot_product_attention_tutorial
11371145
beginner/knowledge_distillation_tutorial
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
Compiled Autograd: Capturing a larger backward graph for ``torch.compile``
2+
==========================================================================
3+
**Author:** `Simon Fan <https://github.com/xmfan>`_
4+
5+
.. grid:: 2
6+
7+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
8+
:class-card: card-prerequisites
9+
10+
* How compiled autograd interacts with ``torch.compile``
11+
* How to use the compiled autograd API
12+
* How to inspect logs using ``TORCH_LOGS``
13+
14+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
15+
:class-card: card-prerequisites
16+
17+
* PyTorch 2.4
18+
* Complete the `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
19+
* Read through the TorchDynamo and AOTAutograd sections of `Get Started with PyTorch 2.x <https://pytorch.org/get-started/pytorch-2.0/>`_
20+
21+
Overview
22+
--------
23+
Compiled Autograd is a ``torch.compile`` extension introduced in PyTorch 2.4
24+
that allows the capture of a larger backward graph.
25+
26+
While ``torch.compile`` does capture the backward graph, it does so **partially**. The AOTAutograd component captures the backward graph ahead-of-time, with certain limitations:
27+
28+
* Graph breaks in the forward lead to graph breaks in the backward
29+
* `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
30+
31+
Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowing
32+
it to capture the full backward graph at runtime. Models with these two characteristics should try
33+
Compiled Autograd, and potentially observe better performance.
34+
35+
However, Compiled Autograd introduces its own limitations:
36+
37+
* Added runtime overhead at the start of the backward for cache lookup
38+
* More prone to recompiles and graph breaks in dynamo due to the larger capture
39+
40+
.. note:: Compiled Autograd is under active development and is not yet compatible with all existing PyTorch features. For the latest status on a particular feature, refer to `Compiled Autograd Landing Page <https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY>`_.
41+
42+
Setup
43+
-----
44+
In this tutorial, we will base our examples on this simple neural network model.
45+
It takes a 10-dimensional input vector, processes it through a single linear layer, and outputs another 10-dimensional vector.
46+
47+
.. code:: python
48+
49+
import torch
50+
51+
class Model(torch.nn.Module):
52+
def __init__(self):
53+
super().__init__()
54+
self.linear = torch.nn.Linear(10, 10)
55+
56+
def forward(self, x):
57+
return self.linear(x)
58+
59+
Basic usage
60+
------------
61+
Before calling the ``torch.compile`` API, make sure to set ``torch._dynamo.config.compiled_autograd`` to ``True``:
62+
63+
.. code:: python
64+
65+
model = Model()
66+
x = torch.randn(10)
67+
68+
torch._dynamo.config.compiled_autograd = True
69+
@torch.compile
70+
def train(model, x):
71+
loss = model(x).sum()
72+
loss.backward()
73+
74+
train(model, x)
75+
76+
In the code above, we create an instance of the ``Model`` class and generate a random 10-dimensional tensor ``x`` by using ``torch.randn(10)``.
77+
We define the training loop function ``train`` and decorate it with @torch.compile to optimize its execution.
78+
When ``train(model, x)`` is called:
79+
80+
* Python Interpreter calls Dynamo, since this call was decorated with ``@torch.compile``.
81+
* Dynamo intercepts the Python bytecode, simulates their execution and records the operations into a graph.
82+
* ``AOTDispatcher`` disables hooks and calls the autograd engine to compute gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph. Using ``torch.autograd.Function``, AOTDispatcher rewrites the forward and backward implementation of ``train``.
83+
* Inductor generates a function corresponding to an optimized implementation of the AOTDispatcher forward and backward.
84+
* Dynamo sets the optimized function to be evaluated next by Python Interpreter.
85+
* Python Interpreter executes the optimized function, which executes ``loss = model(x).sum()``.
86+
* Python Interpreter executes ``loss.backward()``, calling into the autograd engine, which routes to the Compiled Autograd engine since we set ``torch._dynamo.config.compiled_autograd = True``.
87+
* Compiled Autograd computes the gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph, including any hooks it encounters. During this process, it will record the backward previously rewritten by AOTDispatcher. Compiled Autograd then generates a new function which corresponds to a fully-traced implementation of ``loss.backward()``, and executes it with ``torch.compile`` in inference mode.
88+
* The same steps recursively apply to the Compiled Autograd graph, but this time AOTDispatcher will not need to partition the graph.
89+
90+
Inspecting the compiled autograd logs
91+
-------------------------------------
92+
Run the script with the ``TORCH_LOGS`` environment variables:
93+
94+
* To only print the compiled autograd graph, use ``TORCH_LOGS="compiled_autograd" python example.py``
95+
* To print the graph with more tensor metadata and recompile reasons, at the cost of performance, use ``TORCH_LOGS="compiled_autograd_verbose" python example.py``
96+
97+
Rerun the snippet above, the compiled autograd graph should now be logged to ``stderr``. Certain graph nodes will have names that are prefixed by ``aot0_``,
98+
these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0, for example, ``aot0_view_2`` corresponds to ``view_2`` of the AOT backward graph with id=0.
99+
100+
101+
.. code:: python
102+
103+
stderr_output = """
104+
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
105+
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
106+
===== Compiled autograd graph =====
107+
<eval_with_key>.4 class CompiledAutograd(torch.nn.Module):
108+
def forward(self, inputs, sizes, scalars, hooks):
109+
# No stacktrace found for following nodes
110+
aot0_tangents_1: "f32[][]cpu" = inputs[0]
111+
aot0_primals_3: "f32[10][1]cpu" = inputs[1]
112+
getitem_2: "f32[10][1]cpu" = inputs[2]
113+
getitem_3: "f32[10, 10][10, 1]cpu" = inputs[3]; inputs = None
114+
115+
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1)
116+
aot0_expand: "f32[10][0]cpu" = torch.ops.aten.expand.default(aot0_tangents_1, [10]); aot0_tangents_1 = None
117+
aot0_view_2: "f32[1, 10][0, 0]cpu" = torch.ops.aten.view.default(aot0_expand, [1, 10]); aot0_expand = None
118+
aot0_permute_2: "f32[10, 1][0, 0]cpu" = torch.ops.aten.permute.default(aot0_view_2, [1, 0])
119+
aot0_select: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 0)
120+
aot0_view: "f32[1, 10][10, 1]cpu" = torch.ops.aten.view.default(aot0_primals_3, [1, 10]); aot0_primals_3 = None
121+
aot0_mul_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select, aot0_view); aot0_select = None
122+
aot0_select_1: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 1)
123+
aot0_mul_4: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_1, aot0_view); aot0_select_1 = None
124+
aot0_select_2: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 2)
125+
aot0_mul_5: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_2, aot0_view); aot0_select_2 = None
126+
aot0_select_3: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 3)
127+
aot0_mul_6: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_3, aot0_view); aot0_select_3 = None
128+
aot0_select_4: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 4)
129+
aot0_mul_7: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_4, aot0_view); aot0_select_4 = None
130+
aot0_select_5: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 5)
131+
aot0_mul_8: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_5, aot0_view); aot0_select_5 = None
132+
aot0_select_6: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 6)
133+
aot0_mul_9: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_6, aot0_view); aot0_select_6 = None
134+
aot0_select_7: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 7)
135+
aot0_mul_10: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_7, aot0_view); aot0_select_7 = None
136+
aot0_select_8: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 8)
137+
aot0_mul_11: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_8, aot0_view); aot0_select_8 = None
138+
aot0_select_9: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 9); aot0_permute_2 = None
139+
aot0_mul_12: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_9, aot0_view); aot0_select_9 = aot0_view = None
140+
aot0_cat: "f32[10, 10][10, 1]cpu" = torch.ops.aten.cat.default([aot0_mul_3, aot0_mul_4, aot0_mul_5, aot0_mul_6, aot0_mul_7, aot0_mul_8, aot0_mul_9, aot0_mul_10, aot0_mul_11, aot0_mul_12]); aot0_mul_3 = aot0_mul_4 = aot0_mul_5 = aot0_mul_6 = aot0_mul_7 = aot0_mul_8 = aot0_mul_9 = aot0_mul_10 = aot0_mul_11 = aot0_mul_12 = None
141+
aot0_permute_3: "f32[10, 10][1, 10]cpu" = torch.ops.aten.permute.default(aot0_cat, [1, 0]); aot0_cat = None
142+
aot0_sum_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.sum.dim_IntList(aot0_view_2, [0], True); aot0_view_2 = None
143+
aot0_view_3: "f32[10][1]cpu" = torch.ops.aten.view.default(aot0_sum_3, [10]); aot0_sum_3 = None
144+
145+
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 2)
146+
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_2, aot0_view_3); getitem_2 = aot0_view_3 = accumulate_grad_ = None
147+
148+
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1)
149+
aot0_permute_4: "f32[10, 10][10, 1]cpu" = torch.ops.aten.permute.default(aot0_permute_3, [1, 0]); aot0_permute_3 = None
150+
151+
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 3)
152+
accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, aot0_permute_4); getitem_3 = aot0_permute_4 = accumulate_grad__1 = None
153+
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
154+
return []
155+
"""
156+
157+
.. note:: This is the graph on which we will call ``torch.compile``, **NOT** the optimized graph. Compiled Autograd essentially generates some unoptimized Python code to represent the entire C++ autograd execution.
158+
159+
Compiling the forward and backward pass using different flags
160+
-------------------------------------------------------------
161+
You can use different compiler configs for the two compilations, for example, the backward may be a fullgraph even if there are graph breaks in the forward.
162+
163+
.. code:: python
164+
165+
def train(model, x):
166+
model = torch.compile(model)
167+
loss = model(x).sum()
168+
torch._dynamo.config.compiled_autograd = True
169+
torch.compile(lambda: loss.backward(), fullgraph=True)()
170+
171+
Or you can use the context manager, which will apply to all autograd calls within its scope.
172+
173+
.. code:: python
174+
175+
def train(model, x):
176+
model = torch.compile(model)
177+
loss = model(x).sum()
178+
with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
179+
loss.backward()
180+
181+
182+
Compiled Autograd addresses certain limitations of AOTAutograd
183+
--------------------------------------------------------------
184+
1. Graph breaks in the forward pass lead to graph breaks in the backward pass:
185+
186+
.. code:: python
187+
188+
@torch.compile(backend="aot_eager")
189+
def fn(x):
190+
# 1st graph
191+
temp = x + 10
192+
torch._dynamo.graph_break()
193+
# 2nd graph
194+
temp = temp + 10
195+
torch._dynamo.graph_break()
196+
# 3rd graph
197+
return temp.sum()
198+
199+
x = torch.randn(10, 10, requires_grad=True)
200+
torch._dynamo.utils.counters.clear()
201+
loss = fn(x)
202+
203+
# 1. base torch.compile
204+
loss.backward(retain_graph=True)
205+
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
206+
torch._dynamo.utils.counters.clear()
207+
208+
# 2. torch.compile with compiled autograd
209+
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
210+
loss.backward()
211+
212+
# single graph for the backward
213+
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)
214+
215+
216+
In the first ``torch.compile`` case, we see that 3 backward graphs were produced due to the 2 graph breaks in the compiled function ``fn``.
217+
Whereas in the second ``torch.compile`` with compiled autograd case, we see that a full backward graph was traced despite the graph breaks.
218+
219+
2. Backward hooks are not captured
220+
221+
.. code:: python
222+
223+
@torch.compile(backend="aot_eager")
224+
def fn(x):
225+
return x.sum()
226+
227+
x = torch.randn(10, 10, requires_grad=True)
228+
x.register_hook(lambda grad: grad+10)
229+
loss = fn(x)
230+
231+
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
232+
loss.backward()
233+
234+
There should be a ``call_hook`` node in the graph, which dynamo will later inline into the following:
235+
236+
.. code:: python
237+
238+
stderr_output = """
239+
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
240+
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
241+
===== Compiled autograd graph =====
242+
<eval_with_key>.2 class CompiledAutograd(torch.nn.Module):
243+
def forward(self, inputs, sizes, scalars, hooks):
244+
...
245+
getitem_2 = hooks[0]; hooks = None
246+
call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None
247+
...
248+
"""
249+
250+
Common recompilation reasons for Compiled Autograd
251+
--------------------------------------------------
252+
1. Due to changes in the autograd structure of the loss value:
253+
254+
.. code:: python
255+
256+
torch._dynamo.config.compiled_autograd = True
257+
x = torch.randn(10, requires_grad=True)
258+
for op in [torch.add, torch.sub, torch.mul, torch.div]:
259+
loss = op(x, x).sum()
260+
torch.compile(lambda: loss.backward(), backend="eager")()
261+
262+
In the example above, we call a different operator on each iteration, leading to ``loss`` tracking a different autograd history each time. You should see some recompile messages: **Cache miss due to new autograd node**.
263+
264+
.. code:: python
265+
266+
stderr_output = """
267+
Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
268+
...
269+
Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[]
270+
...
271+
Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[]
272+
...
273+
Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[]
274+
...
275+
"""
276+
277+
2. Due to tensors changing shapes:
278+
279+
.. code:: python
280+
281+
torch._dynamo.config.compiled_autograd = True
282+
for i in [10, 100, 10]:
283+
x = torch.randn(i, i, requires_grad=True)
284+
loss = x.sum()
285+
torch.compile(lambda: loss.backward(), backend="eager")()
286+
287+
In the example above, ``x`` changes shapes, and compiled autograd will mark ``x`` as a dynamic shape tensor after the first change. You should see recompiles messages: **Cache miss due to changed shapes**.
288+
289+
.. code:: python
290+
291+
stderr_output = """
292+
...
293+
Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic
294+
Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
295+
Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
296+
Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
297+
...
298+
"""
299+
300+
Conclusion
301+
----------
302+
In this tutorial, we went over the high-level ecosystem of ``torch.compile`` with compiled autograd, the basics of compiled autograd and a few common recompilation reasons. Stay tuned for deep dives on `dev-discuss <https://dev-discuss.pytorch.org/>`_.

0 commit comments

Comments
 (0)