Skip to content

Commit f9b012f

Browse files
committed
Finish my tutorial
1 parent 9be1e01 commit f9b012f

File tree

3 files changed

+178
-70
lines changed

3 files changed

+178
-70
lines changed
718 KB
Loading
Loading

intermediate_source/optimizer_step_in_backward_tutorial.py

Lines changed: 178 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
66
Hello there! This tutorial aims to showcase one way of reducing the
77
memory footprint of a training loop by reducing the memory taken by
8-
the gradients. Say you have a model and you're interested in ways to
8+
the *gradients*. Say you have a model and you're interested in ways to
99
optimize memory to avoid OOMing or simply to ooze more out of your GPU.
10-
Well, you _might_ be in luck! We will explore
10+
Well, you _might_ be in luck (if gradients take up a portion of your
11+
memory and you do not need to do gradient accumulation)! We will explore
1112
1213
1. What takes up memory during your training or finetuning loop,
13-
2. Capturing and visualizing memory snapshots to determine the memory bottleneck,
14-
3. The new `tensor.post_accumulate_grad_hook(hook)` API, and finally, if relevant,
14+
2. How to capture and visualize memory snapshots to determine the bottleneck,
15+
3. The new `Tensor.register_post_accumulate_grad_hook(hook)` API, and finally,
1516
4. How everything fits together in 10 lines to achieve memory savings
1617
1718
The ingredients and tools required:
@@ -57,99 +58,206 @@ def train(model, optimizer):
5758
# We are about to look at some memory snapshots, so we should be prepared to
5859
# analyze them properly. People normally consider training memory to consist of
5960
#
60-
# 1. Model parameters (size P)
61-
# 2. Activations (size A)
62-
# 3. Gradients, which are the same size as the model parameters, so size G = P
63-
# 4. Optimizer state, which is usually a relation to the model parameters. In
64-
# this case, Adam state requires 2x the model parameters, so size O = 2P
65-
# 5. Intermediate tensors, which are allocated throughout the compute. We will
66-
# not worry about them for now as they are usually small and ephemeral.
61+
# 1. Model parameters (size P)
62+
# 2. Activations (size A)
63+
# 3. Gradients, which are the same size as the model parameters, so size G = P
64+
# 4. Optimizer state, which is usually a relation to the model parameters. In
65+
# this case, Adam state requires 2x the model parameters, so size O = 2P
66+
# 5. Intermediate tensors, which are allocated throughout the compute. We will
67+
# not worry about them for now as they are usually small and ephemeral.
6768
#
69+
# Capturing and visualizing memory snapshots
70+
# """"""""""""""""""""""""""""""""""""""""""
6871
# Let's get us a memory snapshot! As your code runs, consider what you may expect
6972
# the CUDA memory timeline to look like.
7073

7174
# tell CUDA to start recording memory allocations
7275
torch.cuda.memory._record_memory_history()
7376

7477
# train 3 steps
75-
train(model, optimizer)
78+
for _ in range(3):
79+
train(model, optimizer)
7680

7781
# save a snapshot of the memory allocations
78-
s = torch.cuda.memory._snapshot()
79-
with open(f"snapshot.pickle", "wb") as f:
80-
dump(s, f)
82+
# s = torch.cuda.memory._snapshot()
83+
# with open(f"snapshot.pickle", "wb") as f:
84+
# dump(s, f)
8185

82-
raise RuntimeError("Stop here and open up the snapshot in Zach Devito's CUDA Memory Visualizer")
86+
# tell CUDA to stop recording memory allocations now
87+
torch.cuda.memory._record_memory_history(enabled=None)
8388

