@@ -2,46 +2,53 @@ Getting Started with Distributed Data Parallel
2
2
=================================================
3
3
**Author **: `Shen Li <https://mrshenli.github.io/ >`_
4
4
5
- `DistributedDataParallel <https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html >`__
6
- (DDP) implements data parallelism at the module level. It uses communication
7
- collectives in the `torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html >`__
8
- package to synchronize gradients, parameters, and buffers. Parallelism is
9
- available both within a process and across processes. Within a process, DDP
10
- replicates the input module to devices specified in ``device_ids ``, scatters
11
- inputs along the batch dimension accordingly, and gathers outputs to the
12
- ``output_device ``, which is similar to
13
- `DataParallel <https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html >`__.
14
- Across processes, DDP inserts necessary parameter synchronizations in forward
15
- passes and gradient synchronizations in backward passes. It is up to users to
16
- map processes to available resources, as long as processes do not share GPU
17
- devices. The recommended (usually fastest) approach is to create a process for
18
- every module replica, i.e., no module replication within a process. The code in
19
- this tutorial runs on an 8-GPU server, but it can be easily generalized to
20
- other environments.
5
+ `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel >`__
6
+ (DDP) implements data parallelism at the module level which can run across
7
+ multiple machines. Applications using DDP should spawn multiple processes and
8
+ create a DDP instance in each process. DDP uses collective communications in the
9
+ `torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html >`__
10
+ package to synchronize gradients and buffers. More specifically, DDP inserts
11
+ an autograd hook for each model parameter which will fire when the
12
+ corresponding gradient is computed in the backward pass. Then DDP uses that
13
+ signal to trigger gradient synchronization across processes. Please refer to
14
+ `DDP design note <https://pytorch.org/docs/master/notes/ddp.html >`__ for more details.
15
+
16
+
17
+ The recommended way to use DDP is to spawn one process for each model replica,
18
+ where a model replica can span multiple devices. DDP processes can be
19
+ placed on the same machine or across machines, but GPU devices cannot be
20
+ shared across processes. This tutorial starts from a basic DDP use case and
21
+ then demonstrates more advanced use cases including checkpointing models and
22
+ combining DDP with model parallel.
23
+
24
+
25
+ .. note ::
26
+ The code in this tutorial runs on an 8-GPU server, but it can be easily
27
+ generalized to other environments.
28
+
21
29
22
30
Comparison between ``DataParallel `` and ``DistributedDataParallel ``
23
31
-------------------------------------------------------------------
24
32
25
33
Before we dive in, let's clarify why, despite the added complexity, you would
26
34
consider using ``DistributedDataParallel `` over ``DataParallel ``:
27
35
28
- - First, recall from the
36
+ - First, ``DataParallel `` is single-process, multi-thread, and only works on a
37
+ single machine, while ``DistributedDataParallel `` is multi-process and works
38
+ for both single- and multi- machine training. ``DataParallel `` is usually
39
+ slower than ``DistributedDataParallel `` even on a single machine due to GIL
40
+ contention across threads, per-iteration replicated model, and additional
41
+ overhead introduced by scattering inputs and gathering outputs.
42
+ - Recall from the
29
43
`prior tutorial <https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html >`__
30
44
that if your model is too large to fit on a single GPU, you must use **model parallel **
31
45
to split it across multiple GPUs. ``DistributedDataParallel `` works with
32
- **model parallel **; ``DataParallel `` does not at this time.
33
- - ``DataParallel `` is single-process, multi-thread, and only works on a single
34
- machine, while ``DistributedDataParallel `` is multi-process and works for both
35
- single- and multi- machine training. Thus, even for single machine training,
36
- where your **data ** is small enough to fit on a single machine, ``DistributedDataParallel ``
37
- is expected to be faster than ``DataParallel ``. ``DistributedDataParallel ``
38
- also replicates models upfront instead of on each iteration and gets Global
39
- Interpreter Lock out of the way.
40
- - If both your data is too large to fit on one machine **and ** your
41
- model is too large to fit on a single GPU, you can combine model parallel
42
- (splitting a single model across multiple GPUs) with ``DistributedDataParallel ``.
43
- Under this regime, each ``DistributedDataParallel `` process could use model parallel,
44
- and all processes collectively would use data parallel.
46
+ **model parallel **; ``DataParallel `` does not at this time. When DDP is combined
47
+ with model parallel, each DDP process would use model parallel, and all processes
48
+ collectively would use data parallel.
49
+ - If your model needs to span multiple machines or if your use case does not fit
50
+ into data parallelism paradigm, please see `the RPC API <https://pytorch.org/docs/stable/rpc.html >`__
51
+ for more generic distributed training support.
45
52
46
53
Basic Use Case
47
54
--------------
@@ -70,18 +77,14 @@ be found in
70
77
# initialize the process group
71
78
dist.init_process_group(" gloo" , rank = rank, world_size = world_size)
72
79
73
- # Explicitly setting seed to make sure that models created in two processes
74
- # start from same random weights and biases.
75
- torch.manual_seed(42 )
76
-
77
80
78
81
def cleanup ():
79
82
dist.destroy_process_group()
80
83
81
84
Now, let's create a toy module, wrap it with DDP, and feed it with some dummy
82
- input data. Please note, if training starts from random parameters, you might
83
- want to make sure that all DDP processes use the same initial values.
84
- Otherwise, global gradient synchronizes will not make sense .
85
+ input data. Please note, as DDP broadcasts model states from rank 0 process to
86
+ all other processes in the DDP constructor, you don't need to worry about
87
+ different DDP processes start from different model parameter initial values .
85
88
86
89
.. code :: python
87
90
@@ -97,24 +100,19 @@ Otherwise, global gradient synchronizes will not make sense.
97
100
98
101
99
102
def demo_basic (rank , world_size ):
103
+ print (f " Running basic DDP example on rank { rank} . " )
100
104
setup(rank, world_size)
101
105
102
- # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
103
- # rank 2 uses GPUs [4, 5, 6, 7].
104
- n = torch.cuda.device_count() // world_size
105
- device_ids = list (range (rank * n, (rank + 1 ) * n))
106
-
107
- # create model and move it to device_ids[0]
108
- model = ToyModel().to(device_ids[0 ])
109
- # output_device defaults to device_ids[0]
110
- ddp_model = DDP(model, device_ids = device_ids)
106
+ # create model and move it to GPU with id rank
107
+ model = ToyModel().to(rank)
108
+ ddp_model = DDP(model, device_ids = [rank])
111
109
112
110
loss_fn = nn.MSELoss()
113
111
optimizer = optim.SGD(ddp_model.parameters(), lr = 0.001 )
114
112
115
113
optimizer.zero_grad()
116
114
outputs = ddp_model(torch.randn(20 , 10 ))
117
- labels = torch.randn(20 , 5 ).to(device_ids[ 0 ] )
115
+ labels = torch.randn(20 , 5 ).to(rank )
118
116
loss_fn(outputs, labels).backward()
119
117
optimizer.step()
120
118
@@ -127,23 +125,27 @@ Otherwise, global gradient synchronizes will not make sense.
127
125
nprocs = world_size,
128
126
join = True )
129
127
130
- As you can see, DDP wraps lower level distributed communication details, and
131
- provides a clean API as if it is a local model. For basic use cases, DDP only
128
+ As you can see, DDP wraps lower-level distributed communication details and
129
+ provides a clean API as if it is a local model. Gradient synchronization
130
+ communications take place during the backward pass and overlap with the
131
+ backward computation. When the ``backward() `` returns, ``param.grad `` already
132
+ contains the synchronized gradient tensor. For basic use cases, DDP only
132
133
requires a few more LoCs to set up the process group. When applying DDP to more
133
- advanced use cases, there are some caveats that require cautions .
134
+ advanced use cases, some caveats require caution .
134
135
135
136
Skewed Processing Speeds
136
137
------------------------
137
138
138
- In DDP, constructor, forward method, and differentiation of the outputs are
139
- distributed synchronization points. Different processes are expected to reach
140
- synchronization points in the same order and enter each synchronization point
141
- at roughly the same time. Otherwise, fast processes might arrive early and
142
- timeout on waiting for stragglers. Hence, users are responsible for balancing
143
- workloads distributions across processes. Sometimes, skewed processing speeds
144
- are inevitable due to, e.g., network delays, resource contentions,
145
- unpredictable workload spikes. To avoid timeouts in these situations, make
146
- sure that you pass a sufficiently large ``timeout `` value when calling
139
+ In DDP, the constructor and the backward pass are distributed synchronization
140
+ points. Different processes are expected to launch the same number of
141
+ synchronizations and reach these synchronization points in the same order and
142
+ enter each synchronization point at roughly the same time. Otherwise, fast
143
+ processes might arrive early and timeout on waiting for stragglers. Hence,
144
+ users are responsible for balancing workloads distributions across processes.
145
+ Sometimes, skewed processing speeds are inevitable due to, e.g., network
146
+ delays, resource contentions, unpredictable workload spikes. To avoid timeouts
147
+ in these situations, make sure that you pass a sufficiently large ``timeout ``
148
+ value when calling
147
149
`init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group >`__.
148
150
149
151
Save and Load Checkpoints
@@ -156,27 +158,23 @@ for more details. When using DDP, one optimization is to save the model in
156
158
only one process and then load it to all processes, reducing write overhead.
157
159
This is correct because all processes start from the same parameters and
158
160
gradients are synchronized in backward passes, and hence optimizers should keep
159
- setting parameters to same values. If you use this optimization, make sure all
161
+ setting parameters to the same values. If you use this optimization, make sure all
160
162
processes do not start loading before the saving is finished. Besides, when
161
163
loading the module, you need to provide an appropriate ``map_location ``
162
164
argument to prevent a process to step into others' devices. If ``map_location ``
163
165
is missing, ``torch.load `` will first load the module to CPU and then copy each
164
166
parameter to where it was saved, which would result in all processes on the
165
- same machine using the same set of devices.
167
+ same machine using the same set of devices. For more advanced failure recovery
168
+ and elasticity support, please refer to `TorchElastic <https://github.com/pytorch/elastic >`__.
166
169
167
170
.. code :: python
168
171
169
172
def demo_checkpoint (rank , world_size ):
173
+ print (f " Running DDP checkpoint example on rank { rank} . " )
170
174
setup(rank, world_size)
171
175
172
- # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
173
- # rank 2 uses GPUs [4, 5, 6, 7].
174
- n = torch.cuda.device_count() // world_size
175
- device_ids = list (range (rank * n, (rank + 1 ) * n))
176
-
177
- model = ToyModel().to(device_ids[0 ])
178
- # output_device defaults to device_ids[0]
179
- ddp_model = DDP(model, device_ids = device_ids)
176
+ model = ToyModel().to(rank)
177
+ ddp_model = DDP(model, device_ids = [rank])
180
178
181
179
loss_fn = nn.MSELoss()
182
180
optimizer = optim.SGD(ddp_model.parameters(), lr = 0.001 )
@@ -192,15 +190,13 @@ same machine using the same set of devices.
192
190
# 0 saves it.
193
191
dist.barrier()
194
192
# configure map_location properly
195
- rank0_devices = [x - rank * len (device_ids) for x in device_ids]
196
- device_pairs = zip (rank0_devices, device_ids)
197
- map_location = {' cuda:%d ' % x: ' cuda:%d ' % y for x, y in device_pairs}
193
+ map_location = {' cuda:%d ' % 0 : ' cuda:%d ' % rank}
198
194
ddp_model.load_state_dict(
199
195
torch.load(CHECKPOINT_PATH , map_location = map_location))
200
196
201
197
optimizer.zero_grad()
202
198
outputs = ddp_model(torch.randn(20 , 10 ))
203
- labels = torch.randn(20 , 5 ).to(device_ids[ 0 ] )
199
+ labels = torch.randn(20 , 5 ).to(rank )
204
200
loss_fn = nn.MSELoss()
205
201
loss_fn(outputs, labels).backward()
206
202
optimizer.step()
@@ -217,13 +213,8 @@ same machine using the same set of devices.
217
213
Combine DDP with Model Parallelism
218
214
----------------------------------
219
215
220
- DDP also works with multi-GPU models, but replications within a process are not
221
- supported. You need to create one process per module replica, which usually
222
- leads to better performance compared to multiple replicas per process. DDP
223
- wrapping multi-GPU models is especially helpful when training large models with
224
- a huge amount of data. When using this feature, the multi-GPU model needs to be
225
- carefully implemented to avoid hard-coded devices, because different model
226
- replicas will be placed to different devices.
216
+ DDP also works with multi-GPU models. DDP wrapping multi-GPU models is especially
217
+ helpful when training large models with a huge amount of data.
227
218
228
219
.. code :: python
229
220
@@ -249,6 +240,7 @@ either the application or the model ``forward()`` method.
249
240
.. code :: python
250
241
251
242
def demo_model_parallel (rank , world_size ):
243
+ print (f " Running DDP with model parallel example on rank { rank} . " )
252
244
setup(rank, world_size)
253
245
254
246
# setup mp_model and devices for this process
@@ -271,8 +263,10 @@ either the application or the model ``forward()`` method.
271
263
272
264
273
265
if __name__ == " __main__" :
274
- run_demo(demo_basic, 2 )
275
- run_demo(demo_checkpoint, 2 )
276
-
277
- if torch.cuda.device_count() >= 8 :
278
- run_demo(demo_model_parallel, 4 )
266
+ n_gpus = torch.cuda.device_count()
267
+ if n_gpus < 8 :
268
+ print (f " Requires at least 8 GPUs to run, but got { n_gpus} . " )
269
+ else :
270
+ run_demo(demo_basic, 8 )
271
+ run_demo(demo_checkpoint, 8 )
272
+ run_demo(demo_model_parallel, 4 )
0 commit comments