@@ -55,7 +55,6 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
55
55
super (TransformerModel , self ).__init__ ()
56
56
from torch .nn import TransformerEncoder , TransformerEncoderLayer
57
57
self .model_type = 'Transformer'
58
- self .src_mask = None
59
58
self .pos_encoder = PositionalEncoding (ninp , dropout )
60
59
encoder_layers = TransformerEncoderLayer (ninp , nhead , nhid , dropout )
61
60
self .transformer_encoder = TransformerEncoder (encoder_layers , nlayers )
@@ -65,7 +64,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
65
64
66
65
self .init_weights ()
67
66
68
- def _generate_square_subsequent_mask (self , sz ):
67
+ def generate_square_subsequent_mask (self , sz ):
69
68
mask = (torch .triu (torch .ones (sz , sz )) == 1 ).transpose (0 , 1 )
70
69
mask = mask .float ().masked_fill (mask == 0 , float ('-inf' )).masked_fill (mask == 1 , float (0.0 ))
71
70
return mask
@@ -76,15 +75,10 @@ def init_weights(self):
76
75
self .decoder .bias .data .zero_ ()
77
76
self .decoder .weight .data .uniform_ (- initrange , initrange )
78
77
79
- def forward (self , src ):
80
- if self .src_mask is None or self .src_mask .size (0 ) != src .size (0 ):
81
- device = src .device
82
- mask = self ._generate_square_subsequent_mask (src .size (0 )).to (device )
83
- self .src_mask = mask
84
-
78
+ def forward (self , src , src_mask ):
85
79
src = self .encoder (src ) * math .sqrt (self .ninp )
86
80
src = self .pos_encoder (src )
87
- output = self .transformer_encoder (src , self . src_mask )
81
+ output = self .transformer_encoder (src , src_mask )
88
82
output = self .decoder (output )
89
83
return output
90
84
@@ -223,7 +217,6 @@ def get_batch(source, i):
223
217
dropout = 0.2 # the dropout value
224
218
model = TransformerModel (ntokens , emsize , nhead , nhid , nlayers , dropout ).to (device )
225
219
226
-
227
220
######################################################################
228
221
# Run the model
229
222
# -------------
@@ -253,10 +246,13 @@ def train():
253
246
total_loss = 0.
254
247
start_time = time .time ()
255
248
ntokens = len (TEXT .vocab .stoi )
249
+ src_mask = model .generate_square_subsequent_mask (bptt ).to (device )
256
250
for batch , i in enumerate (range (0 , train_data .size (0 ) - 1 , bptt )):
257
251
data , targets = get_batch (train_data , i )
258
252
optimizer .zero_grad ()
259
- output = model (data )
253
+ if data .size (0 ) != bptt :
254
+ src_mask = model .generate_square_subsequent_mask (data .size (0 )).to (device )
255
+ output = model (data , src_mask )
260
256
loss = criterion (output .view (- 1 , ntokens ), targets )
261
257
loss .backward ()
262
258
torch .nn .utils .clip_grad_norm_ (model .parameters (), 0.5 )
@@ -280,10 +276,13 @@ def evaluate(eval_model, data_source):
280
276
eval_model .eval () # Turn on the evaluation mode
281
277
total_loss = 0.
282
278
ntokens = len (TEXT .vocab .stoi )
279
+ src_mask = model .generate_square_subsequent_mask (bptt ).to (device )
283
280
with torch .no_grad ():
284
281
for i in range (0 , data_source .size (0 ) - 1 , bptt ):
285
282
data , targets = get_batch (data_source , i )
286
- output = eval_model (data )
283
+ if data .size (0 ) != bptt :
284
+ src_mask = model .generate_square_subsequent_mask (data .size (0 )).to (device )
285
+ output = eval_model (data , src_mask )
287
286
output_flat = output .view (- 1 , ntokens )
288
287
total_loss += len (data ) * criterion (output_flat , targets ).item ()
289
288
return total_loss / (len (data_source ) - 1 )
0 commit comments