Skip to content

Commit 248324a

Browse files
author
Svetlana Karslioglu
committed
Merge branch 'master' into distr-tutorial-landing
2 parents 2200ff1 + 6b3af99 commit 248324a

File tree

11 files changed

+727
-21
lines changed

11 files changed

+727
-21
lines changed

_static/css/custom.css

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
*/
33

44
:root {
5+
--sd-color-info: #ee4c2c;
56
--sd-color-primary: #6c6c6d;
67
--sd-color-primary-highlight: #f3f4f7;
78
--sd-color-card-border-hover: #ee4c2c;
89
--sd-color-card-border: #f3f4f7;
9-
--sd-color-card-background: #f3f4f7;
10+
--sd-color-card-background: #fff;
1011
--sd-color-card-text: inherit;
1112
--sd-color-card-header: transparent;
1213
--sd-color-card-footer: transparent;
@@ -20,13 +21,19 @@
2021
--sd-color-tabs-underline: rgb(222, 222, 222);
2122
}
2223

24+
.sd-text-info {
25+
color: #ee4c2c;
26+
}
27+
28+
2329
.sd-card {
2430
position: relative;
25-
background-color: #f3f4f7;
26-
opacity: 0.5;
31+
background-color: #fff;
32+
opacity: 1.0;
2733
border-radius: 0px;
2834
width: 30%;
29-
border: none
35+
border: none;
36+
padding-bottom: 0px;
3037
}
3138

3239

advanced_source/ddp_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(self, d_model, dropout=0.1, max_len=5000):
5252
pe[:, 0::2] = torch.sin(position * div_term)
5353
pe[:, 1::2] = torch.cos(position * div_term)
5454
pe = pe.unsqueeze(0).transpose(0, 1)
55-
self.register_buffer('pe', pe)
55+
self.pe = nn.Parameter(pe, requires_grad=False)
5656

