Skip to content

Commit 88f9ac2

Browse files
authored
Merge pull request #678 from pietern/dist-tuto-update
Updates to distributed tutorial
2 parents e0791cf + 8434f3d commit 88f9ac2

File tree

1 file changed

+48
-99
lines changed

1 file changed

+48
-99
lines changed

intermediate_source/dist_tuto.rst

Lines changed: 48 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ Writing Distributed Applications with PyTorch
22
=============================================
33
**Author**: `Séb Arnold <https://seba1511.com>`_
44

5-
In this short tutorial, we will be going over the distributed package of PyTorch. We'll see how to set up the distributed setting, use the different communication strategies, and go over some the internals of the package.
5+
In this short tutorial, we will be going over the distributed package
6+
of PyTorch. We'll see how to set up the distributed setting, use the
7+
different communication strategies, and go over some the internals of
8+
the package.
69

710
Setup
811
-----
@@ -17,7 +20,7 @@ Setup
1720
The distributed package included in PyTorch (i.e.,
1821
``torch.distributed``) enables researchers and practitioners to easily
1922
parallelize their computations across processes and clusters of
20-
machines. To do so, it leverages the messaging passing semantics
23+
machines. To do so, it leverages messaging passing semantics
2124
allowing each process to communicate data to any of the other processes.
2225
As opposed to the multiprocessing (``torch.multiprocessing``) package,
2326
processes can use different communication backends and are not
@@ -45,7 +48,7 @@ the following template.
4548
""" Distributed function to be implemented later. """
4649
pass
4750
48-
def init_processes(rank, size, fn, backend='tcp'):
51+
def init_process(rank, size, fn, backend='gloo'):
4952
""" Initialize the distributed environment. """
5053
os.environ['MASTER_ADDR'] = '127.0.0.1'
5154
os.environ['MASTER_PORT'] = '29500'
@@ -57,7 +60,7 @@ the following template.
5760
size = 2
5861
processes = []
5962
for rank in range(size):
60-
p = Process(target=init_processes, args=(rank, size, run))
63+
p = Process(target=init_process, args=(rank, size, run))
6164
p.start()
6265
processes.append(p)
6366
@@ -69,12 +72,10 @@ distributed environment, initialize the process group
6972
(``dist.init_process_group``), and finally execute the given ``run``
7073
function.
7174

