Skip to content

Remove legacy torchtext code from translation tutorial #1250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 3, 2020
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 90 additions & 93 deletions beginner_source/torchtext_translation_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,25 @@
Language Translation with TorchText
===================================

This tutorial shows how to use several convenience classes of ``torchtext`` to preprocess
This tutorial shows how to use ``torchtext`` to preprocess
data from a well-known dataset containing sentences in both English and German and use it to
train a sequence-to-sequence model with attention that can translate German sentences
into English.

It is based off of
`this tutorial <https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb>`__
from PyTorch community member `Ben Trevett <https://github.com/bentrevett>`__
and was created by `Seth Weidman <https://github.com/SethHWeidman/>`__ with Ben's permission.
with Ben's permission. We update the tutorials by removing some legecy code.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: legacy instead of legecy


By the end of this tutorial, you will be able to:

- Preprocess sentences into a commonly-used format for NLP modeling using the following ``torchtext`` convenience classes:
- `TranslationDataset <https://torchtext.readthedocs.io/en/latest/datasets.html#torchtext.datasets.TranslationDataset>`__
- `Field <https://torchtext.readthedocs.io/en/latest/data.html#torchtext.data.Field>`__
- `BucketIterator <https://torchtext.readthedocs.io/en/latest/data.html#torchtext.data.BucketIterator>`__
By the end of this tutorial, you will be able to preprocess sentences into tensors for NLP modeling and use `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>`__ for training and validing the model.
"""

######################################################################
# `Field` and `TranslationDataset`
# Data Processing
# ----------------
# ``torchtext`` has utilities for creating datasets that can be easily
# iterated through for the purposes of creating a language translation
# model. One key class is a
# `Field <https://github.com/pytorch/text/blob/master/torchtext/data/field.py#L64>`__,
# which specifies the way each sentence should be preprocessed, and another is the
# `TranslationDataset` ; ``torchtext``
# has several such datasets; in this tutorial we'll use the
# `Multi30k dataset <https://github.com/multi30k/dataset>`__, which contains about
# 30,000 sentences (averaging about 13 words in length) in both English and German.
# model. In this example, we show how to tokenize a raw text sentence, build vocabulary, and numericalize tokens into tensor.
#
# Note: the tokenization in this tutorial requires `Spacy <https://spacy.io>`__
# We use Spacy because it provides strong support for tokenization in languages
Expand All @@ -48,79 +37,91 @@
#
# python -m spacy download en
# python -m spacy download de
#
# With Spacy installed, the following code will tokenize each of the sentences
# in the ``TranslationDataset`` based on the tokenizer defined in the ``Field``

from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator

SRC = Field(tokenize = "spacy",
tokenizer_language="de",
init_token = '<sos>',
eos_token = '<eos>',
lower = True)

TRG = Field(tokenize = "spacy",
tokenizer_language="en",
init_token = '<sos>',
eos_token = '<eos>',
lower = True)

train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'),
fields = (SRC, TRG))

######################################################################
# Now that we've defined ``train_data``, we can see an extremely useful
# feature of ``torchtext``'s ``Field``: the ``build_vocab`` method
# now allows us to create the vocabulary associated with each language

SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

######################################################################
# Once these lines of code have been run, ``SRC.vocab.stoi`` will be a
# dictionary with the tokens in the vocabulary as keys and their
# corresponding indices as values; ``SRC.vocab.itos`` will be the same
# dictionary with the keys and values swapped. We won't make extensive
# use of this fact in this tutorial, but this will likely be useful in
# other NLP tasks you'll encounter.
import torchtext
import torch
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab
from torchtext.utils import download_from_url, extract_archive
import io

url_base = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/'
train_urls = ('train.de.gz', 'train.en.gz')
val_urls = ('val.de.gz', 'val.en.gz')
test_urls = ('test_2016_flickr.de.gz', 'test_2016_flickr.en.gz')

train_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in train_urls]
val_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in val_urls]
test_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in test_urls]

de_tokenizer = get_tokenizer('spacy', language='de')
en_tokenizer = get_tokenizer('spacy', language='en')

def build_vocab(filepath, tokenizer):
counter = Counter()
with io.open(filepath, encoding="utf8") as f:
for string_ in f:
counter.update(tokenizer(string_))
return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

de_vocab = build_vocab(train_filepaths[0], de_tokenizer)
en_vocab = build_vocab(train_filepaths[1], en_tokenizer)

def data_process(filepaths):
raw_de_iter, raw_en_iter = iter(io.open(filepaths[0])), iter(io.open(filepaths[1]))
data = []
for (raw_de, raw_en) in zip(raw_de_iter, raw_en_iter):
de_tensor_ = torch.tensor([de_vocab[token] for token in de_tokenizer(raw_de)],
dtype=torch.long)
en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer(raw_en)],
dtype=torch.long)
data.append((de_tensor_, en_tensor_))
return data

train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)

######################################################################
# ``BucketIterator``
# ``DataLoader``
# ----------------
# The last ``torchtext`` specific feature we'll use is the ``BucketIterator``,
# which is easy to use since it takes a ``TranslationDataset`` as its
# The last ``torch`` specific feature we'll use is the ``DataLoader``,
# which is easy to use since it takes the data as its
# first argument. Specifically, as the docs say:
# Defines an iterator that batches examples of similar lengths together.
# Minimizes amount of padding needed while producing freshly shuffled
# batches for each new epoch. See pool for the bucketing procedure used.
# ``DataLoader`` combines a dataset and a sampler, and provides an iterable over the given dataset. The ``DataLoader`` supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.
#
# Please pay attention to ``collate_fn`` (optional) that merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
#

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 128
PAD_IDX = de_vocab['<pad>']
BOS_IDX = de_vocab['<bos>']
EOS_IDX = de_vocab['<eos>']

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def generate_batch(data_batch):
de_batch, en_batch = [], []
for (de_item, en_item) in data_batch:
de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0))
en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
de_batch = pad_sequence(de_batch, padding_value=PAD_IDX)
en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
return de_batch, en_batch

train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=generate_batch)

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
(train_data, valid_data, test_data),
batch_size = BATCH_SIZE,
device = device)

######################################################################
# These iterators can be called just like ``DataLoader``s; below, in
# the ``train`` and ``evaluate`` functions, they are called simply with:
#
# ::
#
# for i, batch in enumerate(iterator):
#
# Each ``batch`` then has ``src`` and ``trg`` attributes:
#
# ::
#
# src = batch.src
# trg = batch.trg

######################################################################
# Defining our ``nn.Module`` and ``Optimizer``
Expand Down Expand Up @@ -329,8 +330,8 @@ def forward(self,
return outputs


INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
INPUT_DIM = len(de_vocab)
OUTPUT_DIM = len(en_vocab)
# ENC_EMB_DIM = 256
# DEC_EMB_DIM = 256
# ENC_HID_DIM = 512
Expand Down Expand Up @@ -380,7 +381,7 @@ def count_parameters(model: nn.Module):
# particular, we have to tell the ``nn.CrossEntropyLoss`` function to
# ignore the indices where the target is simply padding.

PAD_IDX = TRG.vocab.stoi['<pad>']
PAD_IDX = en_vocab.stoi['<pad>']

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

Expand All @@ -392,7 +393,7 @@ def count_parameters(model: nn.Module):


def train(model: nn.Module,
iterator: BucketIterator,
iterator: torch.utils.data.DataLoader,
optimizer: optim.Optimizer,
criterion: nn.Module,
clip: float):
Expand All @@ -401,10 +402,8 @@ def train(model: nn.Module,

epoch_loss = 0

for _, batch in enumerate(iterator):

src = batch.src
trg = batch.trg
for _, (src, trg) in enumerate(iterator):
src, trg = src.to(device), trg.to(device)

optimizer.zero_grad()

Expand All @@ -427,7 +426,7 @@ def train(model: nn.Module,


def evaluate(model: nn.Module,
iterator: BucketIterator,
iterator: torch.utils.data.DataLoader,
criterion: nn.Module):

model.eval()
Expand All @@ -436,10 +435,8 @@ def evaluate(model: nn.Module,

with torch.no_grad():

for _, batch in enumerate(iterator):

src = batch.src
trg = batch.trg
for _, (src, trg) in enumerate(iterator):
src, trg = src.to(device), trg.to(device)

output = model(src, trg, 0) #turn off teacher forcing

Expand Down Expand Up @@ -470,8 +467,8 @@ def epoch_time(start_time: int,

start_time = time.time()

train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
valid_loss = evaluate(model, valid_iterator, criterion)
train_loss = train(model, train_iter, optimizer, criterion, CLIP)
valid_loss = evaluate(model, valid_iter, criterion)

end_time = time.time()

Expand All @@ -481,7 +478,7 @@ def epoch_time(start_time: int,
print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')

test_loss = evaluate(model, test_iterator, criterion)
test_loss = evaluate(model, test_iter, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

Expand Down