|
| 1 | +Compiled Autograd: Capturing a larger backward graph for ``torch.compile`` |
| 2 | +========================================================================== |
| 3 | +**Author:** `Simon Fan <https://github.com/xmfan>`_ |
| 4 | + |
| 5 | +.. grid:: 2 |
| 6 | + |
| 7 | + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn |
| 8 | + :class-card: card-prerequisites |
| 9 | + |
| 10 | + * How compiled autograd interacts with ``torch.compile`` |
| 11 | + * How to use the compiled autograd API |
| 12 | + * How to inspect logs using ``TORCH_LOGS`` |
| 13 | + |
| 14 | + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites |
| 15 | + :class-card: card-prerequisites |
| 16 | + |
| 17 | + * PyTorch 2.4 |
| 18 | + * Complete the `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_ |
| 19 | + * Read through the TorchDynamo and AOTAutograd sections of `Get Started with PyTorch 2.x <https://pytorch.org/get-started/pytorch-2.0/>`_ |
| 20 | + |
| 21 | +Overview |
| 22 | +-------- |
| 23 | +Compiled Autograd is a ``torch.compile`` extension introduced in PyTorch 2.4 |
| 24 | +that allows the capture of a larger backward graph. |
| 25 | + |
| 26 | +While ``torch.compile`` does capture the backward graph, it does so **partially**. The AOTAutograd component captures the backward graph ahead-of-time, with certain limitations: |
| 27 | + |
| 28 | +* Graph breaks in the forward lead to graph breaks in the backward |
| 29 | +* `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured |
| 30 | + |
| 31 | +Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowing |
| 32 | +it to capture the full backward graph at runtime. Models with these two characteristics should try |
| 33 | +Compiled Autograd, and potentially observe better performance. |
| 34 | + |
| 35 | +However, Compiled Autograd introduces its own limitations: |
| 36 | + |
| 37 | +* Added runtime overhead at the start of the backward for cache lookup |
| 38 | +* More prone to recompiles and graph breaks in dynamo due to the larger capture |
| 39 | + |
| 40 | +.. note:: Compiled Autograd is under active development and is not yet compatible with all existing PyTorch features. For the latest status on a particular feature, refer to `Compiled Autograd Landing Page <https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY>`_. |
| 41 | + |
| 42 | +Setup |
| 43 | +----- |
| 44 | +In this tutorial, we will base our examples on this simple neural network model. |
| 45 | +It takes a 10-dimensional input vector, processes it through a single linear layer, and outputs another 10-dimensional vector. |
| 46 | + |
| 47 | +.. code:: python |
| 48 | +
|
| 49 | + import torch |
| 50 | +
|
| 51 | + class Model(torch.nn.Module): |
| 52 | + def __init__(self): |
| 53 | + super().__init__() |
| 54 | + self.linear = torch.nn.Linear(10, 10) |
| 55 | +
|
| 56 | + def forward(self, x): |
| 57 | + return self.linear(x) |
| 58 | +
|
| 59 | +Basic usage |
| 60 | +------------ |
| 61 | +Before calling the ``torch.compile`` API, make sure to set ``torch._dynamo.config.compiled_autograd`` to ``True``: |
| 62 | + |
| 63 | +.. code:: python |
| 64 | +
|
| 65 | + model = Model() |
| 66 | + x = torch.randn(10) |
| 67 | +
|
| 68 | + torch._dynamo.config.compiled_autograd = True |
| 69 | + @torch.compile |
| 70 | + def train(model, x): |
| 71 | + loss = model(x).sum() |
| 72 | + loss.backward() |
| 73 | +
|
| 74 | + train(model, x) |
| 75 | +
|
| 76 | +In the code above, we create an instance of the ``Model`` class and generate a random 10-dimensional tensor ``x`` by using ``torch.randn(10)``. |
| 77 | +We define the training loop function ``train`` and decorate it with @torch.compile to optimize its execution. |
| 78 | +When ``train(model, x)`` is called: |
| 79 | + |
| 80 | +* Python Interpreter calls Dynamo, since this call was decorated with ``@torch.compile``. |
| 81 | +* Dynamo intercepts the Python bytecode, simulates their execution and records the operations into a graph. |
| 82 | +* ``AOTDispatcher`` disables hooks and calls the autograd engine to compute gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph. Using ``torch.autograd.Function``, AOTDispatcher rewrites the forward and backward implementation of ``train``. |
| 83 | +* Inductor generates a function corresponding to an optimized implementation of the AOTDispatcher forward and backward. |
| 84 | +* Dynamo sets the optimized function to be evaluated next by Python Interpreter. |
| 85 | +* Python Interpreter executes the optimized function, which executes ``loss = model(x).sum()``. |
| 86 | +* Python Interpreter executes ``loss.backward()``, calling into the autograd engine, which routes to the Compiled Autograd engine since we set ``torch._dynamo.config.compiled_autograd = True``. |
| 87 | +* Compiled Autograd computes the gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph, including any hooks it encounters. During this process, it will record the backward previously rewritten by AOTDispatcher. Compiled Autograd then generates a new function which corresponds to a fully-traced implementation of ``loss.backward()``, and executes it with ``torch.compile`` in inference mode. |
| 88 | +* The same steps recursively apply to the Compiled Autograd graph, but this time AOTDispatcher will not need to partition the graph. |
| 89 | + |
| 90 | +Inspecting the compiled autograd logs |
| 91 | +------------------------------------- |
| 92 | +Run the script with the ``TORCH_LOGS`` environment variables: |
| 93 | + |
| 94 | +* To only print the compiled autograd graph, use ``TORCH_LOGS="compiled_autograd" python example.py`` |
| 95 | +* To print the graph with more tensor metadata and recompile reasons, at the cost of performance, use ``TORCH_LOGS="compiled_autograd_verbose" python example.py`` |
| 96 | + |
| 97 | +Rerun the snippet above, the compiled autograd graph should now be logged to ``stderr``. Certain graph nodes will have names that are prefixed by ``aot0_``, |
| 98 | +these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0, for example, ``aot0_view_2`` corresponds to ``view_2`` of the AOT backward graph with id=0. |
| 99 | + |
| 100 | + |
| 101 | +.. code:: python |
| 102 | +
|
| 103 | + stderr_output = """ |
| 104 | + DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[] |
| 105 | + DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH |
| 106 | + ===== Compiled autograd graph ===== |
| 107 | + <eval_with_key>.4 class CompiledAutograd(torch.nn.Module): |
| 108 | + def forward(self, inputs, sizes, scalars, hooks): |
| 109 | + # No stacktrace found for following nodes |
| 110 | + aot0_tangents_1: "f32[][]cpu" = inputs[0] |
| 111 | + aot0_primals_3: "f32[10][1]cpu" = inputs[1] |
| 112 | + getitem_2: "f32[10][1]cpu" = inputs[2] |
| 113 | + getitem_3: "f32[10, 10][10, 1]cpu" = inputs[3]; inputs = None |
| 114 | + |
| 115 | + # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1) |
| 116 | + aot0_expand: "f32[10][0]cpu" = torch.ops.aten.expand.default(aot0_tangents_1, [10]); aot0_tangents_1 = None |
| 117 | + aot0_view_2: "f32[1, 10][0, 0]cpu" = torch.ops.aten.view.default(aot0_expand, [1, 10]); aot0_expand = None |
| 118 | + aot0_permute_2: "f32[10, 1][0, 0]cpu" = torch.ops.aten.permute.default(aot0_view_2, [1, 0]) |
| 119 | + aot0_select: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 0) |
| 120 | + aot0_view: "f32[1, 10][10, 1]cpu" = torch.ops.aten.view.default(aot0_primals_3, [1, 10]); aot0_primals_3 = None |
| 121 | + aot0_mul_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select, aot0_view); aot0_select = None |
| 122 | + aot0_select_1: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 1) |
| 123 | + aot0_mul_4: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_1, aot0_view); aot0_select_1 = None |
| 124 | + aot0_select_2: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 2) |
| 125 | + aot0_mul_5: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_2, aot0_view); aot0_select_2 = None |
| 126 | + aot0_select_3: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 3) |
| 127 | + aot0_mul_6: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_3, aot0_view); aot0_select_3 = None |
| 128 | + aot0_select_4: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 4) |
| 129 | + aot0_mul_7: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_4, aot0_view); aot0_select_4 = None |
| 130 | + aot0_select_5: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 5) |
| 131 | + aot0_mul_8: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_5, aot0_view); aot0_select_5 = None |
| 132 | + aot0_select_6: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 6) |
| 133 | + aot0_mul_9: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_6, aot0_view); aot0_select_6 = None |
| 134 | + aot0_select_7: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 7) |
| 135 | + aot0_mul_10: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_7, aot0_view); aot0_select_7 = None |
| 136 | + aot0_select_8: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 8) |
| 137 | + aot0_mul_11: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_8, aot0_view); aot0_select_8 = None |
| 138 | + aot0_select_9: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 9); aot0_permute_2 = None |
| 139 | + aot0_mul_12: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_9, aot0_view); aot0_select_9 = aot0_view = None |
| 140 | + aot0_cat: "f32[10, 10][10, 1]cpu" = torch.ops.aten.cat.default([aot0_mul_3, aot0_mul_4, aot0_mul_5, aot0_mul_6, aot0_mul_7, aot0_mul_8, aot0_mul_9, aot0_mul_10, aot0_mul_11, aot0_mul_12]); aot0_mul_3 = aot0_mul_4 = aot0_mul_5 = aot0_mul_6 = aot0_mul_7 = aot0_mul_8 = aot0_mul_9 = aot0_mul_10 = aot0_mul_11 = aot0_mul_12 = None |
| 141 | + aot0_permute_3: "f32[10, 10][1, 10]cpu" = torch.ops.aten.permute.default(aot0_cat, [1, 0]); aot0_cat = None |
| 142 | + aot0_sum_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.sum.dim_IntList(aot0_view_2, [0], True); aot0_view_2 = None |
| 143 | + aot0_view_3: "f32[10][1]cpu" = torch.ops.aten.view.default(aot0_sum_3, [10]); aot0_sum_3 = None |
| 144 | + |
| 145 | + # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 2) |
| 146 | + accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_2, aot0_view_3); getitem_2 = aot0_view_3 = accumulate_grad_ = None |
| 147 | + |
| 148 | + # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1) |
| 149 | + aot0_permute_4: "f32[10, 10][10, 1]cpu" = torch.ops.aten.permute.default(aot0_permute_3, [1, 0]); aot0_permute_3 = None |
| 150 | + |
| 151 | + # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 3) |
| 152 | + accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, aot0_permute_4); getitem_3 = aot0_permute_4 = accumulate_grad__1 = None |
| 153 | + _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None |
| 154 | + return [] |
| 155 | + """ |
| 156 | +
|
| 157 | +.. note:: This is the graph on which we will call ``torch.compile``, **NOT** the optimized graph. Compiled Autograd essentially generates some unoptimized Python code to represent the entire C++ autograd execution. |
| 158 | + |
| 159 | +Compiling the forward and backward pass using different flags |
| 160 | +------------------------------------------------------------- |
| 161 | +You can use different compiler configs for the two compilations, for example, the backward may be a fullgraph even if there are graph breaks in the forward. |
| 162 | + |
| 163 | +.. code:: python |
| 164 | +
|
| 165 | + def train(model, x): |
| 166 | + model = torch.compile(model) |
| 167 | + loss = model(x).sum() |
| 168 | + torch._dynamo.config.compiled_autograd = True |
| 169 | + torch.compile(lambda: loss.backward(), fullgraph=True)() |
| 170 | +
|
| 171 | +Or you can use the context manager, which will apply to all autograd calls within its scope. |
| 172 | + |
| 173 | +.. code:: python |
| 174 | +
|
| 175 | + def train(model, x): |
| 176 | + model = torch.compile(model) |
| 177 | + loss = model(x).sum() |
| 178 | + with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)): |
| 179 | + loss.backward() |
| 180 | +
|
| 181 | +
|
| 182 | +Compiled Autograd addresses certain limitations of AOTAutograd |
| 183 | +-------------------------------------------------------------- |
| 184 | +1. Graph breaks in the forward pass lead to graph breaks in the backward pass: |
| 185 | + |
| 186 | +.. code:: python |
| 187 | +
|
| 188 | + @torch.compile(backend="aot_eager") |
| 189 | + def fn(x): |
| 190 | + # 1st graph |
| 191 | + temp = x + 10 |
| 192 | + torch._dynamo.graph_break() |
| 193 | + # 2nd graph |
| 194 | + temp = temp + 10 |
| 195 | + torch._dynamo.graph_break() |
| 196 | + # 3rd graph |
| 197 | + return temp.sum() |
| 198 | +
|
| 199 | + x = torch.randn(10, 10, requires_grad=True) |
| 200 | + torch._dynamo.utils.counters.clear() |
| 201 | + loss = fn(x) |
| 202 | +
|
| 203 | + # 1. base torch.compile |
| 204 | + loss.backward(retain_graph=True) |
| 205 | + assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3) |
| 206 | + torch._dynamo.utils.counters.clear() |
| 207 | +
|
| 208 | + # 2. torch.compile with compiled autograd |
| 209 | + with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")): |
| 210 | + loss.backward() |
| 211 | +
|
| 212 | + # single graph for the backward |
| 213 | + assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1) |
| 214 | +
|
| 215 | +
|
| 216 | +In the first ``torch.compile`` case, we see that 3 backward graphs were produced due to the 2 graph breaks in the compiled function ``fn``. |
| 217 | +Whereas in the second ``torch.compile`` with compiled autograd case, we see that a full backward graph was traced despite the graph breaks. |
| 218 | + |
| 219 | +2. Backward hooks are not captured |
| 220 | + |
| 221 | +.. code:: python |
| 222 | +
|
| 223 | + @torch.compile(backend="aot_eager") |
| 224 | + def fn(x): |
| 225 | + return x.sum() |
| 226 | +
|
| 227 | + x = torch.randn(10, 10, requires_grad=True) |
| 228 | + x.register_hook(lambda grad: grad+10) |
| 229 | + loss = fn(x) |
| 230 | +
|
| 231 | + with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")): |
| 232 | + loss.backward() |
| 233 | +
|
| 234 | +There should be a ``call_hook`` node in the graph, which dynamo will later inline into the following: |
| 235 | + |
| 236 | +.. code:: python |
| 237 | +
|
| 238 | + stderr_output = """ |
| 239 | + DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[] |
| 240 | + DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH |
| 241 | + ===== Compiled autograd graph ===== |
| 242 | + <eval_with_key>.2 class CompiledAutograd(torch.nn.Module): |
| 243 | + def forward(self, inputs, sizes, scalars, hooks): |
| 244 | + ... |
| 245 | + getitem_2 = hooks[0]; hooks = None |
| 246 | + call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None |
| 247 | + ... |
| 248 | + """ |
| 249 | +
|
| 250 | +Common recompilation reasons for Compiled Autograd |
| 251 | +-------------------------------------------------- |
| 252 | +1. Due to changes in the autograd structure of the loss value: |
| 253 | + |
| 254 | +.. code:: python |
| 255 | +
|
| 256 | + torch._dynamo.config.compiled_autograd = True |
| 257 | + x = torch.randn(10, requires_grad=True) |
| 258 | + for op in [torch.add, torch.sub, torch.mul, torch.div]: |
| 259 | + loss = op(x, x).sum() |
| 260 | + torch.compile(lambda: loss.backward(), backend="eager")() |
| 261 | +
|
| 262 | +In the example above, we call a different operator on each iteration, leading to ``loss`` tracking a different autograd history each time. You should see some recompile messages: **Cache miss due to new autograd node**. |
| 263 | + |
| 264 | +.. code:: python |
| 265 | +
|
| 266 | + stderr_output = """ |
| 267 | + Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[] |
| 268 | + ... |
| 269 | + Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[] |
| 270 | + ... |
| 271 | + Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[] |
| 272 | + ... |
| 273 | + Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[] |
| 274 | + ... |
| 275 | + """ |
| 276 | +
|
| 277 | +2. Due to tensors changing shapes: |
| 278 | + |
| 279 | +.. code:: python |
| 280 | +
|
| 281 | + torch._dynamo.config.compiled_autograd = True |
| 282 | + for i in [10, 100, 10]: |
| 283 | + x = torch.randn(i, i, requires_grad=True) |
| 284 | + loss = x.sum() |
| 285 | + torch.compile(lambda: loss.backward(), backend="eager")() |
| 286 | +
|
| 287 | +In the example above, ``x`` changes shapes, and compiled autograd will mark ``x`` as a dynamic shape tensor after the first change. You should see recompiles messages: **Cache miss due to changed shapes**. |
| 288 | + |
| 289 | +.. code:: python |
| 290 | +
|
| 291 | + stderr_output = """ |
| 292 | + ... |
| 293 | + Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic |
| 294 | + Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic |
| 295 | + Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic |
| 296 | + Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic |
| 297 | + ... |
| 298 | + """ |
| 299 | +
|
| 300 | +Conclusion |
| 301 | +---------- |
| 302 | +In this tutorial, we went over the high-level ecosystem of ``torch.compile`` with compiled autograd, the basics of compiled autograd and a few common recompilation reasons. Stay tuned for deep dives on `dev-discuss <https://dev-discuss.pytorch.org/>`_. |
0 commit comments