Skip to content

Commit f557ee0

Browse files
author
Jessica Lin
authored
Merge pull request #973 from mrshenli/ddp
Update DDP Tutorial to remove Single-Process Multi-Device Use Case
2 parents e765195 + 5ca59d7 commit f557ee0

File tree

1 file changed

+80
-85
lines changed

1 file changed

+80
-85
lines changed

intermediate_source/ddp_tutorial.rst

Lines changed: 80 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,54 @@ Getting Started with Distributed Data Parallel
22
=================================================
33
**Author**: `Shen Li <https://mrshenli.github.io/>`_
44

5-
`DistributedDataParallel <https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html>`__
6-
(DDP) implements data parallelism at the module level. It uses communication
7-
collectives in the `torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__
8-
package to synchronize gradients, parameters, and buffers. Parallelism is
9-
available both within a process and across processes. Within a process, DDP
10-
replicates the input module to devices specified in ``device_ids``, scatters
11-
inputs along the batch dimension accordingly, and gathers outputs to the
12-
``output_device``, which is similar to
13-
`DataParallel <https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html>`__.
14-
Across processes, DDP inserts necessary parameter synchronizations in forward
15-
passes and gradient synchronizations in backward passes. It is up to users to
16-
map processes to available resources, as long as processes do not share GPU
17-
devices. The recommended (usually fastest) approach is to create a process for
18-
every module replica, i.e., no module replication within a process. The code in
19-
this tutorial runs on an 8-GPU server, but it can be easily generalized to
20-
other environments.
5+
`DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__
6+
(DDP) implements data parallelism at the module level which can run across
7+
multiple machines. Applications using DDP should spawn multiple processes and
8+
create a single DDP instance per process. DDP uses collective communications in the
9+
`torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__
10+
package to synchronize gradients and buffers. More specifically, DDP registers
11+
an autograd hook for each parameter given by ``model.parameters()`` and the
12+
hook will fire when the corresponding gradient is computed in the backward
13+
pass. Then DDP uses that signal to trigger gradient synchronization across
14+
processes. Please refer to
15+
`DDP design note <https://pytorch.org/docs/master/notes/ddp.html>`__ for more details.
16+
17+
18+
The recommended way to use DDP is to spawn one process for each model replica,
19+
where a model replica can span multiple devices. DDP processes can be
20+
placed on the same machine or across machines, but GPU devices cannot be
21+
shared across processes. This tutorial starts from a basic DDP use case and
22+
then demonstrates more advanced use cases including checkpointing models and
23+
combining DDP with model parallel.
24+
25+
26+
.. note::
27+
The code in this tutorial runs on an 8-GPU server, but it can be easily
28+
generalized to other environments.
29+
2130

2231
Comparison between ``DataParallel`` and ``DistributedDataParallel``
2332
-------------------------------------------------------------------
2433

2534
Before we dive in, let's clarify why, despite the added complexity, you would
2635
consider using ``DistributedDataParallel`` over ``DataParallel``:
2736

28-
- First, recall from the
37+
- First, ``DataParallel`` is single-process, multi-thread, and only works on a
38+
single machine, while ``DistributedDataParallel`` is multi-process and works
39+
for both single- and multi- machine training. ``DataParallel`` is usually
40+
slower than ``DistributedDataParallel`` even on a single machine due to GIL
41+
contention across threads, per-iteration replicated model, and additional
42+
overhead introduced by scattering inputs and gathering outputs.
43+
- Recall from the
2944
`prior tutorial <https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html>`__
3045
that if your model is too large to fit on a single GPU, you must use **model parallel**
3146
to split it across multiple GPUs. ``DistributedDataParallel`` works with
32-
**model parallel**; ``DataParallel`` does not at this time.
33-
- ``DataParallel`` is single-process, multi-thread, and only works on a single
34-
machine, while ``DistributedDataParallel`` is multi-process and works for both
35-
single- and multi- machine training. Thus, even for single machine training,
36-
where your **data** is small enough to fit on a single machine, ``DistributedDataParallel``
37-
is expected to be faster than ``DataParallel``. ``DistributedDataParallel``
38-
also replicates models upfront instead of on each iteration and gets Global
39-
Interpreter Lock out of the way.
40-
- If both your data is too large to fit on one machine **and** your
41-
model is too large to fit on a single GPU, you can combine model parallel
42-
(splitting a single model across multiple GPUs) with ``DistributedDataParallel``.
43-
Under this regime, each ``DistributedDataParallel`` process could use model parallel,
44-
and all processes collectively would use data parallel.
47+
**model parallel**; ``DataParallel`` does not at this time. When DDP is combined
48+
with model parallel, each DDP process would use model parallel, and all processes
49+
collectively would use data parallel.
50+
- If your model needs to span multiple machines or if your use case does not fit
51+
into data parallelism paradigm, please see `the RPC API <https://pytorch.org/docs/stable/rpc.html>`__
52+
for more generic distributed training support.
4553

