Skip to content

Commit 3d6454d

Browse files
Bullet point rendering fixes
1 parent 65946fe commit 3d6454d

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

recipes_source/recipes/module_load_state_dict_tips.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,15 @@ def forward(self, x):
5050
m = SomeModule(1000)
5151
m.load_state_dict(state_dict)
5252

53+
#############################################################################
5354
# The second example does not use any of the features listed above and will be
5455
# less compute and memory efficient for loading a checkpoint. In the following
5556
# sections, we will discuss each of the features in further detail.
5657

5758
#####################################################################################
5859
# Using ``torch.load(mmap=True)``
5960
# -------------------------------
60-
# First, let us consider what happens when we load the checkpoint with``torch.load``.
61+
# First, let us consider what happens when we load the checkpoint with ``torch.load``.
6162
# When we save a checkpoint with ``torch.save``, tensor storages are tagged with the device they are
6263
# saved on. With ``torch.load``, tensor storages will be loaded to the device
6364
# they were tagged with (unless this behavior is overridden using the
@@ -66,8 +67,7 @@ def forward(self, x):
6667
# loaded into CPU RAM, which can be undesirable when:
6768
#
6869
# * CPU RAM is smaller than the size of the checkpoint.
69-
# * Waiting for the entire checkpoint to be loaded into RAM before
70-
# performing, for example, some per-tensor processing.
70+
# * Waiting for the entire checkpoint to be loaded into RAM before performing, for example, some per-tensor processing.
7171

7272
start_time = time.time()
7373
state_dict = torch.load('checkpoint.pth')
@@ -114,10 +114,11 @@ def my_processing_function(key, device):
114114
# This allocates memory for all parameters/buffers and initializes them per
115115
# the default initialization schemes defined in ``SomeModule.__init__()``, which
116116
# is wasteful when we want to load a checkpoint for the following reasons:
117-
# * The result of the initialization kernels will be overwritten by ``load_state_dict()``
118-
# without ever being used, so initialization is wasteful.
119-
# * We are allocating memory for these parameters/buffers in RAM while ``torch.load`` of
120-
# the saved state dictionary also allocates memory in RAM for the parameters/buffers in the checkpoint.
117+
#
118+
# * The result of the initialization kernels will be overwritten by ``load_state_dict()`` without ever being used, so
119+
# initialization is wasteful.
120+
# * We are allocating memory for these parameters/buffers in RAM while ``torch.load`` of the saved state dictionary also
121+
# allocates memory in RAM for the parameters/buffers in the checkpoint.
121122
#
122123
# In order to solve these two problems, we can use the ``torch.device()``
123124
# context manager with ``device='meta'`` when we instantiate the ``nn.Module()``.

0 commit comments

Comments
 (0)