1
+ """
2
+
3
+ How to save memory by fusing the optimizer step into the backward pass
4
+ ======================================================================
5
+
6
+ Hello there! This tutorial aims to showcase one way of reducing the
7
+ memory footprint of a training loop by reducing the memory taken by
8
+ the gradients. Say you have a model and you're interested in ways to
9
+ optimize memory to avoid OOMing or simply to ooze more out of your GPU.
10
+ Well, you _might_ be in luck! We will explore
11
+
12
+ 1. What takes up memory during your training or finetuning loop,
13
+ 2. Capturing and visualizing memory snapshots to determine the memory bottleneck,
14
+ 3. The new `tensor.post_accumulate_grad_hook(hook)` API, and finally, if relevant,
15
+ 4. How everything fits together in 10 lines to achieve memory savings
16
+
17
+ The ingredients and tools required:
18
+ 1. PyTorch 2.1.0 or newer with torchvision
19
+ 2. A CUDA GPU
20
+
21
+ Let us start by importing the required modules and models. We will use a
22
+ vision transformer model from torchvision, but feel free to substitute with
23
+ your own model. We will also use `torch.optim.Adam` as our optimizer, but,
24
+ again, feel free to substitute with your own optimizer.
25
+
26
+ """
27
+
28
+ import torch
29
+ from torchvision import models
30
+ from pickle import dump
31
+
32
+ model = models .vit_l_16 (weights = 'DEFAULT' ).cuda ()
33
+ optimizer = torch .optim .Adam (model .parameters ())
34
+
35
+ ###############################################################################
36
+ # Now let's define our typical training loop. You should use real images when
37
+ # training, but for the purposes of this tutorial, we are passing in fake
38
+ # inputs and not worrying about loading actual data.
39
+
40
+ IMAGE_SIZE = 224
41
+
42
+ def train (model , optimizer ):
43
+ # create our fake image input: tensor shape is batch_size, channels, height, width
44
+ fake_image = torch .rand (1 , 3 , IMAGE_SIZE , IMAGE_SIZE ).cuda ()
45
+
46
+ # call our forward and backward
47
+ loss = model .forward (fake_image )
48
+ loss .sum ().backward ()
49
+
50
+ # optimizer update
51
+ optimizer .step ()
52
+ optimizer .zero_grad ()
53
+
54
+ ###############################################################################
55
+ # So what comprises the memory usage during training?
56
+ # """""""""""""""""""""""""""""""""""""""""""""""""""
57
+ # We are about to look at some memory snapshots, so we should be prepared to
58
+ # analyze them properly. People normally consider training memory to consist of
59
+ #
60
+ # 1. Model parameters (size P)
61
+ # 2. Activations (size A)
62
+ # 3. Gradients, which are the same size as the model parameters, so size G = P
63
+ # 4. Optimizer state, which is usually a relation to the model parameters. In
64
+ # this case, Adam state requires 2x the model parameters, so size O = 2P
65
+ # 5. Intermediate tensors, which are allocated throughout the compute. We will
66
+ # not worry about them for now as they are usually small and ephemeral.
67
+ #
68
+ # Let's get us a memory snapshot! As your code runs, consider what you may expect
69
+ # the CUDA memory timeline to look like.
70
+
71
+ # tell CUDA to start recording memory allocations
72
+ torch .cuda .memory ._record_memory_history ()
73
+
74
+ # train 3 steps
75
+ train (model , optimizer )
76
+
77
+ # save a snapshot of the memory allocations
78
+ s = torch .cuda .memory ._snapshot ()
79
+ with open (f"snapshot.pickle" , "wb" ) as f :
80
+ dump (s , f )
81
+
82
+ raise RuntimeError ("Stop here and open up the snapshot in Zach Devito's CUDA Memory Visualizer" )
83
+
84
+ ###############################################################################
85
+ # Now open up the snapshot in Zach Devito's [CUDA Memory Visualizer](
86
+ # https://zdevito.github.io/assets/viz/) by dragging the snapshot.pickle file.
87
+ # Does the memory timeline match your expectations?
88
+ #
89
+ # The model parameters have already been loaded in memory before the training
90
+ # step, so we anticipate seeing a chunk of memory devoted to the weights right
91
+ # off the bat. As we start our forward, memory should be allocated gradually
92
+ # for the activations, or the tensors we are saving to be able to compute gradients
93
+ # in the backward. Once we start the backward, the activations should be gradually
94
+ # freed while memory of the gradients start building up.
95
+ #
96
+ # Lastly, as the optimizer kicks in, its state will be lazily initialized, so we
97
+ # should see the optimizer state memory gradually increase during the end of the
98
+ # first training loop only. In future loops, the optimizer memory will remain and
99
+ # be inplace updated. The memory for the gradients should be freed accordingly
100
+ # by the end of every training loop.
101
+
102
+ m = SomeModule ()
103
+
104
+ ###############################################################################
105
+ # This allocates memory for all parameters/buffers and initializes them per
106
+ # the default initialization schemes defined in `SomeModule.__init__()`, which
107
+ # is wasteful when we want to load a checkpoint as
108
+ # 1. We are running the initialization kernels where the results are
109
+ # immediately overwritten by `load_state_dict()`
110
+ # 2. We are allocating memory for these parameters/buffers in RAM while
111
+ # `torch.load` of the saved state dictionary also allocates memory for
112
+ # the parameters/buffers in the checkpoint.
113
+ #
114
+ # In order to solve these two problems, we can use the `torch.device()`
115
+ # context manager with `device='meta'` when we instantiate the `nn.Module()`.
116
+
117
+ with torch .device ('meta' ):
118
+ meta_m = SomeModule ()
119
+
120
+ ###############################################################################
121
+ # The [`torch.device()`](https://pytorch.org/docs/main/tensor_attributes.html#torch-device)
122
+ # context manager makes sure that factory calls will be performed as if they
123
+ # were passed device as an argument. However, it does not affect factory
124
+ # function calls which are called with an explicit device argument.
125
+ #
126
+ # Tensors on the `meta` device do not carry data. However, they possess all
127
+ # other metadata a tensor carries such as `.size()` and `.stride()`,
128
+ # `.requires_grad` etc.
129
+ #
130
+ # Next, we consider the loading of the state dictionary.
131
+
132
+ m .load_state_dict (state_dict )
133
+
134
+ ###############################################################################
135
+ # `nn.Module.load_state_dict()` is usually implemented via an in-place
136
+ # `param_in_model.copy_(param_in_state_dict)` (i.e. a copy from the
137
+ # parameter/buffer with the corresponding key in the state dictionary into
138
+ # the parameters/buffers in the `nn.Module`).
139
+ #
140
+ # However, an in-place copy into a tensor on the `meta` device is a no-op.
141
+ # In order to avoid this, we can pass the `assign=True` keyword argument to
142
+ # `load_state_dict()`.
143
+
144
+ meta_m .load_state_dict (state_dict , assign = True )
145
+
146
+ ###############################################################################
147
+ # Another caveat here is that since optimizers hold a reference to
148
+ # `nn.Module.parameters()`, the optimizer must be initialized after the module
149
+ # is loaded from state dict if `assign=True` is passed.
150
+
151
+ ###############################################################################
152
+ # To recap, in this tutorial, we learned about `torch.load(mmap=True)`, the
153
+ # `torch.device()` context manager with `device=meta` and the
154
+ # `nn.Module.load_state_dict(assign=True)` and how these tools could be used
155
+ # to aid when loading a model from a checkpoint.
0 commit comments