@@ -50,14 +50,15 @@ def forward(self, x):
50
50
m = SomeModule (1000 )
51
51
m .load_state_dict (state_dict )
52
52
53
+ #############################################################################
53
54
# The second example does not use any of the features listed above and will be
54
55
# less compute and memory efficient for loading a checkpoint. In the following
55
56
# sections, we will discuss each of the features in further detail.
56
57
57
58
#####################################################################################
58
59
# Using ``torch.load(mmap=True)``
59
60
# -------------------------------
60
- # First, let us consider what happens when we load the checkpoint with``torch.load``.
61
+ # First, let us consider what happens when we load the checkpoint with ``torch.load``.
61
62
# When we save a checkpoint with ``torch.save``, tensor storages are tagged with the device they are
62
63
# saved on. With ``torch.load``, tensor storages will be loaded to the device
63
64
# they were tagged with (unless this behavior is overridden using the
@@ -66,8 +67,7 @@ def forward(self, x):
66
67
# loaded into CPU RAM, which can be undesirable when:
67
68
#
68
69
# * CPU RAM is smaller than the size of the checkpoint.
69
- # * Waiting for the entire checkpoint to be loaded into RAM before
70
- # performing, for example, some per-tensor processing.
70
+ # * Waiting for the entire checkpoint to be loaded into RAM before performing, for example, some per-tensor processing.
71
71
72
72
start_time = time .time ()
73
73
state_dict = torch .load ('checkpoint.pth' )
@@ -114,10 +114,11 @@ def my_processing_function(key, device):
114
114
# This allocates memory for all parameters/buffers and initializes them per
115
115
# the default initialization schemes defined in ``SomeModule.__init__()``, which
116
116
# is wasteful when we want to load a checkpoint for the following reasons:
117
- # * The result of the initialization kernels will be overwritten by ``load_state_dict()``
118
- # without ever being used, so initialization is wasteful.
119
- # * We are allocating memory for these parameters/buffers in RAM while ``torch.load`` of
120
- # the saved state dictionary also allocates memory in RAM for the parameters/buffers in the checkpoint.
117
+ #
118
+ # * The result of the initialization kernels will be overwritten by ``load_state_dict()`` without ever being used, so
119
+ # initialization is wasteful.
120
+ # * We are allocating memory for these parameters/buffers in RAM while ``torch.load`` of the saved state dictionary also
121
+ # allocates memory in RAM for the parameters/buffers in the checkpoint.
121
122
#
122
123
# In order to solve these two problems, we can use the ``torch.device()``
123
124
# context manager with ``device='meta'`` when we instantiate the ``nn.Module()``.
0 commit comments