diff --git a/_static/img/model-parallel-images/mp_vs_rn.png b/_static/img/model-parallel-images/mp_vs_rn.png new file mode 100644 index 00000000000..c56ec8adf51 Binary files /dev/null and b/_static/img/model-parallel-images/mp_vs_rn.png differ diff --git a/_static/img/model-parallel-images/mp_vs_rn_vs_pp.png b/_static/img/model-parallel-images/mp_vs_rn_vs_pp.png new file mode 100644 index 00000000000..a102c916771 Binary files /dev/null and b/_static/img/model-parallel-images/mp_vs_rn_vs_pp.png differ diff --git a/_static/img/model-parallel-images/split_size_tradeoff.png b/_static/img/model-parallel-images/split_size_tradeoff.png new file mode 100644 index 00000000000..f30eba44637 Binary files /dev/null and b/_static/img/model-parallel-images/split_size_tradeoff.png differ diff --git a/index.rst b/index.rst index 27dc0302b89..9009cbd08de 100644 --- a/index.rst +++ b/index.rst @@ -211,6 +211,11 @@ Production Usage :description: :doc:`/intermediate/dist_tuto` :figure: _static/img/distributed/DistPyTorch.jpg +.. customgalleryitem:: + :tooltip: Train large models with multiple GPUs using model parallel + :description: :doc:`/intermediate/model_parallel_tutorial` + :figure: _static/img/distributed/DistPyTorch.jpg + .. customgalleryitem:: :tooltip: PyTorch distributed trainer with Amazon AWS :description: :doc:`/beginner/aws_distributed_training_tutorial` diff --git a/intermediate_source/model_parallel_tutorial.py b/intermediate_source/model_parallel_tutorial.py new file mode 100644 index 00000000000..e0132ce08b4 --- /dev/null +++ b/intermediate_source/model_parallel_tutorial.py @@ -0,0 +1,331 @@ +# -*- coding: utf-8 -*- +""" +Model Parallel Best Practices +************************************************************* +**Author**: `Shen Li `_ + +Data parallel and model parallel are widely-used distributed training +techniques. Previous posts have explained how to use +`DataParallel `_ +to train a neural network on multiple GPUs. ``DataParallel`` replicates the +same model to all GPUs, where each GPU consumes a different partition of the +input data. Although it can significantly accelerate the training process, it +does not work for some use cases where the model is large to fit into a single +GPU. This post shows how to solve that problem by using model parallel and also +shares some insights on how to speed up model parallel training. + +The high-level idea of model parallel is to place different sub-networks of a +model onto different devices, and implement the ``forward`` method accordingly +to move intermediate outputs across devices. As only part of a model operates +on any individual device, a set of devices can collectively serve a larger +model. In this post, we will not try to construct huge models and squeeze them +into a limited number of GPUs. Instead, this post focuses on showing the idea +of model parallel. It is up to the readers to apply the ideas to real-world +applications. + +Let us start with a toy model that contains two linear layers. To run this +model on two GPUs, simply put each linear layer on a different GPU, and move +inputs and intermediate outputs to match the layer devices accordingly. +""" + +import torch +import torch.nn as nn +import torch.optim as optim + + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = torch.nn.Linear(10, 10).to('cuda:0') + self.relu = torch.nn.ReLU().to('cuda:0') + self.net2 = torch.nn.Linear(10, 5).to('cuda:1') + + def forward(self, x): + return self.net2(self.net1(x.to('cuda:0')).to('cuda:1')) + +###################################################################### +# Note that, the above ``ToyModel`` looks very similar to how one would +# implement it on a single GPU, except the five ``to(device)`` calls which +# place linear layers and tensors on proper devices. That is the only place in +# the model that requires changes. The ``backward()`` and ``torch.optim`` will +# automatically take care of gradients as if the model is on one GPU. You only +# need to make sure that the labels are on the same device as the outputs when +# calling the loss function. + + +model = ToyModel() +loss_fn = nn.MSELoss() +optimizer = optim.SGD(model.parameters(), lr=0.001) + +optimizer.zero_grad() +outputs = model(torch.randn(20, 10)) +labels = torch.randn(20, 5).to('cuda:1') +loss_fn(outputs, labels).backward() +optimizer.step() + +###################################################################### +# Apply Model Parallel to Existing Modules +# ======================= +# +# It is also possible to run an existing single-GPU module on multiple GPUs +# with just a few lines of changes. The code below shows how to decompose +# ``torchvision.models.reset50()`` to two GPUs. The idea is to inherit from +# the existing ``ResNet`` module, and split the layers to two GPUs during +# construction. Then, override the ``forward`` method to stitch two +# sub-networks by moving the intermediate outputs accordingly. + + +from torchvision.models.resnet import ResNet, Bottleneck + +num_classes = 1000 + + +class ModelParallelResNet50(ResNet): + def __init__(self, *args, **kwargs): + super(ModelParallelResNet50, self).__init__( + Bottleneck, [3, 4, 6, 3], num_classes=num_classes, *args, **kwargs) + + self.seq1 = nn.Sequential( + self.conv1, + self.bn1, + self.relu, + self.maxpool, + + self.layer1, + self.layer2 + ).to('cuda:0') + + self.seq2 = nn.Sequential( + self.layer3, + self.layer4, + self.avgpool, + ).to('cuda:1') + + self.fc.to('cuda:1') + + def forward(self, x): + x = self.seq2(self.seq1(x).to('cuda:1')) + return self.fc(x.view(x.size(0), -1)) + + +###################################################################### +# The above implementation solves the problem for cases where the model is too +# large to fit into a single GPU. However, you might have already noticed that +# it will be slower than running it on a single GPU if your model fits. It is +# because, at any point in time, only one of the two GPUs are working, while +# the other one is sitting there doing nothing. The performance further +# deteriorates as the intermediate outputs need to be copied from ``cuda:0`` to +# ``cuda:1`` between ``layer2`` and ``layer3``. +# +# Let us run an experiment to get a more quantitative view of the execution +# time. In this experiment, we train ``ModelParallelResNet50`` and the existing +# ``torchvision.models.reset50()`` by running random inputs and labels through +# them. After the training, the models will not produce any useful predictions, +# but we can get a reasonable understanding of the execution times. + + +import torchvision.models as models + +num_batches = 3 +batch_size = 120 +image_w = 128 +image_h = 128 + + +def train(model): + model.train(True) + loss_fn = nn.MSELoss() + optimizer = optim.SGD(model.parameters(), lr=0.001) + + one_hot_indices = torch.LongTensor(batch_size) \ + .random_(0, num_classes) \ + .view(batch_size, 1) + + for _ in range(num_batches): + # generate random inputs and labels + inputs = torch.randn(batch_size, 3, image_w, image_h) + labels = torch.zeros(batch_size, num_classes) \ + .scatter_(1, one_hot_indices, 1) + + # run forward pass + optimizer.zero_grad() + outputs = model(inputs.to('cuda:0')) + + # run backward pass + labels = labels.to(outputs.device) + loss_fn(outputs, labels).backward() + optimizer.step() + + +###################################################################### +# The ``train(model)`` method above uses ``nn.MSELoss`` as the loss function, +# and ``optim.SGD`` as the optimizer. It mimics training on ``128 X 128`` +# images which are organized into 3 batches where each batch contains 120 +# images. Then, we use ``timeit`` to run the ``train(model)`` method 10 times +# and plot the execution times with standard deviations. + + +import matplotlib.pyplot as plt +plt.switch_backend('agg') +import numpy as np +import timeit + +num_repeat = 10 + +stmt = "train(model)" + +setup = "model = ModelParallelResNet50()" +# globals arg is only available in Python 3. In Python 2, use the following +# import __builtin__ +# __builtin__.__dict__.update(locals()) +mp_run_times = timeit.repeat( + stmt, setup, number=1, repeat=num_repeat, globals=globals()) +mp_mean, mp_std = np.mean(mp_run_times), np.std(mp_run_times) + +setup = "import torchvision.models as models;" + \ + "model = models.resnet50(num_classes=num_classes).to('cuda:0')" +rn_run_times = timeit.repeat( + stmt, setup, number=1, repeat=num_repeat, globals=globals()) +rn_mean, rn_std = np.mean(rn_run_times), np.std(rn_run_times) + + +def plot(means, stds, labels, fig_name): + fig, ax = plt.subplots() + ax.bar(np.arange(len(means)), means, yerr=stds, + align='center', alpha=0.5, ecolor='red', capsize=10, width=0.6) + ax.set_ylabel('ResNet50 Execution Time (Second)') + ax.set_xticks(np.arange(len(means))) + ax.set_xticklabels(labels) + ax.yaxis.grid(True) + plt.tight_layout() + plt.savefig(fig_name) + + +plot([mp_mean, rn_mean], + [mp_std, rn_std], + ['Model Parallel', 'Single GPU'], + 'mp_vs_rn.png') + + +###################################################################### +# +# .. figure:: /_static/img/model-parallel-images/mp_vs_rn.png +# :alt: +# +# The result shows that the execution time of model parallel implementation is +# ``4.02/3.75-1=7%`` longer than the existing single-GPU implementation. So we +# can conclude there is roughly 7% overhead in copying tensors back and forth +# across the GPUs. There are rooms for improvements, as we know one of the two +# GPUs is sitting idle throughout the execution. One option is to further +# divide each batch into a pipeline of splits, such that when one split reaches +# the second sub-network, the following split can be fed into the first +# sub-network. In this way, two consecutive splits can run concurrently on two +# GPUs. + +###################################################################### +# Speed Up by Pipelining Inputs +# ======================= +# +# In the following experiments, we further divide each 120-image batch into +# 20-image splits. As PyTorch launches CUDA operations asynchronizely, the +# implementation does not need to spawn multiple threads to achieve +# concurrency. + + +class PipelineParallelResNet50(ModelParallelResNet50): + def __init__(self, split_size=20, *args, **kwargs): + super(PipelineParallelResNet50, self).__init__(*args, **kwargs) + self.split_size = split_size + + def forward(self, x): + splits = iter(x.split(self.split_size, dim=0)) + s_next = next(splits) + s_prev = self.seq1(s_next).to('cuda:1') + ret = [] + + for s_next in splits: + # A. s_prev runs on cuda:1 + s_prev = self.seq2(s_prev) + ret.append(self.fc(s_prev.view(s_prev.size(0), -1))) + + # B. s_next runs on cuda:0, which can run concurrently with A + s_prev = self.seq1(s_next).to('cuda:1') + + s_prev = self.seq2(s_prev) + ret.append(self.fc(s_prev.view(s_prev.size(0), -1))) + + return torch.cat(ret) + + +setup = "model = PipelineParallelResNet50()" +pp_run_times = timeit.repeat( + stmt, setup, number=1, repeat=num_repeat, globals=globals()) +pp_mean, pp_std = np.mean(pp_run_times), np.std(pp_run_times) + +plot([mp_mean, rn_mean, pp_mean], + [mp_std, rn_std, pp_std], + ['Model Parallel', 'Single GPU', 'Pipelining Model Parallel'], + 'mp_vs_rn_vs_pp.png') + +###################################################################### +# Please note, device-to-device tensor copy operations are synchronized on +# current streams on the source and the destination devices. If you create +# multiple streams, you have to make sure that copy operations are properly +# synchronized. Writing the source tensor or reading/writing the destination +# tensor before finishing the copy operation can lead to undefined behavior. +# The above implementation only uses default streams on both source and +# destination devices, hence it is not necessary to enforce additional +# synchronizations. +# +# .. figure:: /_static/img/model-parallel-images/mp_vs_rn_vs_pp.png +# :alt: +# +# The experiment result shows that, pipelining inputs to model parallel +# ResNet50 speeds up the training process by roughly ``3.75/2.51-1=49%``. It is +# still quite far away from the ideal 100% speedup. As we have introduced a new +# parameter ``split_sizes`` in our pipeline parallel implementation, it is +# unclear how the new parameter affects the overall training time. Intuitively +# speaking, using small ``split_size`` leads to many tiny CUDA kernel launch, +# while using large ``split_size`` results to relatively long idle times during +# the first and last splits. Neither are optimal. There might be an optimal +# ``split_size`` configuration for this specific experiment. Let us try to find +# it by running experiments using several different ``split_size`` values. + + +means = [] +stds = [] +split_sizes = [1, 3, 5, 8, 10, 12, 20, 40, 60] + +for split_size in split_sizes: + setup = "model = PipelineParallelResNet50(split_size=%d)" % split_size + pp_run_times = timeit.repeat( + stmt, setup, number=1, repeat=num_repeat, globals=globals()) + means.append(np.mean(pp_run_times)) + stds.append(np.std(pp_run_times)) + +fig, ax = plt.subplots() +ax.plot(split_sizes, means) +ax.errorbar(split_sizes, means, yerr=stds, ecolor='red', fmt='ro') +ax.set_ylabel('ResNet50 Execution Time (Second)') +ax.set_xlabel('Pipeline Split Size') +ax.set_xticks(split_sizes) +ax.yaxis.grid(True) +plt.tight_layout() +plt.savefig("split_size_tradeoff.png") + +###################################################################### +# +# .. figure:: /_static/img/model-parallel-images/split_size_tradeoff.png +# :alt: +# +# The result shows that setting ``split_size`` to 12 achieves the fastest +# training speed, which leads to ``3.75/2.43-1=54%`` speedup. There are +# still opportunities to further accelerate the training process. For example, +# all operations on ``cuda:0`` is placed on its default stream. It means that +# computations on the next split cannot overlap with the copy operation of the +# prev split. However, as prev and next splits are different tensors, there is +# no problem to overlap one's computation with the other one's copy. The +# implementation need to use multiple streams on both GPUs, and different +# sub-network structures require different stream management strategies. As no +# general multi-stream solution works for all model parallel use cases, we will +# not discuss it in this tutorial.