@@ -2,46 +2,54 @@ Getting Started with Distributed Data Parallel
2
2
=================================================
3
3
**Author **: `Shen Li <https://mrshenli.github.io/ >`_
4
4
5
- `DistributedDataParallel <https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html >`__
6
- (DDP) implements data parallelism at the module level. It uses communication
7
- collectives in the `torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html >`__
8
- package to synchronize gradients, parameters, and buffers. Parallelism is
9
- available both within a process and across processes. Within a process, DDP
10
- replicates the input module to devices specified in ``device_ids ``, scatters
11
- inputs along the batch dimension accordingly, and gathers outputs to the
12
- ``output_device ``, which is similar to
13
- `DataParallel <https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html >`__.
14
- Across processes, DDP inserts necessary parameter synchronizations in forward
15
- passes and gradient synchronizations in backward passes. It is up to users to
16
- map processes to available resources, as long as processes do not share GPU
17
- devices. The recommended (usually fastest) approach is to create a process for
18
- every module replica, i.e., no module replication within a process. The code in
19
- this tutorial runs on an 8-GPU server, but it can be easily generalized to
20
- other environments.
5
+ `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel >`__
6
+ (DDP) implements data parallelism at the module level which can run across
7
+ multiple machines. Applications using DDP should spawn multiple processes and
8
+ create a single DDP instance per process. DDP uses collective communications in the
9
+ `torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html >`__
10
+ package to synchronize gradients and buffers. More specifically, DDP registers
11
+ an autograd hook for each parameter given by ``model.parameters() `` and the
12
+ hook will fire when the corresponding gradient is computed in the backward
13
+ pass. Then DDP uses that signal to trigger gradient synchronization across
14
+ processes. Please refer to
15
+ `DDP design note <https://pytorch.org/docs/master/notes/ddp.html >`__ for more details.
16
+
17
+
18
+ The recommended way to use DDP is to spawn one process for each model replica,
19
+ where a model replica can span multiple devices. DDP processes can be
20
+ placed on the same machine or across machines, but GPU devices cannot be
21
+ shared across processes. This tutorial starts from a basic DDP use case and
22
+ then demonstrates more advanced use cases including checkpointing models and
23
+ combining DDP with model parallel.
24
+
25
+
26
+ .. note ::
27
+ The code in this tutorial runs on an 8-GPU server, but it can be easily
28
+ generalized to other environments.
29
+
21
30
22
31
Comparison between ``DataParallel `` and ``DistributedDataParallel ``
23
32
-------------------------------------------------------------------
24
33
25
34
Before we dive in, let's clarify why, despite the added complexity, you would
26
35
consider using ``DistributedDataParallel `` over ``DataParallel ``:
27
36
28
- - First, recall from the
37
+ - First, ``DataParallel `` is single-process, multi-thread, and only works on a
38
+ single machine, while ``DistributedDataParallel `` is multi-process and works
39
+ for both single- and multi- machine training. ``DataParallel `` is usually
40
+ slower than ``DistributedDataParallel `` even on a single machine due to GIL
41
+ contention across threads, per-iteration replicated model, and additional
42
+ overhead introduced by scattering inputs and gathering outputs.
43
+ - Recall from the
29
44
`prior tutorial <https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html >`__
30
45
that if your model is too large to fit on a single GPU, you must use **model parallel **
31
46
to split it across multiple GPUs. ``DistributedDataParallel `` works with
32
- **model parallel **; ``DataParallel `` does not at this time.
33
- - ``DataParallel `` is single-process, multi-thread, and only works on a single
34
- machine, while ``DistributedDataParallel `` is multi-process and works for both
35
- single- and multi- machine training. Thus, even for single machine training,
36
- where your **data ** is small enough to fit on a single machine, ``DistributedDataParallel ``
37
- is expected to be faster than ``DataParallel ``. ``DistributedDataParallel ``
38
- also replicates models upfront instead of on each iteration and gets Global
39
- Interpreter Lock out of the way.
40
- - If both your data is too large to fit on one machine **and ** your
41
- model is too large to fit on a single GPU, you can combine model parallel
42
- (splitting a single model across multiple GPUs) with ``DistributedDataParallel ``.
43
- Under this regime, each ``DistributedDataParallel `` process could use model parallel,
44
- and all processes collectively would use data parallel.
47
+ **model parallel **; ``DataParallel `` does not at this time. When DDP is combined
48
+ with model parallel, each DDP process would use model parallel, and all processes
49
+ collectively would use data parallel.
50
+ - If your model needs to span multiple machines or if your use case does not fit
51
+ into data parallelism paradigm, please see `the RPC API <https://pytorch.org/docs/stable/rpc.html >`__
52
+ for more generic distributed training support.
45
53
46
54
Basic Use Case
47
55
--------------
@@ -70,18 +78,14 @@ be found in
70
78
# initialize the process group
71
79
dist.init_process_group(" gloo" , rank = rank, world_size = world_size)
72
80
73
- # Explicitly setting seed to make sure that models created in two processes
74
- # start from same random weights and biases.
75
- torch.manual_seed(42 )
76
-
77
81
78
82
def cleanup ():
79
83
dist.destroy_process_group()
80
84
81
85
Now, let's create a toy module, wrap it with DDP, and feed it with some dummy
82
- input data. Please note, if training starts from random parameters, you might
83
- want to make sure that all DDP processes use the same initial values.
84
- Otherwise, global gradient synchronizes will not make sense .
86
+ input data. Please note, as DDP broadcasts model states from rank 0 process to
87
+ all other processes in the DDP constructor, you don't need to worry about
88
+ different DDP processes start from different model parameter initial values .
85
89
86
90
.. code :: python
87
91
@@ -97,24 +101,19 @@ Otherwise, global gradient synchronizes will not make sense.
97
101
98
102
99
103
def demo_basic (rank , world_size ):
104
+ print (f " Running basic DDP example on rank { rank} . " )
100
105
setup(rank, world_size)
101
106
102
- # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
103
- # rank 2 uses GPUs [4, 5, 6, 7].
104
- n = torch.cuda.device_count() // world_size
105
- device_ids = list (range (rank * n, (rank + 1 ) * n))
106
-
107
- # create model and move it to device_ids[0]
108
- model = ToyModel().to(device_ids[0 ])
109
- # output_device defaults to device_ids[0]
110
- ddp_model = DDP(model, device_ids = device_ids)
107
+ # create model and move it to GPU with id rank
108
+ model = ToyModel().to(rank)
109
+ ddp_model = DDP(model, device_ids = [rank])
111
110
112
111
loss_fn = nn.MSELoss()
113
112
optimizer = optim.SGD(ddp_model.parameters(), lr = 0.001 )
114
113
115
114
optimizer.zero_grad()
116
115
outputs = ddp_model(torch.randn(20 , 10 ))
117
- labels = torch.randn(20 , 5 ).to(device_ids[ 0 ] )
116
+ labels = torch.randn(20 , 5 ).to(rank )
118
117
loss_fn(outputs, labels).backward()
119
118
optimizer.step()
120
119
@@ -127,23 +126,27 @@ Otherwise, global gradient synchronizes will not make sense.
127
126
nprocs = world_size,
128
127
join = True )
129
128
130
- As you can see, DDP wraps lower level distributed communication details, and
131
- provides a clean API as if it is a local model. For basic use cases, DDP only
129
+ As you can see, DDP wraps lower-level distributed communication details and
130
+ provides a clean API as if it is a local model. Gradient synchronization
131
+ communications take place during the backward pass and overlap with the
132
+ backward computation. When the ``backward() `` returns, ``param.grad `` already
133
+ contains the synchronized gradient tensor. For basic use cases, DDP only
132
134
requires a few more LoCs to set up the process group. When applying DDP to more
133
- advanced use cases, there are some caveats that require cautions .
135
+ advanced use cases, some caveats require caution .
134
136
135
137
Skewed Processing Speeds
136
138
------------------------
137
139
138
- In DDP, constructor, forward method, and differentiation of the outputs are
139
- distributed synchronization points. Different processes are expected to reach
140
- synchronization points in the same order and enter each synchronization point
141
- at roughly the same time. Otherwise, fast processes might arrive early and
142
- timeout on waiting for stragglers. Hence, users are responsible for balancing
143
- workloads distributions across processes. Sometimes, skewed processing speeds
144
- are inevitable due to, e.g., network delays, resource contentions,
145
- unpredictable workload spikes. To avoid timeouts in these situations, make
146
- sure that you pass a sufficiently large ``timeout `` value when calling
140
+ In DDP, the constructor, the forward pass, and the backward pass are
141
+ distributed synchronization points. Different processes are expected to launch
142
+ the same number of synchronizations and reach these synchronization points in
143
+ the same order and enter each synchronization point at roughly the same time.
144
+ Otherwise, fast processes might arrive early and timeout on waiting for
145
+ stragglers. Hence, users are responsible for balancing workloads distributions
146
+ across processes. Sometimes, skewed processing speeds are inevitable due to,
147
+ e.g., network delays, resource contentions, unpredictable workload spikes. To
148
+ avoid timeouts in these situations, make sure that you pass a sufficiently
149
+ large ``timeout `` value when calling
147
150
`init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group >`__.
148
151
149
152
Save and Load Checkpoints
@@ -156,27 +159,23 @@ for more details. When using DDP, one optimization is to save the model in
156
159
only one process and then load it to all processes, reducing write overhead.
157
160
This is correct because all processes start from the same parameters and
158
161
gradients are synchronized in backward passes, and hence optimizers should keep
159
- setting parameters to same values. If you use this optimization, make sure all
162
+ setting parameters to the same values. If you use this optimization, make sure all
160
163
processes do not start loading before the saving is finished. Besides, when
161
164
loading the module, you need to provide an appropriate ``map_location ``
162
165
argument to prevent a process to step into others' devices. If ``map_location ``
163
166
is missing, ``torch.load `` will first load the module to CPU and then copy each
164
167
parameter to where it was saved, which would result in all processes on the
165
- same machine using the same set of devices.
168
+ same machine using the same set of devices. For more advanced failure recovery
169
+ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elastic >`__.
166
170
167
171
.. code :: python
168
172
169
173
def demo_checkpoint (rank , world_size ):
174
+ print (f " Running DDP checkpoint example on rank { rank} . " )
170
175
setup(rank, world_size)
171
176
172
- # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
173
- # rank 2 uses GPUs [4, 5, 6, 7].
174
- n = torch.cuda.device_count() // world_size
175
- device_ids = list (range (rank * n, (rank + 1 ) * n))
176
-
177
- model = ToyModel().to(device_ids[0 ])
178
- # output_device defaults to device_ids[0]
179
- ddp_model = DDP(model, device_ids = device_ids)
177
+ model = ToyModel().to(rank)
178
+ ddp_model = DDP(model, device_ids = [rank])
180
179
181
180
loss_fn = nn.MSELoss()
182
181
optimizer = optim.SGD(ddp_model.parameters(), lr = 0.001 )
@@ -192,15 +191,13 @@ same machine using the same set of devices.
192
191
# 0 saves it.
193
192
dist.barrier()
194
193
# configure map_location properly
195
- rank0_devices = [x - rank * len (device_ids) for x in device_ids]
196
- device_pairs = zip (rank0_devices, device_ids)
197
- map_location = {' cuda:%d ' % x: ' cuda:%d ' % y for x, y in device_pairs}
194
+ map_location = {' cuda:%d ' % 0 : ' cuda:%d ' % rank}
198
195
ddp_model.load_state_dict(
199
196
torch.load(CHECKPOINT_PATH , map_location = map_location))
200
197
201
198
optimizer.zero_grad()
202
199
outputs = ddp_model(torch.randn(20 , 10 ))
203
- labels = torch.randn(20 , 5 ).to(device_ids[ 0 ] )
200
+ labels = torch.randn(20 , 5 ).to(rank )
204
201
loss_fn = nn.MSELoss()
205
202
loss_fn(outputs, labels).backward()
206
203
optimizer.step()
@@ -217,13 +214,8 @@ same machine using the same set of devices.
217
214
Combine DDP with Model Parallelism
218
215
----------------------------------
219
216
220
- DDP also works with multi-GPU models, but replications within a process are not
221
- supported. You need to create one process per module replica, which usually
222
- leads to better performance compared to multiple replicas per process. DDP
223
- wrapping multi-GPU models is especially helpful when training large models with
224
- a huge amount of data. When using this feature, the multi-GPU model needs to be
225
- carefully implemented to avoid hard-coded devices, because different model
226
- replicas will be placed to different devices.
217
+ DDP also works with multi-GPU models. DDP wrapping multi-GPU models is especially
218
+ helpful when training large models with a huge amount of data.
227
219
228
220
.. code :: python
229
221
@@ -249,6 +241,7 @@ either the application or the model ``forward()`` method.
249
241
.. code :: python
250
242
251
243
def demo_model_parallel (rank , world_size ):
244
+ print (f " Running DDP with model parallel example on rank { rank} . " )
252
245
setup(rank, world_size)
253
246
254
247
# setup mp_model and devices for this process
@@ -271,8 +264,10 @@ either the application or the model ``forward()`` method.
271
264
272
265
273
266
if __name__ == " __main__" :
274
- run_demo(demo_basic, 2 )
275
- run_demo(demo_checkpoint, 2 )
276
-
277
- if torch.cuda.device_count() >= 8 :
278
- run_demo(demo_model_parallel, 4 )
267
+ n_gpus = torch.cuda.device_count()
268
+ if n_gpus < 8 :
269
+ print (f " Requires at least 8 GPUs to run, but got { n_gpus} . " )
270
+ else :
271
+ run_demo(demo_basic, 8 )
272
+ run_demo(demo_checkpoint, 8 )
273
+ run_demo(demo_model_parallel, 4 )
0 commit comments