Skip to content

Update DDP Tutorial to remove Single-Process Multi-Device Use Case #973

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 80 additions & 85 deletions intermediate_source/ddp_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,54 @@ Getting Started with Distributed Data Parallel
=================================================
**Author**: `Shen Li <https://mrshenli.github.io/>`_

`DistributedDataParallel <https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html>`__
(DDP) implements data parallelism at the module level. It uses communication
collectives in the `torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__
package to synchronize gradients, parameters, and buffers. Parallelism is
available both within a process and across processes. Within a process, DDP
replicates the input module to devices specified in ``device_ids``, scatters
inputs along the batch dimension accordingly, and gathers outputs to the
``output_device``, which is similar to
`DataParallel <https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html>`__.
Across processes, DDP inserts necessary parameter synchronizations in forward
passes and gradient synchronizations in backward passes. It is up to users to
map processes to available resources, as long as processes do not share GPU
devices. The recommended (usually fastest) approach is to create a process for
every module replica, i.e., no module replication within a process. The code in
this tutorial runs on an 8-GPU server, but it can be easily generalized to
other environments.
`DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__
(DDP) implements data parallelism at the module level which can run across
multiple machines. Applications using DDP should spawn multiple processes and
create a single DDP instance per process. DDP uses collective communications in the
`torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__
package to synchronize gradients and buffers. More specifically, DDP registers
an autograd hook for each parameter given by ``model.parameters()`` and the
hook will fire when the corresponding gradient is computed in the backward
pass. Then DDP uses that signal to trigger gradient synchronization across
processes. Please refer to
`DDP design note <https://pytorch.org/docs/master/notes/ddp.html>`__ for more details.


The recommended way to use DDP is to spawn one process for each model replica,
where a model replica can span multiple devices. DDP processes can be
placed on the same machine or across machines, but GPU devices cannot be
shared across processes. This tutorial starts from a basic DDP use case and
then demonstrates more advanced use cases including checkpointing models and
combining DDP with model parallel.


.. note::
The code in this tutorial runs on an 8-GPU server, but it can be easily
generalized to other environments.


Comparison between ``DataParallel`` and ``DistributedDataParallel``
-------------------------------------------------------------------

Before we dive in, let's clarify why, despite the added complexity, you would
consider using ``DistributedDataParallel`` over ``DataParallel``:

- First, recall from the
- First, ``DataParallel`` is single-process, multi-thread, and only works on a
single machine, while ``DistributedDataParallel`` is multi-process and works
for both single- and multi- machine training. ``DataParallel`` is usually
slower than ``DistributedDataParallel`` even on a single machine due to GIL
contention across threads, per-iteration replicated model, and additional
overhead introduced by scattering inputs and gathering outputs.
- Recall from the
`prior tutorial <https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html>`__
that if your model is too large to fit on a single GPU, you must use **model parallel**
to split it across multiple GPUs. ``DistributedDataParallel`` works with
**model parallel**; ``DataParallel`` does not at this time.
- ``DataParallel`` is single-process, multi-thread, and only works on a single
machine, while ``DistributedDataParallel`` is multi-process and works for both
single- and multi- machine training. Thus, even for single machine training,
where your **data** is small enough to fit on a single machine, ``DistributedDataParallel``
is expected to be faster than ``DataParallel``. ``DistributedDataParallel``
also replicates models upfront instead of on each iteration and gets Global
Interpreter Lock out of the way.
- If both your data is too large to fit on one machine **and** your
model is too large to fit on a single GPU, you can combine model parallel
(splitting a single model across multiple GPUs) with ``DistributedDataParallel``.
Under this regime, each ``DistributedDataParallel`` process could use model parallel,
and all processes collectively would use data parallel.
**model parallel**; ``DataParallel`` does not at this time. When DDP is combined
with model parallel, each DDP process would use model parallel, and all processes
collectively would use data parallel.
- If your model needs to span multiple machines or if your use case does not fit
into data parallelism paradigm, please see `the RPC API <https://pytorch.org/docs/stable/rpc.html>`__
for more generic distributed training support.

Basic Use Case
--------------
Expand Down Expand Up @@ -70,18 +78,14 @@ be found in
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)

# Explicitly setting seed to make sure that models created in two processes
# start from same random weights and biases.
torch.manual_seed(42)


def cleanup():
dist.destroy_process_group()

Now, let's create a toy module, wrap it with DDP, and feed it with some dummy
input data. Please note, if training starts from random parameters, you might
want to make sure that all DDP processes use the same initial values.
Otherwise, global gradient synchronizes will not make sense.
input data. Please note, as DDP broadcasts model states from rank 0 process to
all other processes in the DDP constructor, you don't need to worry about
different DDP processes start from different model parameter initial values.

.. code:: python

Expand All @@ -97,24 +101,19 @@ Otherwise, global gradient synchronizes will not make sense.


def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)

# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
# rank 2 uses GPUs [4, 5, 6, 7].
n = torch.cuda.device_count() // world_size
device_ids = list(range(rank * n, (rank + 1) * n))

# create model and move it to device_ids[0]
model = ToyModel().to(device_ids[0])
# output_device defaults to device_ids[0]
ddp_model = DDP(model, device_ids=device_ids)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(device_ids[0])
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()

