Skip to content

Commit 4ba0518

Browse files
committed
Rephrase and add to index.rst
1 parent 770ffbc commit 4ba0518

File tree

3 files changed

+90
-78
lines changed

3 files changed

+90
-78
lines changed

en-wordlist.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ DataLoaders
7575
DeepMind
7676
DeiT
7777
DenseNet
78+
Devito
7879
EOS
7980
EPS
8081
Ecker
@@ -209,6 +210,7 @@ VS Code
209210
Wikitext
210211
Xeon
211212
Xcode
213+
Zach
212214
accuracies
213215
activations
214216
adversarially

index.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ What's new in PyTorch tutorials?
504504

505505
.. customcarditem::
506506
:header: Parametrizations Tutorial
507-
:card_description: Learn how to use torch.nn.utils.parametrize to put constriants on your parameters (e.g. make them orthogonal, symmetric positive definite, low-rank...)
507+
:card_description: Learn how to use torch.nn.utils.parametrize to put constraints on your parameters (e.g. make them orthogonal, symmetric positive definite, low-rank...)
508508
:image: _static/img/thumbnails/cropped/parametrizations.png
509509
:link: intermediate/parametrizations.html
510510
:tags: Model-Optimization,Best-Practice
@@ -516,6 +516,13 @@ What's new in PyTorch tutorials?
516516
:link: intermediate/pruning_tutorial.html
517517
:tags: Model-Optimization,Best-Practice
518518

519+
.. customcarditem::
520+
:header: How to save memory by fusing the optimizer step into the backward pass
521+
:card_description: Learn a memory-saving technique through fusing the optimizer step into the backward pass using memory snapshots.
522+
:image: _static/img/thumbnails/cropped/pytorch-logo.png
523+
:link: intermediate/optimizer_step_in_backward_tutorial.html
524+
:tags: Model-Optimization,Best-Practice
525+
519526
.. customcarditem::
520527
:header: (beta) Dynamic Quantization on an LSTM Word Language Model
521528
:card_description: Apply dynamic quantization, the easiest form of quantization, to a LSTM-based next word prediction model.

intermediate_source/optimizer_step_in_backward_tutorial.py

Lines changed: 80 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,24 @@
66
Hello there! This tutorial aims to showcase one way of reducing the
77
memory footprint of a training loop by reducing the memory taken by
88
the *gradients*. Say you have a model and you're interested in ways to
9-
optimize memory to avoid OOMing or simply to ooze more out of your GPU.
10-
Well, you _might_ be in luck (if gradients take up a portion of your
11-
memory and you do not need to do gradient accumulation)! We will explore
9+
optimize memory to avoid ``Out of Memory`` (OOM) errors or simply to ooze
10+
more out of your GPU. Well, you _might_ be in luck (if gradients take up
11+
a portion of your memory and you do not need to do gradient accumulation).
12+
We will explore the following:
1213
1314
1. What takes up memory during your training or finetuning loop,
1415
2. How to capture and visualize memory snapshots to determine the bottleneck,
15-
3. The new `Tensor.register_post_accumulate_grad_hook(hook)` API, and finally,
16-
4. How everything fits together in 10 lines to achieve memory savings
16+
3. The new ``Tensor.register_post_accumulate_grad_hook(hook)`` API, and finally,
17+
4. How everything fits together in 10 lines to achieve memory savings.
1718
18-
The ingredients and tools required:
19-
1. PyTorch 2.1.0 or newer with torchvision
20-
2. A CUDA GPU
19+
To run this tutorial, you will need:
20+
* PyTorch 2.1.0 or newer with ``torchvision``
21+
* 1 CUDA GPU
2122
2223
Let us start by importing the required modules and models. We will use a
23-
vision transformer model from torchvision, but feel free to substitute with
24-
your own model. We will also use `torch.optim.Adam` as our optimizer, but,
25-
again, feel free to substitute with your own optimizer.
24+
vision transformer model from torchvision, but feel free to substitute
25+
with your own model. We will also use ``torch.optim.Adam`` as our optimizer,
26+
but, again, feel free to substitute with your own optimizer.
2627
2728
"""
2829

@@ -36,7 +37,7 @@
3637
###############################################################################
3738
# Now let's define our typical training loop. You should use real images when
3839
# training, but for the purposes of this tutorial, we are passing in fake
39-
# inputs and not worrying about loading actual data.
40+
# inputs and not worrying about loading any actual data.
4041

4142
IMAGE_SIZE = 224
4243

@@ -53,18 +54,18 @@ def train(model, optimizer):
5354
optimizer.zero_grad()
5455

5556
###############################################################################
56-
# So what comprises the memory usage during training?
57-
# """""""""""""""""""""""""""""""""""""""""""""""""""
57+
# Memory usage during training
58+
# """"""""""""""""""""""""""""
5859
# We are about to look at some memory snapshots, so we should be prepared to
59-
# analyze them properly. People normally consider training memory to consist of
60+
# analyze them properly. Typically, training memory consists of:
6061
#
61-
# 1. Model parameters (size P)
62-
# 2. Activations (size A)
63-
# 3. Gradients, which are the same size as the model parameters, so size G = P
64-
# 4. Optimizer state, which is usually a relation to the model parameters. In
65-
# this case, Adam state requires 2x the model parameters, so size O = 2P
66-
# 5. Intermediate tensors, which are allocated throughout the compute. We will
67-
# not worry about them for now as they are usually small and ephemeral.
62+
# * Model parameters (size P)
63+
# * Activations (size A)
64+
# * Gradients, which are the same size as the model parameters, so size G = P.
65+
# * Optimizer state, which is usually a relation to the model parameters. In
66+
# this case, the state for Adam requires 2x the model parameters, so size O = 2P.
67+
# * Intermediate tensors, which are allocated throughout the compute. We will
68+
# not worry about them for now as they are usually small and ephemeral.
6869
#
6970
# Capturing and visualizing memory snapshots
7071
# """"""""""""""""""""""""""""""""""""""""""
@@ -79,33 +80,33 @@ def train(model, optimizer):
7980
train(model, optimizer)
8081

