Skip to content

Commit 3e6d1f4

Browse files
Format imports according to PEP8 (#1014)
Co-authored-by: holly1238 <77758406+holly1238@users.noreply.github.com>
1 parent a075e3d commit 3e6d1f4

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

beginner_source/transformer_tutorial.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,16 @@
4545
#
4646

4747
import math
48+
4849
import torch
4950
import torch.nn as nn
5051
import torch.nn.functional as F
52+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
5153

5254
class TransformerModel(nn.Module):
5355

5456
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
5557
super(TransformerModel, self).__init__()
56-
from torch.nn import TransformerEncoder, TransformerEncoderLayer
5758
self.model_type = 'Transformer'
5859
self.pos_encoder = PositionalEncoding(ninp, dropout)
5960
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
@@ -251,12 +252,13 @@ def get_batch(source, i):
251252
# function to scale all the gradient together to prevent exploding.
252253
#
253254

255+
import time
256+
254257
criterion = nn.CrossEntropyLoss()
255258
lr = 5.0 # learning rate
256259
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
257260
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
258261

259-
import time
260262
def train():
261263
model.train() # Turn on the train mode
262264
total_loss = 0.

0 commit comments

Comments
 (0)