8489
###############################################################################
8590
# Now open up the snapshot in Zach Devito's [CUDA Memory Visualizer](
8691
# https://zdevito.github.io/assets/viz/) by dragging the snapshot.pickle file.
8792
# Does the memory timeline match your expectations?
8893
#
94+
# .. figure:: /_static/img/optim_step_in_bwd/snapshot.jpg
95+
# :alt: snapshot.png loaded into CUDA Memory Visualizer
96+
#
8997
# The model parameters have already been loaded in memory before the training
90-
# step, so we anticipate seeing a chunk of memory devoted to the weights right
91-
# off the bat. As we start our forward, memory should be allocated gradually
92-
# for the activations, or the tensors we are saving to be able to compute gradients
93-
# in the backward. Once we start the backward, the activations should be gradually
94-
# freed while memory of the gradients start building up.
98+
# step, so we see a chunk of memory devoted to the weights right off the bat.
99+
# As we start our forward, memory is allocated gradually for the activations, or
100+
# the tensors we are saving to be able to compute gradients in the backward.
101+
# Once we start the backward, the activations are gradually freed while memory
102+
# of the gradients start building up.
95103
#
96104
# Lastly, as the optimizer kicks in, its state will be lazily initialized, so we
97-
# should see the optimizer state memory gradually increase during the end of the
98-
# first training loop only. In future loops, the optimizer memory will remain and
99-
# be inplace updated. The memory for the gradients should be freed accordingly
100-
# by the end of every training loop.
105+
# should see the optimizer state memory gradually increase during the optimizer
106+
# step of the first training loop only. In future loops, the optimizer memory
107+
# will remain and be updated in-place. The memory for the gradients are then
108+
# freed accordingly at the end of every training loop when zero_grad is called.
109+
#
110+
# Where is the memory bottleneck in this training loop? Or, in other words,
111+
# where is the peak memory?
112+
#
113+
# The peak memory usage is during the optimizer step! Note the memory then
114+
# consists of ~1.2GB of params, ~1.2GB of gradients, and ~2.4GB=2*1.2GB of
115+
# the optimizer state as expected. The last ~1.2GB comes from Adam optimizer
116+
# requiring memory for intermediates, totalling to ~6GB of peak memory.
117+
# Technically, you can remove the need for the last 1.2GB for optimizer
118+
# intermediates if you set `Adam(model.parameters(), foreach=False)` which
119+
# would trade off runtime for memory. If switching off the foreach runtime
120+
# optimization is sufficient in memory savings for you, nice, but please
121+
# read on if you're curious how this tutorial can help you do better!
122+
# With the technique we will soon introduce, we will reduce peak memory by
123+
# removing the need for the ~1.2GB of **gradients memory** as well as **optimizer
124+
# intermediates memory**. Now do the math--what would the new peak memory be?
125+
# The answer will be revealed in the next snapshot :D.
126+
#
127+
# DISCLAIMER. Yes. Is the time for the disclaimer:
128+
#
129+
# Is this technique applicable to you? Only if gradients take up sizable memory.
130+
# """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
131+
# The technique of fusing the optimizer step into the backward only targets
132+
# reducing *gradient* memory (and as a side effect also optimizer intermediates
133+
# memory). In our example above, the gradients eat up 20% of the memory pie, which
134+
# is quite sizable!
135+
#
136+
# This may not be the case for you, for example, if your weights are already tiny,
137+
# (say, due to applying LoRa,) then the gradients do not take much space in your
138+
# training loop and the wins are way less exciting. In that case, you should
139+
# first try other techniques like activations checkpointing, distributed
140+
# training, quantization, or reducing the batch size. Then, when the gradients
141+
# are part of the bottleneck again, come back to this tutorial!
142+
#
143+
# Still here? Cool, let's introduce our new `register_post_accumulate_grad_hook(hook)`
144+
# API on Tensor.
145+
#
146+
# `Tensor.register_post_accumulate_grad_hook(hook)` API and our technique
147+
# """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
148+
# Our technique relies on not having to save the gradients during `backward()`. Instead,
149+
# once a gradient has been accumulated, we will immediately apply the optimizer to
150+
# the corresponding parameter and drop that gradient entirely! This removes the need
151+
# for holding onto a big buffer of gradients until the optimizer step.
152+
#
153+
# So how can we unlock the behavior of applying the optimizer more eagerly? Well,
154+
# in 2.1, we've added a new API [`Tensor.register_post_accumulate_grad_hook(hook)`](
155+
# https://pytorch.org/docs/main/generated/torch.Tensor.register_post_accumulate_grad_hook.html#torch.Tensor.register_post_accumulate_grad_hook)
156+
# that would allow us to add a hook onto a Tensor once its `.grad` field has been
157+
# accumulated. We will encapsulate the optimizer step into this hook. How?
158+
#
159+
# How everything fits together in 10 lines
160+
# """"""""""""""""""""""""""""""""""""""""
161+
# Remember our model and optimizer setup from the beginning? I'll leave them commented
162+
# out below so we don't spend resources rerunning the code.
101163

