Skip to content

Commit 202ba9a

Browse files
authored
Fuse optimizer into backward tutorial (#2568)
* Add Optimizer step in backward tutorial
1 parent b69a12a commit 202ba9a

File tree

5 files changed

+280
-2
lines changed

5 files changed

+280
-2
lines changed
719 KB
Loading
Loading

en-wordlist.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ LeNet
129129
LeakyReLU
130130
LeakyReLUs
131131
Lipschitz
132-
logits
132+
LoRa
133+
LRSchedulers
133134
Lua
134135
Luong
135136
MLP
@@ -206,6 +207,7 @@ Unescape
206207
VGG
207208
VQA
208209
VS Code
210+
Woohoo
209211
Wikitext
210212
Xeon
211213
Xcode
@@ -329,6 +331,7 @@ labelled
329331
learnable
330332
learnings
331333
loadFilename
334+
logits
332335
manualSeed
333336
matmul
334337
matplotlib

index.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ What's new in PyTorch tutorials?
511511

512512
.. customcarditem::
513513
:header: Parametrizations Tutorial
514-
:card_description: Learn how to use torch.nn.utils.parametrize to put constriants on your parameters (e.g. make them orthogonal, symmetric positive definite, low-rank...)
514+
:card_description: Learn how to use torch.nn.utils.parametrize to put constraints on your parameters (e.g. make them orthogonal, symmetric positive definite, low-rank...)
515515
:image: _static/img/thumbnails/cropped/parametrizations.png
516516
:link: intermediate/parametrizations.html
517517
:tags: Model-Optimization,Best-Practice
@@ -523,6 +523,13 @@ What's new in PyTorch tutorials?
523523
:link: intermediate/pruning_tutorial.html
524524
:tags: Model-Optimization,Best-Practice
525525

526+
.. customcarditem::
527+
:header: How to save memory by fusing the optimizer step into the backward pass
528+
:card_description: Learn a memory-saving technique through fusing the optimizer step into the backward pass using memory snapshots.
529+
:image: _static/img/thumbnails/cropped/pytorch-logo.png
530+
:link: intermediate/optimizer_step_in_backward_tutorial.html
531+
:tags: Model-Optimization,Best-Practice,CUDA,Frontend-APIs
532+
526533
.. customcarditem::
527534
:header: (beta) Dynamic Quantization on an LSTM Word Language Model
528535
:card_description: Apply dynamic quantization, the easiest form of quantization, to a LSTM-based next word prediction model.
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
"""
2+
3+
How to save memory by fusing the optimizer step into the backward pass
4+
======================================================================
5+
6+
Hello there! This tutorial aims to showcase one way of reducing the
7+
memory footprint of a training loop by reducing the memory taken by
8+
the *gradients*. Say you have a model and you're interested in ways to
9+
optimize memory to avoid ``Out of Memory`` (OOM) errors or simply to ooze
10+
more out of your GPU. Well, you _might_ be in luck (if gradients take up
11+
a portion of your memory and you do not need to do gradient accumulation).
12+
We will explore the following:
13+
14+
1. What takes up memory during your training or finetuning loop,
15+
2. How to capture and visualize memory snapshots to determine the bottleneck,
16+
3. The new ``Tensor.register_post_accumulate_grad_hook(hook)`` API, and finally,
17+
4. How everything fits together in 10 lines to achieve memory savings.
18+
19+
To run this tutorial, you will need:
20+
21+
* PyTorch 2.1.0 or newer with ``torchvision``
22+
* 1 CUDA GPU if you'd like to run the memory visualizations locally.
23+
Otherwise, this technique would benefit similarly on any device.
24+
25+
Let us start by importing the required modules and models. We will use a
26+
vision transformer model from torchvision, but feel free to substitute
27+
with your own model. We will also use ``torch.optim.Adam`` as our optimizer,
28+
but, again, feel free to substitute with your own optimizer.
29+
30+
"""
31+
32+
import torch
33+
from torchvision import models
34+
from pickle import dump
35+
36+
model = models.vit_l_16(weights='DEFAULT').cuda()
37+
optimizer = torch.optim.Adam(model.parameters())
38+
39+
###############################################################################
40+
# Now let's define our typical training loop. You should use real images when
41+
# training, but for the purposes of this tutorial, we are passing in fake
42+
# inputs and not worrying about loading any actual data.
43+
44+
IMAGE_SIZE = 224
45+
46+
def train(model, optimizer):
47+
# create our fake image input: tensor shape is batch_size, channels, height, width
48+
fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()
49+
50+
# call our forward and backward
51+
loss = model.forward(fake_image)
52+
loss.sum().backward()
53+
54+
# optimizer update
55+
optimizer.step()
56+
optimizer.zero_grad()
57+
58+
###############################################################################
59+
# Memory usage during training
60+
# """"""""""""""""""""""""""""
61+
# We are about to look at some memory snapshots, so we should be prepared to
62+
# analyze them properly. Typically, training memory consists of:
63+
#
64+
# * Model parameters (size P)
65+
# * Activations that are saved for the backward pass (size A)
66+
# * Gradients, which are the same size as the model parameters, so size G = P.
67+
# * Optimizer state, which is proportional to the size of the parameters. In
68+
# this case, the state for Adam requires 2x the model parameters, so size O = 2P.
69+
# * Intermediate tensors, which are allocated throughout the compute. We will
70+
# not worry about them for now as they are usually small and ephemeral.
71+
#
72+
# Capturing and visualizing memory snapshots
73+
# """"""""""""""""""""""""""""""""""""""""""
74+
# Let's get us a memory snapshot! As your code runs, consider what you may expect
75+
# the CUDA memory timeline to look like.
76+
77+
# tell CUDA to start recording memory allocations
78+
torch.cuda.memory._record_memory_history(enabled='all')
79+
80+
# train 3 steps
81+
for _ in range(3):
82+
train(model, optimizer)
83+
84+
# save a snapshot of the memory allocations
85+
s = torch.cuda.memory._snapshot()
86+
with open(f"snapshot.pickle", "wb") as f:
87+
dump(s, f)
88+
89+
# tell CUDA to stop recording memory allocations now
90+
torch.cuda.memory._record_memory_history(enabled=None)
91+
92+
###############################################################################
93+
# Now open up the snapshot in the CUDA Memory Visualizer at
94+
# https://pytorch.org/memory_viz by dragging and dropping the
95+
# ``snapshot.pickle`` file. Does the memory timeline match your expectations?
96+
#
97+
# .. figure:: /_static/img/optim_step_in_bwd/snapshot.jpg
98+
# :alt: snapshot.png loaded into CUDA Memory Visualizer
99+
#
100+
# The model parameters have already been loaded in memory before the training
101+
# step, so we see a chunk of memory devoted to the weights right off the bat.
102+
# As we start our forward pass, memory is allocated gradually for the activations,
103+
# or the tensors we are saving to be able to compute gradients in the backward pass.
104+
# Once we start the backward pass, the activations are gradually freed while memory
105+
# of the gradients starts building up.
106+
#
107+
# Lastly, as the optimizer kicks in, its state will be lazily initialized, so we
108+
# should see the optimizer state memory gradually increase during the optimizer
109+
# step of the first training loop only. In future loops, the optimizer memory
110+
# will remain and be updated in-place. The memory for the gradients is then
111+
# freed accordingly at the end of every training loop when ``zero_grad`` is called.
112+
#
113+
# Where is the memory bottleneck in this training loop? Or, in other words,
114+
# where is the peak memory?
115+
#
116+
# The peak memory usage is during the optimizer step! Note the memory then
117+
# consists of ~1.2GB of parameters, ~1.2GB of gradients, and ~2.4GB=2*1.2GB of
118+
# the optimizer state as expected. The last ~1.2GB comes from Adam optimizer
119+
# requiring memory for intermediates, totaling to ~6GB of peak memory.
120+
# Technically, you can remove the need for the last 1.2GB for optimizer
121+
# intermediates if you set ``Adam(model.parameters(), foreach=False)`` which
122+
# would trade off runtime for memory. If switching off the ``foreach`` runtime
123+
# optimization is sufficient in memory savings for you, nice, but please
124+
# read on if you're curious how this tutorial can help you do better!
125+
# With the technique we will soon introduce, we will reduce peak memory by
126+
# removing the need for the ~1.2GB of **gradients memory** as well as **optimizer
127+
# intermediates memory**. Now, what would you expect the new peak memory to be?
128+
# The answer will be revealed in the `next` snapshot.
129+
#
130+
# DISCLAIMER: This technique is **not** for all
131+
# """""""""""""""""""""""""""""""""""""""""""""
132+
# Before we get too excited, we have to consider whether this technique is applicable
133+
# for `your` use case. This is NOT a silver bullet! The technique of fusing the
134+
# optimizer step into the backward only targets reducing *gradient* memory (and as a side effect also optimizer intermediates
135+
# memory). Thus, the more sizable the memory taken up by the gradients, the more
136+
# tantamount the memory reduction. In our example above, the gradients eat up 20%
137+
# of the memory pie, which is quite sizable!
138+
#
139+
# This may not be the case for you, for example, if your weights are already tiny,
140+
# (say, due to applying LoRa,) then the gradients do not take much space in your
141+
# training loop and the wins are way less exciting. In that case, you should
142+
# first try other techniques like activations checkpointing, distributed
143+
# training, quantization, or reducing the batch size. Then, when the gradients
144+
# are part of the bottleneck again, come back to this tutorial!
145+
#
146+
# Still here? Cool, let's introduce our new ``register_post_accumulate_grad_hook(hook)``
147+
# API on Tensor.
148+
#
149+
# ``Tensor.register_post_accumulate_grad_hook(hook)`` API and our technique
150+
# """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
151+
# Our technique relies on not having to save the gradients during ``backward()``. Instead,
152+
# once a gradient has been accumulated, we will immediately apply the optimizer to
153+
# the corresponding parameter and drop that gradient entirely! This removes the need
154+
# for holding onto a big buffer of gradients until the optimizer step.
155+
#
156+
# So how can we unlock the behavior of applying the optimizer more eagerly? In our 2.1
157+
# release, we've added a new API :func:`torch.Tensor.register_post_accumulate_grad_hook`
158+
# that would allow us to add a hook onto a Tensor once its ``.grad`` field has been
159+
# accumulated. We will encapsulate the optimizer step into this hook. How?
160+
#
161+
# How everything fits together in 10 lines
162+
# """"""""""""""""""""""""""""""""""""""""
163+
# Remember our model and optimizer setup from the beginning? I'll leave them commented
164+
# out below so we don't spend resources rerunning the code.
165+
#
166+
# .. code-block:: python
167+
#
168+
# model = models.vit_l_16(weights='DEFAULT').cuda()
169+
# optimizer = torch.optim.Adam(model.parameters())
170+
171+
# Instead of having just *one* optimizer, we will have a ``dict`` of optimizers
172+
# for every parameter so we could reference them in our hook.
173+
optimizer_dict = {p: torch.optim.Adam([p], foreach=False) for p in model.parameters()}
174+
175+
# Define our hook, which will call the optimizer ``step()`` and ``zero_grad()``
176+
def optimizer_hook(parameter) -> None:
177+
optimizer_dict[parameter].step()
178+
optimizer_dict[parameter].zero_grad()
179+
180+
# Register the hook onto every parameter
181+
for p in model.parameters():
182+
p.register_post_accumulate_grad_hook(optimizer_hook)
183+
184+
# Now remember our previous ``train()`` function? Since the optimizer has been
185+
# fused into the backward, we can remove the optimizer step and zero_grad calls.
186+
def train(model):
187+
# create our fake image input: tensor shape is batch_size, channels, height, width
188+
fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()
189+
190+
# call our forward and backward
191+
loss = model.forward(fake_image)
192+
loss.sum().backward()
193+
194+
# optimizer update --> no longer needed!
195+
# optimizer.step()
196+
# optimizer.zero_grad()
197+
198+
########################################################################
199+
# That took about 10 lines of changes in our sample model, which is neat.
200+
# However, for real models, it could be a fairly intrusive change to switch
201+
# out the optimizer for an optimizer dictionary, especially for those who use
202+
# ``LRScheduler``s or manipulate optimizer configuration throughout the
203+
# training epochs. Working out this API with those changes will be more
204+
# involved and will likely require moving more configuration into global
205+
# state but should not be impossible. That said, a next step for PyTorch
206+
# is to make this API easier to adopt with LRSchedulers and other features
207+
# you are already used to.
208+
#
209+
# But let me get back to convincing you that this technique is worth it.
210+
# We will consult our friend, the memory snapshot.
211+
212+
# delete optimizer memory from before to get a clean slate for the next
213+
# memory snapshot
214+
del optimizer
215+
216+
# tell CUDA to start recording memory allocations
217+
torch.cuda.memory._record_memory_history(enabled='all')
218+
219+
# train 3 steps. note that we no longer pass the optimizer into train()
220+
for _ in range(3):
221+
train(model)
222+
223+
# save a snapshot of the memory allocations
224+
s = torch.cuda.memory._snapshot()
225+
with open(f"snapshot-opt-in-bwd.pickle", "wb") as f:
226+
dump(s, f)
227+
228+
# tell CUDA to stop recording memory allocations now
229+
torch.cuda.memory._record_memory_history(enabled=None)
230+
231+
###############################################################################
232+
# Yes, take some time to drag your snapshot into the CUDA Memory Visualizer.
233+
#
234+
# .. figure:: /_static/img/optim_step_in_bwd/snapshot_opt_in_bwd.jpg
235+
# :alt: snapshot.png loaded into CUDA Memory Visualizer
236+
#
237+
# Several major observations:
238+
# 1. There is no more optimizer step! Right...we fused that into the backward.
239+
# 2. Likewise, the backward drags longer and there are more random allocations
240+
# for intermediates. This is expected, as the optimizer step requires
241+
# intermediates.
242+
# 3. Most importantly! The peak memory is lower! It is now ~4GB (which I
243+
# hope maps closely to your earlier expectation).
244+
#
245+
# Note that there is no longer any big chunk of memory allocated for the gradients
246+
# compared to before, accounting for ~1.2GB of memory savings. Instead, we've freed
247+
# each gradient very quickly after they've been computed by moving the optimizer
248+
# step as far ahead as we can. Woohoo! By the way, the other ~1.2GB of memory savings
249+
# comes from breaking apart the optimizer into per-parameter optimizers, so the
250+
# intermediates have proportionally shrunk. This detail is `less important` than
251+
# the gradient memory savings, as you can get optimizer intermediates savings
252+
# from just turning ``foreach=False`` without this technique.
253+
#
254+
# You may be correctly wondering: if we saved 2.4GB of memory, why is the peak memory
255+
# NOT 6GB - 2.4GB = 3.6GB? Well, the peak has moved! The peak is now near the start
256+
# of the backward step, when we still have activations in memory, where before, the peak
257+
# was during the optimizer step when the activations had been freed. The ~0.4GB difference
258+
# accounting for ~4.0GB - ~3.6GB is thus due to the activations memory. One can then
259+
# imagine that this technique can be coupled with activations checkpointing for more
260+
# memory wins.
261+
#
262+
# Conclusion
263+
# """"""""""
264+
# In this tutorial, we learned about the memory saving technique of
265+
# fusing the optimizer into the backward step through the new
266+
# ``Tensor.register_post_accumulate_grad_hook()`` API and *when* to apply this
267+
# technique (when gradients memory is significant). Along the way, we also learned
268+
# about memory snapshots, which are generally useful in memory optimization.

0 commit comments

Comments
 (0)