Skip to content

Commit 4902c7a

Browse files
committed
Update DDP Tutorial to remove Single-Process Multi-Device Use Case
1 parent e765195 commit 4902c7a

File tree

1 file changed

+76
-82
lines changed

1 file changed

+76
-82
lines changed

intermediate_source/ddp_tutorial.rst

Lines changed: 76 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,53 @@ Getting Started with Distributed Data Parallel
22
=================================================
33
**Author**: `Shen Li <https://mrshenli.github.io/>`_
44

5-
`DistributedDataParallel <https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html>`__
6-
(DDP) implements data parallelism at the module level. It uses communication
7-
collectives in the `torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__
8-
package to synchronize gradients, parameters, and buffers. Parallelism is
9-
available both within a process and across processes. Within a process, DDP
10-
replicates the input module to devices specified in ``device_ids``, scatters
11-
inputs along the batch dimension accordingly, and gathers outputs to the
12-
``output_device``, which is similar to
13-
`DataParallel <https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html>`__.
14-
Across processes, DDP inserts necessary parameter synchronizations in forward
15-
passes and gradient synchronizations in backward passes. It is up to users to
16-
map processes to available resources, as long as processes do not share GPU
17-
devices. The recommended (usually fastest) approach is to create a process for
18-
every module replica, i.e., no module replication within a process. The code in
19-
this tutorial runs on an 8-GPU server, but it can be easily generalized to
20-
other environments.
5+
`DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__
6+
(DDP) implements data parallelism at the module level which can run across
7+
multiple machines. Applications using DDP should spawn multiple processes and
8+
create a DDP instance in each process. DDP uses collective communications in the
9+
`torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__
10+
package to synchronize gradients and buffers. More specifically, DDP inserts
11+
an autograd hook for each model parameter which will fire when the
12+
corresponding gradient is computed in the backward pass. Then DDP uses that
13+
signal to trigger gradient synchronization across processes. Please refer to
14+
`DDP design note <https://pytorch.org/docs/master/notes/ddp.html>`__ for more details.
15+
16+
17+
The recommended way to use DDP is to spawn one process for each model replica,
18+
where a model replica can span multiple devices. DDP processes can be
19+
placed on the same machine or across machines, but GPU devices cannot be
20+
shared across processes. This tutorial starts from a basic DDP use case and
21+
then demonstrates more advanced use cases including checkpointing models and
22+
combining DDP with model parallel.
23+
24+
25+
.. note::
26+
The code in this tutorial runs on an 8-GPU server, but it can be easily
27+
generalized to other environments.
28+
2129

2230
Comparison between ``DataParallel`` and ``DistributedDataParallel``
2331
-------------------------------------------------------------------
2432

2533
Before we dive in, let's clarify why, despite the added complexity, you would
2634
consider using ``DistributedDataParallel`` over ``DataParallel``:
2735

28-
- First, recall from the
36+
- First, ``DataParallel`` is single-process, multi-thread, and only works on a
37+
single machine, while ``DistributedDataParallel`` is multi-process and works
38+
for both single- and multi- machine training. ``DataParallel`` is usually
39+
slower than ``DistributedDataParallel`` even on a single machine due to GIL
40+
contention across threads, per-iteration replicated model, and additional
41+
overhead introduced by scattering inputs and gathering outputs.
42+
- Recall from the
2943
`prior tutorial <https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html>`__
3044
that if your model is too large to fit on a single GPU, you must use **model parallel**
3145
to split it across multiple GPUs. ``DistributedDataParallel`` works with
32-
**model parallel**; ``DataParallel`` does not at this time.
33-
- ``DataParallel`` is single-process, multi-thread, and only works on a single
34-
machine, while ``DistributedDataParallel`` is multi-process and works for both
35-
single- and multi- machine training. Thus, even for single machine training,
36-
where your **data** is small enough to fit on a single machine, ``DistributedDataParallel``
37-
is expected to be faster than ``DataParallel``. ``DistributedDataParallel``
38-
also replicates models upfront instead of on each iteration and gets Global
39-
Interpreter Lock out of the way.
40-
- If both your data is too large to fit on one machine **and** your
41-
model is too large to fit on a single GPU, you can combine model parallel
42-
(splitting a single model across multiple GPUs) with ``DistributedDataParallel``.
43-
Under this regime, each ``DistributedDataParallel`` process could use model parallel,
44-
and all processes collectively would use data parallel.
46+
**model parallel**; ``DataParallel`` does not at this time. When DDP is combined
47+
with model parallel, each DDP process would use model parallel, and all processes
48+
collectively would use data parallel.
49+
- If your model needs to span multiple machines or if your use case does not fit
50+
into data parallelism paradigm, please see `the RPC API <https://pytorch.org/docs/stable/rpc.html>`__
51+
for more generic distributed training support.
4552

