Skip to content

Commit f9d5be7

Browse files
authored
Merge pull request #649 from zhangguanheng66/transformer_tutorial
[WIP] Transformer tutorial
2 parents 98636c4 + c5d8f5c commit f9d5be7

File tree

3 files changed

+332
-0
lines changed

3 files changed

+332
-0
lines changed
65.2 KB
Loading
Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
"""
2+
nn.Transformer Tutorial
3+
============================
4+
5+
This is a tutorial to show how to implement `nn.Transformer <https://pytorch.org/docs/master/nn.html?highlight=nn%20transformer#torch.nn.Transformer>`__ module.
6+
7+
PyTorch 1.2 release includes a standard transformer module based on the
8+
paper `Attention is All You
9+
Need <https://arxiv.org/pdf/1706.03762.pdf>`__. The transformer model
10+
has been proved to be superior in quality for many sequence-to-sequence
11+
problems while being more parallelizable. The ``nn.Transformer`` module
12+
relies entirely on an attention mechanism (another module recently
13+
implemented as `nn.MultiheadAttention <https://pytorch.org/docs/master/nn.html?highlight=multiheadattention#torch.nn.MultiheadAttention>`__) to draw global dependencies
14+
between input and output. The ``nn.Transformer`` module is now highly
15+
modularized such that a single component (like `nn.TransformerEncoder <https://pytorch.org/docs/master/nn.html?highlight=nn%20transformerencoder#torch.nn.TransformerEncoder>`__
16+
in this tutorial) can be easily adapted/composed.
17+
18+
.. image:: ../_static/img/transformer_architecture.jpg
19+
20+
"""
21+
22+
######################################################################
23+
# Define the model
24+
# ----------------
25+
#
26+
27+
28+
######################################################################
29+
# In this tutorial, we train ``nn.TransformerEncoder`` model on a
30+
# language modeling task. The language modeling task is to assign a
31+
# probability for the likelihood of a given word (or a sequence of words)
32+
# to follow a sequence of words. A sequence of tokens are passed to the embedding
33+
# layer first, followed by a positional encoding layer to account for the order
34+
# of the word (see the next paragraph for more details). The
35+
# ``nn.TransformerEncoder`` consists of multiple layers of
36+
# `nn.TransformerEncoderLayer <https://pytorch.org/docs/master/nn.html?highlight=transformerencoderlayer#torch.nn.TransformerEncoderLayer>`__. Along with the input sequence, a square
37+
# attention mask is required because the self-attention layers in
38+
# ``nn.TransformerEncoder`` are only allowed to attend the earlier positions in
39+
# the sequence. For the language modeling task, any tokens on the future
40+
# positions should be masked. To have the actual words, the output
41+
# of ``nn.TransformerEncoder`` model is sent to the final Linear
42+
# layer, which is followed by a log-Softmax function.
43+
#
44+
45+
import math
46+
import torch
47+
import torch.nn as nn
48+
import torch.nn.functional as F
49+
50+
class TransformerModel(nn.Module):
51+
52+
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
53+
super(TransformerModel, self).__init__()
54+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
55+
self.model_type = 'Transformer'
56+
self.src_mask = None
57+
self.pos_encoder = PositionalEncoding(ninp, dropout)
58+
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
59+
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
60+
self.encoder = nn.Embedding(ntoken, ninp)
61+
self.ninp = ninp
62+
self.decoder = nn.Linear(ninp, ntoken)
63+
64+
self.init_weights()
65+
66+
def _generate_square_subsequent_mask(self, sz):
67+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
68+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
69+
return mask
70+
71+
def init_weights(self):
72+
initrange = 0.1
73+
self.encoder.weight.data.uniform_(-initrange, initrange)
74+
self.decoder.bias.data.zero_()
75+
self.decoder.weight.data.uniform_(-initrange, initrange)
76+
77+
def forward(self, src):
78+
if self.src_mask is None or self.src_mask.size(0) != len(src):
79+
device = src.device
80+
mask = self._generate_square_subsequent_mask(len(src)).to(device)
81+
self.src_mask = mask
82+
83+
src = self.encoder(src) * math.sqrt(self.ninp)
84+
src = self.pos_encoder(src)
85+
output = self.transformer_encoder(src, self.src_mask)
86+
output = self.decoder(output)
87+
return F.log_softmax(output, dim=-1)
88+
89+
90+
######################################################################
91+
# ``PositionalEncoding`` module injects some information about the
92+
# relative or absolute position of the tokens in the sequence. The
93+
# positional encodings have the same dimension as the embeddings so that
94+
# the two can be summed. Here, we use ``sine`` and ``cosine`` functions of
95+
# different frequencies.
96+
#
97+
98+
class PositionalEncoding(nn.Module):
99+
100+
def __init__(self, d_model, dropout=0.1, max_len=5000):
101+
super(PositionalEncoding, self).__init__()
102+
self.dropout = nn.Dropout(p=dropout)
103+
104+
pe = torch.zeros(max_len, d_model)
105+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
106+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
107+
pe[:, 0::2] = torch.sin(position * div_term)
108+
pe[:, 1::2] = torch.cos(position * div_term)
109+
pe = pe.unsqueeze(0).transpose(0, 1)
110+
self.register_buffer('pe', pe)
111+
112+
def forward(self, x):
113+
x = x + self.pe[:x.size(0), :]
114+
return self.dropout(x)
115+
116+
117+
######################################################################
118+
# Load and batch data
119+
# -------------------
120+
#
121+
122+
123+
######################################################################
124+
# The training process uses Wikitext-2 dataset from ``torchtext``. The
125+
# vocab object is built based on the train dataset and is used to numericalize
126+
# tokens into tensors. Starting from sequential data, the ``batchify()``
127+
# function arranges the dataset into columns. For instance, with the
128+
# alphabet as the sequence and a batch size of 4, we have the following
129+
# arrangement:
130+
#
131+
# ┌ A G M S ┐
132+
#
133+
# │ B H N T │
134+
#
135+
# │ C I O U |
136+
#
137+
# │ D J P V |
138+
#
139+
# │ E K Q W |
140+
#
141+
# └ F L R X ┘
142+
#
143+
# These columns are treated as independent by the model, which means that
144+
# the dependence of ``G`` and ``F`` can not be learned, but allows more
145+
# efficient batch processing.
146+
#
147+
148+
import torchtext
149+
from torchtext.data.utils import get_tokenizer
150+
TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"),
151+
init_token='<sos>',
152+
eos_token='<eos>',
153+
lower=True)
154+
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
155+
TEXT.build_vocab(train_txt)
156+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
157+
158+
def batchify(data, bsz):
159+
data = TEXT.numericalize([data.examples[0].text])
160+
# Divide the dataset into bsz parts.
161+
nbatch = data.size(0) // bsz
162+
# Trim off any extra elements that wouldn't cleanly fit (remainders).
163+
data = data.narrow(0, 0, nbatch * bsz)
164+
# Evenly divide the data across the bsz batches.
165+
data = data.view(bsz, -1).t().contiguous()
166+
return data.to(device)
167+
168+
batch_size = 20
169+
eval_batch_size = 10
170+
train_data = batchify(train_txt, batch_size)
171+
val_data = batchify(val_txt, eval_batch_size)
172+
test_data = batchify(test_txt, eval_batch_size)
173+
174+
175+
######################################################################
176+
# Functions to generate input and target sequence
177+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
178+
#
179+
180+
181+
######################################################################
182+
# ``get_batch()`` function generates the input and target sequence for
183+
# the transformer model. It subdivides the source data into chunks of
184+
# length ``bptt``. For the language modeling task, the model needs the
185+
# following words as ``Target``. For example, with a ``bptt`` value of 2,
186+
# we’d get the following two Variables for ``i`` = 0:
187+
#
188+
# Input | Target
189+
#
190+
# ┌ A G M S ┐ ┌ B H N T ┐
191+
#
192+
# └ B H N T ┘ └ C I O U ┘
193+
#
194+
# It should be noted that the chunks are along dimension 0, consistent
195+
# with the ``S`` dimension in the Transformer model. The batch dimension
196+
# ``N`` is along dimension 1.
197+
#
198+
199+
bptt = 35
200+
def get_batch(source, i):
201+
seq_len = min(bptt, len(source) - 1 - i)
202+
data = source[i:i+seq_len]
203+
target = source[i+1:i+1+seq_len].view(-1)
204+
return data, target
205+
206+
207+
######################################################################
208+
# Initiate an instance
209+
# --------------------
210+
#
211+
212+
213+
######################################################################
214+
# The model is set up with the hyperparameter below. The vocab size is
215+
# equal to the length of the vocab object.
216+
#
217+
218+
ntokens = len(TEXT.vocab.stoi) # the size of vocabulary
219+
emsize = 200 # embedding dimension
220+
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
221+
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
222+
nhead = 2 # the number of heads in the multiheadattention models
223+
dropout = 0.2 # the dropout value
224+
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
225+
226+
227+
######################################################################
228+
# Run the model
229+
# -------------
230+
#
231+
232+
233+
######################################################################
234+
# `CrossEntropyLoss <https://pytorch.org/docs/master/nn.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss>`__
235+
# is applied to track the loss and
236+
# `SGD <https://pytorch.org/docs/master/optim.html?highlight=sgd#torch.optim.SGD>`__
237+
# implements stochastic gradient descent method as the optimizer. The initial
238+
# learning rate is set to 5.0. `StepLR <https://pytorch.org/docs/master/optim.html?highlight=steplr#torch.optim.lr_scheduler.StepLR>`__ is
239+
# applied to adjust the learn rate through epochs. During the
240+
# training, we use
241+
# `nn.utils.clip_grad_norm\_ <https://pytorch.org/docs/master/nn.html?highlight=nn%20utils%20clip_grad_norm#torch.nn.utils.clip_grad_norm_>`__
242+
# function to scale all the gradient together to prevent exploding.
243+
#
244+
245+
criterion = nn.CrossEntropyLoss()
246+
lr = 5.0 # learning rate
247+
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
248+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
249+
250+
import time
251+
def train():
252+
model.train() # Turn on the train mode
253+
total_loss = 0.
254+
start_time = time.time()
255+
ntokens = len(TEXT.vocab.stoi)
256+
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
257+
data, targets = get_batch(train_data, i)
258+
optimizer.zero_grad()
259+
output = model(data)
260+
loss = criterion(output.view(-1, ntokens), targets)
261+
loss.backward()
262+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
263+
optimizer.step()
264+
265+
total_loss += loss.item()
266+
log_interval = 200
267+
if batch % log_interval == 0 and batch > 0:
268+
cur_loss = total_loss / log_interval
269+
elapsed = time.time() - start_time
270+
print('| epoch {:3d} | {:5d}/{:5d} batches | '
271+
'lr {:02.2f} | ms/batch {:5.2f} | '
272+
'loss {:5.2f} | ppl {:8.2f}'.format(
273+
epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],
274+
elapsed * 1000 / log_interval,
275+
cur_loss, math.exp(cur_loss)))
276+
total_loss = 0
277+
start_time = time.time()
278+
279+
def evaluate(eval_model, data_source):
280+
eval_model.eval() # Turn on the evaluation mode
281+
total_loss = 0.
282+
ntokens = len(TEXT.vocab.stoi)
283+
with torch.no_grad():
284+
for i in range(0, data_source.size(0) - 1, bptt):
285+
data, targets = get_batch(data_source, i)
286+
output = eval_model(data)
287+
output_flat = output.view(-1, ntokens)
288+
total_loss += len(data) * criterion(output_flat, targets).item()
289+
return total_loss / (len(data_source) - 1)
290+
291+
######################################################################
292+
# Loop over epochs. Save the model if the validation loss is the best
293+
# we've seen so far. Adjust the learning rate after each epoch.
294+
295+
best_val_loss = float("inf")
296+
epochs = 3 # The number of epochs
297+
best_model = None
298+
299+
for epoch in range(1, epochs + 1):
300+
epoch_start_time = time.time()
301+
train()
302+
val_loss = evaluate(model, val_data)
303+
print('-' * 89)
304+
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
305+
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
306+
val_loss, math.exp(val_loss)))
307+
print('-' * 89)
308+
309+
if val_loss < best_val_loss:
310+
best_val_loss = val_loss
311+
best_model = model
312+
313+
scheduler.step()
314+
315+
316+
######################################################################
317+
# Evaluate the model with the test dataset
318+
# -------------------------------------
319+
#
320+
# Apply the best model to check the result with the test dataset.
321+
322+
test_loss = evaluate(best_model, test_data)
323+
print('=' * 89)
324+
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
325+
test_loss, math.exp(test_loss)))
326+
print('=' * 89)

index.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@ Text
149149
:figure: /_static/img/text_sentiment_ngrams_model.png
150150
:description: :doc:`/beginner/text_sentiment_ngrams_tutorial`
151151

152+
.. customgalleryitem::
153+
:tooltip: Transformer Transformer Tutorial
154+
:figure: /_static/img/transformer_architecture.jpg
155+
:description: :doc:`/beginner/transformer_tutorial`
156+
152157
.. raw:: html
153158

154159
<div style='clear:both'></div>
@@ -315,6 +320,7 @@ PyTorch in Other Languages
315320
beginner/deep_learning_nlp_tutorial
316321
intermediate/seq2seq_translation_tutorial
317322
beginner/text_sentiment_ngrams_tutorial
323+
beginner/transformer_tutorial
318324

319325
.. toctree::
320326
:maxdepth: 2

0 commit comments

Comments
 (0)