-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Remove legacy torchtext code from translation tutorial #1250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
b4c030b
c63ee4f
21a1115
7eb94b9
14c1f5f
44ed7e0
5d7e194
2f43983
087fe98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,36 +2,25 @@ | |
Language Translation with TorchText | ||
=================================== | ||
|
||
This tutorial shows how to use several convenience classes of ``torchtext`` to preprocess | ||
This tutorial shows how to use ``torchtext`` to preprocess | ||
data from a well-known dataset containing sentences in both English and German and use it to | ||
train a sequence-to-sequence model with attention that can translate German sentences | ||
into English. | ||
|
||
It is based off of | ||
`this tutorial <https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb>`__ | ||
from PyTorch community member `Ben Trevett <https://github.com/bentrevett>`__ | ||
and was created by `Seth Weidman <https://github.com/SethHWeidman/>`__ with Ben's permission. | ||
with Ben's permission. We update the tutorials by removing some legecy code. | ||
|
||
By the end of this tutorial, you will be able to: | ||
|
||
- Preprocess sentences into a commonly-used format for NLP modeling using the following ``torchtext`` convenience classes: | ||
- `TranslationDataset <https://torchtext.readthedocs.io/en/latest/datasets.html#torchtext.datasets.TranslationDataset>`__ | ||
- `Field <https://torchtext.readthedocs.io/en/latest/data.html#torchtext.data.Field>`__ | ||
- `BucketIterator <https://torchtext.readthedocs.io/en/latest/data.html#torchtext.data.BucketIterator>`__ | ||
By the end of this tutorial, you will be able to preprocess sentences into tensors for NLP modeling and use `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>`__ for training and validing the model. | ||
""" | ||
|
||
###################################################################### | ||
# `Field` and `TranslationDataset` | ||
# Data Processing | ||
# ---------------- | ||
# ``torchtext`` has utilities for creating datasets that can be easily | ||
# iterated through for the purposes of creating a language translation | ||
# model. One key class is a | ||
# `Field <https://github.com/pytorch/text/blob/master/torchtext/data/field.py#L64>`__, | ||
# which specifies the way each sentence should be preprocessed, and another is the | ||
# `TranslationDataset` ; ``torchtext`` | ||
# has several such datasets; in this tutorial we'll use the | ||
# `Multi30k dataset <https://github.com/multi30k/dataset>`__, which contains about | ||
# 30,000 sentences (averaging about 13 words in length) in both English and German. | ||
# model. In this example, we show how to tokenize a raw text sentence, build vocabulary, and numericalize tokens into tensor. | ||
# | ||
# Note: the tokenization in this tutorial requires `Spacy <https://spacy.io>`__ | ||
# We use Spacy because it provides strong support for tokenization in languages | ||
|
@@ -48,79 +37,93 @@ | |
# | ||
# python -m spacy download en | ||
# python -m spacy download de | ||
# | ||
# With Spacy installed, the following code will tokenize each of the sentences | ||
# in the ``TranslationDataset`` based on the tokenizer defined in the ``Field`` | ||
|
||
from torchtext.datasets import Multi30k | ||
from torchtext.data import Field, BucketIterator | ||
|
||
SRC = Field(tokenize = "spacy", | ||
tokenizer_language="de", | ||
init_token = '<sos>', | ||
eos_token = '<eos>', | ||
lower = True) | ||
|
||
TRG = Field(tokenize = "spacy", | ||
tokenizer_language="en", | ||
init_token = '<sos>', | ||
eos_token = '<eos>', | ||
lower = True) | ||
|
||
train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), | ||
fields = (SRC, TRG)) | ||
|
||
###################################################################### | ||
# Now that we've defined ``train_data``, we can see an extremely useful | ||
# feature of ``torchtext``'s ``Field``: the ``build_vocab`` method | ||
# now allows us to create the vocabulary associated with each language | ||
|
||
SRC.build_vocab(train_data, min_freq = 2) | ||
TRG.build_vocab(train_data, min_freq = 2) | ||
|
||
###################################################################### | ||
# Once these lines of code have been run, ``SRC.vocab.stoi`` will be a | ||
# dictionary with the tokens in the vocabulary as keys and their | ||
# corresponding indices as values; ``SRC.vocab.itos`` will be the same | ||
# dictionary with the keys and values swapped. We won't make extensive | ||
# use of this fact in this tutorial, but this will likely be useful in | ||
# other NLP tasks you'll encounter. | ||
import torchtext | ||
import torch | ||
from torchtext.data.utils import get_tokenizer | ||
from collections import Counter | ||
from torchtext.vocab import Vocab | ||
from torchtext.utils import download_from_url, extract_archive | ||
import io | ||
|
||
url_base = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/' | ||
train_urls = ('train.de.gz', 'train.en.gz') | ||
val_urls = ('val.de.gz', 'val.en.gz') | ||
test_urls = ('test_2016_flickr.de.gz', 'test_2016_flickr.en.gz') | ||
|
||
train_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in train_urls] | ||
val_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in val_urls] | ||
test_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in test_urls] | ||
|
||
de_tokenizer = get_tokenizer('spacy', language='de') | ||
en_tokenizer = get_tokenizer('spacy', language='en') | ||
|
||
def build_vocab(filepath, tokenizer): | ||
counter = Counter() | ||
with io.open(filepath, encoding="utf8") as f: | ||
for string_ in f: | ||
counter.update(tokenizer(string_)) | ||
return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>']) | ||
|
||
de_vocab = build_vocab(train_filepaths[0], de_tokenizer) | ||
en_vocab = build_vocab(train_filepaths[1], en_tokenizer) | ||
|
||
def data_process(raw_de_iter, raw_en_iter): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How long does this function take? How long does the entire prepreocessing step take? Should we tokenizer async while training instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. offline discussion: we will have a follow-up PR since this PR doesn't focus on the dataloader. |
||
data_ = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove the trailing underscores |
||
for (raw_de, raw_en) in zip(raw_de_iter, raw_en_iter): | ||
de_tensor_ = torch.tensor([de_vocab[token] for token in de_tokenizer(raw_de)], | ||
dtype=torch.long) | ||
en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer(raw_en)], | ||
dtype=torch.long) | ||
data_.append((de_tensor_, en_tensor_)) | ||
return data_ | ||
|
||
train_data = data_process(iter(io.open(train_filepaths[0])), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think merging this into the data_process function and then only passing _filepaths cleans this up a bit |
||
iter(io.open(train_filepaths[1]))) | ||
val_data = data_process(iter(io.open(val_filepaths[0])), | ||
iter(io.open(val_filepaths[1]))) | ||
test_data = data_process(iter(io.open(test_filepaths[0])), | ||
iter(io.open(test_filepaths[1]))) | ||
|
||
###################################################################### | ||
# ``BucketIterator`` | ||
# ``DataLoader`` | ||
# ---------------- | ||
# The last ``torchtext`` specific feature we'll use is the ``BucketIterator``, | ||
# which is easy to use since it takes a ``TranslationDataset`` as its | ||
# The last ``torch`` specific feature we'll use is the ``DataLoader``, | ||
# which is easy to use since it takes the data as its | ||
# first argument. Specifically, as the docs say: | ||
# Defines an iterator that batches examples of similar lengths together. | ||
# Minimizes amount of padding needed while producing freshly shuffled | ||
# batches for each new epoch. See pool for the bucketing procedure used. | ||
# ``DataLoader`` combines a dataset and a sampler, and provides an iterable over the given dataset. The ``DataLoader`` supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning. | ||
# | ||
# Please pay attention to ``collate_fn`` (optional) that merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. | ||
# | ||
|
||
import torch | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
BATCH_SIZE = 128 | ||
PAD_IDX = de_vocab['<pad>'] | ||
BOS_IDX = de_vocab['<bos>'] | ||
EOS_IDX = de_vocab['<eos>'] | ||
|
||
from torch.nn.utils.rnn import pad_sequence | ||
from torch.utils.data import DataLoader | ||
|
||
def generate_batch(data_batch): | ||
de_batch, en_batch = [], [] | ||
for (de_item, en_item) in data_batch: | ||
de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0)) | ||
en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0)) | ||
de_batch = pad_sequence(de_batch, padding_value=PAD_IDX) | ||
en_batch = pad_sequence(en_batch, padding_value=PAD_IDX) | ||
return de_batch, en_batch | ||
|
||
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE, | ||
shuffle=True, collate_fn=generate_batch) | ||
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE, | ||
shuffle=True, collate_fn=generate_batch) | ||
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE, | ||
shuffle=True, collate_fn=generate_batch) | ||
|
||
train_iterator, valid_iterator, test_iterator = BucketIterator.splits( | ||
(train_data, valid_data, test_data), | ||
batch_size = BATCH_SIZE, | ||
device = device) | ||
|
||
###################################################################### | ||
# These iterators can be called just like ``DataLoader``s; below, in | ||
# the ``train`` and ``evaluate`` functions, they are called simply with: | ||
# | ||
# :: | ||
# | ||
# for i, batch in enumerate(iterator): | ||
# | ||
# Each ``batch`` then has ``src`` and ``trg`` attributes: | ||
# | ||
# :: | ||
# | ||
# src = batch.src | ||
# trg = batch.trg | ||
|
||
###################################################################### | ||
# Defining our ``nn.Module`` and ``Optimizer`` | ||
|
@@ -329,8 +332,8 @@ def forward(self, | |
return outputs | ||
|
||
|
||
INPUT_DIM = len(SRC.vocab) | ||
OUTPUT_DIM = len(TRG.vocab) | ||
INPUT_DIM = len(de_vocab) | ||
OUTPUT_DIM = len(en_vocab) | ||
# ENC_EMB_DIM = 256 | ||
# DEC_EMB_DIM = 256 | ||
# ENC_HID_DIM = 512 | ||
|
@@ -380,7 +383,7 @@ def count_parameters(model: nn.Module): | |
# particular, we have to tell the ``nn.CrossEntropyLoss`` function to | ||
# ignore the indices where the target is simply padding. | ||
|
||
PAD_IDX = TRG.vocab.stoi['<pad>'] | ||
PAD_IDX = en_vocab.stoi['<pad>'] | ||
|
||
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX) | ||
|
||
|
@@ -392,7 +395,7 @@ def count_parameters(model: nn.Module): | |
|
||
|
||
def train(model: nn.Module, | ||
iterator: BucketIterator, | ||
iterator: torch.utils.data.DataLoader, | ||
optimizer: optim.Optimizer, | ||
criterion: nn.Module, | ||
clip: float): | ||
|
@@ -401,10 +404,8 @@ def train(model: nn.Module, | |
|
||
epoch_loss = 0 | ||
|
||
for _, batch in enumerate(iterator): | ||
|
||
src = batch.src | ||
trg = batch.trg | ||
for _, (src, trg) in enumerate(iterator): | ||
src, trg = src.to(device), trg.to(device) | ||
|
||
optimizer.zero_grad() | ||
|
||
|
@@ -427,7 +428,7 @@ def train(model: nn.Module, | |
|
||
|
||
def evaluate(model: nn.Module, | ||
iterator: BucketIterator, | ||
iterator: torch.utils.data.DataLoader, | ||
criterion: nn.Module): | ||
|
||
model.eval() | ||
|
@@ -436,10 +437,8 @@ def evaluate(model: nn.Module, | |
|
||
with torch.no_grad(): | ||
|
||
for _, batch in enumerate(iterator): | ||
|
||
src = batch.src | ||
trg = batch.trg | ||
for _, (src, trg) in enumerate(iterator): | ||
src, trg = src.to(device), trg.to(device) | ||
|
||
output = model(src, trg, 0) #turn off teacher forcing | ||
|
||
|
@@ -470,8 +469,8 @@ def epoch_time(start_time: int, | |
|
||
start_time = time.time() | ||
|
||
train_loss = train(model, train_iterator, optimizer, criterion, CLIP) | ||
valid_loss = evaluate(model, valid_iterator, criterion) | ||
train_loss = train(model, train_iter, optimizer, criterion, CLIP) | ||
valid_loss = evaluate(model, valid_iter, criterion) | ||
|
||
end_time = time.time() | ||
|
||
|
@@ -481,7 +480,7 @@ def epoch_time(start_time: int, | |
print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}') | ||
print(f'\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}') | ||
|
||
test_loss = evaluate(model, test_iterator, criterion) | ||
test_loss = evaluate(model, test_iter, criterion) | ||
|
||
print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |') | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: legacy instead of legecy