@@ -55,7 +55,6 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
55
55
super (TransformerModel , self ).__init__ ()
56
56
from torch .nn import TransformerEncoder , TransformerEncoderLayer
57
57
self .model_type = 'Transformer'
58
- self .src_mask = None
59
58
self .pos_encoder = PositionalEncoding (ninp , dropout )
60
59
encoder_layers = TransformerEncoderLayer (ninp , nhead , nhid , dropout )
61
60
self .transformer_encoder = TransformerEncoder (encoder_layers , nlayers )
@@ -65,7 +64,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
65
64
66
65
self .init_weights ()
67
66
68
- def _generate_square_subsequent_mask (self , sz ):
67
+ def generate_square_subsequent_mask (self , sz ):
69
68
mask = (torch .triu (torch .ones (sz , sz )) == 1 ).transpose (0 , 1 )
70
69
mask = mask .float ().masked_fill (mask == 0 , float ('-inf' )).masked_fill (mask == 1 , float (0.0 ))
71
70
return mask
@@ -76,15 +75,10 @@ def init_weights(self):
76
75
self .decoder .bias .data .zero_ ()
77
76
self .decoder .weight .data .uniform_ (- initrange , initrange )
78
77
79
- def forward (self , src ):
80
- if self .src_mask is None or self .src_mask .size (0 ) != src .size (0 ):
81
- device = src .device
82
- mask = self ._generate_square_subsequent_mask (src .size (0 )).to (device )
83
- self .src_mask = mask
84
-
78
+ def forward (self , src , src_mask ):
85
79
src = self .encoder (src ) * math .sqrt (self .ninp )
86
80
src = self .pos_encoder (src )
87
- output = self .transformer_encoder (src , self . src_mask )
81
+ output = self .transformer_encoder (src , src_mask )
88
82
output = self .decoder (output )
89
83
return output
90
84
@@ -253,10 +247,13 @@ def train():
253
247
total_loss = 0.
254
248
start_time = time .time ()
255
249
ntokens = len (TEXT .vocab .stoi )
250
+ src_mask = model .generate_square_subsequent_mask (bptt ).to (device )
256
251
for batch , i in enumerate (range (0 , train_data .size (0 ) - 1 , bptt )):
257
252
data , targets = get_batch (train_data , i )
258
253
optimizer .zero_grad ()
259
- output = model (data )
254
+ if data .size (0 ) != bptt :
255
+ src_mask = model .generate_square_subsequent_mask (data .size (0 )).to (device )
256
+ output = model (data , src_mask )
260
257
loss = criterion (output .view (- 1 , ntokens ), targets )
261
258
loss .backward ()
262
259
torch .nn .utils .clip_grad_norm_ (model .parameters (), 0.5 )
@@ -280,10 +277,13 @@ def evaluate(eval_model, data_source):
280
277
eval_model .eval () # Turn on the evaluation mode
281
278
total_loss = 0.
282
279
ntokens = len (TEXT .vocab .stoi )
280
+ src_mask = model .generate_square_subsequent_mask (bptt ).to (device )
283
281
with torch .no_grad ():
284
282
for i in range (0 , data_source .size (0 ) - 1 , bptt ):
285
283
data , targets = get_batch (data_source , i )
286
- output = eval_model (data )
284
+ if data .size (0 ) != bptt :
285
+ src_mask = model .generate_square_subsequent_mask (data .size (0 )).to (device )
286
+ output = eval_model (data , src_mask )
287
287
output_flat = output .view (- 1 , ntokens )
288
288
total_loss += len (data ) * criterion (output_flat , targets ).item ()
289
289
return total_loss / (len (data_source ) - 1 )
0 commit comments