Skip to content

Commit 309c889

Browse files
mikaylagawareckiSvetlana Karslioglu
and
Svetlana Karslioglu
authored
Add Loading nn.Module from checkpoint tutorial (non ghstack) (#2579)
* Add Module load_state_dict tutorial --------- Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
1 parent f381abf commit 309c889

File tree

2 files changed

+173
-0
lines changed

2 files changed

+173
-0
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
"""
2+
3+
Tips for Loading an ``nn.Module`` from a Checkpoint
4+
===================================================
5+
6+
If you're loading a checkpoint and want to reduce compute and memory as much as possible,
7+
this tutorial shares some recommended practices. In particular, we will discuss
8+
9+
1. The ``mmap`` keyword argument on ``torch.load``
10+
2. The ``torch.device()`` context manager
11+
3. The ``assign`` keyword argument on ``nn.Module.load_state_dict()``
12+
13+
.. note::
14+
This recipe requires PyTorch 2.1.0 or later.
15+
"""
16+
17+
18+
###############################################################################
19+
# Let us consider a simple ``nn.Module`` that contains a list of Linear layers:
20+
import torch
21+
from torch import nn
22+
import time
23+
24+
class SomeModule(torch.nn.Module):
25+
def __init__(self, size):
26+
super().__init__()
27+
self.linears = nn.ModuleList([nn.Linear(size, size) for i in range(10)])
28+
29+
def forward(self, x):
30+
return self.linears(x)
31+
32+
33+
m = SomeModule(1000)
34+
torch.save(m.state_dict(), 'checkpoint.pth')
35+
36+
#################################################################################
37+
# The following snippet demonstrates the use of the the ``mmap`` keyword argument
38+
# to ``torch.load``, the ``torch.device()`` context manager and the ``assign``
39+
# keyword argument to ``nn.Module.load_state_dict()``.
40+
41+
state_dict = torch.load('checkpoint.pth', mmap=True)
42+
with torch.device('meta'):
43+
meta_m = SomeModule(1000)
44+
meta_m.load_state_dict(state_dict, assign=True)
45+
46+
#############################################################################
47+
# Compare the snippet below to the one above:
48+
49+
state_dict = torch.load('checkpoint.pth')
50+
m = SomeModule(1000)
51+
m.load_state_dict(state_dict)
52+
53+
#############################################################################
54+
# The second example does not use any of the features listed above and will be
55+
# less compute and memory efficient for loading a checkpoint. In the following
56+
# sections, we will discuss each of the features in further detail.
57+
58+
#####################################################################################
59+
# Using ``torch.load(mmap=True)``
60+
# -------------------------------
61+
# First, let us consider what happens when we load the checkpoint with ``torch.load``.
62+
# When we save a checkpoint with ``torch.save``, tensor storages are tagged with the device they are
63+
# saved on. With ``torch.load``, tensor storages will be loaded to the device
64+
# they were tagged with (unless this behavior is overridden using the
65+
# ``map_location`` flag). For ease of explanation, let us assume that the tensors
66+
# were saved on CPU. This means that on the first line all tensor storages will be
67+
# loaded into CPU RAM, which can be undesirable when:
68+
#
69+
# * CPU RAM is smaller than the size of the checkpoint.
70+
# * Waiting for the entire checkpoint to be loaded into RAM before performing, for example, some per-tensor processing.
71+
72+
start_time = time.time()
73+
state_dict = torch.load('checkpoint.pth')
74+
end_time = time.time()
75+
print(f"loading time without mmap={end_time - start_time}")
76+
77+
#################################################################################
78+
# The ``mmap`` keyword argument to ``torch.load`` attempts to solve the above two
79+
# problems. As its name implies, the ``mmap`` keyword argument to ``torch.load``
80+
# makes use of an `mmap call <https://man7.org/linux/man-pages/man2/mmap.2.html>`_
81+
# which maps a file on disk into virtual memory and lets the OS handle loading and
82+
# unloading into physical memory automatically. When this flag is passed, tensor
83+
# storages will be memory-mapped.
84+
85+
start_time = time.time()
86+
state_dict = torch.load('checkpoint.pth', mmap=True)
87+
end_time = time.time()
88+
print(f"loading time with mmap={end_time - start_time}")
89+
90+
######################################################################################
91+
# As mentioned above, one can use this argument to do per-tensor processing on a
92+
# checkpoint without loading all tensor storages into CPU memory upfront. For example:
93+
def my_special_routine(t, device):
94+
# this could be a much fancier operation
95+
return t.to(dtype=torch.bfloat16, device=device)
96+
97+
def my_processing_function(key, device):
98+
t = state_dict[key]
99+
processed_t = my_special_routine(t, device)
100+
del t
101+
state_dict[key] = processed_t
102+
103+
for key in state_dict.keys():
104+
device = torch.device('cuda')
105+
my_processing_function(key, device)
106+
107+
##################################################
108+
# Using ``torch.device('meta')``
109+
# ------------------------------
110+
# Next, let's consider the creation of the module.
111+
m = SomeModule(1000)
112+
113+
#######################################################################################################
114+
# This allocates memory for all parameters/buffers and initializes them per
115+
# the default initialization schemes defined in ``SomeModule.__init__()``, which
116+
# is wasteful when we want to load a checkpoint for the following reasons:
117+
#
118+
# * The result of the initialization kernels will be overwritten by ``load_state_dict()`` without ever being used, so
119+
# initialization is wasteful.
120+
# * We are allocating memory for these parameters/buffers in RAM while ``torch.load`` of the saved state dictionary also
121+
# allocates memory in RAM for the parameters/buffers in the checkpoint.
122+
#
123+
# In order to solve these two problems, we can use the ``torch.device()``
124+
# context manager with ``device='meta'`` when we instantiate the ``nn.Module()``.
125+
#
126+
# The `torch.device() <https://pytorch.org/docs/main/tensor_attributes.html#torch-device>`_
127+
# context manager makes sure that factory calls will be performed as if they
128+
# were passed the specified ``device`` as an argument. Tensors on ``torch.device('meta')`` do not
129+
# carry data. However, they possess all other metadata a tensor carries such as ``.size()``, ``.stride()``,
130+
# ``.requires_grad``, and others.
131+
with torch.device('meta'):
132+
new_m = SomeModule(1000)
133+
134+
########################################################
135+
# Using ``load_state_dict(assign=True)``
136+
# --------------------------------------
137+
# Next, we consider the loading of the state dictionary.
138+
139+
m.load_state_dict(state_dict)
140+
141+
######################################################################################
142+
# ``nn.Module.load_state_dict()`` is usually implemented via an in-place
143+
# ``param_in_model.copy_(param_in_state_dict)``. This means that the parameter/buffer
144+
# with the corresponding key in the state dictionary is copied into the
145+
# parameter/buffer in the ``nn.Module``.
146+
#
147+
# However, an in-place copy into a tensor on the ``meta`` device is a no-op.
148+
# In order to avoid this, we can pass the ``assign=True`` keyword argument to
149+
# ``load_state_dict()``.
150+
#
151+
# A caveat here is that since optimizers hold a reference to
152+
# ``nn.Module.parameters()``, the optimizer must be initialized after the module
153+
# is loaded from state dict if ``assign=True`` is passed.
154+
155+
new_m.load_state_dict(state_dict, assign=True)
156+
# This MUST be done AFTER the load_state_dict with assign.
157+
opt = torch.optim.SGD(new_m.parameters(), lr=1e-3)
158+
159+
###############################################################################
160+
# Conclusion
161+
# -------------
162+
#
163+
# To recap, in this tutorial we learned about ``torch.load(mmap=True)``, the
164+
# ``torch.device()`` context manager with ``device=meta``, and
165+
# ``nn.Module.load_state_dict(assign=True)`` as well as how these tools could
166+
# be used to aid when loading a model from a checkpoint.

recipes_source/recipes_index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
137137
:link: ../recipes/recipes/reasoning_about_shapes.html
138138
:tags: Basics
139139

140+
.. customcarditem::
141+
:header: Tips for Loading an nn.Module from a Checkpoint
142+
:card_description: Learn tips for loading an nn.Module from a checkpoint.
143+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
144+
:link: ../recipes/recipes/module_load_state_dict_tips.html
145+
:tags: Basics
146+
140147
.. Interpretability
141148
142149
.. customcarditem::

0 commit comments

Comments
 (0)