3
3
======================================================
4
4
5
5
This tutorial shows:
6
- - How to train a translation model from scratch using Transformer.
7
- - Use tochtext library to access `Multi30k <http://www.statmt.org/wmt16/multimodal-task.html#task1>`__ dataset to train a German to English translation model.
6
+ - How to train a translation model from scratch using Transformer.
7
+ - Use torchtext library to access `Multi30k <http://www.statmt.org/wmt16/multimodal-task.html#task1>`__ dataset to train a German to English translation model.
8
8
"""
9
9
10
10
14
14
#
15
15
# `torchtext library <https://pytorch.org/text/stable/>`__ has utilities for creating datasets that can be easily
16
16
# iterated through for the purposes of creating a language translation
17
- # model. In this example, we show how to use torchtext's inbuilt datasets,
17
+ # model. In this example, we show how to use torchtext's inbuilt datasets,
18
18
# tokenize a raw text sentence, build vocabulary, and numericalize tokens into tensor. We will use
19
19
# `Multi30k dataset from torchtext library <https://pytorch.org/text/stable/datasets.html#multi30k>`__
20
- # that yields a pair of source-target raw sentences.
20
+ # that yields a pair of source-target raw sentences.
21
21
#
22
- # To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.
22
+ # To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.
23
23
#
24
24
25
25
from torchtext .data .utils import get_tokenizer
@@ -61,18 +61,18 @@ def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
61
61
UNK_IDX , PAD_IDX , BOS_IDX , EOS_IDX = 0 , 1 , 2 , 3
62
62
# Make sure the tokens are in order of their indices to properly insert them in vocab
63
63
special_symbols = ['<unk>' , '<pad>' , '<bos>' , '<eos>' ]
64
-
64
+
65
65
for ln in [SRC_LANGUAGE , TGT_LANGUAGE ]:
66
- # Training data Iterator
66
+ # Training data Iterator
67
67
train_iter = Multi30k (split = 'train' , language_pair = (SRC_LANGUAGE , TGT_LANGUAGE ))
68
- # Create torchtext's Vocab object
68
+ # Create torchtext's Vocab object
69
69
vocab_transform [ln ] = build_vocab_from_iterator (yield_tokens (train_iter , ln ),
70
70
min_freq = 1 ,
71
71
specials = special_symbols ,
72
72
special_first = True )
73
73
74
- # Set UNK_IDX as the default index. This index is returned when the token is not found.
75
- # If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
74
+ # Set UNK_IDX as the default index. This index is returned when the token is not found.
75
+ # If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
76
76
for ln in [SRC_LANGUAGE , TGT_LANGUAGE ]:
77
77
vocab_transform [ln ].set_default_index (UNK_IDX )
78
78
@@ -82,14 +82,14 @@ def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
82
82
#
83
83
# Transformer is a Seq2Seq model introduced in `“Attention is all you
84
84
# need” <https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>`__
85
- # paper for solving machine translation tasks.
85
+ # paper for solving machine translation tasks.
86
86
# Below, we will create a Seq2Seq network that uses Transformer. The network
87
87
# consists of three parts. First part is the embedding layer. This layer converts tensor of input indices
88
88
# into corresponding tensor of input embeddings. These embedding are further augmented with positional
89
- # encodings to provide position information of input tokens to the model. The second part is the
90
- # actual `Transformer <https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`__ model.
91
- # Finally, the output of Transformer model is passed through linear layer
92
- # that give un-normalized probabilities for each token in the target language.
89
+ # encodings to provide position information of input tokens to the model. The second part is the
90
+ # actual `Transformer <https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`__ model.
91
+ # Finally, the output of the Transformer model is passed through linear layer
92
+ # that gives un-normalized probabilities for each token in the target language.
93
93
#
94
94
95
95
@@ -130,7 +130,7 @@ def __init__(self, vocab_size: int, emb_size):
130
130
def forward (self , tokens : Tensor ):
131
131
return self .embedding (tokens .long ()) * math .sqrt (self .emb_size )
132
132
133
- # Seq2Seq Network
133
+ # Seq2Seq Network
134
134
class Seq2SeqTransformer (nn .Module ):
135
135
def __init__ (self ,
136
136
num_encoder_layers : int ,
@@ -164,7 +164,7 @@ def forward(self,
164
164
memory_key_padding_mask : Tensor ):
165
165
src_emb = self .positional_encoding (self .src_tok_emb (src ))
166
166
tgt_emb = self .positional_encoding (self .tgt_tok_emb (trg ))
167
- outs = self .transformer (src_emb , tgt_emb , src_mask , tgt_mask , None ,
167
+ outs = self .transformer (src_emb , tgt_emb , src_mask , tgt_mask , None ,
168
168
src_padding_mask , tgt_padding_mask , memory_key_padding_mask )
169
169
return self .generator (outs )
170
170
@@ -179,9 +179,9 @@ def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
179
179
180
180
181
181
######################################################################
182
- # During training, we need a subsequent word mask that will prevent model to look into
182
+ # During training, we need a subsequent word mask that will prevent the model from looking into
183
183
# the future words when making predictions. We will also need masks to hide
184
- # source and target padding tokens. Below, let's define a function that will take care of both.
184
+ # source and target padding tokens. Below, let's define a function that will take care of both.
185
185
#
186
186
187
187
@@ -204,7 +204,7 @@ def create_mask(src, tgt):
204
204
205
205
206
206
######################################################################
207
- # Let's now define the parameters of our model and instantiate the same. Below, we also
207
+ # Let's now define the parameters of our model and instantiate the same. Below, we also
208
208
# define our loss function which is the cross-entropy loss and the optmizer used for training.
209
209
#
210
210
torch .manual_seed (0 )
@@ -218,7 +218,7 @@ def create_mask(src, tgt):
218
218
NUM_ENCODER_LAYERS = 3
219
219
NUM_DECODER_LAYERS = 3
220
220
221
- transformer = Seq2SeqTransformer (NUM_ENCODER_LAYERS , NUM_DECODER_LAYERS , EMB_SIZE ,
221
+ transformer = Seq2SeqTransformer (NUM_ENCODER_LAYERS , NUM_DECODER_LAYERS , EMB_SIZE ,
222
222
NHEAD , SRC_VOCAB_SIZE , TGT_VOCAB_SIZE , FFN_HID_DIM )
223
223
224
224
for p in transformer .parameters ():
@@ -234,11 +234,11 @@ def create_mask(src, tgt):
234
234
######################################################################
235
235
# Collation
236
236
# ---------
237
- #
238
- # As seen in the ``Data Sourcing and Processing`` section, our data iterator yields a pair of raw strings.
239
- # We need to convert these string pairs into the batched tensors that can be processed by our ``Seq2Seq`` network
240
- # defined previously. Below we define our collate function that convert batch of raw strings into batch tensors that
241
- # can be fed directly into our model.
237
+ #
238
+ # As seen in the ``Data Sourcing and Processing`` section, our data iterator yields a pair of raw strings.
239
+ # We need to convert these string pairs into the batched tensors that can be processed by our ``Seq2Seq`` network
240
+ # defined previously. Below we define our collate function that converts a batch of raw strings into batch tensors that
241
+ # can be fed directly into our model.
242
242
#
243
243
244
244
@@ -254,8 +254,8 @@ def func(txt_input):
254
254
255
255
# function to add BOS/EOS and create tensor for input sequence indices
256
256
def tensor_transform (token_ids : List [int ]):
257
- return torch .cat ((torch .tensor ([BOS_IDX ]),
258
- torch .tensor (token_ids ),
257
+ return torch .cat ((torch .tensor ([BOS_IDX ]),
258
+ torch .tensor (token_ids ),
259
259
torch .tensor ([EOS_IDX ])))
260
260
261
261
# src and tgt language text transforms to convert raw strings into tensors indices
@@ -276,9 +276,9 @@ def collate_fn(batch):
276
276
src_batch = pad_sequence (src_batch , padding_value = PAD_IDX )
277
277
tgt_batch = pad_sequence (tgt_batch , padding_value = PAD_IDX )
278
278
return src_batch , tgt_batch
279
-
279
+
280
280
######################################################################
281
- # Let's define training and evaluation loop that will be called for each
281
+ # Let's define training and evaluation loop that will be called for each
282
282
# epoch.
283
283
#
284
284
@@ -289,7 +289,7 @@ def train_epoch(model, optimizer):
289
289
losses = 0
290
290
train_iter = Multi30k (split = 'train' , language_pair = (SRC_LANGUAGE , TGT_LANGUAGE ))
291
291
train_dataloader = DataLoader (train_iter , batch_size = BATCH_SIZE , collate_fn = collate_fn )
292
-
292
+
293
293
for src , tgt in train_dataloader :
294
294
src = src .to (DEVICE )
295
295
tgt = tgt .to (DEVICE )
@@ -328,7 +328,7 @@ def evaluate(model):
328
328
src_mask , tgt_mask , src_padding_mask , tgt_padding_mask = create_mask (src , tgt_input )
329
329
330
330
logits = model (src , tgt_input , src_mask , tgt_mask ,src_padding_mask , tgt_padding_mask , src_padding_mask )
331
-
331
+
332
332
tgt_out = tgt [1 :, :]
333
333
loss = loss_fn (logits .reshape (- 1 , logits .shape [- 1 ]), tgt_out .reshape (- 1 ))
334
334
losses += loss .item ()
@@ -350,7 +350,7 @@ def evaluate(model):
350
350
print ((f"Epoch: { epoch } , Train loss: { train_loss :.3f} , Val loss: { val_loss :.3f} , " f"Epoch time = { (end_time - start_time ):.3f} s" ))
351
351
352
352
353
- # function to generate output sequence using greedy algorithm
353
+ # function to generate output sequence using greedy algorithm
354
354
def greedy_decode (model , src , src_mask , max_len , start_symbol ):
355
355
src = src .to (DEVICE )
356
356
src_mask = src_mask .to (DEVICE )
0 commit comments