|
41 | 41 | serialization_config.save.storage_alignment = 4096
|
42 | 42 |
|
43 | 43 | ################################################################################
|
| 44 | +# The steps involved in the process are as follows: |
| 45 | +# * Write the checkpoint file without any actual data. This reserves the space on disk. |
| 46 | +# * Read the offsets for the storage associated with each tensor in the checkpoint using ``FakeTensor``. |
| 47 | +# * Use ``GDSFile`` to write the appropriate data at these offsets. |
| 48 | +# |
44 | 49 | # Given a state dictionary of tensors that are on the GPU, one can use the ``torch.serialization.skip_data`` context
|
45 | 50 | # manager to save a checkpoint that contains all relevant metadata except the storage bytes. For each ``torch.Storage``
|
46 | 51 | # in the state dictionary, space will be reserved within the checkpoint for the storage bytes.
|
|
59 | 64 | # information about the tensor but does not have any storage bytes. The following snippet will not materialize
|
60 | 65 | # any data but which will tag each ``FakeTensor`` with the offset within the checkpoint that
|
61 | 66 | # corresponds to the tensor.
|
| 67 | +# |
| 68 | +# If you are continuously saving the same state dictionary during training, you |
| 69 | +# would only need to obtain the offsets once and the same offsets can be re-used. Similarly if tensor is going to |
| 70 | +# be loaded to repeatedly one can use the ``torch.cuda.gds.gds_register_buffer`` which wraps |
| 71 | +# ``cuFileBufRegister`` to register the storages as gds buffers. |
| 72 | + |
62 | 73 |
|
63 | 74 | import os
|
64 | 75 | from torch._subclasses.fake_tensor import FakeTensorMode
|
|
73 | 84 |
|
74 | 85 | for k, v in sd.items():
|
75 | 86 | offset = fake_sd[k].untyped_storage()._checkpoint_offset
|
| 87 | + # save_storage is a wrapper around `cuFileWrite` |
76 | 88 | f.save_storage(v.untyped_storage(), offset)
|
77 | 89 |
|
| 90 | + |
78 | 91 | ################################################################################
|
79 | 92 | # We verify correctness of the saved checkpoint by ``torch.load`` and comparing.
|
80 | 93 |
|
|
85 | 98 | ################################################################################
|
86 | 99 | # The loading flow is the inverse, we can ``torch.load`` under the ``torch.serialization.skip_data`` context
|
87 | 100 | # manager to load everything except the storage bytes. This means that any tensors in the checkpoint will be
|
88 |
| -# created but their storages will be empty (i.e. the tensors will be created via ``torch.empty``). If the |
89 |
| -# tensors to be loaded to are persistent, one can use the ``torch.cuda.gds.gds_register_buffer`` API to register |
90 |
| -# the storages as gds buffers. |
| 101 | +# created but their storages will be empty (i.e. the tensors will be created via ``torch.empty``). |
91 | 102 |
|
92 | 103 | with torch.serialization.skip_data():
|
93 | 104 | sd_loaded = torch.load("checkpoint.pt")
|
|
99 | 110 | for k, v in sd_loaded.items():
|
100 | 111 | assert not torch.equal(v, sd[k])
|
101 | 112 | offset = fake_sd[k].untyped_storage()._checkpoint_offset
|
| 113 | + # load_storage is a wrapper around `cuFileRead` |
102 | 114 | f.load_storage(v.untyped_storage(), offset)
|
| 115 | + |
| 116 | +for k, v in sd_loaded.items(): |
103 | 117 | assert torch.equal(v, sd[k])
|
104 | 118 |
|
105 | 119 | del f
|
106 | 120 |
|
107 |
| - |
108 |
| -################################################################################ |
109 |
| -# Buffer Registration |
110 |
| -# =================== |
111 |
| -# We also provide ``torch.cuda.gds.gds_register_buffer`` to register the |
112 |
| -# tensor storages as GPUDirect Storage buffers. See `here |
113 |
| -# <https://docs.nvidia.com/gpudirect-storage/best-practices-guide/index.html#cufile-bufregister-fileread-filewrite>`_ |
114 |
| -# for when one should do this. |
115 |
| - |
116 |
| -for v in sd.values(): |
117 |
| - torch.cuda.gds.gds_register_buffer(v.untyped_storage()) |
118 |
| - |
119 | 121 | # Summary
|
120 | 122 | # =======
|
121 | 123 | #
|
|
0 commit comments