Skip to content

Commit aefee66

Browse files
committed
[WIP] Optimizer step in backward tutorial
1 parent 677c1b6 commit aefee66

File tree

1 file changed

+155
-0
lines changed

1 file changed

+155
-0
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""
2+
3+
How to save memory by fusing the optimizer step into the backward pass
4+
======================================================================
5+
6+
Hello there! This tutorial aims to showcase one way of reducing the
7+
memory footprint of a training loop by reducing the memory taken by
8+
the gradients. Say you have a model and you're interested in ways to
9+
optimize memory to avoid OOMing or simply to ooze more out of your GPU.
10+
Well, you _might_ be in luck! We will explore
11+
12+
1. What takes up memory during your training or finetuning loop,
13+
2. Capturing and visualizing memory snapshots to determine the memory bottleneck,
14+
3. The new `tensor.post_accumulate_grad_hook(hook)` API, and finally, if relevant,
15+
4. How everything fits together in 10 lines to achieve memory savings
16+
17+
The ingredients and tools required:
18+
1. PyTorch 2.1.0 or newer with torchvision
19+
2. A CUDA GPU
20+
21+
Let us start by importing the required modules and models. We will use a
22+
vision transformer model from torchvision, but feel free to substitute with
23+
your own model. We will also use `torch.optim.Adam` as our optimizer, but,
24+
again, feel free to substitute with your own optimizer.
25+
26+
"""
27+
28+
import torch
29+
from torchvision import models
30+
from pickle import dump
31+
32+
model = models.vit_l_16(weights='DEFAULT').cuda()
33+
optimizer = torch.optim.Adam(model.parameters())
34+
35+
###############################################################################
36+
# Now let's define our typical training loop. You should use real images when
37+
# training, but for the purposes of this tutorial, we are passing in fake
38+
# inputs and not worrying about loading actual data.
39+
40+
IMAGE_SIZE = 224
41+
42+
def train(model, optimizer):
43+
# create our fake image input: tensor shape is batch_size, channels, height, width
44+
fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()
45+
46+
# call our forward and backward
47+
loss = model.forward(fake_image)
48+
loss.sum().backward()
49+
50+
# optimizer update
51+
optimizer.step()
52+
optimizer.zero_grad()
53+
54+
###############################################################################
55+
# So what comprises the memory usage during training?
56+
# """""""""""""""""""""""""""""""""""""""""""""""""""
57+
# We are about to look at some memory snapshots, so we should be prepared to
58+
# analyze them properly. People normally consider training memory to consist of
59+
#
60+
# 1. Model parameters (size P)
61+
# 2. Activations (size A)
62+
# 3. Gradients, which are the same size as the model parameters, so size G = P
63+
# 4. Optimizer state, which is usually a relation to the model parameters. In
64+
# this case, Adam state requires 2x the model parameters, so size O = 2P
65+
# 5. Intermediate tensors, which are allocated throughout the compute. We will
66+
# not worry about them for now as they are usually small and ephemeral.
67+
#
68+
# Let's get us a memory snapshot! As your code runs, consider what you may expect
69+
# the CUDA memory timeline to look like.
70+
71+
# tell CUDA to start recording memory allocations
72+
torch.cuda.memory._record_memory_history()
73+
74+
# train 3 steps
75+
train(model, optimizer)
76+
77+
# save a snapshot of the memory allocations
78+
s = torch.cuda.memory._snapshot()
79+
with open(f"snapshot.pickle", "wb") as f:
80+
dump(s, f)
81+
82+
raise RuntimeError("Stop here and open up the snapshot in Zach Devito's CUDA Memory Visualizer")
83+
84+
###############################################################################
85+
# Now open up the snapshot in Zach Devito's [CUDA Memory Visualizer](
86+
# https://zdevito.github.io/assets/viz/) by dragging the snapshot.pickle file.
87+
# Does the memory timeline match your expectations?
88+
#
89+
# The model parameters have already been loaded in memory before the training
90+
# step, so we anticipate seeing a chunk of memory devoted to the weights right
91+
# off the bat. As we start our forward, memory should be allocated gradually
92+
# for the activations, or the tensors we are saving to be able to compute gradients
93+
# in the backward. Once we start the backward, the activations should be gradually
94+
# freed while memory of the gradients start building up.
95+
#
96+
# Lastly, as the optimizer kicks in, its state will be lazily initialized, so we
97+
# should see the optimizer state memory gradually increase during the end of the
98+
# first training loop only. In future loops, the optimizer memory will remain and
99+
# be inplace updated. The memory for the gradients should be freed accordingly
100+
# by the end of every training loop.
101+
102+
m = SomeModule()
103+
104+
###############################################################################
105+
# This allocates memory for all parameters/buffers and initializes them per
106+
# the default initialization schemes defined in `SomeModule.__init__()`, which
107+
# is wasteful when we want to load a checkpoint as
108+
# 1. We are running the initialization kernels where the results are
109+
# immediately overwritten by `load_state_dict()`
110+
# 2. We are allocating memory for these parameters/buffers in RAM while
111+
# `torch.load` of the saved state dictionary also allocates memory for
112+
# the parameters/buffers in the checkpoint.
113+
#
114+
# In order to solve these two problems, we can use the `torch.device()`
115+
# context manager with `device='meta'` when we instantiate the `nn.Module()`.
116+
117+
with torch.device('meta'):
118+
meta_m = SomeModule()
119+
120+
###############################################################################
121+
# The [`torch.device()`](https://pytorch.org/docs/main/tensor_attributes.html#torch-device)
122+
# context manager makes sure that factory calls will be performed as if they
123+
# were passed device as an argument. However, it does not affect factory
124+
# function calls which are called with an explicit device argument.
125+
#
126+
# Tensors on the `meta` device do not carry data. However, they possess all
127+
# other metadata a tensor carries such as `.size()` and `.stride()`,
128+
# `.requires_grad` etc.
129+
#
130+
# Next, we consider the loading of the state dictionary.
131+
132+
m.load_state_dict(state_dict)
133+
134+
###############################################################################
135+
# `nn.Module.load_state_dict()` is usually implemented via an in-place
136+
# `param_in_model.copy_(param_in_state_dict)` (i.e. a copy from the
137+
# parameter/buffer with the corresponding key in the state dictionary into
138+
# the parameters/buffers in the `nn.Module`).
139+
#
140+
# However, an in-place copy into a tensor on the `meta` device is a no-op.
141+
# In order to avoid this, we can pass the `assign=True` keyword argument to
142+
# `load_state_dict()`.
143+
144+
meta_m.load_state_dict(state_dict, assign=True)
145+
146+
###############################################################################
147+
# Another caveat here is that since optimizers hold a reference to
148+
# `nn.Module.parameters()`, the optimizer must be initialized after the module
149+
# is loaded from state dict if `assign=True` is passed.
150+
151+
###############################################################################
152+
# To recap, in this tutorial, we learned about `torch.load(mmap=True)`, the
153+
# `torch.device()` context manager with `device=meta` and the
154+
# `nn.Module.load_state_dict(assign=True)` and how these tools could be used
155+
# to aid when loading a model from a checkpoint.

0 commit comments

Comments
 (0)