4654
Basic Use Case
4755
--------------
@@ -70,18 +78,14 @@ be found in
7078
# initialize the process group
7179
dist.init_process_group("gloo", rank=rank, world_size=world_size)
7280
73-
# Explicitly setting seed to make sure that models created in two processes
74-
# start from same random weights and biases.
75-
torch.manual_seed(42)
76-
7781
7882
def cleanup():
7983
dist.destroy_process_group()
8084
8185
Now, let's create a toy module, wrap it with DDP, and feed it with some dummy
82-
input data. Please note, if training starts from random parameters, you might
83-
want to make sure that all DDP processes use the same initial values.
84-
Otherwise, global gradient synchronizes will not make sense.
86+
input data. Please note, as DDP broadcasts model states from rank 0 process to
87+
all other processes in the DDP constructor, you don't need to worry about
88+
different DDP processes start from different model parameter initial values.
8589

8690
.. code:: python
8791
@@ -97,24 +101,19 @@ Otherwise, global gradient synchronizes will not make sense.
97101
98102
99103
def demo_basic(rank, world_size):
104+
print(f"Running basic DDP example on rank {rank}.")
100105
setup(rank, world_size)
101106
102-
# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
103-
# rank 2 uses GPUs [4, 5, 6, 7].
104-
n = torch.cuda.device_count() // world_size
105-
device_ids = list(range(rank * n, (rank + 1) * n))
106-
107-
# create model and move it to device_ids[0]
108-
model = ToyModel().to(device_ids[0])
109-
# output_device defaults to device_ids[0]
110-
ddp_model = DDP(model, device_ids=device_ids)
107+
# create model and move it to GPU with id rank
108+
model = ToyModel().to(rank)
109+
ddp_model = DDP(model, device_ids=[rank])
111110
112111
loss_fn = nn.MSELoss()
113112
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
114113
115114
optimizer.zero_grad()
116115
outputs = ddp_model(torch.randn(20, 10))
117-
labels = torch.randn(20, 5).to(device_ids[0])
116+
labels = torch.randn(20, 5).to(rank)
118117
loss_fn(outputs, labels).backward()
119118
optimizer.step()
120119
@@ -127,23 +126,27 @@ Otherwise, global gradient synchronizes will not make sense.
127126
nprocs=world_size,
128127
join=True)
129128
130-
As you can see, DDP wraps lower level distributed communication details, and
131-
provides a clean API as if it is a local model. For basic use cases, DDP only
129+
As you can see, DDP wraps lower-level distributed communication details and
130+
provides a clean API as if it is a local model. Gradient synchronization
131+
communications take place during the backward pass and overlap with the
132+
backward computation. When the ``backward()`` returns, ``param.grad`` already
133+
contains the synchronized gradient tensor. For basic use cases, DDP only
132134
requires a few more LoCs to set up the process group. When applying DDP to more
133-
advanced use cases, there are some caveats that require cautions.
135+
advanced use cases, some caveats require caution.
134136

135137
Skewed Processing Speeds
136138
------------------------
137139

138-
In DDP, constructor, forward method, and differentiation of the outputs are
139-
distributed synchronization points. Different processes are expected to reach
140-
synchronization points in the same order and enter each synchronization point
141-
at roughly the same time. Otherwise, fast processes might arrive early and
142-
timeout on waiting for stragglers. Hence, users are responsible for balancing
143-
workloads distributions across processes. Sometimes, skewed processing speeds
144-
are inevitable due to, e.g., network delays, resource contentions,
145-
unpredictable workload spikes. To avoid timeouts in these situations, make
146-
sure that you pass a sufficiently large ``timeout`` value when calling
140+
In DDP, the constructor, the forward pass, and the backward pass are
141+
distributed synchronization points. Different processes are expected to launch
142+
the same number of synchronizations and reach these synchronization points in
143+
the same order and enter each synchronization point at roughly the same time.
144+
Otherwise, fast processes might arrive early and timeout on waiting for
145+
stragglers. Hence, users are responsible for balancing workloads distributions
146+
across processes. Sometimes, skewed processing speeds are inevitable due to,
147+
e.g., network delays, resource contentions, unpredictable workload spikes. To
148+
avoid timeouts in these situations, make sure that you pass a sufficiently
149+
large ``timeout`` value when calling
147150
`init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`__.
148151

