Skip to content

Commit 255da51

Browse files
GPUDirect Storage prototype tutorial
1 parent 63295e8 commit 255da51

File tree

3 files changed

+133
-0
lines changed

3 files changed

+133
-0
lines changed

.jenkins/validate_tutorials_built.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"prototype_source/vmap_recipe",
3232
"prototype_source/torchscript_freezing",
3333
"prototype_source/nestedtensor",
34+
"prototype_source/gpu_direct_storage", # requires specific filesystem + GPUDirect Storage to be set up
3435
"recipes_source/recipes/saving_and_loading_models_for_inference",
3536
"recipes_source/recipes/saving_multiple_models_in_one_file",
3637
"recipes_source/recipes/tensorboard_with_pytorch",
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""
2+
(prototype) Using GPUDirect Storage
3+
====================================
4+
5+
GPUDirect Storage enabes a direct data path for direct memeory access transfers
6+
between GPU memory and storage, avoiding a bounce buffer through the CPU.
7+
8+
In version ``2.7``, we introduced some prototype APIs to ``torch.cuda.gds`` that serve as thin wrappers around
9+
the `cuFile APIs <https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api>`_
10+
that can be used with ``torch.Tensor``.
11+
12+
In this tutorial, we will demonstrate how to use the ``torch.cuda.gds`` APIs in conjunction with
13+
checkpoints generated by ``torch.save`` and ``torch.load`` on local filesystem.
14+
15+
.. grid:: 2
16+
17+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
18+
:class-card: card-prerequisites
19+
20+
* Understand how to use the ``torch.cuda.gds`` APIs in conjunction with
21+
checkpoints generated by ``torch.save`` and ``torch.load`` on local filesystem
22+
23+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
24+
:class-card: card-prerequisites
25+
26+
* PyTorch v.2.7.0 or later
27+
* GPUDirect Storage must be installed per
28+
`the documentation <https://docs.nvidia.com/gpudirect-storage/troubleshooting-guide/contents.html>`_
29+
* Ensure that the filesystem that you are saving/loading to supports GPUDirect Storage.
30+
"""
31+
32+
################################################################################
33+
# Using GPUDirect Storage with ``torch.save`` and ``torch.load``
34+
# =============================================================
35+
# GPUDirect Storage requires a storage alignment of 4KB. One can toggle this using
36+
# ``torch.utils.serialization.config.save.storage_alignment`` to toggle this
37+
38+
import torch
39+
from torch.utils.serialization import config as serialization_config
40+
41+
serialization_config.save.storage_alignment = 4096
42+
43+
################################################################################
44+
# Given a state dictionary of tensors that are on the GPU, one can use the ``torch.serialization.skip_data`` context
45+
# manager to save a checkpoint that contains all relevant metadata except the storage bytes. For each ``torch.Storage``
46+
# in the state dictionary, space will be reserved within the checkpoint for the storage bytes.
47+
48+
import torch.nn as nn
49+
50+
m = nn.Linear(5, 10, device='cuda')
51+
sd = m.state_dict()
52+
53+
with torch.serialization.skip_data():
54+
torch.save(sd, "checkpoint.pt")
55+
56+
################################################################################
57+
# We can get the offsets that each storage should be written to within the checkpoint by loading under
58+
# a ``FakeTensorMode``. A FakeTensor is a tensor that has metadata (e.g. sizes, strides, dtype, device)
59+
# information about the tensor but does not have any storage bytes. The following snippet will not materialize
60+
# any data but which will tag each ``FakeTensor`` with the offset within the checkpoint that
61+
# corresponds to the tensor.
62+
63+
import os
64+
from torch._subclasses.fake_tensor import FakeTensorMode
65+
66+
with FakeTensorMode() as mode:
67+
fake_sd = torch.load("checkpoint.pt")
68+
69+
for k, v in fake_sd.items():
70+
print(f"key={k}, offset={v.untyped_storage()._checkpoint_offset}")
71+
72+
f = torch.cuda.gds.GdsFile("checkpoint.pt", os.O_RDWR)
73+
74+
for k, v in sd.items():
75+
offset = fake_sd[k].untyped_storage()._checkpoint_offset
76+
f.save_storage(v.untyped_storage(), offset)
77+
78+
################################################################################
79+
# We verify correctness of the saved checkpoint by ``torch.load`` and comparing.
80+
81+
sd_loaded = torch.load("checkpoint.pt")
82+
for k, v in sd_loaded.items():
83+
assert torch.equal(v, sd[k])
84+
85+
################################################################################
86+
# The loading flow is the inverse, we can ``torch.load`` under the ``torch.serialization.skip_data`` context
87+
# manager to load everything except the storage bytes. This means that any tensors in the checkpoint will be
88+
# created but their storages will be empty (i.e. the tensors will be created via ``torch.empty``). If the
89+
# tensors to be loaded to are persistent, one can use the ``torch.cuda.gds.gds_register_buffer`` API to register
90+
# the storages as gds buffers.
91+
92+
with torch.serialization.skip_data():
93+
sd_loaded = torch.load("checkpoint.pt")
94+
95+
################################################################################
96+
# We once again use the ``FakeTensorMode`` to get the checkpoint offsets and
97+
# ascertain that the loaded checkpoint is the same as the saved checkpoint.
98+
99+
for k, v in sd_loaded.items():
100+
assert not torch.equal(v, sd[k])
101+
offset = fake_sd[k].untyped_storage()._checkpoint_offset
102+
f.load_storage(v.untyped_storage(), offset)
103+
assert torch.equal(v, sd[k])
104+
105+
del f
106+
107+
108+
################################################################################
109+
# Buffer Registration
110+
# ===================
111+
# We also provide ``torch.cuda.gds.gds_register_buffer`` to register the
112+
# tensor storages as GPUDirect Storage buffers. See `here
113+
# <https://docs.nvidia.com/gpudirect-storage/best-practices-guide/index.html#cufile-bufregister-fileread-filewrite>`_
114+
# for when one should do this.
115+
116+
for v in sd.values():
117+
torch.cuda.gds.gds_register_buffer(v.untyped_storage())
118+
119+
# Summary
120+
# =======
121+
#
122+
# In this tutorial we have demonstrated how to use the prototype ``torch.cuda.gds`` APIs
123+
# in conjunction with ``torch.save`` and ``torch.load`` on local filesystem. Do
124+
# file in issue in the PyTorch GitHub repo if you have any feedback.

prototype_source/prototype_index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,14 @@ Prototype features are not available as part of binary distributions like PyPI o
247247
:link: ../prototype/python_extension_autoload.html
248248
:tags: Extending-PyTorch, Frontend-APIs
249249

250+
.. GPUDirect Storage
251+
.. customcarditem::
252+
:header: (prototype) Using GPUDirect Storage
253+
:card_description: Learn how to use GPUDirect Storage in PyTorch.
254+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
255+
:link: ../prototype/gpudirect_storage.html
256+
:tags: GPUDirect-Storage
257+
250258
.. End of tutorial card section
251259
252260
.. raw:: html

0 commit comments

Comments
 (0)