5757
def forward(self, x):
5858
x = x + self.pe[:x.size(0), :]
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
`Introduction <ddp_series_intro.html>`__ \|\| `What is DDP <ddp_series_theory.html>`__ \|\| `Single-Node
2+
Multi-GPU Training <ddp_series_multigpu.html>`__ \|\| **Fault
3+
Tolerance** \|\| `Multi-Node
4+
training <../intermediate/ddp_series_multinode.html>`__ \|\| `minGPT Training <../intermediate/ddp_series_minGPT.html>`__
5+
6+
7+
Fault-tolerant Distributed Training with ``torchrun``
8+
=====================================================
9+
10+
Authors: `Suraj Subramanian <https://github.com/suraj813>`__
11+
12+
.. grid:: 2
13+
14+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
15+
:margin: 0
16+
17+
- Launching multi-GPU training jobs with ``torchrun``
18+
- Saving and loading snapshots of your training job
19+
- Structuring your training script for graceful restarts
20+
21+
.. grid:: 1
22+
23+
.. grid-item::
24+
25+
:octicon:`code-square;1.0em;` View the code used in this tutorial on `GitHub <https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu_torchrun.py>`__
26+
27+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
28+
:margin: 0
29+
30+
* High-level `overview <ddp_series_theory.html>`__ of DDP
31+
* Familiarity with `DDP code <ddp_series_multigpu.html>`__
32+
* A machine with multiple GPUs (this tutorial uses an AWS p3.8xlarge instance)
33+
* PyTorch `installed <https://pytorch.org/get-started/locally/>`__ with CUDA
34+
35+
Follow along with the video below or on `youtube <https://www.youtube.com/watch/9kIvQOiwYzg>`__.
36+
37+
.. raw:: html
38+
39+
<div style="margin-top:10px; margin-bottom:10px;">
40+
<iframe width="560" height="315" src="https://www.youtube.com/embed/9kIvQOiwYzg" frameborder="0" allow="accelerometer; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
41+
</div>
42+
43+
In distributed training, a single process failure can
44+
disrupt the entire training job. Since the susceptibility for failure can be higher here, making your training
45+
script robust is particularly important here. You might also prefer your training job to be *elastic* i.e.
46+
47+
48+
PyTorch offers a utility called ``torchrun`` that provides fault-tolerance and
49+
elastic training. When a failure occurs, ``torchrun`` logs the errors and
50+
attempts to automatically restart all the processes from the last saved
51+
“snapshot” of the training job.
52+
53+
The snapshot saves more than just the model state; it can include
54+
details about the number of epochs run, optimizer states or any other
55+
stateful attribute of the training job necessary for its continuity.
56+
57+
Why use ``torchrun``
58+
~~~~~~~~~~~~~~~~~~~~
59+
60+
``torchrun`` handles the minutiae of distributed training so that you
61+
don't need to. For instance,
62+
63+
- You don't need to set environment variables or explicitly pass the ``rank`` and ``world_size``; torchrun assigns this along with several other `environment variables <https://pytorch.org/docs/stable/elastic/run.html#environment-variables>`__.
64+
- No need to call ``mp.spawn`` in your script; you only need a generic ``main()`` entrypoint, and launch the script with ``torchrun``. This way the same script can be run in non-distributed as well as single-node and multinode setups.
65+
- Gracefully restarting training from the last saved training snapshot
66+
67+
68+
Graceful restarts
69+
~~~~~~~~~~~~~~~~~~~~~
70+
For graceful restarts, you should structure your train script like:
71+
72+
.. code:: python
73+
74+
def main():
75+
load_snapshot(snapshot_path)
76+
initialize()
77+
train()
78+
79+
def train():
80+
for batch in iter(dataset):
81+
train_step(batch)
82+
83+
if should_checkpoint:
84+
save_snapshot(snapshot_path)
85+
86+
If a failure occurs, ``torchrun`` will terminate all the processes and restart them.
87+
Each process entrypoint first loads and initializes the last saved snapshot, and continues training from there.
88+
So at any failure, you only lose the training progress from the last saved snapshot.
89+
90+
In elastic training, whenever there are any membership changes (adding or removing nodes), ``torchrun`` will terminate and spawn processes
91+
on available devices. Having this structure ensures your training job can continue without manual intervention.
92+
93+
94+
95+
96+
97+
Diff for `multigpu.py <https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu.py>`__ v/s `multigpu_torchrun.py <https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu_torchrun.py>`__
98+
-----------------------------------------------------------
99+
100+
Process group initialization
101+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
102+
103+
- ``torchrun`` assigns ``RANK`` and ``WORLD_SIZE`` automatically,
104+
amongst `other env
105+
variables <https://pytorch.org/docs/stable/elastic/run.html#environment-variables>`__
106+
107+
.. code:: diff
108+
109+
- def ddp_setup(rank, world_size):
110+
+ def ddp_setup():
111+
- """
112+
- Args:
113+
- rank: Unique identifier of each process
114+
- world_size: Total number of processes
115+
- """
116+
- os.environ["MASTER_ADDR"] = "localhost"
117+
- os.environ["MASTER_PORT"] = "12355"
118+
- init_process_group(backend="nccl", rank=rank, world_size=world_size)
119+
+ init_process_group(backend="nccl")
120+
121+
122+
Use Torchrun-provided env variables
123+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
124+
125+
.. code:: diff
126+
127+
- self.gpu_id = gpu_id
128+
+ self.gpu_id = int(os.environ["LOCAL_RANK"])
129+
130+
Saving and loading snapshots
131+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
132+
133+
Regularly storing all the relevant information in snapshots allows our
134+
training job to seamlessly resume after an interruption.
135+
136+
.. code:: diff
137+
138+
+ def _save_snapshot(self, epoch):
139+
+ snapshot = {}
140+
+ snapshot["MODEL_STATE"] = self.model.module.state_dict()
141+
+ snapshot["EPOCHS_RUN"] = epoch
142+
+ torch.save(snapshot, "snapshot.pt")
143+
+ print(f"Epoch {epoch} | Training snapshot saved at snapshot.pt")
144+
145+
+ def _load_snapshot(self, snapshot_path):
146+
+ snapshot = torch.load(snapshot_path)
147+
+ self.model.load_state_dict(snapshot["MODEL_STATE"])
148+
+ self.epochs_run = snapshot["EPOCHS_RUN"]
149+
+ print(f"Resuming training from snapshot at Epoch {self.epochs_run}")
150+
151+
152+
Loading a snapshot in the Trainer constructor
153+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
154+
155+
When restarting an interrupted training job, your script will first try
156+
to load a snapshot to resume training from.
157+
158+
.. code:: diff
159+
160+
class Trainer:
161+
def __init__(self, snapshot_path, ...):
162+
...
163+
+ if os.path.exists(snapshot_path):
164+
+ self._load_snapshot(snapshot_path)
165+
...
166+
167+
168+
Resuming training
169+
~~~~~~~~~~~~~~~~~
170+
171+
Training can resume from the last epoch run, instead of starting all
172+
over from scratch.
173+
174+
.. code:: diff
175+
176+
def train(self, max_epochs: int):
177+
- for epoch in range(max_epochs):
178+
+ for epoch in range(self.epochs_run, max_epochs):
179+
self._run_epoch(epoch)
180+
181+
182+
Running the script
183+
~~~~~~~~~~~~~~~~~~
184+
Simply call your entrypoint function as you would for a non-multiprocessing script; ``torchrun`` automatically
185+
spawns the processes.
186+
187+
.. code:: diff
188+
189+
if __name__ == "__main__":
190+
import sys
191+
total_epochs = int(sys.argv[1])
192+
save_every = int(sys.argv[2])
193+
- world_size = torch.cuda.device_count()
194+
- mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
195+
+ main(save_every, total_epochs)
196+
197+
198+
.. code:: diff
199+
200+
- python multigpu.py 50 10
201+
+ torchrun --standalone --nproc_per_node=4 multigpu_torchrun.py 50 10
202+
203+
Further Reading
204+
---------------
205+
206+
- `Multi-Node training with DDP <../intermediate/ddp_series_multinode.html>`__ (next tutorial in this series)
207+
- `Multi-GPU Training with DDP <ddp_series_multigpu.html>`__ (previous tutorial in this series)
208+
- `torchrun <https://pytorch.org/docs/stable/elastic/run.html>`__
209+
- `Torchrun launch
210+
options <https://github.com/pytorch/pytorch/blob/bbe803cb35948df77b46a2d38372910c96693dcd/torch/distributed/run.py#L401>`__
211+
- `Migrating from torch.distributed.launch to
212+
torchrun <https://pytorch.org/docs/stable/elastic/train_script.html#elastic-train-script>`__