149152
Save and Load Checkpoints
@@ -156,27 +159,23 @@ for more details. When using DDP, one optimization is to save the model in
156159
only one process and then load it to all processes, reducing write overhead.
157160
This is correct because all processes start from the same parameters and
158161
gradients are synchronized in backward passes, and hence optimizers should keep
159-
setting parameters to same values. If you use this optimization, make sure all
162+
setting parameters to the same values. If you use this optimization, make sure all
160163
processes do not start loading before the saving is finished. Besides, when
161164
loading the module, you need to provide an appropriate ``map_location``
162165
argument to prevent a process to step into others' devices. If ``map_location``
163166
is missing, ``torch.load`` will first load the module to CPU and then copy each
164167
parameter to where it was saved, which would result in all processes on the
165-
same machine using the same set of devices.
168+
same machine using the same set of devices. For more advanced failure recovery
169+
and elasticity support, please refer to `TorchElastic <https://pytorch.org/elastic>`__.
166170

167171
.. code:: python
168172
169173
def demo_checkpoint(rank, world_size):
174+
print(f"Running DDP checkpoint example on rank {rank}.")
170175
setup(rank, world_size)
171176
172-
# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
173-
# rank 2 uses GPUs [4, 5, 6, 7].
174-
n = torch.cuda.device_count() // world_size
175-
device_ids = list(range(rank * n, (rank + 1) * n))
176-
177-
model = ToyModel().to(device_ids[0])
178-
# output_device defaults to device_ids[0]
179-
ddp_model = DDP(model, device_ids=device_ids)
177+
model = ToyModel().to(rank)
178+
ddp_model = DDP(model, device_ids=[rank])
180179
181180
loss_fn = nn.MSELoss()
182181
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
@@ -192,15 +191,13 @@ same machine using the same set of devices.
192191
# 0 saves it.
193192
dist.barrier()
194193
# configure map_location properly
195-
rank0_devices = [x - rank * len(device_ids) for x in device_ids]
196-
device_pairs = zip(rank0_devices, device_ids)
197-
map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
194+
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
198195
ddp_model.load_state_dict(
199196
torch.load(CHECKPOINT_PATH, map_location=map_location))
200197
201198
optimizer.zero_grad()
202199
outputs = ddp_model(torch.randn(20, 10))
203-
labels = torch.randn(20, 5).to(device_ids[0])
200+
labels = torch.randn(20, 5).to(rank)
204201
loss_fn = nn.MSELoss()
205202
loss_fn(outputs, labels).backward()
206203
optimizer.step()
@@ -217,13 +214,8 @@ same machine using the same set of devices.
217214
Combine DDP with Model Parallelism
218215
----------------------------------
219216

220-
DDP also works with multi-GPU models, but replications within a process are not
221-
supported. You need to create one process per module replica, which usually
222-
leads to better performance compared to multiple replicas per process. DDP
223-
wrapping multi-GPU models is especially helpful when training large models with
224-
a huge amount of data. When using this feature, the multi-GPU model needs to be
225-
carefully implemented to avoid hard-coded devices, because different model
226-
replicas will be placed to different devices.
217+
DDP also works with multi-GPU models. DDP wrapping multi-GPU models is especially
218+
helpful when training large models with a huge amount of data.
227219

228220
.. code:: python
229221
@@ -249,6 +241,7 @@ either the application or the model ``forward()`` method.
249241
.. code:: python
250242
251243
def demo_model_parallel(rank, world_size):
244+
print(f"Running DDP with model parallel example on rank {rank}.")
252245
setup(rank, world_size)
253246
254247
# setup mp_model and devices for this process
@@ -271,8 +264,10 @@ either the application or the model ``forward()`` method.
271264
272265
273266
if __name__ == "__main__":
274-
run_demo(demo_basic, 2)
275-
run_demo(demo_checkpoint, 2)
276-
277-
if torch.cuda.device_count() >= 8:
278-
run_demo(demo_model_parallel, 4)
267+
n_gpus = torch.cuda.device_count()
268+
if n_gpus < 8:
269+
print(f"Requires at least 8 GPUs to run, but got {n_gpus}.")
270+
else:
271+
run_demo(demo_basic, 8)
272+
run_demo(demo_checkpoint, 8)
273+
run_demo(demo_model_parallel, 4)

0 commit comments

Comments
 (0)