Skip to content

Commit a7be5f0

Browse files
Format imports according to PEP8
1 parent f7d7360 commit a7be5f0

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

beginner_source/transformer_tutorial.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,16 @@
4545
#
4646

4747
import math
48+
4849
import torch
4950
import torch.nn as nn
5051
import torch.nn.functional as F
52+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
5153

5254
class TransformerModel(nn.Module):
5355

5456
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
5557
super(TransformerModel, self).__init__()
56-
from torch.nn import TransformerEncoder, TransformerEncoderLayer
5758
self.model_type = 'Transformer'
5859
self.src_mask = None
5960
self.pos_encoder = PositionalEncoding(ninp, dropout)
@@ -151,6 +152,7 @@ def forward(self, x):
151152

152153
import torchtext
153154
from torchtext.data.utils import get_tokenizer
155+
154156
TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"),
155157
init_token='<sos>',
156158
eos_token='<eos>',
@@ -242,12 +244,13 @@ def get_batch(source, i):
242244
# function to scale all the gradient together to prevent exploding.
243245
#
244246

247+
import time
248+
245249
criterion = nn.CrossEntropyLoss()
246250
lr = 5.0 # learning rate
247251
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
248252
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
249253

250-
import time
251254
def train():
252255
model.train() # Turn on the train mode
253256
total_loss = 0.

0 commit comments

Comments
 (0)