Skip to content

Commit ffa0a81

Browse files
committed
Add compiled autograd tutorial
1 parent 748e52b commit ffa0a81

File tree

1 file changed

+275
-0
lines changed

1 file changed

+275
-0
lines changed
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
Compiled Autograd: Capturing a larger backward graph for ``torch.compile``
5+
==========================================================================
6+
7+
"""
8+
9+
######################################################################
10+
# Compiled Autograd is a torch.compile extension introduced in PyTorch 2.4
11+
# that allows the capture of a larger backward graph. It is highly recommended
12+
# to familiarize yourself with `torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_.
13+
#
14+
15+
######################################################################
16+
# Doesn't torch.compile already capture the backward graph?
17+
# ------------
18+
# Partially. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations:
19+
# - Graph breaks in the forward lead to graph breaks in the backward
20+
# - `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
21+
#
22+
# Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowing
23+
# it to capture the full backward graph at runtime. Models with these two characteristics should try
24+
# Compiled Autograd, and potentially observe better performance.
25+
#
26+
# However, Compiled Autograd has its own limitations:
27+
# - Dynamic autograd structure leads to recompiles
28+
#
29+
30+
######################################################################
31+
# Basic Usage
32+
# ------------
33+
#
34+
35+
# NOTE: Must be enabled before using the decorator
36+
torch._dynamo.config.compiled_autograd = True
37+
38+
class Model(torch.nn.Module):
39+
def __init__(self):
40+
super().__init__()
41+
self.linear = torch.nn.Linear(10, 10)
42+
43+
def forward(self, x):
44+
return self.linear(x)
45+
46+
@torch.compile
47+
def train(model, x):
48+
loss = model(x).sum()
49+
loss.backward()
50+
51+
model = Model()
52+
x = torch.randn(10)
53+
train(model, x)
54+
55+
######################################################################
56+
# Inspecting the compiled autograd logs
57+
# ------------
58+
# Run the script with either TORCH_LOGS environment variables
59+
#
60+
"""
61+
Prints graph:
62+
TORCH_LOGS="compiled_autograd" python example.py
63+
64+
Performance degrading, prints verbose graph and recompile reasons:
65+
TORCH_LOGS="compiled_autograd_verbose" python example.py
66+
"""
67+
68+
######################################################################
69+
# Or with the set_logs private API:
70+
#
71+
72+
# flag must be enabled before wrapping using torch.compile
73+
torch._logging._internal.set_logs(compiled_autograd=True)
74+
75+
@torch.compile
76+
def train(model, x):
77+
loss = model(x).sum()
78+
loss.backward()
79+
80+
train(model, x)
81+
82+
######################################################################
83+
# The compiled autograd graph should now be logged to stdout. Certain graph nodes will have names that are prefixed by "aot0_",
84+
# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0.
85+
#
86+
# NOTE: This is the graph that we will call torch.compile on, NOT the optimized graph. Compiled Autograd basically
87+
# generated some python code to represent the entire C++ autograd execution.
88+
#
89+
"""
90+
INFO:torch._dynamo.compiled_autograd.__compiled_autograd:TRACED GRAPH
91+
===== Compiled autograd graph =====
92+
<eval_with_key>.4 class CompiledAutograd(torch.nn.Module):
93+
def forward(self, inputs, sizes, scalars, hooks):
94+
# No stacktrace found for following nodes
95+
aot0_tangents_1: "f32[][]cpu" = inputs[0]
96+
aot0_primals_3: "f32[10][1]cpu" = inputs[1]
97+
getitem_2: "f32[10][1]cpu" = inputs[2]
98+
getitem_3: "f32[10, 10][10, 1]cpu" = inputs[3]; inputs = None
99+
100+
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1)
101+
aot0_expand: "f32[10][0]cpu" = torch.ops.aten.expand.default(aot0_tangents_1, [10]); aot0_tangents_1 = None
102+
aot0_view_2: "f32[1, 10][0, 0]cpu" = torch.ops.aten.view.default(aot0_expand, [1, 10]); aot0_expand = None
103+
aot0_permute_2: "f32[10, 1][0, 0]cpu" = torch.ops.aten.permute.default(aot0_view_2, [1, 0])
104+
aot0_select: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 0)
105+
aot0_view: "f32[1, 10][10, 1]cpu" = torch.ops.aten.view.default(aot0_primals_3, [1, 10]); aot0_primals_3 = None
106+
aot0_mul_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select, aot0_view); aot0_select = None
107+
aot0_select_1: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 1)
108+
aot0_mul_4: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_1, aot0_view); aot0_select_1 = None
109+
aot0_select_2: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 2)
110+
aot0_mul_5: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_2, aot0_view); aot0_select_2 = None
111+
aot0_select_3: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 3)
112+
aot0_mul_6: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_3, aot0_view); aot0_select_3 = None
113+
aot0_select_4: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 4)
114+
aot0_mul_7: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_4, aot0_view); aot0_select_4 = None
115+
aot0_select_5: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 5)
116+
aot0_mul_8: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_5, aot0_view); aot0_select_5 = None
117+
aot0_select_6: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 6)
118+
aot0_mul_9: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_6, aot0_view); aot0_select_6 = None
119+
aot0_select_7: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 7)
120+
aot0_mul_10: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_7, aot0_view); aot0_select_7 = None
121+
aot0_select_8: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 8)
122+
aot0_mul_11: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_8, aot0_view); aot0_select_8 = None
123+
aot0_select_9: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 9); aot0_permute_2 = None
124+
aot0_mul_12: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_9, aot0_view); aot0_select_9 = aot0_view = None
125+
aot0_cat: "f32[10, 10][10, 1]cpu" = torch.ops.aten.cat.default([aot0_mul_3, aot0_mul_4, aot0_mul_5, aot0_mul_6, aot0_mul_7, aot0_mul_8, aot0_mul_9, aot0_mul_10, aot0_mul_11, aot0_mul_12]); aot0_mul_3 = aot0_mul_4 = aot0_mul_5 = aot0_mul_6 = aot0_mul_7 = aot0_mul_8 = aot0_mul_9 = aot0_mul_10 = aot0_mul_11 = aot0_mul_12 = None
126+
aot0_permute_3: "f32[10, 10][1, 10]cpu" = torch.ops.aten.permute.default(aot0_cat, [1, 0]); aot0_cat = None
127+
aot0_sum_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.sum.dim_IntList(aot0_view_2, [0], True); aot0_view_2 = None
128+
aot0_view_3: "f32[10][1]cpu" = torch.ops.aten.view.default(aot0_sum_3, [10]); aot0_sum_3 = None
129+
130+
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 2)
131+
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_2, aot0_view_3); getitem_2 = aot0_view_3 = accumulate_grad_ = None
132+
133+
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1)
134+
aot0_permute_4: "f32[10, 10][10, 1]cpu" = torch.ops.aten.permute.default(aot0_permute_3, [1, 0]); aot0_permute_3 = None
135+
136+
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 3)
137+
accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, aot0_permute_4); getitem_3 = aot0_permute_4 = accumulate_grad__1 = None
138+
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
139+
return []
140+
"""
141+
142+
######################################################################
143+
# Compiling the forward and backward pass using different flags
144+
# ------------
145+
#
146+
147+
def train(model, x):
148+
model = torch.compile(model)
149+
loss = model(x).sum()
150+
torch.compile(lambda: loss.backward(), fullgraph=True)()
151+
152+
######################################################################
153+
# Or you can use the context manager, which will apply to all autograd calls within it
154+
#
155+
156+
def train(model, x):
157+
model = torch.compile(model)
158+
loss = model(x).sum()
159+
with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
160+
loss.backward()
161+
162+
163+
######################################################################
164+
# Demonstrating the limitations of AOTAutograd addressed by Compiled Autograd
165+
# ------------
166+
# 1. Graph breaks in the forward lead to graph breaks in the backward
167+
#
168+
169+
@torch.compile(backend="aot_eager")
170+
def fn(x):
171+
# 1st graph
172+
temp = x + 10
173+
torch._dynamo.graph_break()
174+
# 2nd graph
175+
temp = temp + 10
176+
torch._dynamo.graph_break()
177+
# 3rd graph
178+
return temp.sum()
179+
180+
x = torch.randn(10, 10, requires_grad=True)
181+
loss = fn(x)
182+
183+
# 1. base torch.compile
184+
loss.backward(retain_graph=True)
185+
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
186+
torch._dynamo.utils.counters.clear()
187+
188+
# 2. torch.compile with compiled autograd
189+
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
190+
loss.backward()
191+
192+
# single graph for the backward
193+
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)
194+
195+
196+
######################################################################
197+
# 2. `Backward hooks are not captured
198+
#
199+
200+
@torch.compile(backend="aot_eager")
201+
def fn(x):
202+
return x.sum()
203+
204+
x = torch.randn(10, 10, requires_grad=True)
205+
x.register_hook(lambda grad: grad+10)
206+
loss = fn(x)
207+
208+
torch._logging._internal.set_logs(compiled_autograd=True)
209+
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
210+
loss.backward()
211+
212+
######################################################################
213+
# There is a `call_hook` node in the graph, which dynamo will inline
214+
#
215+
216+
"""
217+
INFO:torch._dynamo.compiled_autograd.__compiled_autograd:TRACED GRAPH
218+
===== Compiled autograd graph =====
219+
<eval_with_key>.2 class CompiledAutograd(torch.nn.Module):
220+
def forward(self, inputs, sizes, scalars, hooks):
221+
...
222+
getitem_2 = hooks[0]; hooks = None
223+
call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None
224+
...
225+
"""
226+
227+
######################################################################
228+
# Understanding recompilation reasons for Compiled Autograd
229+
# ------------
230+
# 1. Due to change in autograd structure
231+
232+
torch._logging._internal.set_logs(compiled_autograd_verbose=True)
233+
torch._dynamo.config.compiled_autograd = True
234+
x = torch.randn(10, requires_grad=True)
235+
for op in [torch.add, torch.sub, torch.mul, torch.div]:
236+
loss = op(x, x).sum()
237+
torch.compile(lambda: loss.backward(), backend="eager")()
238+
239+
######################################################################
240+
# You should see some cache miss logs (recompiles):
241+
# Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
242+
# ...
243+
# Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[]
244+
# ...
245+
# Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[]
246+
# ...
247+
# Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[]
248+
# ...
249+
250+
######################################################################
251+
# 2. Due to dynamic shapes
252+
#
253+
254+
torch._logging._internal.set_logs(compiled_autograd_verbose=True)
255+
torch._dynamo.config.compiled_autograd = True
256+
for i in [10, 100, 10]:
257+
x = torch.randn(i, i, requires_grad=True)
258+
loss = x.sum()
259+
torch.compile(lambda: loss.backward(), backend="eager")()
260+
261+
######################################################################
262+
# You should see some cache miss logs (recompiles):
263+
# ...
264+
# Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic
265+
# Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
266+
# Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
267+
# Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
268+
# ...
269+
270+
######################################################################
271+
# Compatibility and rough edges
272+
# ------------
273+
#
274+
# Compiled Autograd is under active development and is not yet compatible with all existing PyTorch features.
275+
# For the latest status on a particular feature, refer to: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY.

0 commit comments

Comments
 (0)