|
| 1 | +""" |
| 2 | +
|
| 3 | +Tips for Loading an ``nn.Module`` from a Checkpoint |
| 4 | +=================================================== |
| 5 | +
|
| 6 | +If you're loading a checkpoint and want to reduce compute and memory as much as possible, |
| 7 | +this tutorial shares some recommended practices. In particular, we will discuss |
| 8 | +
|
| 9 | +1. The ``mmap`` keyword argument on ``torch.load`` |
| 10 | +2. The ``torch.device()`` context manager |
| 11 | +3. The ``assign`` keyword argument on ``nn.Module.load_state_dict()`` |
| 12 | +
|
| 13 | +.. note:: |
| 14 | + This recipe requires PyTorch 2.1.0 or later. |
| 15 | +""" |
| 16 | + |
| 17 | + |
| 18 | +############################################################################### |
| 19 | +# Let us consider a simple ``nn.Module`` that contains a list of Linear layers: |
| 20 | +import torch |
| 21 | +from torch import nn |
| 22 | +import time |
| 23 | + |
| 24 | +class SomeModule(torch.nn.Module): |
| 25 | + def __init__(self, size): |
| 26 | + super().__init__() |
| 27 | + self.linears = nn.ModuleList([nn.Linear(size, size) for i in range(10)]) |
| 28 | + |
| 29 | + def forward(self, x): |
| 30 | + return self.linears(x) |
| 31 | + |
| 32 | + |
| 33 | +m = SomeModule(1000) |
| 34 | +torch.save(m.state_dict(), 'checkpoint.pth') |
| 35 | + |
| 36 | +################################################################################# |
| 37 | +# The following snippet demonstrates the use of the the ``mmap`` keyword argument |
| 38 | +# to ``torch.load``, the ``torch.device()`` context manager and the ``assign`` |
| 39 | +# keyword argument to ``nn.Module.load_state_dict()``. |
| 40 | + |
| 41 | +state_dict = torch.load('checkpoint.pth', mmap=True) |
| 42 | +with torch.device('meta'): |
| 43 | + meta_m = SomeModule(1000) |
| 44 | +meta_m.load_state_dict(state_dict, assign=True) |
| 45 | + |
| 46 | +############################################################################# |
| 47 | +# Compare the snippet below to the one above: |
| 48 | + |
| 49 | +state_dict = torch.load('checkpoint.pth') |
| 50 | +m = SomeModule(1000) |
| 51 | +m.load_state_dict(state_dict) |
| 52 | + |
| 53 | +############################################################################# |
| 54 | +# The second example does not use any of the features listed above and will be |
| 55 | +# less compute and memory efficient for loading a checkpoint. In the following |
| 56 | +# sections, we will discuss each of the features in further detail. |
| 57 | + |
| 58 | +##################################################################################### |
| 59 | +# Using ``torch.load(mmap=True)`` |
| 60 | +# ------------------------------- |
| 61 | +# First, let us consider what happens when we load the checkpoint with ``torch.load``. |
| 62 | +# When we save a checkpoint with ``torch.save``, tensor storages are tagged with the device they are |
| 63 | +# saved on. With ``torch.load``, tensor storages will be loaded to the device |
| 64 | +# they were tagged with (unless this behavior is overridden using the |
| 65 | +# ``map_location`` flag). For ease of explanation, let us assume that the tensors |
| 66 | +# were saved on CPU. This means that on the first line all tensor storages will be |
| 67 | +# loaded into CPU RAM, which can be undesirable when: |
| 68 | +# |
| 69 | +# * CPU RAM is smaller than the size of the checkpoint. |
| 70 | +# * Waiting for the entire checkpoint to be loaded into RAM before performing, for example, some per-tensor processing. |
| 71 | + |
| 72 | +start_time = time.time() |
| 73 | +state_dict = torch.load('checkpoint.pth') |
| 74 | +end_time = time.time() |
| 75 | +print(f"loading time without mmap={end_time - start_time}") |
| 76 | + |
| 77 | +################################################################################# |
| 78 | +# The ``mmap`` keyword argument to ``torch.load`` attempts to solve the above two |
| 79 | +# problems. As its name implies, the ``mmap`` keyword argument to ``torch.load`` |
| 80 | +# makes use of an `mmap call <https://man7.org/linux/man-pages/man2/mmap.2.html>`_ |
| 81 | +# which maps a file on disk into virtual memory and lets the OS handle loading and |
| 82 | +# unloading into physical memory automatically. When this flag is passed, tensor |
| 83 | +# storages will be memory-mapped. |
| 84 | + |
| 85 | +start_time = time.time() |
| 86 | +state_dict = torch.load('checkpoint.pth', mmap=True) |
| 87 | +end_time = time.time() |
| 88 | +print(f"loading time with mmap={end_time - start_time}") |
| 89 | + |
| 90 | +###################################################################################### |
| 91 | +# As mentioned above, one can use this argument to do per-tensor processing on a |
| 92 | +# checkpoint without loading all tensor storages into CPU memory upfront. For example: |
| 93 | +def my_special_routine(t, device): |
| 94 | + # this could be a much fancier operation |
| 95 | + return t.to(dtype=torch.bfloat16, device=device) |
| 96 | + |
| 97 | +def my_processing_function(key, device): |
| 98 | + t = state_dict[key] |
| 99 | + processed_t = my_special_routine(t, device) |
| 100 | + del t |
| 101 | + state_dict[key] = processed_t |
| 102 | + |
| 103 | +for key in state_dict.keys(): |
| 104 | + device = torch.device('cuda') |
| 105 | + my_processing_function(key, device) |
| 106 | + |
| 107 | +################################################## |
| 108 | +# Using ``torch.device('meta')`` |
| 109 | +# ------------------------------ |
| 110 | +# Next, let's consider the creation of the module. |
| 111 | +m = SomeModule(1000) |
| 112 | + |
| 113 | +####################################################################################################### |
| 114 | +# This allocates memory for all parameters/buffers and initializes them per |
| 115 | +# the default initialization schemes defined in ``SomeModule.__init__()``, which |
| 116 | +# is wasteful when we want to load a checkpoint for the following reasons: |
| 117 | +# |
| 118 | +# * The result of the initialization kernels will be overwritten by ``load_state_dict()`` without ever being used, so |
| 119 | +# initialization is wasteful. |
| 120 | +# * We are allocating memory for these parameters/buffers in RAM while ``torch.load`` of the saved state dictionary also |
| 121 | +# allocates memory in RAM for the parameters/buffers in the checkpoint. |
| 122 | +# |
| 123 | +# In order to solve these two problems, we can use the ``torch.device()`` |
| 124 | +# context manager with ``device='meta'`` when we instantiate the ``nn.Module()``. |
| 125 | +# |
| 126 | +# The `torch.device() <https://pytorch.org/docs/main/tensor_attributes.html#torch-device>`_ |
| 127 | +# context manager makes sure that factory calls will be performed as if they |
| 128 | +# were passed the specified ``device`` as an argument. Tensors on ``torch.device('meta')`` do not |
| 129 | +# carry data. However, they possess all other metadata a tensor carries such as ``.size()``, ``.stride()``, |
| 130 | +# ``.requires_grad``, and others. |
| 131 | +with torch.device('meta'): |
| 132 | + new_m = SomeModule(1000) |
| 133 | + |
| 134 | +######################################################## |
| 135 | +# Using ``load_state_dict(assign=True)`` |
| 136 | +# -------------------------------------- |
| 137 | +# Next, we consider the loading of the state dictionary. |
| 138 | + |
| 139 | +m.load_state_dict(state_dict) |
| 140 | + |
| 141 | +###################################################################################### |
| 142 | +# ``nn.Module.load_state_dict()`` is usually implemented via an in-place |
| 143 | +# ``param_in_model.copy_(param_in_state_dict)``. This means that the parameter/buffer |
| 144 | +# with the corresponding key in the state dictionary is copied into the |
| 145 | +# parameter/buffer in the ``nn.Module``. |
| 146 | +# |
| 147 | +# However, an in-place copy into a tensor on the ``meta`` device is a no-op. |
| 148 | +# In order to avoid this, we can pass the ``assign=True`` keyword argument to |
| 149 | +# ``load_state_dict()``. |
| 150 | +# |
| 151 | +# A caveat here is that since optimizers hold a reference to |
| 152 | +# ``nn.Module.parameters()``, the optimizer must be initialized after the module |
| 153 | +# is loaded from state dict if ``assign=True`` is passed. |
| 154 | + |
| 155 | +new_m.load_state_dict(state_dict, assign=True) |
| 156 | +# This MUST be done AFTER the load_state_dict with assign. |
| 157 | +opt = torch.optim.SGD(new_m.parameters(), lr=1e-3) |
| 158 | + |
| 159 | +############################################################################### |
| 160 | +# Conclusion |
| 161 | +# ------------- |
| 162 | +# |
| 163 | +# To recap, in this tutorial we learned about ``torch.load(mmap=True)``, the |
| 164 | +# ``torch.device()`` context manager with ``device=meta``, and |
| 165 | +# ``nn.Module.load_state_dict(assign=True)`` as well as how these tools could |
| 166 | +# be used to aid when loading a model from a checkpoint. |
0 commit comments