Skip to content

Mention prerequisites for running tutorial basing on observations #2461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 74 additions & 48 deletions beginner_source/text_sentiment_ngrams_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
- Access to the raw data as an iterator
- Build data processing pipeline to convert the raw text strings into ``torch.Tensor`` that can be used to train the model
- Shuffle and iterate the data with `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>`__


Prerequisites
~~~~~~~~~~~~~~~~

A recent 2.x version of the ``portalocker`` package needs to be installed prior to running the tutorial.
For example, in the Colab environment, this can be done by adding the following line at the top of the script:

.. code-block:: bash

!pip install -U portalocker>=2.0.0`

"""


Expand All @@ -16,12 +28,13 @@
#
# The torchtext library provides a few raw dataset iterators, which yield the raw text strings. For example, the ``AG_NEWS`` dataset iterators yield the raw data as a tuple of label and text.
#
# To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.
# To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.
#

import torch
from torchtext.datasets import AG_NEWS
train_iter = iter(AG_NEWS(split='train'))

train_iter = iter(AG_NEWS(split="train"))

######################################################################
# ::
Expand Down Expand Up @@ -60,13 +73,15 @@
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(split='train')
tokenizer = get_tokenizer("basic_english")
train_iter = AG_NEWS(split="train")


def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text)


vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

Expand Down Expand Up @@ -96,7 +111,6 @@ def yield_tokens(data_iter):
#



######################################################################
# Generate data batch and iterator
# --------------------------------
Expand All @@ -111,22 +125,27 @@ def yield_tokens(data_iter):


from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def collate_batch(batch):
label_list, text_list, offsets = [], [], [0]
for (_label, _text) in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(processed_text)
offsets.append(processed_text.size(0))
for _label, _text in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(processed_text)
offsets.append(processed_text.size(0))
label_list = torch.tensor(label_list, dtype=torch.int64)
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
text_list = torch.cat(text_list)
return label_list.to(device), text_list.to(device), offsets.to(device)

train_iter = AG_NEWS(split='train')
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

train_iter = AG_NEWS(split="train")
dataloader = DataLoader(
train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch
)


######################################################################
Expand All @@ -144,8 +163,8 @@ def collate_batch(batch):

from torch import nn

class TextClassificationModel(nn.Module):

class TextClassificationModel(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
super(TextClassificationModel, self).__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
Expand Down Expand Up @@ -179,7 +198,7 @@ def forward(self, text, offsets):
# We build a model with the embedding dimension of 64. The vocab size is equal to the length of the vocabulary instance. The number of classes is equal to the number of labels,
#

train_iter = AG_NEWS(split='train')
train_iter = AG_NEWS(split="train")
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
emsize = 64
Expand All @@ -194,6 +213,7 @@ def forward(self, text, offsets):

import time


def train(dataloader):
model.train()
total_acc, total_count = 0, 0
Expand All @@ -211,12 +231,16 @@ def train(dataloader):
total_count += label.size(0)
if idx % log_interval == 0 and idx > 0:
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches '
'| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
total_acc/total_count))
print(
"| epoch {:3d} | {:5d}/{:5d} batches "
"| accuracy {:8.3f}".format(
epoch, idx, len(dataloader), total_acc / total_count
)
)
total_acc, total_count = 0, 0
start_time = time.time()


def evaluate(dataloader):
model.eval()
total_acc, total_count = 0, 0
Expand All @@ -227,7 +251,7 @@ def evaluate(dataloader):
loss = criterion(predicted_label, label)
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
return total_acc/total_count
return total_acc / total_count


######################################################################
Expand All @@ -253,10 +277,11 @@ def evaluate(dataloader):

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

# Hyperparameters
EPOCHS = 10 # epoch
EPOCHS = 10 # epoch
LR = 5 # learning rate
BATCH_SIZE = 64 # batch size for training
BATCH_SIZE = 64 # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
Expand All @@ -266,31 +291,36 @@ def evaluate(dataloader):
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = \
random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=collate_batch)
split_train_, split_valid_ = random_split(
train_dataset, [num_train, len(train_dataset) - num_train]
)

train_dataloader = DataLoader(
split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)
valid_dataloader = DataLoader(
split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)
test_dataloader = DataLoader(
test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)

for epoch in range(1, EPOCHS + 1):
epoch_start_time = time.time()
train(train_dataloader)
accu_val = evaluate(valid_dataloader)
if total_accu is not None and total_accu > accu_val:
scheduler.step()
scheduler.step()
else:
total_accu = accu_val
print('-' * 59)
print('| end of epoch {:3d} | time: {:5.2f}s | '
'valid accuracy {:8.3f} '.format(epoch,
time.time() - epoch_start_time,
accu_val))
print('-' * 59)

total_accu = accu_val
print("-" * 59)
print(
"| end of epoch {:3d} | time: {:5.2f}s | "
"valid accuracy {:8.3f} ".format(
epoch, time.time() - epoch_start_time, accu_val
)
)
print("-" * 59)


######################################################################
Expand All @@ -299,15 +329,12 @@ def evaluate(dataloader):
#



######################################################################
# Checking the results of the test dataset…

print('Checking the results of test dataset.')
print("Checking the results of test dataset.")
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))


print("test accuracy {:8.3f}".format(accu_test))


######################################################################
Expand All @@ -318,17 +345,16 @@ def evaluate(dataloader):
#


ag_news_label = {1: "World",
2: "Sports",
3: "Business",
4: "Sci/Tec"}
ag_news_label = {1: "World", 2: "Sports", 3: "Business", 4: "Sci/Tec"}


def predict(text, text_pipeline):
with torch.no_grad():
text = torch.tensor(text_pipeline(text))
output = model(text, torch.tensor([0]))
return output.argmax(1).item() + 1


ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
enduring the season’s worst weather conditions on Sunday at The \
Open on his way to a closing 75 at Royal Portrush, which \
Expand All @@ -343,4 +369,4 @@ def predict(text, text_pipeline):

model = model.to("cpu")

print("This is a %s news" %ag_news_label[predict(ex_text_str, text_pipeline)])
print("This is a %s news" % ag_news_label[predict(ex_text_str, text_pipeline)])