Skip to content

Commit eb732ef

Browse files
Fix bugs in pipeline tutorial with respect to batch size. (#1461)
Summary: As described in pytorch/pytorch#55036, certain modules were not handling batch size correctly. Co-authored-by: pritam <pritam.damania@fb.com>
1 parent 53d0fd7 commit eb732ef

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

advanced_source/ddp_pipeline.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def forward(self, x):
8989
class Encoder(nn.Module):
9090
def __init__(self, ntoken, ninp, dropout=0.5):
9191
super(Encoder, self).__init__()
92-
self.src_mask = None
9392
self.pos_encoder = PositionalEncoding(ninp, dropout)
9493
self.encoder = nn.Embedding(ntoken, ninp)
9594
self.ninp = ninp
@@ -99,17 +98,9 @@ def init_weights(self):
9998
initrange = 0.1
10099
self.encoder.weight.data.uniform_(-initrange, initrange)
101100

102-
def _generate_square_subsequent_mask(self, sz):
103-
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
104-
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
105-
return mask
106-
107101
def forward(self, src):
108-
if self.src_mask is None or self.src_mask.size(0) != src.size(0):
109-
device = src.device
110-
mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
111-
self.src_mask = mask
112-
102+
# Need (S, N) format for encoder.
103+
src = src.t()
113104
src = self.encoder(src) * math.sqrt(self.ninp)
114105
return self.pos_encoder(src)
115106

@@ -125,7 +116,8 @@ def init_weights(self):
125116
self.decoder.weight.data.uniform_(-initrange, initrange)
126117

127118
def forward(self, inp):
128-
return self.decoder(inp)
119+
# Need batch dimension first for output of pipeline.
120+
return self.decoder(inp).permute(1, 0, 2)
129121

130122
######################################################################
131123
# Start multiple processes for training
@@ -245,7 +237,8 @@ def get_batch(source, i):
245237
seq_len = min(bptt, len(source) - 1 - i)
246238
data = source[i:i+seq_len]
247239
target = source[i+1:i+1+seq_len].view(-1)
248-
return data, target
240+
# Need batch dimension first for pipeline parallelism.
241+
return data.t(), target
249242

250243
######################################################################
251244
# Model scale and Pipe initialization
@@ -318,8 +311,9 @@ def get_batch(source, i):
318311
# Need to use 'checkpoint=never' since as of PyTorch 1.8, Pipe checkpointing
319312
# doesn't work with DDP.
320313
from torch.distributed.pipeline.sync import Pipe
314+
chunks = 8
321315
model = Pipe(torch.nn.Sequential(
322-
*module_list), chunks = 8, checkpoint="never")
316+
*module_list), chunks = chunks, checkpoint="never")
323317

324318
# Initialize process group and wrap model in DDP.
325319
from torch.nn.parallel import DistributedDataParallel

0 commit comments

Comments
 (0)