Skip to content

Commit ec59881

Browse files
Address comments
1 parent 854ca11 commit ec59881

File tree

1 file changed

+41
-32
lines changed

1 file changed

+41
-32
lines changed

recipes_source/recipes/module_load_state_dict_tips.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
"""
1616

1717

18-
########################################
19-
# Let us consider a simple ``nn.Module``
18+
###############################################################################
19+
# Let us consider a simple ``nn.Module`` that contains a list of Linear layers:
2020
import torch
2121
from torch import nn
2222
import time
@@ -33,35 +33,41 @@ def forward(self, x):
3333
m = SomeModule(1000)
3434
torch.save(m.state_dict(), 'checkpoint.pth')
3535

36-
#################################################################
37-
# The follow snippet demonstrates the use of the three utilities.
36+
##############################################################################
37+
# The follow snippet demonstrates the use of the the ``mmap`` keyword argument
38+
# to ``torch.load``, the ``torch.device()`` context manager and the ``assign``
39+
# keyword argument to ``nn.Module.load_state_dict()``.
3840

3941
state_dict = torch.load('checkpoint.pth', mmap=True)
4042
with torch.device('meta'):
4143
meta_m = SomeModule(1000)
4244
meta_m.load_state_dict(state_dict, assign=True)
4345

4446
#############################################################################
45-
# Taking a step back, let us inspect the following more vanilla code snippet
46-
# that does not use any of the features listed above:
47+
# Compare the snippet below to the one above:
4748

4849
state_dict = torch.load('checkpoint.pth')
4950
m = SomeModule(1000)
5051
m.load_state_dict(state_dict)
5152

52-
#################################################################################
53+
# The second example does not use any of the features listed above and will be
54+
# less compute and memory efficient for loading a checkpoint. In the following
55+
# sections, we will discuss each of the features in further detail.
56+
57+
#####################################################################################
5358
# Using ``torch.load(mmap=True)``
5459
# -------------------------------
55-
# First let us consider what happens when we ``torch.load`` the checkpoint.
56-
# At ``torch.save`` time, tensor storages are tagged with the device they are
57-
# saved on. At ``torch.load`` time, tensor storages will be loaded to the device
60+
# First, let us consider what happens when we load the checkpoint with``torch.load``.
61+
# When we save a checkpoint with ``torch.save``, tensor storages are tagged with the device they are
62+
# saved on. With ``torch.load``, tensor storages will be loaded to the device
5863
# they were tagged with (unless this behavior is overridden using the
5964
# ``map_location`` flag). For ease of explanation, let us assume that the tensors
6065
# were saved on CPU. This means that on the first line all tensor storages will be
61-
# loaded into CPU RAM, which can be undesirable when
62-
# 1. CPU RAM is smaller than the size of the checkpoint
63-
# 2. Waiting for the entire checkpoint to be loaded into RAM before
64-
# doing for example some per-tensor processing
66+
# loaded into CPU RAM, which can be undesirable when:
67+
#
68+
# * CPU RAM is smaller than the size of the checkpoint.
69+
# * Waiting for the entire checkpoint to be loaded into RAM before
70+
# performing, for example, some per-tensor processing.
6571

6672
start_time = time.time()
6773
state_dict = torch.load('checkpoint.pth')
@@ -83,7 +89,7 @@ def forward(self, x):
8389

8490
######################################################################################
8591
# As mentioned above, one can use this argument to do per-tensor processing on a
86-
# checkpoint without loading all tensor storages into CPU memory upfront. For example,
92+
# checkpoint without loading all tensor storages into CPU memory upfront. For example:
8793
def my_special_routine(t, device):
8894
# this could be a much fancier operation
8995
return t.to(dtype=torch.bfloat16, device=device)
@@ -92,35 +98,35 @@ def my_processing_function(key, device):
9298
t = state_dict[key]
9399
processed_t = my_special_routine(t, device)
94100
del t
95-
return processed_t
101+
state_dict[key] = processed_t
96102

97103
for key in state_dict.keys():
98-
device = torch.device('cuda:' + str(int(key.lstrip("linears.")[0]) % 8))
99-
state_dict[key] = my_processing_function(key, device)
104+
device = torch.device('cuda')
105+
my_processing_function(key, device)
100106

101-
##############################################
107+
##################################################
102108
# Using ``torch.device('meta')``
103109
# ------------------------------
104-
# Next, we consider the creation of the module.
110+
# Next, let's consider the creation of the module.
105111
m = SomeModule(1000)
106112

107113
#######################################################################################################
108114
# This allocates memory for all parameters/buffers and initializes them per
109115
# the default initialization schemes defined in ``SomeModule.__init__()``, which
110-
# is wasteful when we want to load a checkpoint as
111-
# 1. The result of the initialization kernels will be overwritten by ``load_state_dict()``
112-
# without ever being used, so initialization is wasteful.
113-
# 2. We are allocating memory for these parameters/buffers in RAM while ``torch.load`` of
114-
# the saved state dictionary also allocates memory for the parameters/buffers in the checkpoint.
116+
# is wasteful when we want to load a checkpoint for the following reasons:
117+
# * The result of the initialization kernels will be overwritten by ``load_state_dict()``
118+
# without ever being used, so initialization is wasteful.
119+
# * We are allocating memory for these parameters/buffers in RAM while ``torch.load`` of
120+
# the saved state dictionary also allocates memory in RAM for the parameters/buffers in the checkpoint.
115121
#
116122
# In order to solve these two problems, we can use the ``torch.device()``
117123
# context manager with ``device='meta'`` when we instantiate the ``nn.Module()``.
118124
#
119125
# The `torch.device() <https://pytorch.org/docs/main/tensor_attributes.html#torch-device>`_
120126
# context manager makes sure that factory calls will be performed as if they
121127
# were passed the specified ``device`` as an argument. Tensors on ``torch.device('meta')`` do not
122-
# carry data. However, they possess all other metadata a tensor carries such as ``.size()``, ``.stride()``
123-
# and ``.requires_grad`` etc.
128+
# carry data. However, they possess all other metadata a tensor carries such as ``.size()``, ``.stride()``,
129+
# ``.requires_grad``, and others.
124130
with torch.device('meta'):
125131
new_m = SomeModule(1000)
126132

@@ -131,11 +137,11 @@ def my_processing_function(key, device):
131137

132138
m.load_state_dict(state_dict)
133139

134-
###############################################################################
140+
######################################################################################
135141
# ``nn.Module.load_state_dict()`` is usually implemented via an in-place
136-
# ``param_in_model.copy_(param_in_state_dict)`` (i.e. a copy from the
137-
# parameter/buffer with the corresponding key in the state dictionary into
138-
# the parameter/buffer in the ``nn.Module``).
142+
# ``param_in_model.copy_(param_in_state_dict)``. This means that the parameter/buffer
143+
# with the corresponding key in the state dictionary is copied into the
144+
# parameter/buffer in the ``nn.Module``.
139145
#
140146
# However, an in-place copy into a tensor on the ``meta`` device is a no-op.
141147
# In order to avoid this, we can pass the ``assign=True`` keyword argument to
@@ -150,7 +156,10 @@ def my_processing_function(key, device):
150156
opt = torch.optim.SGD(new_m.parameters(), lr=1e-3)
151157

152158
###############################################################################
159+
# Conclusion
160+
# -------------
161+
#
153162
# To recap, in this tutorial we learned about ``torch.load(mmap=True)``, the
154-
# ``torch.device()`` context manager with ``device=meta`` and
163+
# ``torch.device()`` context manager with ``device=meta``, and
155164
# ``nn.Module.load_state_dict(assign=True)`` as well as how these tools could
156165
# be used to aid when loading a model from a checkpoint.

0 commit comments

Comments
 (0)