18
18
19
19
To run this tutorial, you will need:
20
20
* PyTorch 2.1.0 or newer with ``torchvision``
21
- * 1 CUDA GPU
21
+ * 1 CUDA GPU if you'd like to run the memory visualizations locally.
22
+ Otherwise, this technique would benefit similarly on any device.
22
23
23
24
Let us start by importing the required modules and models. We will use a
24
25
vision transformer model from torchvision, but feel free to substitute
@@ -60,9 +61,9 @@ def train(model, optimizer):
60
61
# analyze them properly. Typically, training memory consists of:
61
62
#
62
63
# * Model parameters (size P)
63
- # * Activations (size A)
64
+ # * Activations that are saved for the backward pass (size A)
64
65
# * Gradients, which are the same size as the model parameters, so size G = P.
65
- # * Optimizer state, which is usually a relation to the model parameters. In
66
+ # * Optimizer state, which is proportional to the size of the parameters. In
66
67
# this case, the state for Adam requires 2x the model parameters, so size O = 2P.
67
68
# * Intermediate tensors, which are allocated throughout the compute. We will
68
69
# not worry about them for now as they are usually small and ephemeral.
@@ -88,8 +89,8 @@ def train(model, optimizer):
88
89
torch .cuda .memory ._record_memory_history (enabled = None )
89
90
90
91
###############################################################################
91
- # Now open up the snapshot in Zach Devito's [ CUDA Memory Visualizer](
92
- # https://zdevito.github.io/assets/viz/) by dragging and dropping the
92
+ # Now open up the snapshot in the CUDA Memory Visualizer at
93
+ # https://pytorch.org/memory_viz by dragging and dropping the
93
94
# ``snapshot.pickle`` file. Does the memory timeline match your expectations?
94
95
#
95
96
# .. figure:: /_static/img/optim_step_in_bwd/snapshot.jpg
@@ -114,7 +115,7 @@ def train(model, optimizer):
114
115
# The peak memory usage is during the optimizer step! Note the memory then
115
116
# consists of ~1.2GB of params, ~1.2GB of gradients, and ~2.4GB=2*1.2GB of
116
117
# the optimizer state as expected. The last ~1.2GB comes from Adam optimizer
117
- # requiring memory for intermediates, totalling to ~6GB of peak memory.
118
+ # requiring memory for intermediates, totaling to ~6GB of peak memory.
118
119
# Technically, you can remove the need for the last 1.2GB for optimizer
119
120
# intermediates if you set ``Adam(model.parameters(), foreach=False)`` which
120
121
# would trade off runtime for memory. If switching off the ``foreach`` runtime
@@ -168,7 +169,7 @@ def train(model, optimizer):
168
169
169
170
# Instead of having just *one* optimizer, we will have a ``dict`` of optimizers
170
171
# for every parameter so we could reference them in our hook.
171
- optimizer_dict = {p : torch .optim .Adam ([p ]) for p in model .parameters ()}
172
+ optimizer_dict = {p : torch .optim .Adam ([p ], foreach = False ) for p in model .parameters ()}
172
173
173
174
# Define our hook, which will call the optimizer ``step()`` and ``zero_grad()``
174
175
def optimizer_hook (parameter ) -> None :
@@ -200,14 +201,14 @@ def train(model):
200
201
# ``LRScheduler``s or manipulate optimizer configuration throughout the
201
202
# training epochs. Working out this API with those changes will be more
202
203
# involved and will likely require moving more configuration into global
203
- # state but should not be impossible. That said, a next step for us is
204
- # to make this API easier to adopt with LRSchedulers and other features
204
+ # state but should not be impossible. That said, a next step for PyTorch
205
+ # is to make this API easier to adopt with LRSchedulers and other features
205
206
# you are already used to.
206
207
#
207
208
# But let me get back to convincing you that this technique is worth it.
208
209
# We will consult our friend, the memory snapshot.
209
210
210
- # del optimizer memory from before to get a clean slate for the next
211
+ # delete optimizer memory from before to get a clean slate for the next
211
212
# memory snapshot
212
213
del optimizer
213
214
@@ -243,7 +244,7 @@ def train(model):
243
244
# Note that there is no longer any big chunk of memory allocated for the gradients
244
245
# compared to before, accounting for ~1.2GB of memory savings. Instead, we've freed
245
246
# each gradient very quickly after they've been computed by moving the optimizer
246
- # step as far ahead as we can. Woo-hoo ! By the way, the other ~1.2GB of memory savings
247
+ # step as far ahead as we can. Woohoo ! By the way, the other ~1.2GB of memory savings
247
248
# comes from breaking apart the optimizer into per-parameter optimizers, so the
248
249
# intermediates have proportionally shrunk. This detail is `less important` than
249
250
# the gradient memory savings, as you can get optimizer intermediates savings
0 commit comments