Skip to content

Commit f52c85d

Browse files
authored
Merge pull request #436 from mrshenli/mp
Adding model parallel tutorial
2 parents bca73b7 + 43ed1b9 commit f52c85d

File tree

5 files changed

+336
-0
lines changed

5 files changed

+336
-0
lines changed
17.1 KB
Loading
Loading
Loading

index.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ Production Usage
211211
:description: :doc:`/intermediate/dist_tuto`
212212
:figure: _static/img/distributed/DistPyTorch.jpg
213213

214+
.. customgalleryitem::
215+
:tooltip: Train large models with multiple GPUs using model parallel
216+
:description: :doc:`/intermediate/model_parallel_tutorial`
217+
:figure: _static/img/distributed/DistPyTorch.jpg
218+
214219
.. customgalleryitem::
215220
:tooltip: PyTorch distributed trainer with Amazon AWS
216221
:description: :doc:`/beginner/aws_distributed_training_tutorial`
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Model Parallel Best Practices
4+
*************************************************************
5+
**Author**: `Shen Li <https://mrshenli.github.io/>`_
6+
7+
Data parallel and model parallel are widely-used distributed training
8+
techniques. Previous posts have explained how to use
9+
`DataParallel <https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html>`_
10+
to train a neural network on multiple GPUs. ``DataParallel`` replicates the
11+
same model to all GPUs, where each GPU consumes a different partition of the
12+
input data. Although it can significantly accelerate the training process, it
13+
does not work for some use cases where the model is large to fit into a single
14+
GPU. This post shows how to solve that problem by using model parallel and also
15+
shares some insights on how to speed up model parallel training.
16+
17+
The high-level idea of model parallel is to place different sub-networks of a
18+
model onto different devices, and implement the ``forward`` method accordingly
19+
to move intermediate outputs across devices. As only part of a model operates
20+
on any individual device, a set of devices can collectively serve a larger
21+
model. In this post, we will not try to construct huge models and squeeze them
22+
into a limited number of GPUs. Instead, this post focuses on showing the idea
23+
of model parallel. It is up to the readers to apply the ideas to real-world
24+
applications.
25+
26+
Let us start with a toy model that contains two linear layers. To run this
27+
model on two GPUs, simply put each linear layer on a different GPU, and move
28+
inputs and intermediate outputs to match the layer devices accordingly.
29+
"""
30+
31+
import torch
32+
import torch.nn as nn
33+
import torch.optim as optim
34+
35+
36+
class ToyModel(nn.Module):
37+
def __init__(self):
38+
super(ToyModel, self).__init__()
39+
self.net1 = torch.nn.Linear(10, 10).to('cuda:0')
40+
self.relu = torch.nn.ReLU().to('cuda:0')
41+
self.net2 = torch.nn.Linear(10, 5).to('cuda:1')
42+
43+
def forward(self, x):
44+
return self.net2(self.net1(x.to('cuda:0')).to('cuda:1'))
45+
46+
######################################################################
47+
# Note that, the above ``ToyModel`` looks very similar to how one would
48+
# implement it on a single GPU, except the five ``to(device)`` calls which
49+
# place linear layers and tensors on proper devices. That is the only place in
50+
# the model that requires changes. The ``backward()`` and ``torch.optim`` will
51+
# automatically take care of gradients as if the model is on one GPU. You only
52+
# need to make sure that the labels are on the same device as the outputs when
53+
# calling the loss function.
54+
55+
56+
model = ToyModel()
57+
loss_fn = nn.MSELoss()
58+
optimizer = optim.SGD(model.parameters(), lr=0.001)
59+
60+
optimizer.zero_grad()
61+
outputs = model(torch.randn(20, 10))
62+
labels = torch.randn(20, 5).to('cuda:1')
63+
loss_fn(outputs, labels).backward()
64+
optimizer.step()
65+
66+
######################################################################
67+
# Apply Model Parallel to Existing Modules
68+
# =======================
69+
#
70+
# It is also possible to run an existing single-GPU module on multiple GPUs
71+
# with just a few lines of changes. The code below shows how to decompose
72+
# ``torchvision.models.reset50()`` to two GPUs. The idea is to inherit from
73+
# the existing ``ResNet`` module, and split the layers to two GPUs during
74+
# construction. Then, override the ``forward`` method to stitch two
75+
# sub-networks by moving the intermediate outputs accordingly.
76+
77+
78+
from torchvision.models.resnet import ResNet, Bottleneck
79+
80+
num_classes = 1000
81+
82+
83+
class ModelParallelResNet50(ResNet):
84+
def __init__(self, *args, **kwargs):
85+
super(ModelParallelResNet50, self).__init__(
86+
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, *args, **kwargs)
87+
88+
self.seq1 = nn.Sequential(
89+
self.conv1,
90+
self.bn1,
91+
self.relu,
92+
self.maxpool,
93+
94+
self.layer1,
95+
self.layer2
96+
).to('cuda:0')
97+
98+
self.seq2 = nn.Sequential(
99+
self.layer3,
100+
self.layer4,
101+
self.avgpool,
102+
).to('cuda:1')
103+
104+
self.fc.to('cuda:1')
105+
106+
def forward(self, x):
107+
x = self.seq2(self.seq1(x).to('cuda:1'))
108+
return self.fc(x.view(x.size(0), -1))
109+
110+
111+
######################################################################
112+
# The above implementation solves the problem for cases where the model is too
113+
# large to fit into a single GPU. However, you might have already noticed that
114+
# it will be slower than running it on a single GPU if your model fits. It is
115+
# because, at any point in time, only one of the two GPUs are working, while
116+
# the other one is sitting there doing nothing. The performance further
117+
# deteriorates as the intermediate outputs need to be copied from ``cuda:0`` to
118+
# ``cuda:1`` between ``layer2`` and ``layer3``.
119+
#
120+
# Let us run an experiment to get a more quantitative view of the execution
121+
# time. In this experiment, we train ``ModelParallelResNet50`` and the existing
122+
# ``torchvision.models.reset50()`` by running random inputs and labels through
123+
# them. After the training, the models will not produce any useful predictions,
124+
# but we can get a reasonable understanding of the execution times.
125+
126+
127+
import torchvision.models as models
128+
129+
num_batches = 3
130+
batch_size = 120
131+
image_w = 128
132+
image_h = 128
133+
134+
135+
def train(model):
136+
model.train(True)
137+
loss_fn = nn.MSELoss()
138+
optimizer = optim.SGD(model.parameters(), lr=0.001)
139+
140+
one_hot_indices = torch.LongTensor(batch_size) \
141+
.random_(0, num_classes) \
142+
.view(batch_size, 1)
143+
144+
for _ in range(num_batches):
145+
# generate random inputs and labels
146+
inputs = torch.randn(batch_size, 3, image_w, image_h)
147+
labels = torch.zeros(batch_size, num_classes) \
148+
.scatter_(1, one_hot_indices, 1)
149+
150+
# run forward pass
151+
optimizer.zero_grad()
152+
outputs = model(inputs.to('cuda:0'))
153+
154+
# run backward pass
155+
labels = labels.to(outputs.device)
156+
loss_fn(outputs, labels).backward()
157+
optimizer.step()
158+
159+
160+
######################################################################
161+
# The ``train(model)`` method above uses ``nn.MSELoss`` as the loss function,
162+
# and ``optim.SGD`` as the optimizer. It mimics training on ``128 X 128``
163+
# images which are organized into 3 batches where each batch contains 120
164+
# images. Then, we use ``timeit`` to run the ``train(model)`` method 10 times
165+
# and plot the execution times with standard deviations.
166+
167+
168+
import matplotlib.pyplot as plt
169+
plt.switch_backend('agg')
170+
import numpy as np
171+
import timeit
172+
173+
num_repeat = 10
174+
175+
stmt = "train(model)"
176+
177+
setup = "model = ModelParallelResNet50()"
178+
# globals arg is only available in Python 3. In Python 2, use the following
179+
# import __builtin__
180+
# __builtin__.__dict__.update(locals())
181+
mp_run_times = timeit.repeat(
182+
stmt, setup, number=1, repeat=num_repeat, globals=globals())
183+
mp_mean, mp_std = np.mean(mp_run_times), np.std(mp_run_times)
184+
185+
setup = "import torchvision.models as models;" + \
186+
"model = models.resnet50(num_classes=num_classes).to('cuda:0')"
187+
rn_run_times = timeit.repeat(
188+
stmt, setup, number=1, repeat=num_repeat, globals=globals())
189+
rn_mean, rn_std = np.mean(rn_run_times), np.std(rn_run_times)
190+
191+
192+
def plot(means, stds, labels, fig_name):
193+
fig, ax = plt.subplots()
194+
ax.bar(np.arange(len(means)), means, yerr=stds,
195+
align='center', alpha=0.5, ecolor='red', capsize=10, width=0.6)
196+
ax.set_ylabel('ResNet50 Execution Time (Second)')
197+
ax.set_xticks(np.arange(len(means)))
198+
ax.set_xticklabels(labels)
199+
ax.yaxis.grid(True)
200+
plt.tight_layout()
201+
plt.savefig(fig_name)
202+
203+
204+
plot([mp_mean, rn_mean],
205+
[mp_std, rn_std],
206+
['Model Parallel', 'Single GPU'],
207+
'mp_vs_rn.png')
208+
209+
210+
######################################################################
211+
#
212+
# .. figure:: /_static/img/model-parallel-images/mp_vs_rn.png
213+
# :alt:
214+
#
215+
# The result shows that the execution time of model parallel implementation is
216+
# ``4.02/3.75-1=7%`` longer than the existing single-GPU implementation. So we
217+
# can conclude there is roughly 7% overhead in copying tensors back and forth
218+
# across the GPUs. There are rooms for improvements, as we know one of the two
219+
# GPUs is sitting idle throughout the execution. One option is to further
220+
# divide each batch into a pipeline of splits, such that when one split reaches
221+
# the second sub-network, the following split can be fed into the first
222+
# sub-network. In this way, two consecutive splits can run concurrently on two
223+
# GPUs.
224+
225+
######################################################################
226+
# Speed Up by Pipelining Inputs
227+
# =======================
228+
#
229+
# In the following experiments, we further divide each 120-image batch into
230+
# 20-image splits. As PyTorch launches CUDA operations asynchronizely, the
231+
# implementation does not need to spawn multiple threads to achieve
232+
# concurrency.
233+
234+
235+
class PipelineParallelResNet50(ModelParallelResNet50):
236+
def __init__(self, split_size=20, *args, **kwargs):
237+
super(PipelineParallelResNet50, self).__init__(*args, **kwargs)
238+
self.split_size = split_size
239+
240+
def forward(self, x):
241+
splits = iter(x.split(self.split_size, dim=0))
242+
s_next = next(splits)
243+
s_prev = self.seq1(s_next).to('cuda:1')
244+
ret = []
245+
246+
for s_next in splits:
247+
# A. s_prev runs on cuda:1
248+
s_prev = self.seq2(s_prev)
249+
ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))
250+
251+
# B. s_next runs on cuda:0, which can run concurrently with A
252+
s_prev = self.seq1(s_next).to('cuda:1')
253+
254+
s_prev = self.seq2(s_prev)
255+
ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))
256+
257+
return torch.cat(ret)
258+
259+
260+
setup = "model = PipelineParallelResNet50()"
261+
pp_run_times = timeit.repeat(
262+
stmt, setup, number=1, repeat=num_repeat, globals=globals())
263+
pp_mean, pp_std = np.mean(pp_run_times), np.std(pp_run_times)
264+
265+
plot([mp_mean, rn_mean, pp_mean],
266+
[mp_std, rn_std, pp_std],
267+
['Model Parallel', 'Single GPU', 'Pipelining Model Parallel'],
268+
'mp_vs_rn_vs_pp.png')
269+
270+
######################################################################
271+
# Please note, device-to-device tensor copy operations are synchronized on
272+
# current streams on the source and the destination devices. If you create
273+
# multiple streams, you have to make sure that copy operations are properly
274+
# synchronized. Writing the source tensor or reading/writing the destination
275+
# tensor before finishing the copy operation can lead to undefined behavior.
276+
# The above implementation only uses default streams on both source and
277+
# destination devices, hence it is not necessary to enforce additional
278+
# synchronizations.
279+
#
280+
# .. figure:: /_static/img/model-parallel-images/mp_vs_rn_vs_pp.png
281+
# :alt:
282+
#
283+
# The experiment result shows that, pipelining inputs to model parallel
284+
# ResNet50 speeds up the training process by roughly ``3.75/2.51-1=49%``. It is
285+
# still quite far away from the ideal 100% speedup. As we have introduced a new
286+
# parameter ``split_sizes`` in our pipeline parallel implementation, it is
287+
# unclear how the new parameter affects the overall training time. Intuitively
288+
# speaking, using small ``split_size`` leads to many tiny CUDA kernel launch,
289+
# while using large ``split_size`` results to relatively long idle times during
290+
# the first and last splits. Neither are optimal. There might be an optimal
291+
# ``split_size`` configuration for this specific experiment. Let us try to find
292+
# it by running experiments using several different ``split_size`` values.
293+
294+
295+
means = []
296+
stds = []
297+
split_sizes = [1, 3, 5, 8, 10, 12, 20, 40, 60]
298+
299+
for split_size in split_sizes:
300+
setup = "model = PipelineParallelResNet50(split_size=%d)" % split_size
301+
pp_run_times = timeit.repeat(
302+
stmt, setup, number=1, repeat=num_repeat, globals=globals())
303+
means.append(np.mean(pp_run_times))
304+
stds.append(np.std(pp_run_times))
305+
306+
fig, ax = plt.subplots()
307+
ax.plot(split_sizes, means)
308+
ax.errorbar(split_sizes, means, yerr=stds, ecolor='red', fmt='ro')
309+
ax.set_ylabel('ResNet50 Execution Time (Second)')
310+
ax.set_xlabel('Pipeline Split Size')
311+
ax.set_xticks(split_sizes)
312+
ax.yaxis.grid(True)
313+
plt.tight_layout()
314+
plt.savefig("split_size_tradeoff.png")
315+
316+
######################################################################
317+
#
318+
# .. figure:: /_static/img/model-parallel-images/split_size_tradeoff.png
319+
# :alt:
320+
#
321+
# The result shows that setting ``split_size`` to 12 achieves the fastest
322+
# training speed, which leads to ``3.75/2.43-1=54%`` speedup. There are
323+
# still opportunities to further accelerate the training process. For example,
324+
# all operations on ``cuda:0`` is placed on its default stream. It means that
325+
# computations on the next split cannot overlap with the copy operation of the
326+
# prev split. However, as prev and next splits are different tensors, there is
327+
# no problem to overlap one's computation with the other one's copy. The
328+
# implementation need to use multiple streams on both GPUs, and different
329+
# sub-network structures require different stream management strategies. As no
330+
# general multi-stream solution works for all model parallel use cases, we will
331+
# not discuss it in this tutorial.

0 commit comments

Comments
 (0)