4653
Basic Use Case
4754
--------------
@@ -70,18 +77,14 @@ be found in
7077
# initialize the process group
7178
dist.init_process_group("gloo", rank=rank, world_size=world_size)
7279
73-
# Explicitly setting seed to make sure that models created in two processes
74-
# start from same random weights and biases.
75-
torch.manual_seed(42)
76-
7780
7881
def cleanup():
7982
dist.destroy_process_group()
8083
8184
Now, let's create a toy module, wrap it with DDP, and feed it with some dummy
82-
input data. Please note, if training starts from random parameters, you might
83-
want to make sure that all DDP processes use the same initial values.
84-
Otherwise, global gradient synchronizes will not make sense.
85+
input data. Please note, as DDP broadcasts model states from rank 0 process to
86+
all other processes in the DDP constructor, you don't need to worry about
87+
different DDP processes start from different model parameter initial values.
8588

8689
.. code:: python
8790
@@ -97,24 +100,19 @@ Otherwise, global gradient synchronizes will not make sense.
97100
98101
99102
def demo_basic(rank, world_size):
103+
print("Running basic DDP example.")
100104
setup(rank, world_size)
101105
102-
# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
103-
# rank 2 uses GPUs [4, 5, 6, 7].
104-
n = torch.cuda.device_count() // world_size
105-
device_ids = list(range(rank * n, (rank + 1) * n))
106-
107-
# create model and move it to device_ids[0]
108-
model = ToyModel().to(device_ids[0])
109-
# output_device defaults to device_ids[0]
110-
ddp_model = DDP(model, device_ids=device_ids)
106+
# create model and move it to GPU with id rank
107+
model = ToyModel().to(rank)
108+
ddp_model = DDP(model, device_ids=[rank])
111109
112110
loss_fn = nn.MSELoss()
113111
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
114112
115113
optimizer.zero_grad()
116114
outputs = ddp_model(torch.randn(20, 10))
117-
labels = torch.randn(20, 5).to(device_ids[0])
115+
labels = torch.randn(20, 5).to(rank)
118116
loss_fn(outputs, labels).backward()
119117
optimizer.step()
120118
@@ -128,22 +126,26 @@ Otherwise, global gradient synchronizes will not make sense.
128126
join=True)
129127
130128
As you can see, DDP wraps lower level distributed communication details, and
131-
provides a clean API as if it is a local model. For basic use cases, DDP only
129+
provides a clean API as if it is a local model. Gradient synchronization
130+
communications take place during the backward pass and overlap with the
131+
backward computation. When the ``backward()`` returns, ``param.grad`` already
132+
contains the synchronized gradient tensor. For basic use cases, DDP only
132133
requires a few more LoCs to set up the process group. When applying DDP to more
133134
advanced use cases, there are some caveats that require cautions.
134135

135136
Skewed Processing Speeds
136137
------------------------
137138

138-
In DDP, constructor, forward method, and differentiation of the outputs are
139-
distributed synchronization points. Different processes are expected to reach
140-
synchronization points in the same order and enter each synchronization point
141-
at roughly the same time. Otherwise, fast processes might arrive early and
142-
timeout on waiting for stragglers. Hence, users are responsible for balancing
143-
workloads distributions across processes. Sometimes, skewed processing speeds
144-
are inevitable due to, e.g., network delays, resource contentions,
145-
unpredictable workload spikes. To avoid timeouts in these situations, make
146-
sure that you pass a sufficiently large ``timeout`` value when calling
139+
In DDP, the constructor and the backward pass are distributed synchronization
140+
points. Different processes are expected to launch the same number of
141+
synchronizations and reach these synchronization points in the same order and
142+
enter each synchronization point at roughly the same time. Otherwise, fast
143+
processes might arrive early and timeout on waiting for stragglers. Hence,
144+
users are responsible for balancing workloads distributions across processes.
145+
Sometimes, skewed processing speeds are inevitable due to, e.g., network
146+
delays, resource contentions, unpredictable workload spikes. To avoid timeouts
147+
in these situations, make sure that you pass a sufficiently large ``timeout``
148+
value when calling
147149
`init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`__.
148150

