diff --git a/intermediate_source/dist_tuto.rst b/intermediate_source/dist_tuto.rst index 36d07f18849..3a76bb1dd3c 100644 --- a/intermediate_source/dist_tuto.rst +++ b/intermediate_source/dist_tuto.rst @@ -2,7 +2,10 @@ Writing Distributed Applications with PyTorch ============================================= **Author**: `Séb Arnold `_ -In this short tutorial, we will be going over the distributed package of PyTorch. We'll see how to set up the distributed setting, use the different communication strategies, and go over some the internals of the package. +In this short tutorial, we will be going over the distributed package +of PyTorch. We'll see how to set up the distributed setting, use the +different communication strategies, and go over some the internals of +the package. Setup ----- @@ -17,7 +20,7 @@ Setup The distributed package included in PyTorch (i.e., ``torch.distributed``) enables researchers and practitioners to easily parallelize their computations across processes and clusters of -machines. To do so, it leverages the messaging passing semantics +machines. To do so, it leverages messaging passing semantics allowing each process to communicate data to any of the other processes. As opposed to the multiprocessing (``torch.multiprocessing``) package, processes can use different communication backends and are not @@ -45,7 +48,7 @@ the following template. """ Distributed function to be implemented later. """ pass - def init_processes(rank, size, fn, backend='tcp'): + def init_process(rank, size, fn, backend='gloo'): """ Initialize the distributed environment. """ os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29500' @@ -57,7 +60,7 @@ the following template. size = 2 processes = [] for rank in range(size): - p = Process(target=init_processes, args=(rank, size, run)) + p = Process(target=init_process, args=(rank, size, run)) p.start() processes.append(p) @@ -69,12 +72,10 @@ distributed environment, initialize the process group (``dist.init_process_group``), and finally execute the given ``run`` function. -Let's have a look at the ``init_processes`` function. It ensures that +Let's have a look at the ``init_process`` function. It ensures that every process will be able to coordinate through a master, using the -same ip address and port. Note that we used the TCP backend, but we -could have used -`MPI `__ or -`Gloo `__ instead. (c.f. +same ip address and port. Note that we used the ``gloo`` backend but +other backends are available. (c.f. `Section 5.1 <#communication-backends>`__) We will go over the magic happening in ``dist.init_process_group`` at the end of this tutorial, but it essentially allows processes to communicate with each other by @@ -119,7 +120,7 @@ order to store the data it will receive. Also notice that ``send``/``recv`` are **blocking**: both processes stop until the communication is completed. On the other hand immediates are **non-blocking**; the script continues its execution and the methods -return a ``DistributedRequest`` object upon which we can choose to +return a ``Work`` object upon which we can choose to ``wait()``. .. code:: python @@ -257,7 +258,7 @@ something useful with it. Our goal will be to replicate the functionality of `DistributedDataParallel `__. Of course, this will be a didactic example and in a real-world -situtation you should use the official, well-tested and well-optimized +situation you should use the official, well-tested and well-optimized version linked above. Quite simply we want to implement a distributed version of stochastic @@ -443,43 +444,27 @@ Communication Backends One of the most elegant aspects of ``torch.distributed`` is its ability to abstract and build on top of different backends. As mentioned before, -there are currently three backends implemented in PyTorch: TCP, MPI, and -Gloo. They each have different specifications and tradeoffs, depending -on the desired use-case. A comparative table of supported functions can +there are currently three backends implemented in PyTorch: Gloo, NCCL, and +MPI. They each have different specifications and tradeoffs, depending +on the desired use case. A comparative table of supported functions can be found -`here `__. Note that a fourth backend, NCCL, has been added since the creation of this tutorial. See `this section `__ of the ``torch.distributed`` docs for more information about its use and value. - -**TCP Backend** - -So far we have made extensive usage of the TCP backend. It is quite -handy as a development platform, as it is guaranteed to work on most -machines and operating systems. It also supports all point-to-point and -collective functions on CPU. However, there is no support for GPUs and -its communication routines are not as optimized as the MPI one. +`here `__. **Gloo Backend** -The `Gloo backend `__ -provides an optimized implementation of *collective* communication -procedures, both for CPUs and GPUs. It particularly shines on GPUs as it -can perform communication without transferring data to the CPU's memory -using `GPUDirect `__. It is also -capable of using `NCCL `__ to perform -fast intra-node communication and implements its `own -algorithms `__ -for inter-node routines. - -Since version 0.2.0, the Gloo backend is automatically included with the -pre-compiled binaries of PyTorch. As you have surely noticed, our +So far we have made extensive usage of the `Gloo backend `__. +It is quite handy as a development platform, as it is included in +the pre-compiled PyTorch binaries and works on both Linux (since 0.2) +and macOS (since 1.3). It supports all point-to-point and collective +operations on CPU, and all collective operations on GPU. The +implementation of the collective operations for CUDA tensors is not as +optimized as the ones provided by the NCCL backend. + +As you have surely noticed, our distributed SGD example does not work if you put ``model`` on the GPU. -Let's fix it by first replacing ``backend='gloo'`` in -``init_processes(rank, size, fn, backend='tcp')``. At this point, the -script will still run on CPU but uses the Gloo backend behind the -scenes. In order to use multiple GPUs, let us also do the following +In order to use multiple GPUs, let us also do the following modifications: -0. ``init_processes(rank, size, fn, backend='tcp')`` :math:`\rightarrow` - ``init_processes(rank, size, fn, backend='gloo')`` 1. Use ``device = torch.device("cuda:{}".format(rank))`` 2. ``model = Net()`` :math:`\rightarrow` ``model = Net().to(device)`` 3. Use ``data, target = data.to(device), target.to(device)`` @@ -526,7 +511,7 @@ In order to test our newly installed backend, a few modifications are required. 1. Replace the content under ``if __name__ == '__main__':`` with - ``init_processes(0, 0, run, backend='mpi')``. + ``init_process(0, 0, run, backend='mpi')``. 2. Run ``mpirun -n 4 python myscript.py``. The reason for these changes is that MPI needs to create its own @@ -541,6 +526,14 @@ more `__) Doing so, you should obtain the same familiar output as with the other communication backends. +**NCCL Backend** + +The `NCCL backend `__ provides an +optimized implementation of collective operations against CUDA +tensors. If you only use CUDA tensors for your collective operations, +consider using this backend for the best in class performance. The +NCCL backend is included in the pre-built binaries with CUDA support. + Initialization Methods ~~~~~~~~~~~~~~~~~~~~~~ @@ -554,33 +547,6 @@ naturally more suitable than the others. In addition to the following sections, you should also have a look at the `official documentation `__. -Before diving into the initialization methods, let's have a quick look -at what happens behind ``init_process_group`` from the C/C++ -perspective. - -1. First, the arguments are parsed and validated. -2. The backend is resolved via the ``name2channel.at()`` function. A - ``Channel`` class is returned, and will be used to perform the data - transmission. -3. The GIL is dropped, and ``THDProcessGroupInit()`` is called. This - instantiates the channel and adds the address of the master node. -4. The process with rank 0 will execute the ``master`` procedure, while - all other ranks will be ``workers``. -5. The master - - a. Creates sockets for all workers. - b. Waits for all workers to connect. - c. Sends them information about the location of the other processes. - -6. Each worker - - a. Creates a socket to the master. - b. Sends their own location information. - c. Receives information about the other workers. - d. Opens a socket and handshakes with all other workers. - -7. The initialization is done, and everyone is connected to everyone. - **Environment Variable** We have been using the environment variable initialization method @@ -606,44 +572,27 @@ that each process will open the file, write its information, and wait until everybody did so. After what all required information will be readily available to all processes. In order to avoid race conditions, the file system must support locking through -`fcntl `__. Note that -you can specify ranks manually or let the processes figure it out by -themselves. Be defining a unique ``groupname`` per job you can use the -same file path for multiple jobs and safely avoid collision. +`fcntl `__. .. code:: python - dist.init_process_group(init_method='file:///mnt/nfs/sharedfile', world_size=4, - group_name='mygroup') - -**TCP Init & Multicast** - -Initializing via TCP can be achieved in two different ways: - -1. By providing the IP address of the process with rank 0 and the world - size. -2. By providing *any* valid IP `multicast - address `__ and the - world size. - -In the first case, all workers will be able to connect to the process -with rank 0 and follow the procedure described above. - -.. code:: python + dist.init_process_group( + init_method='file:///mnt/nfs/sharedfile', + rank=args.rank, + world_size=4) - dist.init_process_group(init_method='tcp://10.1.1.20:23456', rank=args.rank, world_size=4) +**TCP** -In the second case, the multicast address specifies the group of nodes -who might potentially be active and the coordination can be handled by -allowing each process to have an initial handshake before following the -above procedure. In addition TCP multicast initialization also supports -a ``group_name`` argument (as with the shared file method) allowing -multiple jobs to be scheduled on the same cluster. +Initializing via TCP can be achieved by providing the IP address of the process with rank 0 and a reachable port number. +Here, all workers will be able to connect to the process +with rank 0 and exchange information on how to reach each other. .. code:: python - dist.init_process_group(init_method='tcp://[ff15:1e18:5d4c:4cf0:d02d:b659:53ba:b0a7]:23456', - world_size=4) + dist.init_process_group( + init_method='tcp://10.1.1.20:23456', + rank=args.rank, + world_size=4) .. raw:: html