|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +Model Parallel Best Practices |
| 4 | +************************************************************* |
| 5 | +**Author**: `Shen Li <https://mrshenli.github.io/>`_ |
| 6 | +
|
| 7 | +Data parallel and model parallel are widely-used distributed training |
| 8 | +techniques. Previous posts have explained how to use |
| 9 | +`DataParallel <https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html>`_ |
| 10 | +to train a neural network on multiple GPUs. ``DataParallel`` replicates the |
| 11 | +same model to all GPUs, where each GPU consumes a different partition of the |
| 12 | +input data. Although it can significantly accelerate the training process, it |
| 13 | +does not work for some use cases where the model is large to fit into a single |
| 14 | +GPU. This post shows how to solve that problem by using model parallel and also |
| 15 | +shares some insights on how to speed up model parallel training. |
| 16 | +
|
| 17 | +The high-level idea of model parallel is to place different sub-networks of a |
| 18 | +model onto different devices, and implement the ``forward`` method accordingly |
| 19 | +to move intermediate outputs across devices. As only part of a model operates |
| 20 | +on any individual device, a set of devices can collectively serve a larger |
| 21 | +model. In this post, we will not try to construct huge models and squeeze them |
| 22 | +into a limited number of GPUs. Instead, this post focuses on showing the idea |
| 23 | +of model parallel. It is up to the readers to apply the ideas to real-world |
| 24 | +applications. |
| 25 | +
|
| 26 | +Let us start with a toy model that contains two linear layers. To run this |
| 27 | +model on two GPUs, simply put each linear layer on a different GPU, and move |
| 28 | +inputs and intermediate outputs to match the layer devices accordingly. |
| 29 | +""" |
| 30 | + |
| 31 | +import torch |
| 32 | +import torch.nn as nn |
| 33 | +import torch.optim as optim |
| 34 | + |
| 35 | + |
| 36 | +class ToyModel(nn.Module): |
| 37 | + def __init__(self): |
| 38 | + super(ToyModel, self).__init__() |
| 39 | + self.net1 = torch.nn.Linear(10, 10).to('cuda:0') |
| 40 | + self.relu = torch.nn.ReLU().to('cuda:0') |
| 41 | + self.net2 = torch.nn.Linear(10, 5).to('cuda:1') |
| 42 | + |
| 43 | + def forward(self, x): |
| 44 | + return self.net2(self.net1(x.to('cuda:0')).to('cuda:1')) |
| 45 | + |
| 46 | +###################################################################### |
| 47 | +# Note that, the above ``ToyModel`` looks very similar to how one would |
| 48 | +# implement it on a single GPU, except the five ``to(device)`` calls which |
| 49 | +# place linear layers and tensors on proper devices. That is the only place in |
| 50 | +# the model that requires changes. The ``backward()`` and ``torch.optim`` will |
| 51 | +# automatically take care of gradients as if the model is on one GPU. You only |
| 52 | +# need to make sure that the labels are on the same device as the outputs when |
| 53 | +# calling the loss function. |
| 54 | + |
| 55 | + |
| 56 | +model = ToyModel() |
| 57 | +loss_fn = nn.MSELoss() |
| 58 | +optimizer = optim.SGD(model.parameters(), lr=0.001) |
| 59 | + |
| 60 | +optimizer.zero_grad() |
| 61 | +outputs = model(torch.randn(20, 10)) |
| 62 | +labels = torch.randn(20, 5).to('cuda:1') |
| 63 | +loss_fn(outputs, labels).backward() |
| 64 | +optimizer.step() |
| 65 | + |
| 66 | +###################################################################### |
| 67 | +# Apply Model Parallel to Existing Modules |
| 68 | +# ======================= |
| 69 | +# |
| 70 | +# It is also possible to run an existing single-GPU module on multiple GPUs |
| 71 | +# with just a few lines of changes. The code below shows how to decompose |
| 72 | +# ``torchvision.models.reset50()`` to two GPUs. The idea is to inherit from |
| 73 | +# the existing ``ResNet`` module, and split the layers to two GPUs during |
| 74 | +# construction. Then, override the ``forward`` method to stitch two |
| 75 | +# sub-networks by moving the intermediate outputs accordingly. |
| 76 | + |
| 77 | + |
| 78 | +from torchvision.models.resnet import ResNet, Bottleneck |
| 79 | + |
| 80 | +num_classes = 1000 |
| 81 | + |
| 82 | + |
| 83 | +class ModelParallelResNet50(ResNet): |
| 84 | + def __init__(self, *args, **kwargs): |
| 85 | + super(ModelParallelResNet50, self).__init__( |
| 86 | + Bottleneck, [3, 4, 6, 3], num_classes=num_classes, *args, **kwargs) |
| 87 | + |
| 88 | + self.seq1 = nn.Sequential( |
| 89 | + self.conv1, |
| 90 | + self.bn1, |
| 91 | + self.relu, |
| 92 | + self.maxpool, |
| 93 | + |
| 94 | + self.layer1, |
| 95 | + self.layer2 |
| 96 | + ).to('cuda:0') |
| 97 | + |
| 98 | + self.seq2 = nn.Sequential( |
| 99 | + self.layer3, |
| 100 | + self.layer4, |
| 101 | + self.avgpool, |
| 102 | + ).to('cuda:1') |
| 103 | + |
| 104 | + self.fc.to('cuda:1') |
| 105 | + |
| 106 | + def forward(self, x): |
| 107 | + x = self.seq2(self.seq1(x).to('cuda:1')) |
| 108 | + return self.fc(x.view(x.size(0), -1)) |
| 109 | + |
| 110 | + |
| 111 | +###################################################################### |
| 112 | +# The above implementation solves the problem for cases where the model is too |
| 113 | +# large to fit into a single GPU. However, you might have already noticed that |
| 114 | +# it will be slower than running it on a single GPU if your model fits. It is |
| 115 | +# because, at any point in time, only one of the two GPUs are working, while |
| 116 | +# the other one is sitting there doing nothing. The performance further |
| 117 | +# deteriorates as the intermediate outputs need to be copied from ``cuda:0`` to |
| 118 | +# ``cuda:1`` between ``layer2`` and ``layer3``. |
| 119 | +# |
| 120 | +# Let us run an experiment to get a more quantitative view of the execution |
| 121 | +# time. In this experiment, we train ``ModelParallelResNet50`` and the existing |
| 122 | +# ``torchvision.models.reset50()`` by running random inputs and labels through |
| 123 | +# them. After the training, the models will not produce any useful predictions, |
| 124 | +# but we can get a reasonable understanding of the execution times. |
| 125 | + |
| 126 | + |
| 127 | +import torchvision.models as models |
| 128 | + |
| 129 | +num_batches = 3 |
| 130 | +batch_size = 120 |
| 131 | +image_w = 128 |
| 132 | +image_h = 128 |
| 133 | + |
| 134 | + |
| 135 | +def train(model): |
| 136 | + model.train(True) |
| 137 | + loss_fn = nn.MSELoss() |
| 138 | + optimizer = optim.SGD(model.parameters(), lr=0.001) |
| 139 | + |
| 140 | + one_hot_indices = torch.LongTensor(batch_size) \ |
| 141 | + .random_(0, num_classes) \ |
| 142 | + .view(batch_size, 1) |
| 143 | + |
| 144 | + for _ in range(num_batches): |
| 145 | + # generate random inputs and labels |
| 146 | + inputs = torch.randn(batch_size, 3, image_w, image_h) |
| 147 | + labels = torch.zeros(batch_size, num_classes) \ |
| 148 | + .scatter_(1, one_hot_indices, 1) |
| 149 | + |
| 150 | + # run forward pass |
| 151 | + optimizer.zero_grad() |
| 152 | + outputs = model(inputs.to('cuda:0')) |
| 153 | + |
| 154 | + # run backward pass |
| 155 | + labels = labels.to(outputs.device) |
| 156 | + loss_fn(outputs, labels).backward() |
| 157 | + optimizer.step() |
| 158 | + |
| 159 | + |
| 160 | +###################################################################### |
| 161 | +# The ``train(model)`` method above uses ``nn.MSELoss`` as the loss function, |
| 162 | +# and ``optim.SGD`` as the optimizer. It mimics training on ``128 X 128`` |
| 163 | +# images which are organized into 3 batches where each batch contains 120 |
| 164 | +# images. Then, we use ``timeit`` to run the ``train(model)`` method 10 times |
| 165 | +# and plot the execution times with standard deviations. |
| 166 | + |
| 167 | + |
| 168 | +import matplotlib.pyplot as plt |
| 169 | +plt.switch_backend('agg') |
| 170 | +import numpy as np |
| 171 | +import timeit |
| 172 | + |
| 173 | +num_repeat = 10 |
| 174 | + |
| 175 | +stmt = "train(model)" |
| 176 | + |
| 177 | +setup = "model = ModelParallelResNet50()" |
| 178 | +# globals arg is only available in Python 3. In Python 2, use the following |
| 179 | +# import __builtin__ |
| 180 | +# __builtin__.__dict__.update(locals()) |
| 181 | +mp_run_times = timeit.repeat( |
| 182 | + stmt, setup, number=1, repeat=num_repeat, globals=globals()) |
| 183 | +mp_mean, mp_std = np.mean(mp_run_times), np.std(mp_run_times) |
| 184 | + |
| 185 | +setup = "import torchvision.models as models;" + \ |
| 186 | + "model = models.resnet50(num_classes=num_classes).to('cuda:0')" |
| 187 | +rn_run_times = timeit.repeat( |
| 188 | + stmt, setup, number=1, repeat=num_repeat, globals=globals()) |
| 189 | +rn_mean, rn_std = np.mean(rn_run_times), np.std(rn_run_times) |
| 190 | + |
| 191 | + |
| 192 | +def plot(means, stds, labels, fig_name): |
| 193 | + fig, ax = plt.subplots() |
| 194 | + ax.bar(np.arange(len(means)), means, yerr=stds, |
| 195 | + align='center', alpha=0.5, ecolor='red', capsize=10, width=0.6) |
| 196 | + ax.set_ylabel('ResNet50 Execution Time (Second)') |
| 197 | + ax.set_xticks(np.arange(len(means))) |
| 198 | + ax.set_xticklabels(labels) |
| 199 | + ax.yaxis.grid(True) |
| 200 | + plt.tight_layout() |
| 201 | + plt.savefig(fig_name) |
| 202 | + |
| 203 | + |
| 204 | +plot([mp_mean, rn_mean], |
| 205 | + [mp_std, rn_std], |
| 206 | + ['Model Parallel', 'Single GPU'], |
| 207 | + 'mp_vs_rn.png') |
| 208 | + |
| 209 | + |
| 210 | +###################################################################### |
| 211 | +# |
| 212 | +# .. figure:: /_static/img/model-parallel-images/mp_vs_rn.png |
| 213 | +# :alt: |
| 214 | +# |
| 215 | +# The result shows that the execution time of model parallel implementation is |
| 216 | +# ``4.02/3.75-1=7%`` longer than the existing single-GPU implementation. So we |
| 217 | +# can conclude there is roughly 7% overhead in copying tensors back and forth |
| 218 | +# across the GPUs. There are rooms for improvements, as we know one of the two |
| 219 | +# GPUs is sitting idle throughout the execution. One option is to further |
| 220 | +# divide each batch into a pipeline of splits, such that when one split reaches |
| 221 | +# the second sub-network, the following split can be fed into the first |
| 222 | +# sub-network. In this way, two consecutive splits can run concurrently on two |
| 223 | +# GPUs. |
| 224 | + |
| 225 | +###################################################################### |
| 226 | +# Speed Up by Pipelining Inputs |
| 227 | +# ======================= |
| 228 | +# |
| 229 | +# In the following experiments, we further divide each 120-image batch into |
| 230 | +# 20-image splits. As PyTorch launches CUDA operations asynchronizely, the |
| 231 | +# implementation does not need to spawn multiple threads to achieve |
| 232 | +# concurrency. |
| 233 | + |
| 234 | + |
| 235 | +class PipelineParallelResNet50(ModelParallelResNet50): |
| 236 | + def __init__(self, split_size=20, *args, **kwargs): |
| 237 | + super(PipelineParallelResNet50, self).__init__(*args, **kwargs) |
| 238 | + self.split_size = split_size |
| 239 | + |
| 240 | + def forward(self, x): |
| 241 | + splits = iter(x.split(self.split_size, dim=0)) |
| 242 | + s_next = next(splits) |
| 243 | + s_prev = self.seq1(s_next).to('cuda:1') |
| 244 | + ret = [] |
| 245 | + |
| 246 | + for s_next in splits: |
| 247 | + # A. s_prev runs on cuda:1 |
| 248 | + s_prev = self.seq2(s_prev) |
| 249 | + ret.append(self.fc(s_prev.view(s_prev.size(0), -1))) |
| 250 | + |
| 251 | + # B. s_next runs on cuda:0, which can run concurrently with A |
| 252 | + s_prev = self.seq1(s_next).to('cuda:1') |
| 253 | + |
| 254 | + s_prev = self.seq2(s_prev) |
| 255 | + ret.append(self.fc(s_prev.view(s_prev.size(0), -1))) |
| 256 | + |
| 257 | + return torch.cat(ret) |
| 258 | + |
| 259 | + |
| 260 | +setup = "model = PipelineParallelResNet50()" |
| 261 | +pp_run_times = timeit.repeat( |
| 262 | + stmt, setup, number=1, repeat=num_repeat, globals=globals()) |
| 263 | +pp_mean, pp_std = np.mean(pp_run_times), np.std(pp_run_times) |
| 264 | + |
| 265 | +plot([mp_mean, rn_mean, pp_mean], |
| 266 | + [mp_std, rn_std, pp_std], |
| 267 | + ['Model Parallel', 'Single GPU', 'Pipelining Model Parallel'], |
| 268 | + 'mp_vs_rn_vs_pp.png') |
| 269 | + |
| 270 | +###################################################################### |
| 271 | +# Please note, device-to-device tensor copy operations are synchronized on |
| 272 | +# current streams on the source and the destination devices. If you create |
| 273 | +# multiple streams, you have to make sure that copy operations are properly |
| 274 | +# synchronized. Writing the source tensor or reading/writing the destination |
| 275 | +# tensor before finishing the copy operation can lead to undefined behavior. |
| 276 | +# The above implementation only uses default streams on both source and |
| 277 | +# destination devices, hence it is not necessary to enforce additional |
| 278 | +# synchronizations. |
| 279 | +# |
| 280 | +# .. figure:: /_static/img/model-parallel-images/mp_vs_rn_vs_pp.png |
| 281 | +# :alt: |
| 282 | +# |
| 283 | +# The experiment result shows that, pipelining inputs to model parallel |
| 284 | +# ResNet50 speeds up the training process by roughly ``3.75/2.51-1=49%``. It is |
| 285 | +# still quite far away from the ideal 100% speedup. As we have introduced a new |
| 286 | +# parameter ``split_sizes`` in our pipeline parallel implementation, it is |
| 287 | +# unclear how the new parameter affects the overall training time. Intuitively |
| 288 | +# speaking, using small ``split_size`` leads to many tiny CUDA kernel launch, |
| 289 | +# while using large ``split_size`` results to relatively long idle times during |
| 290 | +# the first and last splits. Neither are optimal. There might be an optimal |
| 291 | +# ``split_size`` configuration for this specific experiment. Let us try to find |
| 292 | +# it by running experiments using several different ``split_size`` values. |
| 293 | + |
| 294 | + |
| 295 | +means = [] |
| 296 | +stds = [] |
| 297 | +split_sizes = [1, 3, 5, 8, 10, 12, 20, 40, 60] |
| 298 | + |
| 299 | +for split_size in split_sizes: |
| 300 | + setup = "model = PipelineParallelResNet50(split_size=%d)" % split_size |
| 301 | + pp_run_times = timeit.repeat( |
| 302 | + stmt, setup, number=1, repeat=num_repeat, globals=globals()) |
| 303 | + means.append(np.mean(pp_run_times)) |
| 304 | + stds.append(np.std(pp_run_times)) |
| 305 | + |
| 306 | +fig, ax = plt.subplots() |
| 307 | +ax.plot(split_sizes, means) |
| 308 | +ax.errorbar(split_sizes, means, yerr=stds, ecolor='red', fmt='ro') |
| 309 | +ax.set_ylabel('ResNet50 Execution Time (Second)') |
| 310 | +ax.set_xlabel('Pipeline Split Size') |
| 311 | +ax.set_xticks(split_sizes) |
| 312 | +ax.yaxis.grid(True) |
| 313 | +plt.tight_layout() |
| 314 | +plt.savefig("split_size_tradeoff.png") |
| 315 | + |
| 316 | +###################################################################### |
| 317 | +# |
| 318 | +# .. figure:: /_static/img/model-parallel-images/split_size_tradeoff.png |
| 319 | +# :alt: |
| 320 | +# |
| 321 | +# The result shows that setting ``split_size`` to 12 achieves the fastest |
| 322 | +# training speed, which leads to ``3.75/2.43-1=54%`` speedup. There are |
| 323 | +# still opportunities to further accelerate the training process. For example, |
| 324 | +# all operations on ``cuda:0`` is placed on its default stream. It means that |
| 325 | +# computations on the next split cannot overlap with the copy operation of the |
| 326 | +# prev split. However, as prev and next splits are different tensors, there is |
| 327 | +# no problem to overlap one's computation with the other one's copy. The |
| 328 | +# implementation need to use multiple streams on both GPUs, and different |
| 329 | +# sub-network structures require different stream management strategies. As no |
| 330 | +# general multi-stream solution works for all model parallel use cases, we will |
| 331 | +# not discuss it in this tutorial. |
0 commit comments