Skip to content

Commit c4c45c6

Browse files
address comments
1 parent a5f98f1 commit c4c45c6

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

prototype_source/gpu_direct_storage.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""
2-
(prototype) Using GPUDirect Storage
3-
====================================
2+
(prototype) Accelerating ``torch.save`` and ``torch.load`` with GPUDirect Storage
3+
=================================================================================
44
5-
GPUDirect Storage enabes a direct data path for direct memeory access transfers
5+
GPUDirect Storage enables a direct data path for direct memory access transfers
66
between GPU memory and storage, avoiding a bounce buffer through the CPU.
77
8-
In version ``2.7``, we introduced some prototype APIs to ``torch.cuda.gds`` that serve as thin wrappers around
8+
In version **2.7**, we introduced new prototype APIs to ``torch.cuda.gds`` that serve as thin wrappers around
99
the `cuFile APIs <https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api>`_
10-
that can be used with ``torch.Tensor``.
10+
that can be used with ``torch.Tensor`` to achieve improved I/O performance.
1111
1212
In this tutorial, we will demonstrate how to use the ``torch.cuda.gds`` APIs in conjunction with
1313
checkpoints generated by ``torch.save`` and ``torch.load`` on local filesystem.
@@ -32,8 +32,8 @@
3232
################################################################################
3333
# Using GPUDirect Storage with ``torch.save`` and ``torch.load``
3434
# =============================================================
35-
# GPUDirect Storage requires a storage alignment of 4KB. One can toggle this using
36-
# ``torch.utils.serialization.config.save.storage_alignment`` to toggle this
35+
# GPUDirect Storage requires a storage alignment of 4KB. You can toggle this by using
36+
# ``torch.utils.serialization.config.save.storage_alignment``:
3737

3838
import torch
3939
from torch.utils.serialization import config as serialization_config
@@ -60,15 +60,18 @@
6060

6161
################################################################################
6262
# We can get the offsets that each storage should be written to within the checkpoint by loading under
63-
# a ``FakeTensorMode``. A FakeTensor is a tensor that has metadata (e.g. sizes, strides, dtype, device)
63+
# a ``FakeTensorMode``. A FakeTensor is a tensor that has metadata (such as sizes, strides, dtype, device)
6464
# information about the tensor but does not have any storage bytes. The following snippet will not materialize
65-
# any data but which will tag each ``FakeTensor`` with the offset within the checkpoint that
65+
# any data but will tag each ``FakeTensor`` with the offset within the checkpoint that
6666
# corresponds to the tensor.
6767
#
6868
# If you are continuously saving the same state dictionary during training, you
6969
# would only need to obtain the offsets once and the same offsets can be re-used. Similarly if tensor is going to
70-
# be saved or loaded to repeatedly one can use the ``torch.cuda.gds.gds_register_buffer`` which wraps
71-
# ``cuFileBufRegister`` to register the storages as gds buffers.
70+
# be saved or loaded to repeatedly you can use the ``torch.cuda.gds.gds_register_buffer`` which wraps
71+
# ``cuFileBufRegister`` to register the storages as GDS buffers.
72+
#
73+
# Note that ``torch.cuda.gds.GdsFile.save_storage`` binds to the synchronous ``cuFileWrite`` API,
74+
# so no synchronization is needed afterwards.
7275

7376

7477
import os
@@ -96,16 +99,19 @@
9699
assert torch.equal(v, sd[k])
97100

98101
################################################################################
99-
# The loading flow is the inverse, we can ``torch.load`` under the ``torch.serialization.skip_data`` context
102+
# The loading flow is the inverse: you can use ``torch.load`` with the ``torch.serialization.skip_data`` context
100103
# manager to load everything except the storage bytes. This means that any tensors in the checkpoint will be
101-
# created but their storages will be empty (i.e. the tensors will be created via ``torch.empty``).
104+
# created but their storages will be empty (as if the tensors were created via ``torch.empty``).
102105

103106
with torch.serialization.skip_data():
104107
sd_loaded = torch.load("checkpoint.pt")
105108

106109
################################################################################
107110
# We once again use the ``FakeTensorMode`` to get the checkpoint offsets and
108111
# ascertain that the loaded checkpoint is the same as the saved checkpoint.
112+
#
113+
# Similar to ``torch.cuda.gds.GdsFile.save_storage``, ``torch.cuda.gds.GdsFile.load_storage``
114+
# binds to the synchronous ``cuFileRead`` API, so no synchronization is needed afterwards.
109115

110116
for k, v in sd_loaded.items():
111117
assert not torch.equal(v, sd[k])
@@ -118,9 +124,9 @@
118124

119125
del f
120126

121-
# Summary
122-
# =======
127+
# Conclusion
128+
# ==========
123129
#
124130
# In this tutorial we have demonstrated how to use the prototype ``torch.cuda.gds`` APIs
125-
# in conjunction with ``torch.save`` and ``torch.load`` on local filesystem. Do
126-
# file in issue in the PyTorch GitHub repo if you have any feedback.
131+
# in conjunction with ``torch.save`` and ``torch.load`` on local filesystem. Please
132+
# file an issue in the PyTorch GitHub repo if you have any feedback.

0 commit comments

Comments
 (0)