Skip to content

Commit 0050004

Browse files
SleepyDeveloperSvetlana Karslioglu
and
Svetlana Karslioglu
authored
Update translation tutorial. (#2141)
* Update translation tutorial. * spelling & grammar updates Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
1 parent 4a06e05 commit 0050004

File tree

1 file changed

+33
-33
lines changed

1 file changed

+33
-33
lines changed

beginner_source/translation_transformer.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
======================================================
44
55
This tutorial shows:
6-
- How to train a translation model from scratch using Transformer.
7-
- Use tochtext library to access `Multi30k <http://www.statmt.org/wmt16/multimodal-task.html#task1>`__ dataset to train a German to English translation model.
6+
- How to train a translation model from scratch using Transformer.
7+
- Use torchtext library to access `Multi30k <http://www.statmt.org/wmt16/multimodal-task.html#task1>`__ dataset to train a German to English translation model.
88
"""
99

1010

@@ -14,12 +14,12 @@
1414
#
1515
# `torchtext library <https://pytorch.org/text/stable/>`__ has utilities for creating datasets that can be easily
1616
# iterated through for the purposes of creating a language translation
17-
# model. In this example, we show how to use torchtext's inbuilt datasets,
17+
# model. In this example, we show how to use torchtext's inbuilt datasets,
1818
# tokenize a raw text sentence, build vocabulary, and numericalize tokens into tensor. We will use
1919
# `Multi30k dataset from torchtext library <https://pytorch.org/text/stable/datasets.html#multi30k>`__
20-
# that yields a pair of source-target raw sentences.
20+
# that yields a pair of source-target raw sentences.
2121
#
22-
# To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.
22+
# To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.
2323
#
2424

2525
from torchtext.data.utils import get_tokenizer
@@ -61,18 +61,18 @@ def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
6161
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
6262
# Make sure the tokens are in order of their indices to properly insert them in vocab
6363
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
64-
64+
6565
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
66-
# Training data Iterator
66+
# Training data Iterator
6767
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
68-
# Create torchtext's Vocab object
68+
# Create torchtext's Vocab object
6969
vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
7070
min_freq=1,
7171
specials=special_symbols,
7272
special_first=True)
7373

74-
# Set UNK_IDX as the default index. This index is returned when the token is not found.
75-
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
74+
# Set UNK_IDX as the default index. This index is returned when the token is not found.
75+
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
7676
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
7777
vocab_transform[ln].set_default_index(UNK_IDX)
7878

@@ -82,14 +82,14 @@ def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
8282
#
8383
# Transformer is a Seq2Seq model introduced in `“Attention is all you
8484
# need” <https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>`__
85-
# paper for solving machine translation tasks.
85+
# paper for solving machine translation tasks.
8686
# Below, we will create a Seq2Seq network that uses Transformer. The network
8787
# consists of three parts. First part is the embedding layer. This layer converts tensor of input indices
8888
# into corresponding tensor of input embeddings. These embedding are further augmented with positional
89-
# encodings to provide position information of input tokens to the model. The second part is the
90-
# actual `Transformer <https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`__ model.
91-
# Finally, the output of Transformer model is passed through linear layer
92-
# that give un-normalized probabilities for each token in the target language.
89+
# encodings to provide position information of input tokens to the model. The second part is the
90+
# actual `Transformer <https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`__ model.
91+
# Finally, the output of the Transformer model is passed through linear layer
92+
# that gives un-normalized probabilities for each token in the target language.
9393
#
9494

9595

@@ -130,7 +130,7 @@ def __init__(self, vocab_size: int, emb_size):
130130
def forward(self, tokens: Tensor):
131131
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
132132

133-
# Seq2Seq Network
133+
# Seq2Seq Network
134134
class Seq2SeqTransformer(nn.Module):
135135
def __init__(self,
136136
num_encoder_layers: int,
@@ -164,7 +164,7 @@ def forward(self,
164164
memory_key_padding_mask: Tensor):
165165
src_emb = self.positional_encoding(self.src_tok_emb(src))
166166
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
167-
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
167+
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
168168
src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
169169
return self.generator(outs)
170170

@@ -179,9 +179,9 @@ def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
179179

180180

181181
######################################################################
182-
# During training, we need a subsequent word mask that will prevent model to look into
182+
# During training, we need a subsequent word mask that will prevent the model from looking into
183183
# the future words when making predictions. We will also need masks to hide
184-
# source and target padding tokens. Below, let's define a function that will take care of both.
184+
# source and target padding tokens. Below, let's define a function that will take care of both.
185185
#
186186

187187

@@ -204,7 +204,7 @@ def create_mask(src, tgt):
204204

205205

206206
######################################################################
207-
# Let's now define the parameters of our model and instantiate the same. Below, we also
207+
# Let's now define the parameters of our model and instantiate the same. Below, we also
208208
# define our loss function which is the cross-entropy loss and the optmizer used for training.
209209
#
210210
torch.manual_seed(0)
@@ -218,7 +218,7 @@ def create_mask(src, tgt):
218218
NUM_ENCODER_LAYERS = 3
219219
NUM_DECODER_LAYERS = 3
220220

221-
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
221+
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
222222
NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
223223

224224
for p in transformer.parameters():
@@ -234,11 +234,11 @@ def create_mask(src, tgt):
234234
######################################################################
235235
# Collation
236236
# ---------
237-
#
238-
# As seen in the ``Data Sourcing and Processing`` section, our data iterator yields a pair of raw strings.
239-
# We need to convert these string pairs into the batched tensors that can be processed by our ``Seq2Seq`` network
240-
# defined previously. Below we define our collate function that convert batch of raw strings into batch tensors that
241-
# can be fed directly into our model.
237+
#
238+
# As seen in the ``Data Sourcing and Processing`` section, our data iterator yields a pair of raw strings.
239+
# We need to convert these string pairs into the batched tensors that can be processed by our ``Seq2Seq`` network
240+
# defined previously. Below we define our collate function that converts a batch of raw strings into batch tensors that
241+
# can be fed directly into our model.
242242
#
243243

244244

@@ -254,8 +254,8 @@ def func(txt_input):
254254

255255
# function to add BOS/EOS and create tensor for input sequence indices
256256
def tensor_transform(token_ids: List[int]):
257-
return torch.cat((torch.tensor([BOS_IDX]),
258-
torch.tensor(token_ids),
257+
return torch.cat((torch.tensor([BOS_IDX]),
258+
torch.tensor(token_ids),
259259
torch.tensor([EOS_IDX])))
260260

261261
# src and tgt language text transforms to convert raw strings into tensors indices
@@ -276,9 +276,9 @@ def collate_fn(batch):
276276
src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
277277
tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
278278
return src_batch, tgt_batch
279-
279+
280280
######################################################################
281-
# Let's define training and evaluation loop that will be called for each
281+
# Let's define training and evaluation loop that will be called for each
282282
# epoch.
283283
#
284284

@@ -289,7 +289,7 @@ def train_epoch(model, optimizer):
289289
losses = 0
290290
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
291291
train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
292-
292+
293293
for src, tgt in train_dataloader:
294294
src = src.to(DEVICE)
295295
tgt = tgt.to(DEVICE)
@@ -328,7 +328,7 @@ def evaluate(model):
328328
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
329329

330330
logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
331-
331+
332332
tgt_out = tgt[1:, :]
333333
loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
334334
losses += loss.item()
@@ -350,7 +350,7 @@ def evaluate(model):
350350
print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
351351

352352

353-
# function to generate output sequence using greedy algorithm
353+
# function to generate output sequence using greedy algorithm
354354
def greedy_decode(model, src, src_mask, max_len, start_symbol):
355355
src = src.to(DEVICE)
356356
src_mask = src_mask.to(DEVICE)

0 commit comments

Comments
 (0)