Skip to content

Commit ba7f4e8

Browse files
address comments
1 parent 255da51 commit ba7f4e8

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

prototype_source/gpu_direct_storage.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@
4141
serialization_config.save.storage_alignment = 4096
4242

4343
################################################################################
44+
# The steps involved in the process are as follows:
45+
# * Write the checkpoint file without any actual data. This reserves the space on disk.
46+
# * Read the offsets for the storage associated with each tensor in the checkpoint using ``FakeTensor``.
47+
# * Use ``GDSFile`` to write the appropriate data at these offsets.
48+
#
4449
# Given a state dictionary of tensors that are on the GPU, one can use the ``torch.serialization.skip_data`` context
4550
# manager to save a checkpoint that contains all relevant metadata except the storage bytes. For each ``torch.Storage``
4651
# in the state dictionary, space will be reserved within the checkpoint for the storage bytes.
@@ -59,6 +64,12 @@
5964
# information about the tensor but does not have any storage bytes. The following snippet will not materialize
6065
# any data but which will tag each ``FakeTensor`` with the offset within the checkpoint that
6166
# corresponds to the tensor.
67+
#
68+
# If you are continuously saving the same state dictionary during training, you
69+
# would only need to obtain the offsets once and the same offsets can be re-used. Similarly if tensor is going to
70+
# be loaded to repeatedly one can use the ``torch.cuda.gds.gds_register_buffer`` which wraps
71+
# ``cuFileBufRegister`` to register the storages as gds buffers.
72+
6273

6374
import os
6475
from torch._subclasses.fake_tensor import FakeTensorMode
@@ -73,8 +84,10 @@
7384

7485
for k, v in sd.items():
7586
offset = fake_sd[k].untyped_storage()._checkpoint_offset
87+
# save_storage is a wrapper around `cuFileWrite`
7688
f.save_storage(v.untyped_storage(), offset)
7789

90+
7891
################################################################################
7992
# We verify correctness of the saved checkpoint by ``torch.load`` and comparing.
8093

@@ -85,9 +98,7 @@
8598
################################################################################
8699
# The loading flow is the inverse, we can ``torch.load`` under the ``torch.serialization.skip_data`` context
87100
# manager to load everything except the storage bytes. This means that any tensors in the checkpoint will be
88-
# created but their storages will be empty (i.e. the tensors will be created via ``torch.empty``). If the
89-
# tensors to be loaded to are persistent, one can use the ``torch.cuda.gds.gds_register_buffer`` API to register
90-
# the storages as gds buffers.
101+
# created but their storages will be empty (i.e. the tensors will be created via ``torch.empty``).
91102

92103
with torch.serialization.skip_data():
93104
sd_loaded = torch.load("checkpoint.pt")
@@ -99,23 +110,14 @@
99110
for k, v in sd_loaded.items():
100111
assert not torch.equal(v, sd[k])
101112
offset = fake_sd[k].untyped_storage()._checkpoint_offset
113+
# load_storage is a wrapper around `cuFileRead`
102114
f.load_storage(v.untyped_storage(), offset)
115+
116+
for k, v in sd_loaded.items():
103117
assert torch.equal(v, sd[k])
104118

105119
del f
106120

107-
108-
################################################################################
109-
# Buffer Registration
110-
# ===================
111-
# We also provide ``torch.cuda.gds.gds_register_buffer`` to register the
112-
# tensor storages as GPUDirect Storage buffers. See `here
113-
# <https://docs.nvidia.com/gpudirect-storage/best-practices-guide/index.html#cufile-bufregister-fileread-filewrite>`_
114-
# for when one should do this.
115-
116-
for v in sd.values():
117-
torch.cuda.gds.gds_register_buffer(v.untyped_storage())
118-
119121
# Summary
120122
# =======
121123
#

0 commit comments

Comments
 (0)