15
15
"""
16
16
17
17
18
- ########################################
19
- # Let us consider a simple ``nn.Module``
18
+ ###############################################################################
19
+ # Let us consider a simple ``nn.Module`` that contains a list of Linear layers:
20
20
import torch
21
21
from torch import nn
22
22
import time
@@ -33,35 +33,41 @@ def forward(self, x):
33
33
m = SomeModule (1000 )
34
34
torch .save (m .state_dict (), 'checkpoint.pth' )
35
35
36
- #################################################################
37
- # The follow snippet demonstrates the use of the three utilities.
36
+ ##############################################################################
37
+ # The follow snippet demonstrates the use of the the ``mmap`` keyword argument
38
+ # to ``torch.load``, the ``torch.device()`` context manager and the ``assign``
39
+ # keyword argument to ``nn.Module.load_state_dict()``.
38
40
39
41
state_dict = torch .load ('checkpoint.pth' , mmap = True )
40
42
with torch .device ('meta' ):
41
43
meta_m = SomeModule (1000 )
42
44
meta_m .load_state_dict (state_dict , assign = True )
43
45
44
46
#############################################################################
45
- # Taking a step back, let us inspect the following more vanilla code snippet
46
- # that does not use any of the features listed above:
47
+ # Compare the snippet below to the one above:
47
48
48
49
state_dict = torch .load ('checkpoint.pth' )
49
50
m = SomeModule (1000 )
50
51
m .load_state_dict (state_dict )
51
52
52
- #################################################################################
53
+ # The second example does not use any of the features listed above and will be
54
+ # less compute and memory efficient for loading a checkpoint. In the following
55
+ # sections, we will discuss each of the features in further detail.
56
+
57
+ #####################################################################################
53
58
# Using ``torch.load(mmap=True)``
54
59
# -------------------------------
55
- # First let us consider what happens when we ``torch.load`` the checkpoint .
56
- # At ``torch.save`` time , tensor storages are tagged with the device they are
57
- # saved on. At ``torch.load`` time , tensor storages will be loaded to the device
60
+ # First, let us consider what happens when we load the checkpoint with ``torch.load``.
61
+ # When we save a checkpoint with ``torch.save``, tensor storages are tagged with the device they are
62
+ # saved on. With ``torch.load``, tensor storages will be loaded to the device
58
63
# they were tagged with (unless this behavior is overridden using the
59
64
# ``map_location`` flag). For ease of explanation, let us assume that the tensors
60
65
# were saved on CPU. This means that on the first line all tensor storages will be
61
- # loaded into CPU RAM, which can be undesirable when
62
- # 1. CPU RAM is smaller than the size of the checkpoint
63
- # 2. Waiting for the entire checkpoint to be loaded into RAM before
64
- # doing for example some per-tensor processing
66
+ # loaded into CPU RAM, which can be undesirable when:
67
+ #
68
+ # * CPU RAM is smaller than the size of the checkpoint.
69
+ # * Waiting for the entire checkpoint to be loaded into RAM before
70
+ # performing, for example, some per-tensor processing.
65
71
66
72
start_time = time .time ()
67
73
state_dict = torch .load ('checkpoint.pth' )
@@ -83,7 +89,7 @@ def forward(self, x):
83
89
84
90
######################################################################################
85
91
# As mentioned above, one can use this argument to do per-tensor processing on a
86
- # checkpoint without loading all tensor storages into CPU memory upfront. For example,
92
+ # checkpoint without loading all tensor storages into CPU memory upfront. For example:
87
93
def my_special_routine (t , device ):
88
94
# this could be a much fancier operation
89
95
return t .to (dtype = torch .bfloat16 , device = device )
@@ -92,35 +98,35 @@ def my_processing_function(key, device):
92
98
t = state_dict [key ]
93
99
processed_t = my_special_routine (t , device )
94
100
del t
95
- return processed_t
101
+ state_dict [ key ] = processed_t
96
102
97
103
for key in state_dict .keys ():
98
- device = torch .device ('cuda:' + str ( int ( key . lstrip ( "linears." )[ 0 ]) % 8 ) )
99
- state_dict [ key ] = my_processing_function (key , device )
104
+ device = torch .device ('cuda' )
105
+ my_processing_function (key , device )
100
106
101
- ##############################################
107
+ ##################################################
102
108
# Using ``torch.device('meta')``
103
109
# ------------------------------
104
- # Next, we consider the creation of the module.
110
+ # Next, let's consider the creation of the module.
105
111
m = SomeModule (1000 )
106
112
107
113
#######################################################################################################
108
114
# This allocates memory for all parameters/buffers and initializes them per
109
115
# the default initialization schemes defined in ``SomeModule.__init__()``, which
110
- # is wasteful when we want to load a checkpoint as
111
- # 1. The result of the initialization kernels will be overwritten by ``load_state_dict()``
112
- # without ever being used, so initialization is wasteful.
113
- # 2. We are allocating memory for these parameters/buffers in RAM while ``torch.load`` of
114
- # the saved state dictionary also allocates memory for the parameters/buffers in the checkpoint.
116
+ # is wasteful when we want to load a checkpoint for the following reasons:
117
+ # * The result of the initialization kernels will be overwritten by ``load_state_dict()``
118
+ # without ever being used, so initialization is wasteful.
119
+ # * We are allocating memory for these parameters/buffers in RAM while ``torch.load`` of
120
+ # the saved state dictionary also allocates memory in RAM for the parameters/buffers in the checkpoint.
115
121
#
116
122
# In order to solve these two problems, we can use the ``torch.device()``
117
123
# context manager with ``device='meta'`` when we instantiate the ``nn.Module()``.
118
124
#
119
125
# The `torch.device() <https://pytorch.org/docs/main/tensor_attributes.html#torch-device>`_
120
126
# context manager makes sure that factory calls will be performed as if they
121
127
# were passed the specified ``device`` as an argument. Tensors on ``torch.device('meta')`` do not
122
- # carry data. However, they possess all other metadata a tensor carries such as ``.size()``, ``.stride()``
123
- # and ``.requires_grad`` etc .
128
+ # carry data. However, they possess all other metadata a tensor carries such as ``.size()``, ``.stride()``,
129
+ # ``.requires_grad``, and others .
124
130
with torch .device ('meta' ):
125
131
new_m = SomeModule (1000 )
126
132
@@ -131,11 +137,11 @@ def my_processing_function(key, device):
131
137
132
138
m .load_state_dict (state_dict )
133
139
134
- ###############################################################################
140
+ ######################################################################################
135
141
# ``nn.Module.load_state_dict()`` is usually implemented via an in-place
136
- # ``param_in_model.copy_(param_in_state_dict)`` (i.e. a copy from the
137
- # parameter/buffer with the corresponding key in the state dictionary into
138
- # the parameter/buffer in the ``nn.Module``) .
142
+ # ``param_in_model.copy_(param_in_state_dict)``. This means that the parameter/buffer
143
+ # with the corresponding key in the state dictionary is copied into the
144
+ # parameter/buffer in the ``nn.Module``.
139
145
#
140
146
# However, an in-place copy into a tensor on the ``meta`` device is a no-op.
141
147
# In order to avoid this, we can pass the ``assign=True`` keyword argument to
@@ -150,7 +156,10 @@ def my_processing_function(key, device):
150
156
opt = torch .optim .SGD (new_m .parameters (), lr = 1e-3 )
151
157
152
158
###############################################################################
159
+ # Conclusion
160
+ # -------------
161
+ #
153
162
# To recap, in this tutorial we learned about ``torch.load(mmap=True)``, the
154
- # ``torch.device()`` context manager with ``device=meta`` and
163
+ # ``torch.device()`` context manager with ``device=meta``, and
155
164
# ``nn.Module.load_state_dict(assign=True)`` as well as how these tools could
156
165
# be used to aid when loading a model from a checkpoint.
0 commit comments