diff --git a/beginner_source/transformer_tutorial.py b/beginner_source/transformer_tutorial.py index ba4be379072..680e9dc4b62 100644 --- a/beginner_source/transformer_tutorial.py +++ b/beginner_source/transformer_tutorial.py @@ -45,15 +45,16 @@ # import math + import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn import TransformerEncoder, TransformerEncoderLayer class TransformerModel(nn.Module): def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): super(TransformerModel, self).__init__() - from torch.nn import TransformerEncoder, TransformerEncoderLayer self.model_type = 'Transformer' self.pos_encoder = PositionalEncoding(ninp, dropout) encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) @@ -251,12 +252,13 @@ def get_batch(source, i): # function to scale all the gradient together to prevent exploding. # +import time + criterion = nn.CrossEntropyLoss() lr = 5.0 # learning rate optimizer = torch.optim.SGD(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95) -import time def train(): model.train() # Turn on the train mode total_loss = 0.