8182
# save a snapshot of the memory allocations
82-
# s = torch.cuda.memory._snapshot()
83-
# with open(f"snapshot.pickle", "wb") as f:
84-
# dump(s, f)
83+
s = torch.cuda.memory._snapshot()
84+
with open(f"snapshot.pickle", "wb") as f:
85+
dump(s, f)
8586

8687
# tell CUDA to stop recording memory allocations now
8788
torch.cuda.memory._record_memory_history(enabled=None)
8889

8990
###############################################################################
9091
# Now open up the snapshot in Zach Devito's [CUDA Memory Visualizer](
91-
# https://zdevito.github.io/assets/viz/) by dragging the snapshot.pickle file.
92-
# Does the memory timeline match your expectations?
92+
# https://zdevito.github.io/assets/viz/) by dragging and dropping the
93+
# ``snapshot.pickle`` file. Does the memory timeline match your expectations?
9394
#
9495
# .. figure:: /_static/img/optim_step_in_bwd/snapshot.jpg
9596
# :alt: snapshot.png loaded into CUDA Memory Visualizer
9697
#
9798
# The model parameters have already been loaded in memory before the training
9899
# step, so we see a chunk of memory devoted to the weights right off the bat.
99-
# As we start our forward, memory is allocated gradually for the activations, or
100-
# the tensors we are saving to be able to compute gradients in the backward.
101-
# Once we start the backward, the activations are gradually freed while memory
102-
# of the gradients start building up.
100+
# As we start our forward pass, memory is allocated gradually for the activations,
101+
# or the tensors we are saving to be able to compute gradients in the backward pass.
102+
# Once we start the backward pass, the activations are gradually freed while memory
103+
# of the gradients starts building up.
103104
#
104105
# Lastly, as the optimizer kicks in, its state will be lazily initialized, so we
105106
# should see the optimizer state memory gradually increase during the optimizer
106107
# step of the first training loop only. In future loops, the optimizer memory
107-
# will remain and be updated in-place. The memory for the gradients are then
108-
# freed accordingly at the end of every training loop when zero_grad is called.
108+
# will remain and be updated in-place. The memory for the gradients is then
109+
# freed accordingly at the end of every training loop when ``zero_grad`` is called.
109110
#
110111
# Where is the memory bottleneck in this training loop? Or, in other words,
111112
# where is the peak memory?
@@ -115,23 +116,23 @@ def train(model, optimizer):
115116
# the optimizer state as expected. The last ~1.2GB comes from Adam optimizer
116117
# requiring memory for intermediates, totalling to ~6GB of peak memory.
117118
# Technically, you can remove the need for the last 1.2GB for optimizer
118-
# intermediates if you set `Adam(model.parameters(), foreach=False)` which
119-
# would trade off runtime for memory. If switching off the foreach runtime
119+
# intermediates if you set ``Adam(model.parameters(), foreach=False)`` which
120+
# would trade off runtime for memory. If switching off the ``foreach`` runtime
120121
# optimization is sufficient in memory savings for you, nice, but please
121122
# read on if you're curious how this tutorial can help you do better!
122123
# With the technique we will soon introduce, we will reduce peak memory by
123124
# removing the need for the ~1.2GB of **gradients memory** as well as **optimizer
124-
# intermediates memory**. Now do the math--what would the new peak memory be?
125-
# The answer will be revealed in the next snapshot :D.
125+
# intermediates memory**. Now, what would you expect the new peak memory to be?
126+
# The answer will be revealed in the `next` snapshot.
126127
#
127-
# DISCLAIMER. Yes. Is the time for the disclaimer:
128-
#
129-
# Is this technique applicable to you? Only if gradients take up sizable memory.
130-
# """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
131-
# The technique of fusing the optimizer step into the backward only targets
132-
# reducing *gradient* memory (and as a side effect also optimizer intermediates
133-
# memory). In our example above, the gradients eat up 20% of the memory pie, which
134-
# is quite sizable!
128+
# DISCLAIMER: This technique is **not** for all
129+
# """""""""""""""""""""""""""""""""""""""""""""
130+
# Before we get too excited, we have to consider whether this technique is applicable
131+
# for `your` use case. This is NOT a silver bullet! The technique of fusing the
132+
# optimizer step into the backward only targets reducing *gradient* memory (and as a side effect also optimizer intermediates
133+
# memory). Thus, the more sizable the memory taken up by the gradients, the more
134+
# tantamount the memory reduction. In our example above, the gradients eat up 20%
135+
# of the memory pie, which is quite sizable!
135136
#
136137
# This may not be the case for you, for example, if your weights are already tiny,
137138
# (say, due to applying LoRa,) then the gradients do not take much space in your
@@ -140,35 +141,36 @@ def train(model, optimizer):
140141
# training, quantization, or reducing the batch size. Then, when the gradients
141142
# are part of the bottleneck again, come back to this tutorial!
142143
#
143-
# Still here? Cool, let's introduce our new `register_post_accumulate_grad_hook(hook)`
144+
# Still here? Cool, let's introduce our new ``register_post_accumulate_grad_hook(hook)``
144145
# API on Tensor.
145146
#
146-
# `Tensor.register_post_accumulate_grad_hook(hook)` API and our technique
147+
# ``Tensor.register_post_accumulate_grad_hook(hook)`` API and our technique
147148
# """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
148-
# Our technique relies on not having to save the gradients during `backward()`. Instead,
149+
# Our technique relies on not having to save the gradients during ``backward()``. Instead,
149150
# once a gradient has been accumulated, we will immediately apply the optimizer to
150151
# the corresponding parameter and drop that gradient entirely! This removes the need
151152
# for holding onto a big buffer of gradients until the optimizer step.
152153
#
153-
# So how can we unlock the behavior of applying the optimizer more eagerly? Well,
154-
# in 2.1, we've added a new API [`Tensor.register_post_accumulate_grad_hook(hook)`](
155-
# https://pytorch.org/docs/main/generated/torch.Tensor.register_post_accumulate_grad_hook.html#torch.Tensor.register_post_accumulate_grad_hook)
156-
# that would allow us to add a hook onto a Tensor once its `.grad` field has been
154+
# So how can we unlock the behavior of applying the optimizer more eagerly? In our 2.1
155+
# release, we've added a new API :func:`torch.Tensor.register_post_accumulate_grad_hook`
156+
# that would allow us to add a hook onto a Tensor once its ``.grad`` field has been
157157
# accumulated. We will encapsulate the optimizer step into this hook. How?
158158
#
159159
# How everything fits together in 10 lines
160160
# """"""""""""""""""""""""""""""""""""""""
161161
# Remember our model and optimizer setup from the beginning? I'll leave them commented
162162
# out below so we don't spend resources rerunning the code.
163+
#
164+
# .. code-block:: python
165+
#
166+
# model = models.vit_l_16(weights='DEFAULT').cuda()
167+
# optimizer = torch.optim.Adam(model.parameters())
163168

