Skip to content

Commit b4c030b

Browse files
author
Guanheng Zhang
committed
checkpoint
1 parent 73b6cd9 commit b4c030b

File tree

1 file changed

+94
-95
lines changed

1 file changed

+94
-95
lines changed

beginner_source/torchtext_translation_tutorial.py

Lines changed: 94 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,25 @@
22
Language Translation with TorchText
33
===================================
44
5-
This tutorial shows how to use several convenience classes of ``torchtext`` to preprocess
5+
This tutorial shows how to use ``torchtext`` to preprocess
66
data from a well-known dataset containing sentences in both English and German and use it to
77
train a sequence-to-sequence model with attention that can translate German sentences
88
into English.
99
1010
It is based off of
1111
`this tutorial <https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb>`__
1212
from PyTorch community member `Ben Trevett <https://github.com/bentrevett>`__
13-
and was created by `Seth Weidman <https://github.com/SethHWeidman/>`__ with Ben's permission.
13+
with Ben's permission. We update the tutorials by removing some legecy code.
1414
15-
By the end of this tutorial, you will be able to:
16-
17-
- Preprocess sentences into a commonly-used format for NLP modeling using the following ``torchtext`` convenience classes:
18-
- `TranslationDataset <https://torchtext.readthedocs.io/en/latest/datasets.html#torchtext.datasets.TranslationDataset>`__
19-
- `Field <https://torchtext.readthedocs.io/en/latest/data.html#torchtext.data.Field>`__
20-
- `BucketIterator <https://torchtext.readthedocs.io/en/latest/data.html#torchtext.data.BucketIterator>`__
15+
By the end of this tutorial, you will be able to preprocess sentences into tensors for NLP modeling and use `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>`__ for training and validing the model.
2116
"""
2217

2318
######################################################################
24-
# `Field` and `TranslationDataset`
19+
# Data Processing
2520
# ----------------
2621
# ``torchtext`` has utilities for creating datasets that can be easily
2722
# iterated through for the purposes of creating a language translation
28-
# model. One key class is a
29-
# `Field <https://github.com/pytorch/text/blob/master/torchtext/data/field.py#L64>`__,
30-
# which specifies the way each sentence should be preprocessed, and another is the
31-
# `TranslationDataset` ; ``torchtext``
32-
# has several such datasets; in this tutorial we'll use the
33-
# `Multi30k dataset <https://github.com/multi30k/dataset>`__, which contains about
34-
# 30,000 sentences (averaging about 13 words in length) in both English and German.
23+
# model. In this example, we show how to tokenize a raw text sentence, build vocabulary, and numericalize tokens into tensor.
3524
#
3625
# Note: the tokenization in this tutorial requires `Spacy <https://spacy.io>`__
3726
# We use Spacy because it provides strong support for tokenization in languages
@@ -46,81 +35,95 @@
4635
#
4736
# ::
4837
#
49-
# python -m spacy download en
50-
# python -m spacy download de
51-
#
52-
# With Spacy installed, the following code will tokenize each of the sentences
53-
# in the ``TranslationDataset`` based on the tokenizer defined in the ``Field``
54-
55-
from torchtext.datasets import Multi30k
56-
from torchtext.data import Field, BucketIterator
57-
58-
SRC = Field(tokenize = "spacy",
59-
tokenizer_language="de",
60-
init_token = '<sos>',
61-
eos_token = '<eos>',
62-
lower = True)
63-
64-
TRG = Field(tokenize = "spacy",
65-
tokenizer_language="en",
66-
init_token = '<sos>',
67-
eos_token = '<eos>',
68-
lower = True)
69-
70-
train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'),
71-
fields = (SRC, TRG))
38+
# python -m spacy download en
39+
# python -m spacy download de
7240

73-
######################################################################
74-
# Now that we've defined ``train_data``, we can see an extremely useful
75-
# feature of ``torchtext``'s ``Field``: the ``build_vocab`` method
76-
# now allows us to create the vocabulary associated with each language
77-
78-
SRC.build_vocab(train_data, min_freq = 2)
79-
TRG.build_vocab(train_data, min_freq = 2)
80-
81-
######################################################################
82-
# Once these lines of code have been run, ``SRC.vocab.stoi`` will be a
83-
# dictionary with the tokens in the vocabulary as keys and their
84-
# corresponding indices as values; ``SRC.vocab.itos`` will be the same
85-
# dictionary with the keys and values swapped. We won't make extensive
86-
# use of this fact in this tutorial, but this will likely be useful in
87-
# other NLP tasks you'll encounter.
41+
import torchtext
42+
import torch
43+
from torchtext.data.utils import get_tokenizer
44+
from collections import Counter
45+
from torchtext.vocab import Vocab
46+
from torchtext.utils import download_from_url, extract_archive
47+
import io
48+
49+
url_base = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/'
50+
train_urls = ('train.de.gz', 'train.en.gz')
51+
val_urls = ('val.de.gz', 'val.en.gz')
52+
test_urls = ('test_2016_flickr.de.gz', 'test_2016_flickr.en.gz')
53+
54+
train_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in train_urls]
55+
val_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in val_urls]
56+
test_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in test_urls]
57+
58+
de_tokenizer = get_tokenizer('spacy', language='de')
59+
en_tokenizer = get_tokenizer('spacy', language='en')
60+
61+
def build_vocab(filepath, tokenizer):
62+
counter = Counter()
63+
with io.open(filepath, encoding="utf8") as f:
64+
for string_ in f:
65+
counter.update(tokenizer(string_))
66+
return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
67+
68+
de_vocab = build_vocab(train_filepaths[0], de_tokenizer)
69+
en_vocab = build_vocab(train_filepaths[1], en_tokenizer)
70+
71+
def data_process(raw_de_iter, raw_en_iter):
72+
data_ = []
73+
for (raw_de, raw_en) in zip(raw_de_iter, raw_en_iter):
74+
de_tensor_ = torch.tensor([de_vocab[token] for token in de_tokenizer(raw_de)],
75+
dtype=torch.long)
76+
en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer(raw_en)],
77+
dtype=torch.long)
78+
data_.append((de_tensor_, en_tensor_))
79+
return data_
80+
81+
train_data = data_process(iter(io.open(train_filepaths[0])),
82+
iter(io.open(train_filepaths[1])))
83+
val_data = data_process(iter(io.open(val_filepaths[0])),
84+
iter(io.open(val_filepaths[1])))
85+
test_data = data_process(iter(io.open(test_filepaths[0])),
86+
iter(io.open(test_filepaths[1])))
8887

8988
######################################################################
90-
# ``BucketIterator``
89+
# ``DataLoader``
9190
# ----------------
92-
# The last ``torchtext`` specific feature we'll use is the ``BucketIterator``,
93-
# which is easy to use since it takes a ``TranslationDataset`` as its
91+
# The last ``torch`` specific feature we'll use is the ``DataLoader``,
92+
# which is easy to use since it takes the data as its
9493
# first argument. Specifically, as the docs say:
95-
# Defines an iterator that batches examples of similar lengths together.
96-
# Minimizes amount of padding needed while producing freshly shuffled
97-
# batches for each new epoch. See pool for the bucketing procedure used.
94+
# ``DataLoader`` combines a dataset and a sampler, and provides an iterable over the given dataset. The ``DataLoader`` supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.
95+
#
96+
# Please pay attention to ``collate_fn`` (optional) that merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
97+
#
9898

9999
import torch
100100

101101
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102102

103103
BATCH_SIZE = 128
104+
PAD_IDX = de_vocab['<pad>']
105+
BOS_IDX = de_vocab['<bos>']
106+
EOS_IDX = de_vocab['<eos>']
107+
108+
from torch.nn.utils.rnn import pad_sequence
109+
from torch.utils.data import DataLoader
110+
111+
def generate_batch(data_batch):
112+
de_batch, en_batch = [], []
113+
for (de_item, en_item) in data_batch:
114+
de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0))
115+
en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
116+
de_batch = pad_sequence(de_batch, padding_value=PAD_IDX)
117+
en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
118+
return de_batch, en_batch
119+
120+
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
121+
shuffle=True, collate_fn=generate_batch)
122+
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
123+
shuffle=True, collate_fn=generate_batch)
124+
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
125+
shuffle=True, collate_fn=generate_batch)
104126

105-
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
106-
(train_data, valid_data, test_data),
107-
batch_size = BATCH_SIZE,
108-
device = device)
109-
110-
######################################################################
111-
# These iterators can be called just like ``DataLoader``s; below, in
112-
# the ``train`` and ``evaluate`` functions, they are called simply with:
113-
#
114-
# ::
115-
#
116-
# for i, batch in enumerate(iterator):
117-
#
118-
# Each ``batch`` then has ``src`` and ``trg`` attributes:
119-
#
120-
# ::
121-
#
122-
# src = batch.src
123-
# trg = batch.trg
124127

125128
######################################################################
126129
# Defining our ``nn.Module`` and ``Optimizer``
@@ -329,8 +332,8 @@ def forward(self,
329332
return outputs
330333

331334

332-
INPUT_DIM = len(SRC.vocab)
333-
OUTPUT_DIM = len(TRG.vocab)
335+
INPUT_DIM = len(de_vocab)
336+
OUTPUT_DIM = len(en_vocab)
334337
# ENC_EMB_DIM = 256
335338
# DEC_EMB_DIM = 256
336339
# ENC_HID_DIM = 512
@@ -380,7 +383,7 @@ def count_parameters(model: nn.Module):
380383
# particular, we have to tell the ``nn.CrossEntropyLoss`` function to
381384
# ignore the indices where the target is simply padding.
382385

383-
PAD_IDX = TRG.vocab.stoi['<pad>']
386+
PAD_IDX = en_vocab.stoi['<pad>']
384387

385388
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
386389

@@ -392,7 +395,7 @@ def count_parameters(model: nn.Module):
392395

393396

394397
def train(model: nn.Module,
395-
iterator: BucketIterator,
398+
iterator,
396399
optimizer: optim.Optimizer,
397400
criterion: nn.Module,
398401
clip: float):
@@ -401,10 +404,8 @@ def train(model: nn.Module,
401404

402405
epoch_loss = 0
403406

404-
for _, batch in enumerate(iterator):
405-
406-
src = batch.src
407-
trg = batch.trg
407+
for _, (src, trg) in enumerate(iterator):
408+
src, trg = src.to(device), trg.to(device)
408409

409410
optimizer.zero_grad()
410411

@@ -427,7 +428,7 @@ def train(model: nn.Module,
427428

428429

429430
def evaluate(model: nn.Module,
430-
iterator: BucketIterator,
431+
iterator,
431432
criterion: nn.Module):
432433

433434
model.eval()
@@ -436,10 +437,8 @@ def evaluate(model: nn.Module,
436437

437438
with torch.no_grad():
438439

439-
for _, batch in enumerate(iterator):
440-
441-
src = batch.src
442-
trg = batch.trg
440+
for _, (src, trg) in enumerate(iterator):
441+
src, trg = src.to(device), trg.to(device)
443442

444443
output = model(src, trg, 0) #turn off teacher forcing
445444

@@ -470,8 +469,8 @@ def epoch_time(start_time: int,
470469

471470
start_time = time.time()
472471

473-
train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
474-
valid_loss = evaluate(model, valid_iterator, criterion)
472+
train_loss = train(model, train_iter, optimizer, criterion, CLIP)
473+
valid_loss = evaluate(model, valid_iter, criterion)
475474

476475
end_time = time.time()
477476

@@ -481,7 +480,7 @@ def epoch_time(start_time: int,
481480
print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
482481
print(f'\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')
483482

484-
test_loss = evaluate(model, test_iterator, criterion)
483+
test_loss = evaluate(model, test_iter, criterion)
485484

486485
print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')
487486

0 commit comments

Comments
 (0)