149151
Save and Load Checkpoints
@@ -162,21 +164,17 @@ loading the module, you need to provide an appropriate ``map_location``
162164
argument to prevent a process to step into others' devices. If ``map_location``
163165
is missing, ``torch.load`` will first load the module to CPU and then copy each
164166
parameter to where it was saved, which would result in all processes on the
165-
same machine using the same set of devices.
167+
same machine using the same set of devices. For more advanced failure recovery
168+
and elasticity support, please refer to `TorchElastic <https://github.com/pytorch/elastic>`__.
166169

167170
.. code:: python
168171
169172
def demo_checkpoint(rank, world_size):
173+
print("Running DDP checkpoint example.")
170174
setup(rank, world_size)
171175
172-
# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
173-
# rank 2 uses GPUs [4, 5, 6, 7].
174-
n = torch.cuda.device_count() // world_size
175-
device_ids = list(range(rank * n, (rank + 1) * n))
176-
177-
model = ToyModel().to(device_ids[0])
178-
# output_device defaults to device_ids[0]
179-
ddp_model = DDP(model, device_ids=device_ids)
176+
model = ToyModel().to(rank)
177+
ddp_model = DDP(model, device_ids=[rank])
180178
181179
loss_fn = nn.MSELoss()
182180
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
@@ -192,15 +190,13 @@ same machine using the same set of devices.
192190
# 0 saves it.
193191
dist.barrier()
194192
# configure map_location properly
195-
rank0_devices = [x - rank * len(device_ids) for x in device_ids]
196-
device_pairs = zip(rank0_devices, device_ids)
197-
map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
193+
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
198194
ddp_model.load_state_dict(
199195
torch.load(CHECKPOINT_PATH, map_location=map_location))
200196
201197
optimizer.zero_grad()
202198
outputs = ddp_model(torch.randn(20, 10))
203-
labels = torch.randn(20, 5).to(device_ids[0])
199+
labels = torch.randn(20, 5).to(rank)
204200
loss_fn = nn.MSELoss()
205201
loss_fn(outputs, labels).backward()
206202
optimizer.step()
@@ -217,13 +213,8 @@ same machine using the same set of devices.
217213
Combine DDP with Model Parallelism
218214
----------------------------------
219215

220-
DDP also works with multi-GPU models, but replications within a process are not
221-
supported. You need to create one process per module replica, which usually
222-
leads to better performance compared to multiple replicas per process. DDP
223-
wrapping multi-GPU models is especially helpful when training large models with
224-
a huge amount of data. When using this feature, the multi-GPU model needs to be
225-
carefully implemented to avoid hard-coded devices, because different model
226-
replicas will be placed to different devices.
216+
DDP also works with multi-GPU models. DDP wrapping multi-GPU models is especially
217+
helpful when training large models with a huge amount of data.
227218

228219
.. code:: python
229220
@@ -249,6 +240,7 @@ either the application or the model ``forward()`` method.
249240
.. code:: python
250241
251242
def demo_model_parallel(rank, world_size):
243+
print("Running DDP with model parallel example.")
252244
setup(rank, world_size)
253245
254246
# setup mp_model and devices for this process
@@ -271,8 +263,10 @@ either the application or the model ``forward()`` method.
271263
272264
273265
if __name__ == "__main__":
274-
run_demo(demo_basic, 2)
275-
run_demo(demo_checkpoint, 2)
276-
277-
if torch.cuda.device_count() >= 8:
278-
run_demo(demo_model_parallel, 4)
266+
n_gpus = torch.cuda.device_count()
267+
if n_gpus < 8:
268+
print(f"Requires at least 8 GPUs to run, but got {n_gpus}.")
269+
else:
270+
run_demo(demo_basic, 8)
271+
run_demo(demo_checkpoint, 8)
272+
run_demo(demo_model_parallel, 4)

0 commit comments

Comments
 (0)