72-
Let's have a look at the ``init_processes`` function. It ensures that
75+
Let's have a look at the ``init_process`` function. It ensures that
7376
every process will be able to coordinate through a master, using the
74-
same ip address and port. Note that we used the TCP backend, but we
75-
could have used
76-
`MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`__ or
77-
`Gloo <https://github.com/facebookincubator/gloo>`__ instead. (c.f.
77+
same ip address and port. Note that we used the ``gloo`` backend but
78+
other backends are available. (c.f.
7879
`Section 5.1 <#communication-backends>`__) We will go over the magic
7980
happening in ``dist.init_process_group`` at the end of this tutorial,
8081
but it essentially allows processes to communicate with each other by
@@ -119,7 +120,7 @@ order to store the data it will receive.
119120
Also notice that ``send``/``recv`` are **blocking**: both processes stop
120121
until the communication is completed. On the other hand immediates are
121122
**non-blocking**; the script continues its execution and the methods
122-
return a ``DistributedRequest`` object upon which we can choose to
123+
return a ``Work`` object upon which we can choose to
123124
``wait()``.
124125

125126
.. code:: python
@@ -257,7 +258,7 @@ something useful with it. Our goal will be to replicate the
257258
functionality of
258259
`DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__.
259260
Of course, this will be a didactic example and in a real-world
260-
situtation you should use the official, well-tested and well-optimized
261+
situation you should use the official, well-tested and well-optimized
261262
version linked above.
262263

263264
Quite simply we want to implement a distributed version of stochastic
@@ -443,43 +444,27 @@ Communication Backends
443444

444445
One of the most elegant aspects of ``torch.distributed`` is its ability
445446
to abstract and build on top of different backends. As mentioned before,
446-
there are currently three backends implemented in PyTorch: TCP, MPI, and
447-
Gloo. They each have different specifications and tradeoffs, depending
448-
on the desired use-case. A comparative table of supported functions can
447+
there are currently three backends implemented in PyTorch: Gloo, NCCL, and
448+
MPI. They each have different specifications and tradeoffs, depending
449+
on the desired use case. A comparative table of supported functions can
449450
be found
450-
`here <https://pytorch.org/docs/stable/distributed.html#module-torch.distributed>`__. Note that a fourth backend, NCCL, has been added since the creation of this tutorial. See `this section <https://pytorch.org/docs/stable/distributed.html#multi-gpu-collective-functions>`__ of the ``torch.distributed`` docs for more information about its use and value.
451-
452-
**TCP Backend**
453-
454-
So far we have made extensive usage of the TCP backend. It is quite
455-
handy as a development platform, as it is guaranteed to work on most
456-
machines and operating systems. It also supports all point-to-point and
457-
collective functions on CPU. However, there is no support for GPUs and
458-
its communication routines are not as optimized as the MPI one.
451+
`here <https://pytorch.org/docs/stable/distributed.html#module-torch.distributed>`__.
459452

460453
**Gloo Backend**
461454

462-
The `Gloo backend <https://github.com/facebookincubator/gloo>`__
463-
provides an optimized implementation of *collective* communication
464-
procedures, both for CPUs and GPUs. It particularly shines on GPUs as it
465-
can perform communication without transferring data to the CPU's memory
466-
using `GPUDirect <https://developer.nvidia.com/gpudirect>`__. It is also
467-
capable of using `NCCL <https://github.com/NVIDIA/nccl>`__ to perform
468-
fast intra-node communication and implements its `own
469-
algorithms <https://github.com/facebookincubator/gloo/blob/master/docs/algorithms.md>`__
470-
for inter-node routines.
471-
472-
Since version 0.2.0, the Gloo backend is automatically included with the
473-
pre-compiled binaries of PyTorch. As you have surely noticed, our
455+
So far we have made extensive usage of the `Gloo backend <https://github.com/facebookincubator/gloo>`__.
456+
It is quite handy as a development platform, as it is included in
457+
the pre-compiled PyTorch binaries and works on both Linux (since 0.2)
458+
and macOS (since 1.3). It supports all point-to-point and collective
459+
operations on CPU, and all collective operations on GPU. The
460+
implementation of the collective operations for CUDA tensors is not as
461+
optimized as the ones provided by the NCCL backend.
462+
463+
As you have surely noticed, our
474464
distributed SGD example does not work if you put ``model`` on the GPU.
475-
Let's fix it by first replacing ``backend='gloo'`` in
476-
``init_processes(rank, size, fn, backend='tcp')``. At this point, the
477-
script will still run on CPU but uses the Gloo backend behind the
478-
scenes. In order to use multiple GPUs, let us also do the following
465+
In order to use multiple GPUs, let us also do the following
479466
modifications:
480467

481-
0. ``init_processes(rank, size, fn, backend='tcp')`` :math:`\rightarrow`
482-
``init_processes(rank, size, fn, backend='gloo')``
483468
1. Use ``device = torch.device("cuda:{}".format(rank))``
484469
2. ``model = Net()`` :math:`\rightarrow` ``model = Net().to(device)``
485470
3. Use ``data, target = data.to(device), target.to(device)``
@@ -526,7 +511,7 @@ In order to test our newly installed backend, a few modifications are
526511
required.
527512

528513
1. Replace the content under ``if __name__ == '__main__':`` with
529-
``init_processes(0, 0, run, backend='mpi')``.
514+
``init_process(0, 0, run, backend='mpi')``.
530515
2. Run ``mpirun -n 4 python myscript.py``.
531516

532517
The reason for these changes is that MPI needs to create its own
@@ -541,6 +526,14 @@ more <https://www.open-mpi.org/faq/?category=running#mpirun-hostfile>`__)
541526
Doing so, you should obtain the same familiar output as with the other
542527
communication backends.
543528

529+
**NCCL Backend**
530+
531+
The `NCCL backend <https://github.com/nvidia/nccl>`__ provides an
532+
optimized implementation of collective operations against CUDA
533+
tensors. If you only use CUDA tensors for your collective operations,
534+
consider using this backend for the best in class performance. The
535+
NCCL backend is included in the pre-built binaries with CUDA support.
536+
544537
Initialization Methods
545538
~~~~~~~~~~~~~~~~~~~~~~
546539

@@ -554,33 +547,6 @@ naturally more suitable than the others. In addition to the following
554547
sections, you should also have a look at the `official
555548
documentation <https://pytorch.org/docs/stable/distributed.html#initialization>`__.
556549

557-
Before diving into the initialization methods, let's have a quick look
558-
at what happens behind ``init_process_group`` from the C/C++
559-
perspective.
560-
561-
1. First, the arguments are parsed and validated.
562-
2. The backend is resolved via the ``name2channel.at()`` function. A
563-
``Channel`` class is returned, and will be used to perform the data
564-
transmission.
565-
3. The GIL is dropped, and ``THDProcessGroupInit()`` is called. This
566-
instantiates the channel and adds the address of the master node.
567-
4. The process with rank 0 will execute the ``master`` procedure, while
568-
all other ranks will be ``workers``.
569-
5. The master
570-
571-
a. Creates sockets for all workers.
572-
b. Waits for all workers to connect.
573-
c. Sends them information about the location of the other processes.
574-
575-
6. Each worker
576-
577-
a. Creates a socket to the master.
578-
b. Sends their own location information.
579-
c. Receives information about the other workers.
580-
d. Opens a socket and handshakes with all other workers.
581-
582-
7. The initialization is done, and everyone is connected to everyone.
583-
584550
**Environment Variable**
585551

586552
We have been using the environment variable initialization method
@@ -606,44 +572,27 @@ that each process will open the file, write its information, and wait
606572
until everybody did so. After what all required information will be
607573
readily available to all processes. In order to avoid race conditions,
608574
the file system must support locking through
609-
`fcntl <http://man7.org/linux/man-pages/man2/fcntl.2.html>`__. Note that
610-
you can specify ranks manually or let the processes figure it out by
611-
themselves. Be defining a unique ``groupname`` per job you can use the
612-
same file path for multiple jobs and safely avoid collision.
575+
`fcntl <http://man7.org/linux/man-pages/man2/fcntl.2.html>`__.
613576

614577
.. code:: python
615578
616-
dist.init_process_group(init_method='file:///mnt/nfs/sharedfile', world_size=4,
617-
group_name='mygroup')
618-
619-
**TCP Init & Multicast**
620-
621-
Initializing via TCP can be achieved in two different ways:
622-
623-
1. By providing the IP address of the process with rank 0 and the world
624-
size.
625-
2. By providing *any* valid IP `multicast
626-
address <https://en.wikipedia.org/wiki/Multicast_address>`__ and the
627-
world size.
628-
629-
In the first case, all workers will be able to connect to the process
630-
with rank 0 and follow the procedure described above.
631-
632-
.. code:: python
579+
dist.init_process_group(
580+
init_method='file:///mnt/nfs/sharedfile',
581+
rank=args.rank,
582+
world_size=4)
633583
634-
dist.init_process_group(init_method='tcp://10.1.1.20:23456', rank=args.rank, world_size=4)
584+
**TCP**
635585

636-
In the second case, the multicast address specifies the group of nodes
637-
who might potentially be active and the coordination can be handled by
638-
allowing each process to have an initial handshake before following the
639-
above procedure. In addition TCP multicast initialization also supports
640-
a ``group_name`` argument (as with the shared file method) allowing
641-
multiple jobs to be scheduled on the same cluster.
586+
Initializing via TCP can be achieved by providing the IP address of the process with rank 0 and a reachable port number.
587+
Here, all workers will be able to connect to the process
588+
with rank 0 and exchange information on how to reach each other.
642589

643590
.. code:: python
644591
645-
dist.init_process_group(init_method='tcp://[ff15:1e18:5d4c:4cf0:d02d:b659:53ba:b0a7]:23456',
646-
world_size=4)
592+
dist.init_process_group(
593+
init_method='tcp://10.1.1.20:23456',
594+
rank=args.rank,
595+
world_size=4)
647596
648597
.. raw:: html
649598

0 commit comments

Comments
 (0)