@@ -89,7 +89,6 @@ def forward(self, x):
89
89
class Encoder (nn .Module ):
90
90
def __init__ (self , ntoken , ninp , dropout = 0.5 ):
91
91
super (Encoder , self ).__init__ ()
92
- self .src_mask = None
93
92
self .pos_encoder = PositionalEncoding (ninp , dropout )
94
93
self .encoder = nn .Embedding (ntoken , ninp )
95
94
self .ninp = ninp
@@ -99,17 +98,9 @@ def init_weights(self):
99
98
initrange = 0.1
100
99
self .encoder .weight .data .uniform_ (- initrange , initrange )
101
100
102
- def _generate_square_subsequent_mask (self , sz ):
103
- mask = (torch .triu (torch .ones (sz , sz )) == 1 ).transpose (0 , 1 )
104
- mask = mask .float ().masked_fill (mask == 0 , float ('-inf' )).masked_fill (mask == 1 , float (0.0 ))
105
- return mask
106
-
107
101
def forward (self , src ):
108
- if self .src_mask is None or self .src_mask .size (0 ) != src .size (0 ):
109
- device = src .device
110
- mask = self ._generate_square_subsequent_mask (src .size (0 )).to (device )
111
- self .src_mask = mask
112
-
102
+ # Need (S, N) format for encoder.
103
+ src = src .t ()
113
104
src = self .encoder (src ) * math .sqrt (self .ninp )
114
105
return self .pos_encoder (src )
115
106
@@ -125,7 +116,8 @@ def init_weights(self):
125
116
self .decoder .weight .data .uniform_ (- initrange , initrange )
126
117
127
118
def forward (self , inp ):
128
- return self .decoder (inp )
119
+ # Need batch dimension first for output of pipeline.
120
+ return self .decoder (inp ).permute (1 , 0 , 2 )
129
121
130
122
######################################################################
131
123
# Start multiple processes for training
@@ -245,7 +237,8 @@ def get_batch(source, i):
245
237
seq_len = min (bptt , len (source ) - 1 - i )
246
238
data = source [i :i + seq_len ]
247
239
target = source [i + 1 :i + 1 + seq_len ].view (- 1 )
248
- return data , target
240
+ # Need batch dimension first for pipeline parallelism.
241
+ return data .t (), target
249
242
250
243
######################################################################
251
244
# Model scale and Pipe initialization
@@ -318,8 +311,9 @@ def get_batch(source, i):
318
311
# Need to use 'checkpoint=never' since as of PyTorch 1.8, Pipe checkpointing
319
312
# doesn't work with DDP.
320
313
from torch .distributed .pipeline .sync import Pipe
314
+ chunks = 8
321
315
model = Pipe (torch .nn .Sequential (
322
- * module_list ), chunks = 8 , checkpoint = "never" )
316
+ * module_list ), chunks = chunks , checkpoint = "never" )
323
317
324
318
# Initialize process group and wrap model in DDP.
325
319
from torch .nn .parallel import DistributedDataParallel
0 commit comments