Skip to content

Commit bb7e0f1

Browse files
author
Guanheng Zhang
committed
checkpoint
1 parent 2751bf3 commit bb7e0f1

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

beginner_source/transformer_tutorial.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
5555
super(TransformerModel, self).__init__()
5656
from torch.nn import TransformerEncoder, TransformerEncoderLayer
5757
self.model_type = 'Transformer'
58-
self.src_mask = None
5958
self.pos_encoder = PositionalEncoding(ninp, dropout)
6059
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
6160
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
@@ -65,7 +64,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
6564

6665
self.init_weights()
6766

68-
def _generate_square_subsequent_mask(self, sz):
67+
def generate_square_subsequent_mask(self, sz):
6968
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
7069
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
7170
return mask
@@ -76,15 +75,10 @@ def init_weights(self):
7675
self.decoder.bias.data.zero_()
7776
self.decoder.weight.data.uniform_(-initrange, initrange)
7877

79-
def forward(self, src):
80-
if self.src_mask is None or self.src_mask.size(0) != src.size(0):
81-
device = src.device
82-
mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
83-
self.src_mask = mask
84-
78+
def forward(self, src, src_mask):
8579
src = self.encoder(src) * math.sqrt(self.ninp)
8680
src = self.pos_encoder(src)
87-
output = self.transformer_encoder(src, self.src_mask)
81+
output = self.transformer_encoder(src, src_mask)
8882
output = self.decoder(output)
8983
return output
9084

@@ -223,7 +217,6 @@ def get_batch(source, i):
223217
dropout = 0.2 # the dropout value
224218
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
225219

226-
227220
######################################################################
228221
# Run the model
229222
# -------------
@@ -253,10 +246,13 @@ def train():
253246
total_loss = 0.
254247
start_time = time.time()
255248
ntokens = len(TEXT.vocab.stoi)
249+
src_mask = model.generate_square_subsequent_mask(bptt).to(device)
256250
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
257251
data, targets = get_batch(train_data, i)
258252
optimizer.zero_grad()
259-
output = model(data)
253+
if data.size(0) != bptt:
254+
src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
255+
output = model(data, src_mask)
260256
loss = criterion(output.view(-1, ntokens), targets)
261257
loss.backward()
262258
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
@@ -280,10 +276,13 @@ def evaluate(eval_model, data_source):
280276
eval_model.eval() # Turn on the evaluation mode
281277
total_loss = 0.
282278
ntokens = len(TEXT.vocab.stoi)
279+
src_mask = model.generate_square_subsequent_mask(bptt).to(device)
283280
with torch.no_grad():
284281
for i in range(0, data_source.size(0) - 1, bptt):
285282
data, targets = get_batch(data_source, i)
286-
output = eval_model(data)
283+
if data.size(0) != bptt:
284+
src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
285+
output = eval_model(data, src_mask)
287286
output_flat = output.view(-1, ntokens)
288287
total_loss += len(data) * criterion(output_flat, targets).item()
289288
return total_loss / (len(data_source) - 1)

0 commit comments

Comments
 (0)