You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: intermediate_source/compiled_autograd_tutorial.rst
+10-91Lines changed: 10 additions & 91 deletions
Original file line number
Diff line number
Diff line change
@@ -97,62 +97,10 @@ Run the script with the ``TORCH_LOGS`` environment variables:
97
97
Rerun the snippet above, the compiled autograd graph should now be logged to ``stderr``. Certain graph nodes will have names that are prefixed by ``aot0_``,
98
98
these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0, for example, ``aot0_view_2`` corresponds to ``view_2`` of the AOT backward graph with id=0.
99
99
100
+
In the image below, the red box encapsulates the AOT backward graph that is captured by ``torch.compile`` without Compiled Autograd.
100
101
101
-
.. code:: python
102
102
103
-
stderr_output ="""
104
-
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
.. note:: This is the graph on which we will call ``torch.compile``, **NOT** the optimized graph. Compiled Autograd essentially generates some unoptimized Python code to represent the entire C++ autograd execution.
158
106
@@ -181,7 +129,7 @@ Or you can use the context manager, which will apply to all autograd calls withi
181
129
182
130
Compiled Autograd addresses certain limitations of AOTAutograd
1. Graph breaks in the forward pass lead to graph breaks in the backward pass:
132
+
1. Graph breaks in the forward pass no longer necessarily lead to graph breaks in the backward pass:
185
133
186
134
.. code:: python
187
135
@@ -216,7 +164,10 @@ Compiled Autograd addresses certain limitations of AOTAutograd
216
164
In the first ``torch.compile`` case, we see that 3 backward graphs were produced due to the 2 graph breaks in the compiled function ``fn``.
217
165
Whereas in the second ``torch.compile`` with compiled autograd case, we see that a full backward graph was traced despite the graph breaks.
218
166
219
-
2. Backward hooks are not captured
167
+
.. note:: It is still possible for the Dynamo to graph break when tracing backward hooks captured by Compiled Autograd.
168
+
169
+
170
+
2. Backward hooks can now be captured
220
171
221
172
.. code:: python
222
173
@@ -233,19 +184,7 @@ Whereas in the second ``torch.compile`` with compiled autograd case, we see that
233
184
234
185
There should be a ``call_hook`` node in the graph, which dynamo will later inline into the following:
235
186
236
-
.. code:: python
237
-
238
-
stderr_output ="""
239
-
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
@@ -261,18 +200,7 @@ Common recompilation reasons for Compiled Autograd
261
200
262
201
In the example above, we call a different operator on each iteration, leading to ``loss`` tracking a different autograd history each time. You should see some recompile messages: **Cache miss due to new autograd node**.
263
202
264
-
.. code:: python
265
-
266
-
stderr_output ="""
267
-
Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
268
-
...
269
-
Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[]
270
-
...
271
-
Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[]
272
-
...
273
-
Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[]
@@ -286,16 +214,7 @@ In the example above, we call a different operator on each iteration, leading to
286
214
287
215
In the example above, ``x`` changes shapes, and compiled autograd will mark ``x`` as a dynamic shape tensor after the first change. You should see recompiles messages: **Cache miss due to changed shapes**.
288
216
289
-
.. code:: python
290
-
291
-
stderr_output ="""
292
-
...
293
-
Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic
294
-
Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
295
-
Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
296
-
Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
0 commit comments