102-
m = SomeModule()
164+
# model = models.vit_l_16(weights='DEFAULT').cuda()
165+
# optimizer = torch.optim.Adam(model.parameters())
103166

104-
###############################################################################
105-
# This allocates memory for all parameters/buffers and initializes them per
106-
# the default initialization schemes defined in `SomeModule.__init__()`, which
107-
# is wasteful when we want to load a checkpoint as
108-
# 1. We are running the initialization kernels where the results are
109-
# immediately overwritten by `load_state_dict()`
110-
# 2. We are allocating memory for these parameters/buffers in RAM while
111-
# `torch.load` of the saved state dictionary also allocates memory for
112-
# the parameters/buffers in the checkpoint.
113-
#
114-
# In order to solve these two problems, we can use the `torch.device()`
115-
# context manager with `device='meta'` when we instantiate the `nn.Module()`.
167+
# Instead of having just *one* optimizer, we will have a Dict of optimizers
168+
# for every parameter so we could reference them in our hook.
169+
optimizer_dict = {p: torch.optim.Adam([p]) for p in model.parameters()}
116170

117-
with torch.device('meta'):
118-
meta_m = SomeModule()
171+
# Define our hook, which will call the optimizer `step()` and `zero_grad()`
172+
def optimizer_hook(parameter) -> None:
173+
optimizer_dict[parameter].step()
174+
optimizer_dict[parameter].zero_grad()
119175

120-
###############################################################################
121-
# The [`torch.device()`](https://pytorch.org/docs/main/tensor_attributes.html#torch-device)
122-
# context manager makes sure that factory calls will be performed as if they
123-
# were passed device as an argument. However, it does not affect factory
124-
# function calls which are called with an explicit device argument.
125-
#
126-
# Tensors on the `meta` device do not carry data. However, they possess all
127-
# other metadata a tensor carries such as `.size()` and `.stride()`,
128-
# `.requires_grad` etc.
129-
#
130-
# Next, we consider the loading of the state dictionary.
176+
# Register the hook onto every parameter
177+
for p in model.parameters():
178+
p.register_post_accumulate_grad_hook(optimizer_hook)
131179

132-
m.load_state_dict(state_dict)
180+
# Now remember our previous `train()` function? Since the optimizer has been
181+
# fused into the backward, we can remove the optimizer step and zero_grad calls.
182+
def train(model):
183+
# create our fake image input: tensor shape is batch_size, channels, height, width
184+
fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()
133185

134-
###############################################################################
135-
# `nn.Module.load_state_dict()` is usually implemented via an in-place
136-
# `param_in_model.copy_(param_in_state_dict)` (i.e. a copy from the
137-
# parameter/buffer with the corresponding key in the state dictionary into
138-
# the parameters/buffers in the `nn.Module`).
139-
#
140-
# However, an in-place copy into a tensor on the `meta` device is a no-op.
141-
# In order to avoid this, we can pass the `assign=True` keyword argument to
142-
# `load_state_dict()`.
186+
# call our forward and backward
187+
loss = model.forward(fake_image)
188+
loss.sum().backward()
143189

144-
meta_m.load_state_dict(state_dict, assign=True)
190+
# optimizer update --> no longer needed!
191+
# optimizer.step()
192+
# optimizer.zero_grad()
145193

