Skip to content

Commit 6bf2d16

Browse files
committed
Adding an Overview Page for PyTorch Distributed
1 parent 2f3ab79 commit 6bf2d16

File tree

3 files changed

+202
-0
lines changed

3 files changed

+202
-0
lines changed
Loading

beginner_source/dist_overview.rst

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
PyTorch Distributed Overview
2+
============================
3+
**Author**: `Shen Li <https://mrshenli.github.io/>`_
4+
5+
6+
This is the overview page for the ``torch.distributed`` package. As there are
7+
more and more documents, examples and tutorials added at different locations,
8+
it becomes unclear which document or tutorial to consult for a specific problem
9+
or what is the best order to read these contents. The goal of this page is to
10+
address this problem by categorizing documents into different topics and briefly
11+
describe each of them. If this is your first time building distributed training
12+
applications using PyTorch, it is recommended to use this document to navigate
13+
to the technology that can best serve your use case.
14+
15+
16+
Introduction
17+
------------
18+
19+
As of PyTorch v1.6.0, features in ``torch.distributed`` can be categorized into
20+
three main components:
21+
22+
* `Distributed Data-Parallel Training <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`__
23+
(DDP) is a widely adopted single-program multiple-data training paradigm. With
24+
DDP, the model is replicated on every process, and every model replica will be
25+
fed with a different set of input data samples. DDP takes care of gradient
26+
communications to keep model replicas synchronized and overlaps it with the
27+
gradient computations to speed up training.
28+
* `RPC-Based Distributed Training <https://pytorch.org/docs/master/rpc.html>`__
29+
(RPC) is developed to support general training structures that cannot fit into
30+
data-parallel training, such as distributed pipeline parallelism, parameter
31+
server paradigm, and combination of DDP with other training paradigms. It
32+
helps manage remote object lifetime and extend autograd engine to beyond
33+
machine boundaries.
34+
* `Collective Communication <https://pytorch.org/docs/stable/distributed.html>`__
35+
(c10d) library support sending tensors across processes within a group. It
36+
offers both collective communication APIs (e.g.,
37+
`all_reduce <https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce>`__
38+
and `all_gather <https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather>`__)
39+
and P2P communication APIs (e.g.,
40+
`send <https://pytorch.org/docs/stable/distributed.html#torch.distributed.send>`__
41+
and `isend <https://pytorch.org/docs/stable/distributed.html#torch.distributed.isend>`__).
42+
DDP and RPC (`ProcessGroup Backend <https://pytorch.org/docs/master/rpc.html#process-group-backend>`__)
43+
are built on c10d as of v1.6.0, where the former uses collective communications
44+
and the latter uses P2P communications. Usually, developers do not need to
45+
directly use this raw communication API, as DDP and RPC features above can serve
46+
many distributed training scenarios. However, there are use cases where this API
47+
is still helpful. One example would be distributed parameter averaging, where
48+
applications would like to compute the average values of all model parameters
49+
after the backward pass instead of using DDP to communicate gradients. This can
50+
decouple communications from computations and allow finer-grain control over
51+
what to communicate, but on the other hand, it also gives up the performance
52+
optimizations offered by DDP. The
53+
`Writing Distributed Applications with PyTorch <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__
54+
shows examples of using c10d communication APIs.
55+
56+
57+
Most of the existing documents are written for either DDP or RPC, the remainder
58+
of this page will elaborate materials for these two components.
59+
60+
61+
Data Parallel Training
62+
----------------------
63+
64+
PyTorch provides several options for data-parallel training. For applications
65+
that gradually grow from simple to complex and from prototype to production, the
66+
common development trajectory would be:
67+
68+
1. Use single-device training, if the data and model can fit in one GPU, and the
69+
training speed is not a concern.
70+
2. Use single-machine multi-GPU
71+
`DataParallel <https://pytorch.org/docs/master/generated/torch.nn.DataParallel.html>`__,
72+
if there are multiple GPUs on the server, and you would like to speed up
73+
training with the minimum code change.
74+
3. Use single-machine multi-GPU
75+
`DistributedDataParallel <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`__,
76+
if you would like to further speed up training and are willing to write a
77+
little more code to set it up.
78+
4. Use multi-machine `DistributedDataParallel <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`__
79+
and the `launching script <https://github.com/pytorch/examples/blob/master/distributed/ddp/README.md>`__,
80+
if the application needs to scale across machine boundaries.
81+
5. Use `torchelastic <https://pytorch.org/elastic>`__ to launch distributed
82+
training, if errors (e.g., OOM) are expected or if the resources can join and
83+
leave dynamically during the training.
84+
85+
86+
``torch.nn.DataParallel``
87+
~~~~~~~~~~~~~~~~~~~~~~~~~
88+
89+
The `DataParallel <https://pytorch.org/docs/master/generated/torch.nn.DataParallel.html>`__
90+
package enables single-machine multi-GPU parallelism with the lowest coding
91+
hurdle. It only requires a one-line change to the application code. The tutorial
92+
`Optional: Data Parallelism <https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html>`__
93+
shows an example. The caveat is that, although ``DataParallel`` is very easy to
94+
use, it usually does not offer the best performance. This is because the
95+
implementation of ``DataParallel`` replicates the model in every forward pass,
96+
and its single-process multi-thread parallelism naturally suffers from GIL
97+
contentions. To get better performance, please consider using
98+
`DistributedDataParallel <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`__.
99+
100+
101+
``torch.nn.parallel.DistributedDataParallel``
102+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
103+
104+
Compared to `DataParallel <https://pytorch.org/docs/master/generated/torch.nn.DataParallel.html>`__,
105+
`DistributedDataParallel <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html>`__
106+
requires one more step to set up, i.e., calling
107+
`init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`__.
108+
DDP uses multi-process parallelism, and hence there is no GIL contention across
109+
model replicas. Moreover, the model is broadcast at DDP construction time instead
110+
of in every forward pass, which also helps to speed up training. DDP is shipped
111+
with several performance optimization technologies. For a more in-depth
112+
explanation, please refer to this
113+
`DDP paper <https://arxiv.org/abs/2006.15704>`__ (VLDB'20).
114+
115+
116+
DDP materials are listed below:
117+
118+
1. `DDP notes <https://pytorch.org/docs/stable/notes/ddp.html>`__
119+
offer a starter example and some brief descriptions of its design and
120+
implementation. If this is your first time using DDP, please start from this
121+
document.
122+
2. `Getting Started with Distributed Data Parallel <../intermediate/ddp_tutorial.html>`__
123+
explains some common problems with DDP training, including unbalanced
124+
workload, checkpointing, and multi-device models. Note that, DDP can be
125+
easily combined with single-machine multi-device model parallelism which is
126+
described in the
127+
`Single-Machine Model Parallel Best Practices <../intermediate/model_parallel_tutorial.html>`__
128+
tutorial.
129+
3. The `Launching and configuring distributed data parallel applications <https://github.com/pytorch/examples/blob/master/distributed/ddp/README.md>`__
130+
document shows how to use the DDP launching script.
131+
4. `PyTorch Distributed Trainer with Amazon AWS <aws_distributed_training_tutorial.html>`__
132+
demonstrates how to use DDP on AWS.
133+
134+
TorchElastic
135+
~~~~~~~~~~~~
136+
137+
With the growth of the application complexity and scale, failure recovery
138+
becomes an imperative requirement. Sometimes, it is inevitable to hit errors
139+
like OOM when using DDP, but DDP itself cannot recover from those errors nor
140+
does basic ``try-except`` block work. This is because DDP requires all processes
141+
to operate in a closely synchronized manner and all ``AllReduce`` communications
142+
launched in different processes must match. If one of the processes in the group
143+
throws an OOM exception, it is likely to lead to desynchronization (mismatched
144+
``AllReduce`` operations) which would then cause a crash or hang. If you expect
145+
failures to occur during training or if resources might leave and join
146+
dynamically, please launch distributed data-parallel training using
147+
`torchelastic <https://pytorch.org/elastic>`__.
148+
149+
150+
General Distributed Training
151+
----------------------------
152+
153+
Many training paradigms do not fit into data parallelism, e.g.,
154+
parameter server paradigm, distributed pipeline parallelism, reinforcement
155+
learning applications with multiple observers or agents, etc. The
156+
`torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`__ aims at
157+
supporting general distributed training scenarios.
158+
159+
The `torch.distributed.rpc <https://pytorch.org/docs/master/rpc.html>`__ package
160+
has four main pillars:
161+
162+
* `RPC <https://pytorch.org/docs/master/rpc.html#rpc>`__ supports running
163+
a given function on a remote worker.
164+
* `RRef <https://pytorch.org/docs/master/rpc.html#rref>`__ helps to manage the
165+
lifetime of a remote object. The reference counting protocol is presented in the
166+
`RRef notes <https://pytorch.org/docs/master/rpc/rref.html#remote-reference-protocol>`__.
167+
* `Distributed Autograd <https://pytorch.org/docs/master/rpc.html#distributed-autograd-framework>`__
168+
extends the autograd engine beyond machine boundaries. Please refer to
169+
`Distributed Autograd Design <https://pytorch.org/docs/master/rpc/distributed_autograd.html#distributed-autograd-design>`__
170+
for more details.
171+
* `Distributed Optimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`__
172+
that automatically reaches out to all participating workers to update
173+
parameters using gradients computed by the distributed autograd engine.
174+
175+
RPC Tutorials are listed below:
176+
177+
1. The `Getting Started with Distributed RPC Framework <../intermediate/rpc_tutorial.html>`__
178+
tutorial first uses a simple Reinforcement Learning (RL) example to
179+
demonstrate RPC and RRef. Then, it applies a basic distributed model
180+
parallelism to an RNN example to show how to use distributed autograd and
181+
distributed optimizer.
182+
2. The `Implementing a Parameter Server Using Distributed RPC Framework <../intermediate/rpc_param_server_tutorial.html>`__
183+
tutorial borrows the spirit of
184+
`HogWild! training <https://people.eecs.berkeley.edu/~brecht/papers/hogwildTR.pdf>`__
185+
and applies it to an asynchronous parameter server (PS) training application.
186+
3. The `Distributed Pipeliine Parallelism Using RPC <../intermediate/dist_pipeline_parallel_tutorial.html>`__
187+
tutorial extends the single-machine pipeline parallel example (presented in
188+
`Single-Machine Model Parallel Best Practices <../intermediate/model_parallel_tutorial.html>`__)
189+
to a distributed environment and shows how to implement it using RPC.
190+
4. The `Implementing Batch RPC Processing Using Asynchronous Executions <../intermediate/rpc_async_execution.html>`__
191+
tutorial demonstrates how to implement RPC batch processing using the
192+
`@rpc.functions.async_execution <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.functions.async_execution>`__
193+
decorator, which can help speed up inference and training. It uses similar
194+
RL and PS examples employed in the above tutorials 1 and 2.

index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,13 @@ Welcome to PyTorch Tutorials
297297

298298
.. Parallel-and-Distributed-Training
299299
300+
.. customcarditem::
301+
:header: PyTorch Distributed Overview
302+
:card_description: Have a high-level overview of all concepts and features in the distributed package. Use this to find the distributed training technology can best serve your application.
303+
:image: _static/img/thumbnails/cropped/PyTorch-Distributed-Overview.png
304+
:link: beginner/dist_overview.html
305+
:tags: Parallel-and-Distributed-Training
306+
300307
.. customcarditem::
301308
:header: Single-Machine Model Parallel Best Practices
302309
:card_description: Learn how to implement model parallel, a distributed training technique which splits a single model onto different GPUs, rather than replicating the entire model on each GPU
@@ -513,6 +520,7 @@ Additional Resources
513520
:hidden:
514521
:caption: Parallel and Distributed Training
515522

523+
beginner/dist_overview
516524
intermediate/model_parallel_tutorial
517525
intermediate/ddp_tutorial
518526
intermediate/dist_tuto

0 commit comments

Comments
 (0)