Expand All @@ -127,23 +126,27 @@ Otherwise, global gradient synchronizes will not make sense.
nprocs=world_size,
join=True)

As you can see, DDP wraps lower level distributed communication details, and
provides a clean API as if it is a local model. For basic use cases, DDP only
As you can see, DDP wraps lower-level distributed communication details and
provides a clean API as if it is a local model. Gradient synchronization
communications take place during the backward pass and overlap with the
backward computation. When the ``backward()`` returns, ``param.grad`` already
contains the synchronized gradient tensor. For basic use cases, DDP only
requires a few more LoCs to set up the process group. When applying DDP to more
advanced use cases, there are some caveats that require cautions.
advanced use cases, some caveats require caution.

Skewed Processing Speeds
------------------------

In DDP, constructor, forward method, and differentiation of the outputs are
distributed synchronization points. Different processes are expected to reach
synchronization points in the same order and enter each synchronization point
at roughly the same time. Otherwise, fast processes might arrive early and
timeout on waiting for stragglers. Hence, users are responsible for balancing
workloads distributions across processes. Sometimes, skewed processing speeds
are inevitable due to, e.g., network delays, resource contentions,
unpredictable workload spikes. To avoid timeouts in these situations, make
sure that you pass a sufficiently large ``timeout`` value when calling
In DDP, the constructor, the forward pass, and the backward pass are
distributed synchronization points. Different processes are expected to launch
the same number of synchronizations and reach these synchronization points in
the same order and enter each synchronization point at roughly the same time.
Otherwise, fast processes might arrive early and timeout on waiting for
stragglers. Hence, users are responsible for balancing workloads distributions
across processes. Sometimes, skewed processing speeds are inevitable due to,
e.g., network delays, resource contentions, unpredictable workload spikes. To
avoid timeouts in these situations, make sure that you pass a sufficiently
large ``timeout`` value when calling
`init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`__.

Save and Load Checkpoints
Expand All @@ -156,27 +159,23 @@ for more details. When using DDP, one optimization is to save the model in
only one process and then load it to all processes, reducing write overhead.
This is correct because all processes start from the same parameters and
gradients are synchronized in backward passes, and hence optimizers should keep
setting parameters to same values. If you use this optimization, make sure all
setting parameters to the same values. If you use this optimization, make sure all
processes do not start loading before the saving is finished. Besides, when
loading the module, you need to provide an appropriate ``map_location``
argument to prevent a process to step into others' devices. If ``map_location``
is missing, ``torch.load`` will first load the module to CPU and then copy each
parameter to where it was saved, which would result in all processes on the
same machine using the same set of devices.
same machine using the same set of devices. For more advanced failure recovery
and elasticity support, please refer to `TorchElastic <https://pytorch.org/elastic>`__.

.. code:: python

def demo_checkpoint(rank, world_size):
print(f"Running DDP checkpoint example on rank {rank}.")
setup(rank, world_size)

# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
# rank 2 uses GPUs [4, 5, 6, 7].
n = torch.cuda.device_count() // world_size
device_ids = list(range(rank * n, (rank + 1) * n))

model = ToyModel().to(device_ids[0])
# output_device defaults to device_ids[0]
ddp_model = DDP(model, device_ids=device_ids)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
Expand All @@ -192,15 +191,13 @@ same machine using the same set of devices.
# 0 saves it.
dist.barrier()
# configure map_location properly
rank0_devices = [x - rank * len(device_ids) for x in device_ids]
device_pairs = zip(rank0_devices, device_ids)
map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location))

optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(device_ids[0])
labels = torch.randn(20, 5).to(rank)
loss_fn = nn.MSELoss()
loss_fn(outputs, labels).backward()
optimizer.step()
Expand All @@ -217,13 +214,8 @@ same machine using the same set of devices.
Combine DDP with Model Parallelism
----------------------------------

DDP also works with multi-GPU models, but replications within a process are not
supported. You need to create one process per module replica, which usually
leads to better performance compared to multiple replicas per process. DDP
wrapping multi-GPU models is especially helpful when training large models with
a huge amount of data. When using this feature, the multi-GPU model needs to be
carefully implemented to avoid hard-coded devices, because different model
replicas will be placed to different devices.
DDP also works with multi-GPU models. DDP wrapping multi-GPU models is especially
helpful when training large models with a huge amount of data.

.. code:: python

Expand All @@ -249,6 +241,7 @@ either the application or the model ``forward()`` method.
.. code:: python

def demo_model_parallel(rank, world_size):
print(f"Running DDP with model parallel example on rank {rank}.")
setup(rank, world_size)

# setup mp_model and devices for this process
Expand All @@ -271,8 +264,10 @@ either the application or the model ``forward()`` method.


if __name__ == "__main__":
run_demo(demo_basic, 2)
run_demo(demo_checkpoint, 2)

if torch.cuda.device_count() >= 8:
run_demo(demo_model_parallel, 4)
n_gpus = torch.cuda.device_count()
if n_gpus < 8:
print(f"Requires at least 8 GPUs to run, but got {n_gpus}.")
else:
run_demo(demo_basic, 8)
run_demo(demo_checkpoint, 8)
run_demo(demo_model_parallel, 4)