164-
# model = models.vit_l_16(weights='DEFAULT').cuda()
165-
# optimizer = torch.optim.Adam(model.parameters())
166-
167-
# Instead of having just *one* optimizer, we will have a Dict of optimizers
169+
# Instead of having just *one* optimizer, we will have a ``dict`` of optimizers
168170
# for every parameter so we could reference them in our hook.
169171
optimizer_dict = {p: torch.optim.Adam([p]) for p in model.parameters()}
170172

171-
# Define our hook, which will call the optimizer `step()` and `zero_grad()`
173+
# Define our hook, which will call the optimizer ``step()`` and ``zero_grad()``
172174
def optimizer_hook(parameter) -> None:
173175
optimizer_dict[parameter].step()
174176
optimizer_dict[parameter].zero_grad()
@@ -177,7 +179,7 @@ def optimizer_hook(parameter) -> None:
177179
for p in model.parameters():
178180
p.register_post_accumulate_grad_hook(optimizer_hook)
179181

180-
# Now remember our previous `train()` function? Since the optimizer has been
182+
# Now remember our previous ``train()`` function? Since the optimizer has been
181183
# fused into the backward, we can remove the optimizer step and zero_grad calls.
182184
def train(model):
183185
# create our fake image input: tensor shape is batch_size, channels, height, width
@@ -192,20 +194,21 @@ def train(model):
192194
# optimizer.zero_grad()
193195

