5
5
6
6
Hello there! This tutorial aims to showcase one way of reducing the
7
7
memory footprint of a training loop by reducing the memory taken by
8
- the gradients. Say you have a model and you're interested in ways to
8
+ the * gradients* . Say you have a model and you're interested in ways to
9
9
optimize memory to avoid OOMing or simply to ooze more out of your GPU.
10
- Well, you _might_ be in luck! We will explore
10
+ Well, you _might_ be in luck (if gradients take up a portion of your
11
+ memory and you do not need to do gradient accumulation)! We will explore
11
12
12
13
1. What takes up memory during your training or finetuning loop,
13
- 2. Capturing and visualizing memory snapshots to determine the memory bottleneck,
14
- 3. The new `tensor.post_accumulate_grad_hook (hook)` API, and finally, if relevant ,
14
+ 2. How to capture and visualize memory snapshots to determine the bottleneck,
15
+ 3. The new `Tensor.register_post_accumulate_grad_hook (hook)` API, and finally,
15
16
4. How everything fits together in 10 lines to achieve memory savings
16
17
17
18
The ingredients and tools required:
@@ -57,99 +58,206 @@ def train(model, optimizer):
57
58
# We are about to look at some memory snapshots, so we should be prepared to
58
59
# analyze them properly. People normally consider training memory to consist of
59
60
#
60
- # 1. Model parameters (size P)
61
- # 2. Activations (size A)
62
- # 3. Gradients, which are the same size as the model parameters, so size G = P
63
- # 4. Optimizer state, which is usually a relation to the model parameters. In
64
- # this case, Adam state requires 2x the model parameters, so size O = 2P
65
- # 5. Intermediate tensors, which are allocated throughout the compute. We will
66
- # not worry about them for now as they are usually small and ephemeral.
61
+ # 1. Model parameters (size P)
62
+ # 2. Activations (size A)
63
+ # 3. Gradients, which are the same size as the model parameters, so size G = P
64
+ # 4. Optimizer state, which is usually a relation to the model parameters. In
65
+ # this case, Adam state requires 2x the model parameters, so size O = 2P
66
+ # 5. Intermediate tensors, which are allocated throughout the compute. We will
67
+ # not worry about them for now as they are usually small and ephemeral.
67
68
#
69
+ # Capturing and visualizing memory snapshots
70
+ # """"""""""""""""""""""""""""""""""""""""""
68
71
# Let's get us a memory snapshot! As your code runs, consider what you may expect
69
72
# the CUDA memory timeline to look like.
70
73
71
74
# tell CUDA to start recording memory allocations
72
75
torch .cuda .memory ._record_memory_history ()
73
76
74
77
# train 3 steps
75
- train (model , optimizer )
78
+ for _ in range (3 ):
79
+ train (model , optimizer )
76
80
77
81
# save a snapshot of the memory allocations
78
- s = torch .cuda .memory ._snapshot ()
79
- with open (f"snapshot.pickle" , "wb" ) as f :
80
- dump (s , f )
82
+ # s = torch.cuda.memory._snapshot()
83
+ # with open(f"snapshot.pickle", "wb") as f:
84
+ # dump(s, f)
81
85
82
- raise RuntimeError ("Stop here and open up the snapshot in Zach Devito's CUDA Memory Visualizer" )
86
+ # tell CUDA to stop recording memory allocations now
87
+ torch .cuda .memory ._record_memory_history (enabled = None )
83
88
84
89
###############################################################################
85
90
# Now open up the snapshot in Zach Devito's [CUDA Memory Visualizer](
86
91
# https://zdevito.github.io/assets/viz/) by dragging the snapshot.pickle file.
87
92
# Does the memory timeline match your expectations?
88
93
#
94
+ # .. figure:: /_static/img/optim_step_in_bwd/snapshot.jpg
95
+ # :alt: snapshot.png loaded into CUDA Memory Visualizer
96
+ #
89
97
# The model parameters have already been loaded in memory before the training
90
- # step, so we anticipate seeing a chunk of memory devoted to the weights right
91
- # off the bat. As we start our forward, memory should be allocated gradually
92
- # for the activations, or the tensors we are saving to be able to compute gradients
93
- # in the backward. Once we start the backward, the activations should be gradually
94
- # freed while memory of the gradients start building up.
98
+ # step, so we see a chunk of memory devoted to the weights right off the bat.
99
+ # As we start our forward, memory is allocated gradually for the activations, or
100
+ # the tensors we are saving to be able to compute gradients in the backward.
101
+ # Once we start the backward, the activations are gradually freed while memory
102
+ # of the gradients start building up.
95
103
#
96
104
# Lastly, as the optimizer kicks in, its state will be lazily initialized, so we
97
- # should see the optimizer state memory gradually increase during the end of the
98
- # first training loop only. In future loops, the optimizer memory will remain and
99
- # be inplace updated. The memory for the gradients should be freed accordingly
100
- # by the end of every training loop.
105
+ # should see the optimizer state memory gradually increase during the optimizer
106
+ # step of the first training loop only. In future loops, the optimizer memory
107
+ # will remain and be updated in-place. The memory for the gradients are then
108
+ # freed accordingly at the end of every training loop when zero_grad is called.
109
+ #
110
+ # Where is the memory bottleneck in this training loop? Or, in other words,
111
+ # where is the peak memory?
112
+ #
113
+ # The peak memory usage is during the optimizer step! Note the memory then
114
+ # consists of ~1.2GB of params, ~1.2GB of gradients, and ~2.4GB=2*1.2GB of
115
+ # the optimizer state as expected. The last ~1.2GB comes from Adam optimizer
116
+ # requiring memory for intermediates, totalling to ~6GB of peak memory.
117
+ # Technically, you can remove the need for the last 1.2GB for optimizer
118
+ # intermediates if you set `Adam(model.parameters(), foreach=False)` which
119
+ # would trade off runtime for memory. If switching off the foreach runtime
120
+ # optimization is sufficient in memory savings for you, nice, but please
121
+ # read on if you're curious how this tutorial can help you do better!
122
+ # With the technique we will soon introduce, we will reduce peak memory by
123
+ # removing the need for the ~1.2GB of **gradients memory** as well as **optimizer
124
+ # intermediates memory**. Now do the math--what would the new peak memory be?
125
+ # The answer will be revealed in the next snapshot :D.
126
+ #
127
+ # DISCLAIMER. Yes. Is the time for the disclaimer:
128
+ #
129
+ # Is this technique applicable to you? Only if gradients take up sizable memory.
130
+ # """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
131
+ # The technique of fusing the optimizer step into the backward only targets
132
+ # reducing *gradient* memory (and as a side effect also optimizer intermediates
133
+ # memory). In our example above, the gradients eat up 20% of the memory pie, which
134
+ # is quite sizable!
135
+ #
136
+ # This may not be the case for you, for example, if your weights are already tiny,
137
+ # (say, due to applying LoRa,) then the gradients do not take much space in your
138
+ # training loop and the wins are way less exciting. In that case, you should
139
+ # first try other techniques like activations checkpointing, distributed
140
+ # training, quantization, or reducing the batch size. Then, when the gradients
141
+ # are part of the bottleneck again, come back to this tutorial!
142
+ #
143
+ # Still here? Cool, let's introduce our new `register_post_accumulate_grad_hook(hook)`
144
+ # API on Tensor.
145
+ #
146
+ # `Tensor.register_post_accumulate_grad_hook(hook)` API and our technique
147
+ # """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
148
+ # Our technique relies on not having to save the gradients during `backward()`. Instead,
149
+ # once a gradient has been accumulated, we will immediately apply the optimizer to
150
+ # the corresponding parameter and drop that gradient entirely! This removes the need
151
+ # for holding onto a big buffer of gradients until the optimizer step.
152
+ #
153
+ # So how can we unlock the behavior of applying the optimizer more eagerly? Well,
154
+ # in 2.1, we've added a new API [`Tensor.register_post_accumulate_grad_hook(hook)`](
155
+ # https://pytorch.org/docs/main/generated/torch.Tensor.register_post_accumulate_grad_hook.html#torch.Tensor.register_post_accumulate_grad_hook)
156
+ # that would allow us to add a hook onto a Tensor once its `.grad` field has been
157
+ # accumulated. We will encapsulate the optimizer step into this hook. How?
158
+ #
159
+ # How everything fits together in 10 lines
160
+ # """"""""""""""""""""""""""""""""""""""""
161
+ # Remember our model and optimizer setup from the beginning? I'll leave them commented
162
+ # out below so we don't spend resources rerunning the code.
101
163
102
- m = SomeModule ()
164
+ # model = models.vit_l_16(weights='DEFAULT').cuda()
165
+ # optimizer = torch.optim.Adam(model.parameters())
103
166
104
- ###############################################################################
105
- # This allocates memory for all parameters/buffers and initializes them per
106
- # the default initialization schemes defined in `SomeModule.__init__()`, which
107
- # is wasteful when we want to load a checkpoint as
108
- # 1. We are running the initialization kernels where the results are
109
- # immediately overwritten by `load_state_dict()`
110
- # 2. We are allocating memory for these parameters/buffers in RAM while
111
- # `torch.load` of the saved state dictionary also allocates memory for
112
- # the parameters/buffers in the checkpoint.
113
- #
114
- # In order to solve these two problems, we can use the `torch.device()`
115
- # context manager with `device='meta'` when we instantiate the `nn.Module()`.
167
+ # Instead of having just *one* optimizer, we will have a Dict of optimizers
168
+ # for every parameter so we could reference them in our hook.
169
+ optimizer_dict = {p : torch .optim .Adam ([p ]) for p in model .parameters ()}
116
170
117
- with torch .device ('meta' ):
118
- meta_m = SomeModule ()
171
+ # Define our hook, which will call the optimizer `step()` and `zero_grad()`
172
+ def optimizer_hook (parameter ) -> None :
173
+ optimizer_dict [parameter ].step ()
174
+ optimizer_dict [parameter ].zero_grad ()
119
175
120
- ###############################################################################
121
- # The [`torch.device()`](https://pytorch.org/docs/main/tensor_attributes.html#torch-device)
122
- # context manager makes sure that factory calls will be performed as if they
123
- # were passed device as an argument. However, it does not affect factory
124
- # function calls which are called with an explicit device argument.
125
- #
126
- # Tensors on the `meta` device do not carry data. However, they possess all
127
- # other metadata a tensor carries such as `.size()` and `.stride()`,
128
- # `.requires_grad` etc.
129
- #
130
- # Next, we consider the loading of the state dictionary.
176
+ # Register the hook onto every parameter
177
+ for p in model .parameters ():
178
+ p .register_post_accumulate_grad_hook (optimizer_hook )
131
179
132
- m .load_state_dict (state_dict )
180
+ # Now remember our previous `train()` function? Since the optimizer has been
181
+ # fused into the backward, we can remove the optimizer step and zero_grad calls.
182
+ def train (model ):
183
+ # create our fake image input: tensor shape is batch_size, channels, height, width
184
+ fake_image = torch .rand (1 , 3 , IMAGE_SIZE , IMAGE_SIZE ).cuda ()
133
185
134
- ###############################################################################
135
- # `nn.Module.load_state_dict()` is usually implemented via an in-place
136
- # `param_in_model.copy_(param_in_state_dict)` (i.e. a copy from the
137
- # parameter/buffer with the corresponding key in the state dictionary into
138
- # the parameters/buffers in the `nn.Module`).
139
- #
140
- # However, an in-place copy into a tensor on the `meta` device is a no-op.
141
- # In order to avoid this, we can pass the `assign=True` keyword argument to
142
- # `load_state_dict()`.
186
+ # call our forward and backward
187
+ loss = model .forward (fake_image )
188
+ loss .sum ().backward ()
143
189
144
- meta_m .load_state_dict (state_dict , assign = True )
190
+ # optimizer update --> no longer needed!
191
+ # optimizer.step()
192
+ # optimizer.zero_grad()
145
193
146
- ###############################################################################
147
- # Another caveat here is that since optimizers hold a reference to
148
- # `nn.Module.parameters()`, the optimizer must be initialized after the module
149
- # is loaded from state dict if `assign=True` is passed.
194
+ ########################################################################
195
+ # I believe that was about 10 lines of changes in our sample model. I do
196
+ # recognize that it could be a fairly intrusive change to switch out the
197
+ # optimizer for a optimizer dictionary, especially for those who use
198
+ # `LRScheduler`s or manipulate optimizer configuration throughout the
199
+ # training epochs. Working out this API with those changes will be more
200
+ # involved and likely requires moving more configuration into global
201
+ # state but should not be impossible. That said, a next step for us is
202
+ # to make this API easier to adopt with LRSchedulers and other features
203
+ # you are already used to.
204
+ #
205
+ # But let me get back to convincing you that this technique is worth it.
206
+ # We will consult our friend, the memory snapshot.
207
+
208
+ # del optimizer memory from before to get a clean slate
209
+ del optimizer
210
+
211
+ # tell CUDA to start recording memory allocations
212
+ torch .cuda .memory ._record_memory_history ()
213
+
214
+ # train 3 steps. note that we no longer pass the optimizer into train()
215
+ for _ in range (3 ):
216
+ train (model )
217
+
218
+ # save a snapshot of the memory allocations
219
+ s = torch .cuda .memory ._snapshot ()
220
+ with open (f"snapshot-opt-in-bwd.pickle" , "wb" ) as f :
221
+ dump (s , f )
222
+
223
+ # tell CUDA to stop recording memory allocations now
224
+ torch .cuda .memory ._record_memory_history (enabled = None )
150
225
151
226
###############################################################################
152
- # To recap, in this tutorial, we learned about `torch.load(mmap=True)`, the
153
- # `torch.device()` context manager with `device=meta` and the
154
- # `nn.Module.load_state_dict(assign=True)` and how these tools could be used
155
- # to aid when loading a model from a checkpoint.
227
+ # Yes, take some time to drag your snapshot into the CUDA Memory Visualizer.
228
+ #
229
+ # .. figure:: /_static/img/optim_step_in_bwd/snapshot_opt_in_bwd.jpg
230
+ # :alt: snapshot.png loaded into CUDA Memory Visualizer
231
+ #
232
+ # Several major observations:
233
+ # 1. There is no more optimizer step! Right...we fused that into the backward.
234
+ # 2. Likewise, the backward drags longer and there are more random allocations
235
+ # for intermediates. This is expected, as the optimizer step requires
236
+ # intermediates.
237
+ # 3. Most importantly! The peak memory is lower! It is now ~4GB (which I
238
+ # hope is close to your answer earlier :).)
239
+ #
240
+ # Note that there is no longer any big chunk of memory allocated for the gradients
241
+ # compared to before, accounting for ~1.2GB of memory savings. Instead, we've freed
242
+ # each gradient very quickly after they've been computed by moving the optimizer
243
+ # step as far ahead as we can. Woohoo! By the way, the other ~1.2GB of memory savings
244
+ # comes from breaking apart the optimizer into per-parameter optimizers, so the
245
+ # intermediates have proportionally shrunk. I'd like to stress this detail is less
246
+ # important than the gradient memory savings, as you can get optimizer intermediates
247
+ # savings from just turning `foreach=False` without this technique.
248
+ #
249
+ # You may be correctly wondering: if we saved 2.4GB of memory, why is the peak memory
250
+ # NOT 6GB - 2.4GB = 3.6GB? Well, the peak has moved! The peak is now near the start
251
+ # of the backward step, when we still have activations in memory, where before, the peak
252
+ # was during the optimizer step when the activations had been freed. The ~0.4GB difference
253
+ # accounting for ~4.0GB - ~3.6GB is thus due to the activations memory. One can then
254
+ # imagine that this technique can be coupled with activations checkpointing for more
255
+ # memory wins.
256
+ #
257
+ # Recap
258
+ # """""
259
+ # In this tutorial, we learned about the memory saving technique of
260
+ # fusing the optimizer into the backward step through the new
261
+ # `Tensor.register_post_accumulate_grad_hook()` API and *when* to apply this
262
+ # technique (when gradients memory is significant). Along the way, we also learned
263
+ # about memory snapshots, which are generally useful in memory optimization.
0 commit comments