diff --git a/beginner_source/transformer_tutorial.py b/beginner_source/transformer_tutorial.py index 90c8b902d37..aab564391e5 100644 --- a/beginner_source/transformer_tutorial.py +++ b/beginner_source/transformer_tutorial.py @@ -55,7 +55,6 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): super(TransformerModel, self).__init__() from torch.nn import TransformerEncoder, TransformerEncoderLayer self.model_type = 'Transformer' - self.src_mask = None self.pos_encoder = PositionalEncoding(ninp, dropout) encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) @@ -65,7 +64,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): self.init_weights() - def _generate_square_subsequent_mask(self, sz): + def generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask @@ -76,15 +75,10 @@ def init_weights(self): self.decoder.bias.data.zero_() self.decoder.weight.data.uniform_(-initrange, initrange) - def forward(self, src): - if self.src_mask is None or self.src_mask.size(0) != src.size(0): - device = src.device - mask = self._generate_square_subsequent_mask(src.size(0)).to(device) - self.src_mask = mask - + def forward(self, src, src_mask): src = self.encoder(src) * math.sqrt(self.ninp) src = self.pos_encoder(src) - output = self.transformer_encoder(src, self.src_mask) + output = self.transformer_encoder(src, src_mask) output = self.decoder(output) return output @@ -253,10 +247,13 @@ def train(): total_loss = 0. start_time = time.time() ntokens = len(TEXT.vocab.stoi) + src_mask = model.generate_square_subsequent_mask(bptt).to(device) for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)): data, targets = get_batch(train_data, i) optimizer.zero_grad() - output = model(data) + if data.size(0) != bptt: + src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device) + output = model(data, src_mask) loss = criterion(output.view(-1, ntokens), targets) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) @@ -280,10 +277,13 @@ def evaluate(eval_model, data_source): eval_model.eval() # Turn on the evaluation mode total_loss = 0. ntokens = len(TEXT.vocab.stoi) + src_mask = model.generate_square_subsequent_mask(bptt).to(device) with torch.no_grad(): for i in range(0, data_source.size(0) - 1, bptt): data, targets = get_batch(data_source, i) - output = eval_model(data) + if data.size(0) != bptt: + src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device) + output = eval_model(data, src_mask) output_flat = output.view(-1, ntokens) total_loss += len(data) * criterion(output_flat, targets).item() return total_loss / (len(data_source) - 1)