194196
########################################################################
195-
# I believe that was about 10 lines of changes in our sample model. I do
196-
# recognize that it could be a fairly intrusive change to switch out the
197-
# optimizer for a optimizer dictionary, especially for those who use
198-
# `LRScheduler`s or manipulate optimizer configuration throughout the
197+
# That took about 10 lines of changes in our sample model, which is neat.
198+
# However, for real models, it could be a fairly intrusive change to switch
199+
# out the optimizer for an optimizer dictionary, especially for those who use
200+
# ``LRScheduler``s or manipulate optimizer configuration throughout the
199201
# training epochs. Working out this API with those changes will be more
200-
# involved and likely requires moving more configuration into global
202+
# involved and will likely require moving more configuration into global
201203
# state but should not be impossible. That said, a next step for us is
202204
# to make this API easier to adopt with LRSchedulers and other features
203205
# you are already used to.
204206
#
205207
# But let me get back to convincing you that this technique is worth it.
206208
# We will consult our friend, the memory snapshot.
207209

208-
# del optimizer memory from before to get a clean slate
210+
# del optimizer memory from before to get a clean slate for the next
211+
# memory snapshot
209212
del optimizer
210213

211214
# tell CUDA to start recording memory allocations
@@ -230,21 +233,21 @@ def train(model):
230233
# :alt: snapshot.png loaded into CUDA Memory Visualizer
231234
#
232235
# Several major observations:
233-
# 1. There is no more optimizer step! Right...we fused that into the backward.
234-
# 2. Likewise, the backward drags longer and there are more random allocations
235-
# for intermediates. This is expected, as the optimizer step requires
236-
# intermediates.
237-
# 3. Most importantly! The peak memory is lower! It is now ~4GB (which I
238-
# hope is close to your answer earlier :).)
236+
# 1. There is no more optimizer step! Right...we fused that into the backward.
237+
# 2. Likewise, the backward drags longer and there are more random allocations
238+
# for intermediates. This is expected, as the optimizer step requires
239+
# intermediates.
240+
# 3. Most importantly! The peak memory is lower! It is now ~4GB (which I
241+
# hope maps closely to your earlier expectation).
239242
#
240243
# Note that there is no longer any big chunk of memory allocated for the gradients
241244
# compared to before, accounting for ~1.2GB of memory savings. Instead, we've freed
242245
# each gradient very quickly after they've been computed by moving the optimizer
243-
# step as far ahead as we can. Woohoo! By the way, the other ~1.2GB of memory savings
246+
# step as far ahead as we can. Woo-hoo! By the way, the other ~1.2GB of memory savings
244247
# comes from breaking apart the optimizer into per-parameter optimizers, so the
245-
# intermediates have proportionally shrunk. I'd like to stress this detail is less
246-
# important than the gradient memory savings, as you can get optimizer intermediates
247-
# savings from just turning `foreach=False` without this technique.
248+
# intermediates have proportionally shrunk. This detail is `less important` than
249+
# the gradient memory savings, as you can get optimizer intermediates savings
250+
# from just turning ``foreach=False`` without this technique.
248251
#
249252
# You may be correctly wondering: if we saved 2.4GB of memory, why is the peak memory
250253
# NOT 6GB - 2.4GB = 3.6GB? Well, the peak has moved! The peak is now near the start
@@ -254,10 +257,10 @@ def train(model):
254257
# imagine that this technique can be coupled with activations checkpointing for more
255258
# memory wins.
256259
#
257-
# Recap
258-
# """""
260+
# Conclusion
261+
# """"""""""
259262
# In this tutorial, we learned about the memory saving technique of
260263
# fusing the optimizer into the backward step through the new
261-
# `Tensor.register_post_accumulate_grad_hook()` API and *when* to apply this
264+
# ``Tensor.register_post_accumulate_grad_hook()`` API and *when* to apply this
262265
# technique (when gradients memory is significant). Along the way, we also learned
263266
# about memory snapshots, which are generally useful in memory optimization.

0 commit comments

Comments
 (0)