beginner_source/ddp_series_intro.rst

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
**Introduction** \|\| `What is DDP <ddp_series_theory.html>`__ \|\| `Single-Node
2+
Multi-GPU Training <ddp_series_multigpu.html>`__ \|\| `Fault
3+
Tolerance <ddp_series_fault_tolerance.html>`__ \|\| `Multi-Node
4+
training <../intermediate/ddp_series_multinode.html>`__ \|\| `minGPT Training <../intermediate/ddp_series_minGPT.html>`__
5+
6+
Distributed Data Parallel in PyTorch - Video Tutorials
7+
======================================================
8+
9+
Authors: `Suraj Subramanian <https://github.com/suraj813>`__
10+
11+
Follow along with the video below or on `youtube <https://www.youtube.com/watch/-K3bZYHYHEA>`__.
12+
13+
.. raw:: html
14+
15+
<div style="margin-top:10px; margin-bottom:10px;">
16+
<iframe width="560" height="315" src="https://www.youtube.com/embed/-K3bZYHYHEA" frameborder="0" allow="accelerometer; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
17+
</div>
18+
19+
This series of video tutorials walks you through distributed training in
20+
PyTorch via DDP.
21+
22+
The series starts with a simple non-distributed training job, and ends
23+
with deploying a training job across several machines in a cluster.
24+
Along the way, you will also learn about
25+
`torchrun <https://pytorch.org/docs/stable/elastic/run.html>`__ for
26+
fault-tolerant distributed training.
27+
28+
The tutorial assumes a basic familiarity with model training in PyTorch.
29+
30+
Running the code
31+
----------------
32+
33+
You will need multiple CUDA GPUs to run the tutorial code. Typically,
34+
this can be done on a cloud instance with multiple GPUs (the tutorials
35+
use an Amazon EC2 P3 instance with 4 GPUs).
36+
37+
The tutorial code is hosted at this `github
38+
repo <https://github.com/pytorch/examples/tree/main/distributed/ddp-tutorial-series>`__. Clone the repo and
39+
follow along!
40+
41+
Tutorial sections
42+
-----------------
43+
44+
0. Introduction (this page)
45+
1. `What is DDP? <ddp_series_theory.html>`__ Gently introduces what DDP is doing
46+
under the hood
47+
2. `Single-Node Multi-GPU Training <ddp_series_multigpu.html>`__ Training models
48+
using multiple GPUs on a single machine
49+
3. `Fault-tolerant distributed training <ddp_series_fault_tolerance.html>`__
50+
Making your distributed training job robust with torchrun
51+
4. `Multi-Node training <../intermediate/ddp_series_multinode.html>`__ Training models using
52+
multiple GPUs on multiple machines
53+
5. `Training a GPT model with DDP <../intermediate/ddp_series_minGPT.html>`__ “Real-world”
54+
example of training a `minGPT <https://github.com/karpathy/minGPT>`__
55+
model with DDP

0 commit comments

Comments
 (0)