Skip to content

Commit 40f8027

Browse files
zhangguanheng66Guanheng Zhangbrianjo
authored
Update transformer tutorial (#1197)
* checkpoint * checkpoint Co-authored-by: Guanheng Zhang <zhangguanheng@devfair0197.h2.fair> Co-authored-by: Brian Johnson <brianjo@fb.com>
1 parent 8eb4e3d commit 40f8027

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

beginner_source/transformer_tutorial.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
5555
super(TransformerModel, self).__init__()
5656
from torch.nn import TransformerEncoder, TransformerEncoderLayer
5757
self.model_type = 'Transformer'
58-
self.src_mask = None
5958
self.pos_encoder = PositionalEncoding(ninp, dropout)
6059
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
6160
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
@@ -65,7 +64,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
6564

6665
self.init_weights()
6766

68-
def _generate_square_subsequent_mask(self, sz):
67+
def generate_square_subsequent_mask(self, sz):
6968
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
7069
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
7170
return mask
@@ -76,15 +75,10 @@ def init_weights(self):
7675
self.decoder.bias.data.zero_()
7776
self.decoder.weight.data.uniform_(-initrange, initrange)
7877

79-
def forward(self, src):
80-
if self.src_mask is None or self.src_mask.size(0) != src.size(0):
81-
device = src.device
82-
mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
83-
self.src_mask = mask
84-
78+
def forward(self, src, src_mask):
8579
src = self.encoder(src) * math.sqrt(self.ninp)
8680
src = self.pos_encoder(src)
87-
output = self.transformer_encoder(src, self.src_mask)
81+
output = self.transformer_encoder(src, src_mask)
8882
output = self.decoder(output)
8983
return output
9084

@@ -253,10 +247,13 @@ def train():
253247
total_loss = 0.
254248
start_time = time.time()
255249
ntokens = len(TEXT.vocab.stoi)
250+
src_mask = model.generate_square_subsequent_mask(bptt).to(device)
256251
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
257252
data, targets = get_batch(train_data, i)
258253
optimizer.zero_grad()
259-
output = model(data)
254+
if data.size(0) != bptt:
255+
src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
256+
output = model(data, src_mask)
260257
loss = criterion(output.view(-1, ntokens), targets)
261258
loss.backward()
262259
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
@@ -280,10 +277,13 @@ def evaluate(eval_model, data_source):
280277
eval_model.eval() # Turn on the evaluation mode
281278
total_loss = 0.
282279
ntokens = len(TEXT.vocab.stoi)
280+
src_mask = model.generate_square_subsequent_mask(bptt).to(device)
283281
with torch.no_grad():
284282
for i in range(0, data_source.size(0) - 1, bptt):
285283
data, targets = get_batch(data_source, i)
286-
output = eval_model(data)
284+
if data.size(0) != bptt:
285+
src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
286+
output = eval_model(data, src_mask)
287287
output_flat = output.view(-1, ntokens)
288288
total_loss += len(data) * criterion(output_flat, targets).item()
289289
return total_loss / (len(data_source) - 1)

0 commit comments

Comments
 (0)