|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +""" |
| 4 | +Compiled Autograd: Capturing a larger backward graph for ``torch.compile`` |
| 5 | +========================================================================== |
| 6 | +
|
| 7 | +""" |
| 8 | + |
| 9 | +###################################################################### |
| 10 | +# Compiled Autograd is a torch.compile extension introduced in PyTorch 2.4 |
| 11 | +# that allows the capture of a larger backward graph. It is highly recommended |
| 12 | +# to familiarize yourself with `torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_. |
| 13 | +# |
| 14 | + |
| 15 | +###################################################################### |
| 16 | +# Doesn't torch.compile already capture the backward graph? |
| 17 | +# ------------ |
| 18 | +# Partially. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations: |
| 19 | +# - Graph breaks in the forward lead to graph breaks in the backward |
| 20 | +# - `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured |
| 21 | +# |
| 22 | +# Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowing |
| 23 | +# it to capture the full backward graph at runtime. Models with these two characteristics should try |
| 24 | +# Compiled Autograd, and potentially observe better performance. |
| 25 | +# |
| 26 | +# However, Compiled Autograd has its own limitations: |
| 27 | +# - Dynamic autograd structure leads to recompiles |
| 28 | +# |
| 29 | + |
| 30 | +###################################################################### |
| 31 | +# Basic Usage |
| 32 | +# ------------ |
| 33 | +# |
| 34 | + |
| 35 | +# NOTE: Must be enabled before using the decorator |
| 36 | +torch._dynamo.config.compiled_autograd = True |
| 37 | + |
| 38 | +class Model(torch.nn.Module): |
| 39 | + def __init__(self): |
| 40 | + super().__init__() |
| 41 | + self.linear = torch.nn.Linear(10, 10) |
| 42 | + |
| 43 | + def forward(self, x): |
| 44 | + return self.linear(x) |
| 45 | + |
| 46 | +@torch.compile |
| 47 | +def train(model, x): |
| 48 | + loss = model(x).sum() |
| 49 | + loss.backward() |
| 50 | + |
| 51 | +model = Model() |
| 52 | +x = torch.randn(10) |
| 53 | +train(model, x) |
| 54 | + |
| 55 | +###################################################################### |
| 56 | +# Inspecting the compiled autograd logs |
| 57 | +# ------------ |
| 58 | +# Run the script with either TORCH_LOGS environment variables |
| 59 | +# |
| 60 | +""" |
| 61 | +Prints graph: |
| 62 | +TORCH_LOGS="compiled_autograd" python example.py |
| 63 | +
|
| 64 | +Performance degrading, prints verbose graph and recompile reasons: |
| 65 | +TORCH_LOGS="compiled_autograd_verbose" python example.py |
| 66 | +""" |
| 67 | + |
| 68 | +###################################################################### |
| 69 | +# Or with the set_logs private API: |
| 70 | +# |
| 71 | + |
| 72 | +# flag must be enabled before wrapping using torch.compile |
| 73 | +torch._logging._internal.set_logs(compiled_autograd=True) |
| 74 | + |
| 75 | +@torch.compile |
| 76 | +def train(model, x): |
| 77 | + loss = model(x).sum() |
| 78 | + loss.backward() |
| 79 | + |
| 80 | +train(model, x) |
| 81 | + |
| 82 | +###################################################################### |
| 83 | +# The compiled autograd graph should now be logged to stdout. Certain graph nodes will have names that are prefixed by "aot0_", |
| 84 | +# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0. |
| 85 | +# |
| 86 | +# NOTE: This is the graph that we will call torch.compile on, NOT the optimized graph. Compiled Autograd basically |
| 87 | +# generated some python code to represent the entire C++ autograd execution. |
| 88 | +# |
| 89 | +""" |
| 90 | +INFO:torch._dynamo.compiled_autograd.__compiled_autograd:TRACED GRAPH |
| 91 | + ===== Compiled autograd graph ===== |
| 92 | + <eval_with_key>.4 class CompiledAutograd(torch.nn.Module): |
| 93 | + def forward(self, inputs, sizes, scalars, hooks): |
| 94 | + # No stacktrace found for following nodes |
| 95 | + aot0_tangents_1: "f32[][]cpu" = inputs[0] |
| 96 | + aot0_primals_3: "f32[10][1]cpu" = inputs[1] |
| 97 | + getitem_2: "f32[10][1]cpu" = inputs[2] |
| 98 | + getitem_3: "f32[10, 10][10, 1]cpu" = inputs[3]; inputs = None |
| 99 | + |
| 100 | + # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1) |
| 101 | + aot0_expand: "f32[10][0]cpu" = torch.ops.aten.expand.default(aot0_tangents_1, [10]); aot0_tangents_1 = None |
| 102 | + aot0_view_2: "f32[1, 10][0, 0]cpu" = torch.ops.aten.view.default(aot0_expand, [1, 10]); aot0_expand = None |
| 103 | + aot0_permute_2: "f32[10, 1][0, 0]cpu" = torch.ops.aten.permute.default(aot0_view_2, [1, 0]) |
| 104 | + aot0_select: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 0) |
| 105 | + aot0_view: "f32[1, 10][10, 1]cpu" = torch.ops.aten.view.default(aot0_primals_3, [1, 10]); aot0_primals_3 = None |
| 106 | + aot0_mul_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select, aot0_view); aot0_select = None |
| 107 | + aot0_select_1: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 1) |
| 108 | + aot0_mul_4: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_1, aot0_view); aot0_select_1 = None |
| 109 | + aot0_select_2: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 2) |
| 110 | + aot0_mul_5: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_2, aot0_view); aot0_select_2 = None |
| 111 | + aot0_select_3: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 3) |
| 112 | + aot0_mul_6: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_3, aot0_view); aot0_select_3 = None |
| 113 | + aot0_select_4: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 4) |
| 114 | + aot0_mul_7: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_4, aot0_view); aot0_select_4 = None |
| 115 | + aot0_select_5: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 5) |
| 116 | + aot0_mul_8: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_5, aot0_view); aot0_select_5 = None |
| 117 | + aot0_select_6: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 6) |
| 118 | + aot0_mul_9: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_6, aot0_view); aot0_select_6 = None |
| 119 | + aot0_select_7: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 7) |
| 120 | + aot0_mul_10: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_7, aot0_view); aot0_select_7 = None |
| 121 | + aot0_select_8: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 8) |
| 122 | + aot0_mul_11: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_8, aot0_view); aot0_select_8 = None |
| 123 | + aot0_select_9: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 9); aot0_permute_2 = None |
| 124 | + aot0_mul_12: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_9, aot0_view); aot0_select_9 = aot0_view = None |
| 125 | + aot0_cat: "f32[10, 10][10, 1]cpu" = torch.ops.aten.cat.default([aot0_mul_3, aot0_mul_4, aot0_mul_5, aot0_mul_6, aot0_mul_7, aot0_mul_8, aot0_mul_9, aot0_mul_10, aot0_mul_11, aot0_mul_12]); aot0_mul_3 = aot0_mul_4 = aot0_mul_5 = aot0_mul_6 = aot0_mul_7 = aot0_mul_8 = aot0_mul_9 = aot0_mul_10 = aot0_mul_11 = aot0_mul_12 = None |
| 126 | + aot0_permute_3: "f32[10, 10][1, 10]cpu" = torch.ops.aten.permute.default(aot0_cat, [1, 0]); aot0_cat = None |
| 127 | + aot0_sum_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.sum.dim_IntList(aot0_view_2, [0], True); aot0_view_2 = None |
| 128 | + aot0_view_3: "f32[10][1]cpu" = torch.ops.aten.view.default(aot0_sum_3, [10]); aot0_sum_3 = None |
| 129 | + |
| 130 | + # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 2) |
| 131 | + accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_2, aot0_view_3); getitem_2 = aot0_view_3 = accumulate_grad_ = None |
| 132 | + |
| 133 | + # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1) |
| 134 | + aot0_permute_4: "f32[10, 10][10, 1]cpu" = torch.ops.aten.permute.default(aot0_permute_3, [1, 0]); aot0_permute_3 = None |
| 135 | + |
| 136 | + # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 3) |
| 137 | + accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, aot0_permute_4); getitem_3 = aot0_permute_4 = accumulate_grad__1 = None |
| 138 | + _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None |
| 139 | + return [] |
| 140 | +""" |
| 141 | + |
| 142 | +###################################################################### |
| 143 | +# Compiling the forward and backward pass using different flags |
| 144 | +# ------------ |
| 145 | +# |
| 146 | + |
| 147 | +def train(model, x): |
| 148 | + model = torch.compile(model) |
| 149 | + loss = model(x).sum() |
| 150 | + torch.compile(lambda: loss.backward(), fullgraph=True)() |
| 151 | + |
| 152 | +###################################################################### |
| 153 | +# Or you can use the context manager, which will apply to all autograd calls within it |
| 154 | +# |
| 155 | + |
| 156 | +def train(model, x): |
| 157 | + model = torch.compile(model) |
| 158 | + loss = model(x).sum() |
| 159 | + with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)): |
| 160 | + loss.backward() |
| 161 | + |
| 162 | + |
| 163 | +###################################################################### |
| 164 | +# Demonstrating the limitations of AOTAutograd addressed by Compiled Autograd |
| 165 | +# ------------ |
| 166 | +# 1. Graph breaks in the forward lead to graph breaks in the backward |
| 167 | +# |
| 168 | + |
| 169 | +@torch.compile(backend="aot_eager") |
| 170 | +def fn(x): |
| 171 | + # 1st graph |
| 172 | + temp = x + 10 |
| 173 | + torch._dynamo.graph_break() |
| 174 | + # 2nd graph |
| 175 | + temp = temp + 10 |
| 176 | + torch._dynamo.graph_break() |
| 177 | + # 3rd graph |
| 178 | + return temp.sum() |
| 179 | + |
| 180 | +x = torch.randn(10, 10, requires_grad=True) |
| 181 | +loss = fn(x) |
| 182 | + |
| 183 | +# 1. base torch.compile |
| 184 | +loss.backward(retain_graph=True) |
| 185 | +assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3) |
| 186 | +torch._dynamo.utils.counters.clear() |
| 187 | + |
| 188 | +# 2. torch.compile with compiled autograd |
| 189 | +with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")): |
| 190 | + loss.backward() |
| 191 | + |
| 192 | +# single graph for the backward |
| 193 | +assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1) |
| 194 | + |
| 195 | + |
| 196 | +###################################################################### |
| 197 | +# 2. `Backward hooks are not captured |
| 198 | +# |
| 199 | + |
| 200 | +@torch.compile(backend="aot_eager") |
| 201 | +def fn(x): |
| 202 | + return x.sum() |
| 203 | + |
| 204 | +x = torch.randn(10, 10, requires_grad=True) |
| 205 | +x.register_hook(lambda grad: grad+10) |
| 206 | +loss = fn(x) |
| 207 | + |
| 208 | +torch._logging._internal.set_logs(compiled_autograd=True) |
| 209 | +with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")): |
| 210 | + loss.backward() |
| 211 | + |
| 212 | +###################################################################### |
| 213 | +# There is a `call_hook` node in the graph, which dynamo will inline |
| 214 | +# |
| 215 | + |
| 216 | +""" |
| 217 | +INFO:torch._dynamo.compiled_autograd.__compiled_autograd:TRACED GRAPH |
| 218 | + ===== Compiled autograd graph ===== |
| 219 | + <eval_with_key>.2 class CompiledAutograd(torch.nn.Module): |
| 220 | + def forward(self, inputs, sizes, scalars, hooks): |
| 221 | + ... |
| 222 | + getitem_2 = hooks[0]; hooks = None |
| 223 | + call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None |
| 224 | + ... |
| 225 | +""" |
| 226 | + |
| 227 | +###################################################################### |
| 228 | +# Understanding recompilation reasons for Compiled Autograd |
| 229 | +# ------------ |
| 230 | +# 1. Due to change in autograd structure |
| 231 | + |
| 232 | +torch._logging._internal.set_logs(compiled_autograd_verbose=True) |
| 233 | +torch._dynamo.config.compiled_autograd = True |
| 234 | +x = torch.randn(10, requires_grad=True) |
| 235 | +for op in [torch.add, torch.sub, torch.mul, torch.div]: |
| 236 | + loss = op(x, x).sum() |
| 237 | + torch.compile(lambda: loss.backward(), backend="eager")() |
| 238 | + |
| 239 | +###################################################################### |
| 240 | +# You should see some cache miss logs (recompiles): |
| 241 | +# Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[] |
| 242 | +# ... |
| 243 | +# Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[] |
| 244 | +# ... |
| 245 | +# Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[] |
| 246 | +# ... |
| 247 | +# Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[] |
| 248 | +# ... |
| 249 | + |
| 250 | +###################################################################### |
| 251 | +# 2. Due to dynamic shapes |
| 252 | +# |
| 253 | + |
| 254 | +torch._logging._internal.set_logs(compiled_autograd_verbose=True) |
| 255 | +torch._dynamo.config.compiled_autograd = True |
| 256 | +for i in [10, 100, 10]: |
| 257 | + x = torch.randn(i, i, requires_grad=True) |
| 258 | + loss = x.sum() |
| 259 | + torch.compile(lambda: loss.backward(), backend="eager")() |
| 260 | + |
| 261 | +###################################################################### |
| 262 | +# You should see some cache miss logs (recompiles): |
| 263 | +# ... |
| 264 | +# Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic |
| 265 | +# Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic |
| 266 | +# Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic |
| 267 | +# Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic |
| 268 | +# ... |
| 269 | + |
| 270 | +###################################################################### |
| 271 | +# Compatibility and rough edges |
| 272 | +# ------------ |
| 273 | +# |
| 274 | +# Compiled Autograd is under active development and is not yet compatible with all existing PyTorch features. |
| 275 | +# For the latest status on a particular feature, refer to: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY. |
0 commit comments