Skip to content

Update transformer tutorial #1197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 27, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions beginner_source/transformer_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
super(TransformerModel, self).__init__()
from torch.nn import TransformerEncoder, TransformerEncoderLayer
self.model_type = 'Transformer'
self.src_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout)
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
Expand All @@ -65,7 +64,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):

self.init_weights()

def _generate_square_subsequent_mask(self, sz):
def generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
Expand All @@ -76,15 +75,10 @@ def init_weights(self):
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)

def forward(self, src):
if self.src_mask is None or self.src_mask.size(0) != src.size(0):
device = src.device
mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
self.src_mask = mask

def forward(self, src, src_mask):
src = self.encoder(src) * math.sqrt(self.ninp)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, self.src_mask)
output = self.transformer_encoder(src, src_mask)
output = self.decoder(output)
return output

Expand Down Expand Up @@ -253,10 +247,13 @@ def train():
total_loss = 0.
start_time = time.time()
ntokens = len(TEXT.vocab.stoi)
src_mask = model.generate_square_subsequent_mask(bptt).to(device)
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets = get_batch(train_data, i)
optimizer.zero_grad()
output = model(data)
if data.size(0) != bptt:
src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
output = model(data, src_mask)
loss = criterion(output.view(-1, ntokens), targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
Expand All @@ -280,10 +277,13 @@ def evaluate(eval_model, data_source):
eval_model.eval() # Turn on the evaluation mode
total_loss = 0.
ntokens = len(TEXT.vocab.stoi)
src_mask = model.generate_square_subsequent_mask(bptt).to(device)
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, bptt):
data, targets = get_batch(data_source, i)
output = eval_model(data)
if data.size(0) != bptt:
src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
output = eval_model(data, src_mask)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()
return total_loss / (len(data_source) - 1)
Expand Down