6
6
Hello there! This tutorial aims to showcase one way of reducing the
7
7
memory footprint of a training loop by reducing the memory taken by
8
8
the *gradients*. Say you have a model and you're interested in ways to
9
- optimize memory to avoid OOMing or simply to ooze more out of your GPU.
10
- Well, you _might_ be in luck (if gradients take up a portion of your
11
- memory and you do not need to do gradient accumulation)! We will explore
9
+ optimize memory to avoid ``Out of Memory`` (OOM) errors or simply to ooze
10
+ more out of your GPU. Well, you _might_ be in luck (if gradients take up
11
+ a portion of your memory and you do not need to do gradient accumulation).
12
+ We will explore the following:
12
13
13
14
1. What takes up memory during your training or finetuning loop,
14
15
2. How to capture and visualize memory snapshots to determine the bottleneck,
15
- 3. The new `Tensor.register_post_accumulate_grad_hook(hook)` API, and finally,
16
- 4. How everything fits together in 10 lines to achieve memory savings
16
+ 3. The new `` Tensor.register_post_accumulate_grad_hook(hook)` ` API, and finally,
17
+ 4. How everything fits together in 10 lines to achieve memory savings.
17
18
18
- The ingredients and tools required :
19
- 1. PyTorch 2.1.0 or newer with torchvision
20
- 2. A CUDA GPU
19
+ To run this tutorial, you will need :
20
+ * PyTorch 2.1.0 or newer with `` torchvision``
21
+ * 1 CUDA GPU
21
22
22
23
Let us start by importing the required modules and models. We will use a
23
- vision transformer model from torchvision, but feel free to substitute with
24
- your own model. We will also use `torch.optim.Adam` as our optimizer, but ,
25
- again, feel free to substitute with your own optimizer.
24
+ vision transformer model from torchvision, but feel free to substitute
25
+ with your own model. We will also use `` torch.optim.Adam`` as our optimizer,
26
+ but, again, feel free to substitute with your own optimizer.
26
27
27
28
"""
28
29
36
37
###############################################################################
37
38
# Now let's define our typical training loop. You should use real images when
38
39
# training, but for the purposes of this tutorial, we are passing in fake
39
- # inputs and not worrying about loading actual data.
40
+ # inputs and not worrying about loading any actual data.
40
41
41
42
IMAGE_SIZE = 224
42
43
@@ -53,18 +54,18 @@ def train(model, optimizer):
53
54
optimizer .zero_grad ()
54
55
55
56
###############################################################################
56
- # So what comprises the memory usage during training?
57
- # """""""""""""""""""""""""""""""""""""""""""""""""""
57
+ # Memory usage during training
58
+ # """"""""""""""""""""""""""""
58
59
# We are about to look at some memory snapshots, so we should be prepared to
59
- # analyze them properly. People normally consider training memory to consist of
60
+ # analyze them properly. Typically, training memory consists of:
60
61
#
61
- # 1. Model parameters (size P)
62
- # 2. Activations (size A)
63
- # 3. Gradients, which are the same size as the model parameters, so size G = P
64
- # 4. Optimizer state, which is usually a relation to the model parameters. In
65
- # this case, Adam state requires 2x the model parameters, so size O = 2P
66
- # 5. Intermediate tensors, which are allocated throughout the compute. We will
67
- # not worry about them for now as they are usually small and ephemeral.
62
+ # * Model parameters (size P)
63
+ # * Activations (size A)
64
+ # * Gradients, which are the same size as the model parameters, so size G = P.
65
+ # * Optimizer state, which is usually a relation to the model parameters. In
66
+ # this case, the state for Adam requires 2x the model parameters, so size O = 2P.
67
+ # * Intermediate tensors, which are allocated throughout the compute. We will
68
+ # not worry about them for now as they are usually small and ephemeral.
68
69
#
69
70
# Capturing and visualizing memory snapshots
70
71
# """"""""""""""""""""""""""""""""""""""""""
@@ -79,33 +80,33 @@ def train(model, optimizer):
79
80
train (model , optimizer )
80
81
81
82
# save a snapshot of the memory allocations
82
- # s = torch.cuda.memory._snapshot()
83
- # with open(f"snapshot.pickle", "wb") as f:
84
- # dump(s, f)
83
+ s = torch .cuda .memory ._snapshot ()
84
+ with open (f"snapshot.pickle" , "wb" ) as f :
85
+ dump (s , f )
85
86
86
87
# tell CUDA to stop recording memory allocations now
87
88
torch .cuda .memory ._record_memory_history (enabled = None )
88
89
89
90
###############################################################################
90
91
# Now open up the snapshot in Zach Devito's [CUDA Memory Visualizer](
91
- # https://zdevito.github.io/assets/viz/) by dragging the snapshot.pickle file.
92
- # Does the memory timeline match your expectations?
92
+ # https://zdevito.github.io/assets/viz/) by dragging and dropping the
93
+ # ``snapshot.pickle`` file. Does the memory timeline match your expectations?
93
94
#
94
95
# .. figure:: /_static/img/optim_step_in_bwd/snapshot.jpg
95
96
# :alt: snapshot.png loaded into CUDA Memory Visualizer
96
97
#
97
98
# The model parameters have already been loaded in memory before the training
98
99
# step, so we see a chunk of memory devoted to the weights right off the bat.
99
- # As we start our forward, memory is allocated gradually for the activations, or
100
- # the tensors we are saving to be able to compute gradients in the backward.
101
- # Once we start the backward, the activations are gradually freed while memory
102
- # of the gradients start building up.
100
+ # As we start our forward pass , memory is allocated gradually for the activations,
101
+ # or the tensors we are saving to be able to compute gradients in the backward pass.
102
+ # Once we start the backward pass , the activations are gradually freed while memory
103
+ # of the gradients starts building up.
103
104
#
104
105
# Lastly, as the optimizer kicks in, its state will be lazily initialized, so we
105
106
# should see the optimizer state memory gradually increase during the optimizer
106
107
# step of the first training loop only. In future loops, the optimizer memory
107
- # will remain and be updated in-place. The memory for the gradients are then
108
- # freed accordingly at the end of every training loop when zero_grad is called.
108
+ # will remain and be updated in-place. The memory for the gradients is then
109
+ # freed accordingly at the end of every training loop when `` zero_grad`` is called.
109
110
#
110
111
# Where is the memory bottleneck in this training loop? Or, in other words,
111
112
# where is the peak memory?
@@ -115,23 +116,23 @@ def train(model, optimizer):
115
116
# the optimizer state as expected. The last ~1.2GB comes from Adam optimizer
116
117
# requiring memory for intermediates, totalling to ~6GB of peak memory.
117
118
# Technically, you can remove the need for the last 1.2GB for optimizer
118
- # intermediates if you set `Adam(model.parameters(), foreach=False)` which
119
- # would trade off runtime for memory. If switching off the foreach runtime
119
+ # intermediates if you set `` Adam(model.parameters(), foreach=False)` ` which
120
+ # would trade off runtime for memory. If switching off the `` foreach`` runtime
120
121
# optimization is sufficient in memory savings for you, nice, but please
121
122
# read on if you're curious how this tutorial can help you do better!
122
123
# With the technique we will soon introduce, we will reduce peak memory by
123
124
# removing the need for the ~1.2GB of **gradients memory** as well as **optimizer
124
- # intermediates memory**. Now do the math-- what would the new peak memory be?
125
- # The answer will be revealed in the next snapshot :D .
125
+ # intermediates memory**. Now, what would you expect the new peak memory to be?
126
+ # The answer will be revealed in the ` next` snapshot.
126
127
#
127
- # DISCLAIMER. Yes. Is the time for the disclaimer:
128
- #
129
- # Is this technique applicable to you? Only if gradients take up sizable memory.
130
- # """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
131
- # The technique of fusing the optimizer step into the backward only targets
132
- # reducing *gradient* memory (and as a side effect also optimizer intermediates
133
- # memory) . In our example above, the gradients eat up 20% of the memory pie, which
134
- # is quite sizable!
128
+ # DISCLAIMER: This technique is **not** for all
129
+ # """""""""""""""""""""""""""""""""""""""""""""
130
+ # Before we get too excited, we have to consider whether this technique is applicable
131
+ # for `your` use case. This is NOT a silver bullet! The technique of fusing the
132
+ # optimizer step into the backward only targets reducing *gradient* memory (and as a side effect also optimizer intermediates
133
+ # memory). Thus, the more sizable the memory taken up by the gradients, the more
134
+ # tantamount the memory reduction . In our example above, the gradients eat up 20%
135
+ # of the memory pie, which is quite sizable!
135
136
#
136
137
# This may not be the case for you, for example, if your weights are already tiny,
137
138
# (say, due to applying LoRa,) then the gradients do not take much space in your
@@ -140,35 +141,36 @@ def train(model, optimizer):
140
141
# training, quantization, or reducing the batch size. Then, when the gradients
141
142
# are part of the bottleneck again, come back to this tutorial!
142
143
#
143
- # Still here? Cool, let's introduce our new `register_post_accumulate_grad_hook(hook)`
144
+ # Still here? Cool, let's introduce our new `` register_post_accumulate_grad_hook(hook)` `
144
145
# API on Tensor.
145
146
#
146
- # `Tensor.register_post_accumulate_grad_hook(hook)` API and our technique
147
+ # `` Tensor.register_post_accumulate_grad_hook(hook)` ` API and our technique
147
148
# """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
148
- # Our technique relies on not having to save the gradients during `backward()`. Instead,
149
+ # Our technique relies on not having to save the gradients during `` backward()` `. Instead,
149
150
# once a gradient has been accumulated, we will immediately apply the optimizer to
150
151
# the corresponding parameter and drop that gradient entirely! This removes the need
151
152
# for holding onto a big buffer of gradients until the optimizer step.
152
153
#
153
- # So how can we unlock the behavior of applying the optimizer more eagerly? Well,
154
- # in 2.1, we've added a new API [`Tensor.register_post_accumulate_grad_hook(hook)`](
155
- # https://pytorch.org/docs/main/generated/torch.Tensor.register_post_accumulate_grad_hook.html#torch.Tensor.register_post_accumulate_grad_hook)
156
- # that would allow us to add a hook onto a Tensor once its `.grad` field has been
154
+ # So how can we unlock the behavior of applying the optimizer more eagerly? In our 2.1
155
+ # release, we've added a new API :func:`torch.Tensor.register_post_accumulate_grad_hook`
156
+ # that would allow us to add a hook onto a Tensor once its ``.grad`` field has been
157
157
# accumulated. We will encapsulate the optimizer step into this hook. How?
158
158
#
159
159
# How everything fits together in 10 lines
160
160
# """"""""""""""""""""""""""""""""""""""""
161
161
# Remember our model and optimizer setup from the beginning? I'll leave them commented
162
162
# out below so we don't spend resources rerunning the code.
163
+ #
164
+ # .. code-block:: python
165
+ #
166
+ # model = models.vit_l_16(weights='DEFAULT').cuda()
167
+ # optimizer = torch.optim.Adam(model.parameters())
163
168
164
- # model = models.vit_l_16(weights='DEFAULT').cuda()
165
- # optimizer = torch.optim.Adam(model.parameters())
166
-
167
- # Instead of having just *one* optimizer, we will have a Dict of optimizers
169
+ # Instead of having just *one* optimizer, we will have a ``dict`` of optimizers
168
170
# for every parameter so we could reference them in our hook.
169
171
optimizer_dict = {p : torch .optim .Adam ([p ]) for p in model .parameters ()}
170
172
171
- # Define our hook, which will call the optimizer `step()` and `zero_grad()`
173
+ # Define our hook, which will call the optimizer `` step()`` and `` zero_grad()` `
172
174
def optimizer_hook (parameter ) -> None :
173
175
optimizer_dict [parameter ].step ()
174
176
optimizer_dict [parameter ].zero_grad ()
@@ -177,7 +179,7 @@ def optimizer_hook(parameter) -> None:
177
179
for p in model .parameters ():
178
180
p .register_post_accumulate_grad_hook (optimizer_hook )
179
181
180
- # Now remember our previous `train()` function? Since the optimizer has been
182
+ # Now remember our previous `` train()` ` function? Since the optimizer has been
181
183
# fused into the backward, we can remove the optimizer step and zero_grad calls.
182
184
def train (model ):
183
185
# create our fake image input: tensor shape is batch_size, channels, height, width
@@ -192,20 +194,21 @@ def train(model):
192
194
# optimizer.zero_grad()
193
195
194
196
########################################################################
195
- # I believe that was about 10 lines of changes in our sample model. I do
196
- # recognize that it could be a fairly intrusive change to switch out the
197
- # optimizer for a optimizer dictionary, especially for those who use
198
- # `LRScheduler`s or manipulate optimizer configuration throughout the
197
+ # That took about 10 lines of changes in our sample model, which is neat.
198
+ # However, for real models, it could be a fairly intrusive change to switch
199
+ # out the optimizer for an optimizer dictionary, especially for those who use
200
+ # `` LRScheduler` `s or manipulate optimizer configuration throughout the
199
201
# training epochs. Working out this API with those changes will be more
200
- # involved and likely requires moving more configuration into global
202
+ # involved and will likely require moving more configuration into global
201
203
# state but should not be impossible. That said, a next step for us is
202
204
# to make this API easier to adopt with LRSchedulers and other features
203
205
# you are already used to.
204
206
#
205
207
# But let me get back to convincing you that this technique is worth it.
206
208
# We will consult our friend, the memory snapshot.
207
209
208
- # del optimizer memory from before to get a clean slate
210
+ # del optimizer memory from before to get a clean slate for the next
211
+ # memory snapshot
209
212
del optimizer
210
213
211
214
# tell CUDA to start recording memory allocations
@@ -230,21 +233,21 @@ def train(model):
230
233
# :alt: snapshot.png loaded into CUDA Memory Visualizer
231
234
#
232
235
# Several major observations:
233
- # 1. There is no more optimizer step! Right...we fused that into the backward.
234
- # 2. Likewise, the backward drags longer and there are more random allocations
235
- # for intermediates. This is expected, as the optimizer step requires
236
- # intermediates.
237
- # 3. Most importantly! The peak memory is lower! It is now ~4GB (which I
238
- # hope is close to your answer earlier :).)
236
+ # 1. There is no more optimizer step! Right...we fused that into the backward.
237
+ # 2. Likewise, the backward drags longer and there are more random allocations
238
+ # for intermediates. This is expected, as the optimizer step requires
239
+ # intermediates.
240
+ # 3. Most importantly! The peak memory is lower! It is now ~4GB (which I
241
+ # hope maps closely to your earlier expectation).
239
242
#
240
243
# Note that there is no longer any big chunk of memory allocated for the gradients
241
244
# compared to before, accounting for ~1.2GB of memory savings. Instead, we've freed
242
245
# each gradient very quickly after they've been computed by moving the optimizer
243
- # step as far ahead as we can. Woohoo ! By the way, the other ~1.2GB of memory savings
246
+ # step as far ahead as we can. Woo-hoo ! By the way, the other ~1.2GB of memory savings
244
247
# comes from breaking apart the optimizer into per-parameter optimizers, so the
245
- # intermediates have proportionally shrunk. I'd like to stress this detail is less
246
- # important than the gradient memory savings, as you can get optimizer intermediates
247
- # savings from just turning `foreach=False` without this technique.
248
+ # intermediates have proportionally shrunk. This detail is ` less important` than
249
+ # the gradient memory savings, as you can get optimizer intermediates savings
250
+ # from just turning `` foreach=False` ` without this technique.
248
251
#
249
252
# You may be correctly wondering: if we saved 2.4GB of memory, why is the peak memory
250
253
# NOT 6GB - 2.4GB = 3.6GB? Well, the peak has moved! The peak is now near the start
@@ -254,10 +257,10 @@ def train(model):
254
257
# imagine that this technique can be coupled with activations checkpointing for more
255
258
# memory wins.
256
259
#
257
- # Recap
258
- # """""
260
+ # Conclusion
261
+ # """"""""""
259
262
# In this tutorial, we learned about the memory saving technique of
260
263
# fusing the optimizer into the backward step through the new
261
- # `Tensor.register_post_accumulate_grad_hook()` API and *when* to apply this
264
+ # `` Tensor.register_post_accumulate_grad_hook()` ` API and *when* to apply this
262
265
# technique (when gradients memory is significant). Along the way, we also learned
263
266
# about memory snapshots, which are generally useful in memory optimization.
0 commit comments