1
+ """
2
+
3
+ How to save memory by fusing the optimizer step into the backward pass
4
+ ======================================================================
5
+
6
+ Hello there! This tutorial aims to showcase one way of reducing the
7
+ memory footprint of a training loop by reducing the memory taken by
8
+ the *gradients*. Say you have a model and you're interested in ways to
9
+ optimize memory to avoid ``Out of Memory`` (OOM) errors or simply to ooze
10
+ more out of your GPU. Well, you _might_ be in luck (if gradients take up
11
+ a portion of your memory and you do not need to do gradient accumulation).
12
+ We will explore the following:
13
+
14
+ 1. What takes up memory during your training or finetuning loop,
15
+ 2. How to capture and visualize memory snapshots to determine the bottleneck,
16
+ 3. The new ``Tensor.register_post_accumulate_grad_hook(hook)`` API, and finally,
17
+ 4. How everything fits together in 10 lines to achieve memory savings.
18
+
19
+ To run this tutorial, you will need:
20
+
21
+ * PyTorch 2.1.0 or newer with ``torchvision``
22
+ * 1 CUDA GPU if you'd like to run the memory visualizations locally.
23
+ Otherwise, this technique would benefit similarly on any device.
24
+
25
+ Let us start by importing the required modules and models. We will use a
26
+ vision transformer model from torchvision, but feel free to substitute
27
+ with your own model. We will also use ``torch.optim.Adam`` as our optimizer,
28
+ but, again, feel free to substitute with your own optimizer.
29
+
30
+ """
31
+
32
+ import torch
33
+ from torchvision import models
34
+ from pickle import dump
35
+
36
+ model = models .vit_l_16 (weights = 'DEFAULT' ).cuda ()
37
+ optimizer = torch .optim .Adam (model .parameters ())
38
+
39
+ ###############################################################################
40
+ # Now let's define our typical training loop. You should use real images when
41
+ # training, but for the purposes of this tutorial, we are passing in fake
42
+ # inputs and not worrying about loading any actual data.
43
+
44
+ IMAGE_SIZE = 224
45
+
46
+ def train (model , optimizer ):
47
+ # create our fake image input: tensor shape is batch_size, channels, height, width
48
+ fake_image = torch .rand (1 , 3 , IMAGE_SIZE , IMAGE_SIZE ).cuda ()
49
+
50
+ # call our forward and backward
51
+ loss = model .forward (fake_image )
52
+ loss .sum ().backward ()
53
+
54
+ # optimizer update
55
+ optimizer .step ()
56
+ optimizer .zero_grad ()
57
+
58
+ ###############################################################################
59
+ # Memory usage during training
60
+ # """"""""""""""""""""""""""""
61
+ # We are about to look at some memory snapshots, so we should be prepared to
62
+ # analyze them properly. Typically, training memory consists of:
63
+ #
64
+ # * Model parameters (size P)
65
+ # * Activations that are saved for the backward pass (size A)
66
+ # * Gradients, which are the same size as the model parameters, so size G = P.
67
+ # * Optimizer state, which is proportional to the size of the parameters. In
68
+ # this case, the state for Adam requires 2x the model parameters, so size O = 2P.
69
+ # * Intermediate tensors, which are allocated throughout the compute. We will
70
+ # not worry about them for now as they are usually small and ephemeral.
71
+ #
72
+ # Capturing and visualizing memory snapshots
73
+ # """"""""""""""""""""""""""""""""""""""""""
74
+ # Let's get us a memory snapshot! As your code runs, consider what you may expect
75
+ # the CUDA memory timeline to look like.
76
+
77
+ # tell CUDA to start recording memory allocations
78
+ torch .cuda .memory ._record_memory_history (enabled = 'all' )
79
+
80
+ # train 3 steps
81
+ for _ in range (3 ):
82
+ train (model , optimizer )
83
+
84
+ # save a snapshot of the memory allocations
85
+ s = torch .cuda .memory ._snapshot ()
86
+ with open (f"snapshot.pickle" , "wb" ) as f :
87
+ dump (s , f )
88
+
89
+ # tell CUDA to stop recording memory allocations now
90
+ torch .cuda .memory ._record_memory_history (enabled = None )
91
+
92
+ ###############################################################################
93
+ # Now open up the snapshot in the CUDA Memory Visualizer at
94
+ # https://pytorch.org/memory_viz by dragging and dropping the
95
+ # ``snapshot.pickle`` file. Does the memory timeline match your expectations?
96
+ #
97
+ # .. figure:: /_static/img/optim_step_in_bwd/snapshot.jpg
98
+ # :alt: snapshot.png loaded into CUDA Memory Visualizer
99
+ #
100
+ # The model parameters have already been loaded in memory before the training
101
+ # step, so we see a chunk of memory devoted to the weights right off the bat.
102
+ # As we start our forward pass, memory is allocated gradually for the activations,
103
+ # or the tensors we are saving to be able to compute gradients in the backward pass.
104
+ # Once we start the backward pass, the activations are gradually freed while memory
105
+ # of the gradients starts building up.
106
+ #
107
+ # Lastly, as the optimizer kicks in, its state will be lazily initialized, so we
108
+ # should see the optimizer state memory gradually increase during the optimizer
109
+ # step of the first training loop only. In future loops, the optimizer memory
110
+ # will remain and be updated in-place. The memory for the gradients is then
111
+ # freed accordingly at the end of every training loop when ``zero_grad`` is called.
112
+ #
113
+ # Where is the memory bottleneck in this training loop? Or, in other words,
114
+ # where is the peak memory?
115
+ #
116
+ # The peak memory usage is during the optimizer step! Note the memory then
117
+ # consists of ~1.2GB of parameters, ~1.2GB of gradients, and ~2.4GB=2*1.2GB of
118
+ # the optimizer state as expected. The last ~1.2GB comes from Adam optimizer
119
+ # requiring memory for intermediates, totaling to ~6GB of peak memory.
120
+ # Technically, you can remove the need for the last 1.2GB for optimizer
121
+ # intermediates if you set ``Adam(model.parameters(), foreach=False)`` which
122
+ # would trade off runtime for memory. If switching off the ``foreach`` runtime
123
+ # optimization is sufficient in memory savings for you, nice, but please
124
+ # read on if you're curious how this tutorial can help you do better!
125
+ # With the technique we will soon introduce, we will reduce peak memory by
126
+ # removing the need for the ~1.2GB of **gradients memory** as well as **optimizer
127
+ # intermediates memory**. Now, what would you expect the new peak memory to be?
128
+ # The answer will be revealed in the `next` snapshot.
129
+ #
130
+ # DISCLAIMER: This technique is **not** for all
131
+ # """""""""""""""""""""""""""""""""""""""""""""
132
+ # Before we get too excited, we have to consider whether this technique is applicable
133
+ # for `your` use case. This is NOT a silver bullet! The technique of fusing the
134
+ # optimizer step into the backward only targets reducing *gradient* memory (and as a side effect also optimizer intermediates
135
+ # memory). Thus, the more sizable the memory taken up by the gradients, the more
136
+ # tantamount the memory reduction. In our example above, the gradients eat up 20%
137
+ # of the memory pie, which is quite sizable!
138
+ #
139
+ # This may not be the case for you, for example, if your weights are already tiny,
140
+ # (say, due to applying LoRa,) then the gradients do not take much space in your
141
+ # training loop and the wins are way less exciting. In that case, you should
142
+ # first try other techniques like activations checkpointing, distributed
143
+ # training, quantization, or reducing the batch size. Then, when the gradients
144
+ # are part of the bottleneck again, come back to this tutorial!
145
+ #
146
+ # Still here? Cool, let's introduce our new ``register_post_accumulate_grad_hook(hook)``
147
+ # API on Tensor.
148
+ #
149
+ # ``Tensor.register_post_accumulate_grad_hook(hook)`` API and our technique
150
+ # """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
151
+ # Our technique relies on not having to save the gradients during ``backward()``. Instead,
152
+ # once a gradient has been accumulated, we will immediately apply the optimizer to
153
+ # the corresponding parameter and drop that gradient entirely! This removes the need
154
+ # for holding onto a big buffer of gradients until the optimizer step.
155
+ #
156
+ # So how can we unlock the behavior of applying the optimizer more eagerly? In our 2.1
157
+ # release, we've added a new API :func:`torch.Tensor.register_post_accumulate_grad_hook`
158
+ # that would allow us to add a hook onto a Tensor once its ``.grad`` field has been
159
+ # accumulated. We will encapsulate the optimizer step into this hook. How?
160
+ #
161
+ # How everything fits together in 10 lines
162
+ # """"""""""""""""""""""""""""""""""""""""
163
+ # Remember our model and optimizer setup from the beginning? I'll leave them commented
164
+ # out below so we don't spend resources rerunning the code.
165
+ #
166
+ # .. code-block:: python
167
+ #
168
+ # model = models.vit_l_16(weights='DEFAULT').cuda()
169
+ # optimizer = torch.optim.Adam(model.parameters())
170
+
171
+ # Instead of having just *one* optimizer, we will have a ``dict`` of optimizers
172
+ # for every parameter so we could reference them in our hook.
173
+ optimizer_dict = {p : torch .optim .Adam ([p ], foreach = False ) for p in model .parameters ()}
174
+
175
+ # Define our hook, which will call the optimizer ``step()`` and ``zero_grad()``
176
+ def optimizer_hook (parameter ) -> None :
177
+ optimizer_dict [parameter ].step ()
178
+ optimizer_dict [parameter ].zero_grad ()
179
+
180
+ # Register the hook onto every parameter
181
+ for p in model .parameters ():
182
+ p .register_post_accumulate_grad_hook (optimizer_hook )
183
+
184
+ # Now remember our previous ``train()`` function? Since the optimizer has been
185
+ # fused into the backward, we can remove the optimizer step and zero_grad calls.
186
+ def train (model ):
187
+ # create our fake image input: tensor shape is batch_size, channels, height, width
188
+ fake_image = torch .rand (1 , 3 , IMAGE_SIZE , IMAGE_SIZE ).cuda ()
189
+
190
+ # call our forward and backward
191
+ loss = model .forward (fake_image )
192
+ loss .sum ().backward ()
193
+
194
+ # optimizer update --> no longer needed!
195
+ # optimizer.step()
196
+ # optimizer.zero_grad()
197
+
198
+ ########################################################################
199
+ # That took about 10 lines of changes in our sample model, which is neat.
200
+ # However, for real models, it could be a fairly intrusive change to switch
201
+ # out the optimizer for an optimizer dictionary, especially for those who use
202
+ # ``LRScheduler``s or manipulate optimizer configuration throughout the
203
+ # training epochs. Working out this API with those changes will be more
204
+ # involved and will likely require moving more configuration into global
205
+ # state but should not be impossible. That said, a next step for PyTorch
206
+ # is to make this API easier to adopt with LRSchedulers and other features
207
+ # you are already used to.
208
+ #
209
+ # But let me get back to convincing you that this technique is worth it.
210
+ # We will consult our friend, the memory snapshot.
211
+
212
+ # delete optimizer memory from before to get a clean slate for the next
213
+ # memory snapshot
214
+ del optimizer
215
+
216
+ # tell CUDA to start recording memory allocations
217
+ torch .cuda .memory ._record_memory_history (enabled = 'all' )
218
+
219
+ # train 3 steps. note that we no longer pass the optimizer into train()
220
+ for _ in range (3 ):
221
+ train (model )
222
+
223
+ # save a snapshot of the memory allocations
224
+ s = torch .cuda .memory ._snapshot ()
225
+ with open (f"snapshot-opt-in-bwd.pickle" , "wb" ) as f :
226
+ dump (s , f )
227
+
228
+ # tell CUDA to stop recording memory allocations now
229
+ torch .cuda .memory ._record_memory_history (enabled = None )
230
+
231
+ ###############################################################################
232
+ # Yes, take some time to drag your snapshot into the CUDA Memory Visualizer.
233
+ #
234
+ # .. figure:: /_static/img/optim_step_in_bwd/snapshot_opt_in_bwd.jpg
235
+ # :alt: snapshot.png loaded into CUDA Memory Visualizer
236
+ #
237
+ # Several major observations:
238
+ # 1. There is no more optimizer step! Right...we fused that into the backward.
239
+ # 2. Likewise, the backward drags longer and there are more random allocations
240
+ # for intermediates. This is expected, as the optimizer step requires
241
+ # intermediates.
242
+ # 3. Most importantly! The peak memory is lower! It is now ~4GB (which I
243
+ # hope maps closely to your earlier expectation).
244
+ #
245
+ # Note that there is no longer any big chunk of memory allocated for the gradients
246
+ # compared to before, accounting for ~1.2GB of memory savings. Instead, we've freed
247
+ # each gradient very quickly after they've been computed by moving the optimizer
248
+ # step as far ahead as we can. Woohoo! By the way, the other ~1.2GB of memory savings
249
+ # comes from breaking apart the optimizer into per-parameter optimizers, so the
250
+ # intermediates have proportionally shrunk. This detail is `less important` than
251
+ # the gradient memory savings, as you can get optimizer intermediates savings
252
+ # from just turning ``foreach=False`` without this technique.
253
+ #
254
+ # You may be correctly wondering: if we saved 2.4GB of memory, why is the peak memory
255
+ # NOT 6GB - 2.4GB = 3.6GB? Well, the peak has moved! The peak is now near the start
256
+ # of the backward step, when we still have activations in memory, where before, the peak
257
+ # was during the optimizer step when the activations had been freed. The ~0.4GB difference
258
+ # accounting for ~4.0GB - ~3.6GB is thus due to the activations memory. One can then
259
+ # imagine that this technique can be coupled with activations checkpointing for more
260
+ # memory wins.
261
+ #
262
+ # Conclusion
263
+ # """"""""""
264
+ # In this tutorial, we learned about the memory saving technique of
265
+ # fusing the optimizer into the backward step through the new
266
+ # ``Tensor.register_post_accumulate_grad_hook()`` API and *when* to apply this
267
+ # technique (when gradients memory is significant). Along the way, we also learned
268
+ # about memory snapshots, which are generally useful in memory optimization.
0 commit comments