146-
###############################################################################
147-
# Another caveat here is that since optimizers hold a reference to
148-
# `nn.Module.parameters()`, the optimizer must be initialized after the module
149-
# is loaded from state dict if `assign=True` is passed.
194+
########################################################################
195+
# I believe that was about 10 lines of changes in our sample model. I do
196+
# recognize that it could be a fairly intrusive change to switch out the
197+
# optimizer for a optimizer dictionary, especially for those who use
198+
# `LRScheduler`s or manipulate optimizer configuration throughout the
199+
# training epochs. Working out this API with those changes will be more
200+
# involved and likely requires moving more configuration into global
201+
# state but should not be impossible. That said, a next step for us is
202+
# to make this API easier to adopt with LRSchedulers and other features
203+
# you are already used to.
204+
#
205+
# But let me get back to convincing you that this technique is worth it.
206+
# We will consult our friend, the memory snapshot.
207+
208+
# del optimizer memory from before to get a clean slate
209+
del optimizer
210+
211+
# tell CUDA to start recording memory allocations
212+
torch.cuda.memory._record_memory_history()
213+
214+
# train 3 steps. note that we no longer pass the optimizer into train()
215+
for _ in range(3):
216+
train(model)
217+
218+
# save a snapshot of the memory allocations
219+
s = torch.cuda.memory._snapshot()
220+
with open(f"snapshot-opt-in-bwd.pickle", "wb") as f:
221+
dump(s, f)
222+
223+
# tell CUDA to stop recording memory allocations now
224+
torch.cuda.memory._record_memory_history(enabled=None)
150225

151226
###############################################################################
152-
# To recap, in this tutorial, we learned about `torch.load(mmap=True)`, the
153-
# `torch.device()` context manager with `device=meta` and the
154-
# `nn.Module.load_state_dict(assign=True)` and how these tools could be used
155-
# to aid when loading a model from a checkpoint.
227+
# Yes, take some time to drag your snapshot into the CUDA Memory Visualizer.
228+
#
229+
# .. figure:: /_static/img/optim_step_in_bwd/snapshot_opt_in_bwd.jpg
230+
# :alt: snapshot.png loaded into CUDA Memory Visualizer
231+
#
232+
# Several major observations:
233+
# 1. There is no more optimizer step! Right...we fused that into the backward.
234+
# 2. Likewise, the backward drags longer and there are more random allocations
235+
# for intermediates. This is expected, as the optimizer step requires
236+
# intermediates.
237+
# 3. Most importantly! The peak memory is lower! It is now ~4GB (which I
238+
# hope is close to your answer earlier :).)
239+
#
240+
# Note that there is no longer any big chunk of memory allocated for the gradients
241+
# compared to before, accounting for ~1.2GB of memory savings. Instead, we've freed
242+
# each gradient very quickly after they've been computed by moving the optimizer
243+
# step as far ahead as we can. Woohoo! By the way, the other ~1.2GB of memory savings
244+
# comes from breaking apart the optimizer into per-parameter optimizers, so the
245+
# intermediates have proportionally shrunk. I'd like to stress this detail is less
246+
# important than the gradient memory savings, as you can get optimizer intermediates
247+
# savings from just turning `foreach=False` without this technique.
248+
#
249+
# You may be correctly wondering: if we saved 2.4GB of memory, why is the peak memory
250+
# NOT 6GB - 2.4GB = 3.6GB? Well, the peak has moved! The peak is now near the start
251+
# of the backward step, when we still have activations in memory, where before, the peak
252+
# was during the optimizer step when the activations had been freed. The ~0.4GB difference
253+
# accounting for ~4.0GB - ~3.6GB is thus due to the activations memory. One can then
254+
# imagine that this technique can be coupled with activations checkpointing for more
255+
# memory wins.
256+
#
257+
# Recap
258+
# """""
259+
# In this tutorial, we learned about the memory saving technique of
260+
# fusing the optimizer into the backward step through the new
261+
# `Tensor.register_post_accumulate_grad_hook()` API and *when* to apply this
262+
# technique (when gradients memory is significant). Along the way, we also learned
263+
# about memory snapshots, which are generally useful in memory optimization.

0 commit comments

Comments
 (0)