1
1
"""
2
- Text Classification Tutorial
3
- ============================
2
+ Text Classification with TorchText
3
+ ==================================
4
4
5
- This tutorial shows how to use the text classification datasets,
6
- including
5
+ This tutorial shows how to use the text classification datasets
6
+ in ``torchtext``, including
7
7
8
8
::
9
9
10
10
- AG_NEWS,
11
- - SogouNews,
12
- - DBpedia,
11
+ - SogouNews,
12
+ - DBpedia,
13
13
- YelpReviewPolarity,
14
- - YelpReviewFull,
15
- - YahooAnswers,
14
+ - YelpReviewFull,
15
+ - YahooAnswers,
16
16
- AmazonReviewPolarity,
17
17
- AmazonReviewFull
18
18
19
- This example shows the application of ``TextClassification`` Dataset for
20
- supervised learning analysis .
19
+ This example shows how to train a supervised learning algorithm for
20
+ classification using one of these ``TextClassification`` datasets .
21
21
22
22
Load data with ngrams
23
23
---------------------
54
54
######################################################################
55
55
# Define the model
56
56
# ----------------
57
- #
57
+ #
58
58
# The model is composed of the
59
59
# `EmbeddingBag <https://pytorch.org/docs/stable/nn.html?highlight=embeddingbag#torch.nn.EmbeddingBag>`__
60
60
# layer and the linear layer (see the figure below). ``nn.EmbeddingBag``
61
61
# computes the mean value of a “bag” of embeddings. The text entries here
62
62
# have different lengths. ``nn.EmbeddingBag`` requires no padding here
63
63
# since the text lengths are saved in offsets.
64
- #
64
+ #
65
65
# Additionally, since ``nn.EmbeddingBag`` accumulates the average across
66
66
# the embeddings on the fly, ``nn.EmbeddingBag`` can enhance the
67
67
# performance and memory efficiency to process a sequence of tensors.
68
- #
68
+ #
69
69
# .. image:: ../_static/img/text_sentiment_ngrams_model.png
70
- #
70
+ #
71
71
72
72
import torch .nn as nn
73
73
import torch .nn .functional as F
@@ -83,7 +83,7 @@ def init_weights(self):
83
83
self .embedding .weight .data .uniform_ (- initrange , initrange )
84
84
self .fc .weight .data .uniform_ (- initrange , initrange )
85
85
self .fc .bias .data .zero_ ()
86
-
86
+
87
87
def forward (self , text , offsets ):
88
88
embedded = self .embedding (text , offsets )
89
89
return self .fc (embedded )
@@ -92,21 +92,21 @@ def forward(self, text, offsets):
92
92
######################################################################
93
93
# Initiate an instance
94
94
# --------------------
95
- #
95
+ #
96
96
# The AG_NEWS dataset has four labels and therefore the number of classes
97
97
# is four.
98
- #
98
+ #
99
99
# ::
100
- #
100
+ #
101
101
# 1 : World
102
102
# 2 : Sports
103
103
# 3 : Business
104
104
# 4 : Sci/Tec
105
- #
105
+ #
106
106
# The vocab size is equal to the length of vocab (including single word
107
107
# and ngrams). The number of classes is equal to the number of labels,
108
108
# which is four in AG_NEWS case.
109
- #
109
+ #
110
110
111
111
VOCAB_SIZE = len (train_dataset .get_vocab ())
112
112
EMBED_DIM = 32
@@ -117,7 +117,7 @@ def forward(self, text, offsets):
117
117
######################################################################
118
118
# Functions used to generate batch
119
119
# --------------------------------
120
- #
120
+ #
121
121
122
122
123
123
######################################################################
@@ -129,13 +129,13 @@ def forward(self, text, offsets):
129
129
# mini-batch. Pay attention here and make sure that ``collate_fn`` is
130
130
# declared as a top level def. This ensures that the function is available
131
131
# in each worker.
132
- #
132
+ #
133
133
# The text entries in the original data batch input are packed into a list
134
134
# and concatenated as a single tensor as the input of ``nn.EmbeddingBag``.
135
135
# The offsets is a tensor of delimiters to represent the beginning index
136
136
# of the individual sequence in the text tensor. Label is a tensor saving
137
137
# the labels of individual text entries.
138
- #
138
+ #
139
139
140
140
def generate_batch (batch ):
141
141
label = torch .tensor ([entry [0 ] for entry in batch ])
@@ -144,7 +144,7 @@ def generate_batch(batch):
144
144
# torch.Tensor.cumsum returns the cumulative sum
145
145
# of elements in the dimension dim.
146
146
# torch.Tensor([1.0, 2.0, 3.0]).cumsum(dim=0)
147
-
147
+
148
148
offsets = torch .tensor (offsets [:- 1 ]).cumsum (dim = 0 )
149
149
text = torch .cat (text )
150
150
return text , offsets , label
@@ -153,7 +153,7 @@ def generate_batch(batch):
153
153
######################################################################
154
154
# Define functions to train the model and evaluate results.
155
155
# ---------------------------------------------------------
156
- #
156
+ #
157
157
158
158
159
159
######################################################################
@@ -163,7 +163,7 @@ def generate_batch(batch):
163
163
# `here <https://pytorch.org/tutorials/beginner/data_loading_tutorial.html>`__).
164
164
# We use ``DataLoader`` here to load AG_NEWS datasets and send it to the
165
165
# model for training/validation.
166
- #
166
+ #
167
167
168
168
from torch .utils .data import DataLoader
169
169
@@ -186,7 +186,7 @@ def train_func(sub_train_):
186
186
187
187
# Adjust the learning rate
188
188
scheduler .step ()
189
-
189
+
190
190
return train_loss / len (sub_train_ ), train_acc / len (sub_train_ )
191
191
192
192
def test (data_ ):
@@ -207,13 +207,13 @@ def test(data_):
207
207
######################################################################
208
208
# Split the dataset and run the model
209
209
# -----------------------------------
210
- #
210
+ #
211
211
# Since the original AG_NEWS has no valid dataset, we split the training
212
212
# dataset into train/valid sets with a split ratio of 0.95 (train) and
213
213
# 0.05 (valid). Here we use
214
214
# `torch.utils.data.dataset.random_split <https://pytorch.org/docs/stable/data.html?highlight=random_split#torch.utils.data.random_split>`__
215
215
# function in PyTorch core library.
216
- #
216
+ #
217
217
# `CrossEntropyLoss <https://pytorch.org/docs/stable/nn.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss>`__
218
218
# criterion combines nn.LogSoftmax() and nn.NLLLoss() in a single class.
219
219
# It is useful when training a classification problem with C classes.
@@ -222,7 +222,7 @@ def test(data_):
222
222
# learning rate is set to 4.0.
223
223
# `StepLR <https://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html#StepLR>`__
224
224
# is used here to adjust the learning rate through epochs.
225
- #
225
+ #
226
226
227
227
import time
228
228
from torch .utils .data .dataset import random_split
@@ -250,56 +250,56 @@ def test(data_):
250
250
print ('Epoch: %d' % (epoch + 1 ), " | time in %d minutes, %d seconds" % (mins , secs ))
251
251
print (f'\t Loss: { train_loss :.4f} (train)\t |\t Acc: { train_acc * 100 :.1f} %(train)' )
252
252
print (f'\t Loss: { valid_loss :.4f} (valid)\t |\t Acc: { valid_acc * 100 :.1f} %(valid)' )
253
-
253
+
254
254
255
255
######################################################################
256
256
# Running the model on GPU with the following information:
257
- #
257
+ #
258
258
# Epoch: 1 \| time in 0 minutes, 11 seconds
259
- #
259
+ #
260
260
# ::
261
- #
261
+ #
262
262
# Loss: 0.0263(train) | Acc: 84.5%(train)
263
263
# Loss: 0.0001(valid) | Acc: 89.0%(valid)
264
- #
265
- #
264
+ #
265
+ #
266
266
# Epoch: 2 \| time in 0 minutes, 10 seconds
267
- #
267
+ #
268
268
# ::
269
- #
269
+ #
270
270
# Loss: 0.0119(train) | Acc: 93.6%(train)
271
271
# Loss: 0.0000(valid) | Acc: 89.6%(valid)
272
- #
273
- #
272
+ #
273
+ #
274
274
# Epoch: 3 \| time in 0 minutes, 9 seconds
275
- #
275
+ #
276
276
# ::
277
- #
277
+ #
278
278
# Loss: 0.0069(train) | Acc: 96.4%(train)
279
279
# Loss: 0.0000(valid) | Acc: 90.5%(valid)
280
- #
281
- #
280
+ #
281
+ #
282
282
# Epoch: 4 \| time in 0 minutes, 11 seconds
283
- #
283
+ #
284
284
# ::
285
- #
285
+ #
286
286
# Loss: 0.0038(train) | Acc: 98.2%(train)
287
287
# Loss: 0.0000(valid) | Acc: 90.4%(valid)
288
- #
289
- #
288
+ #
289
+ #
290
290
# Epoch: 5 \| time in 0 minutes, 11 seconds
291
- #
291
+ #
292
292
# ::
293
- #
293
+ #
294
294
# Loss: 0.0022(train) | Acc: 99.0%(train)
295
- # Loss: 0.0000(valid) | Acc: 91.0%(valid)
296
- #
295
+ # Loss: 0.0000(valid) | Acc: 91.0%(valid)
296
+ #
297
297
298
298
299
299
######################################################################
300
300
# Evaluate the model with test dataset
301
301
# ------------------------------------
302
- #
302
+ #
303
303
304
304
print ('Checking the results of test dataset...' )
305
305
test_loss , test_acc = test (test_dataset )
@@ -308,21 +308,21 @@ def test(data_):
308
308
309
309
######################################################################
310
310
# Checking the results of test dataset…
311
- #
311
+ #
312
312
# ::
313
- #
313
+ #
314
314
# Loss: 0.0237(test) | Acc: 90.5%(test)
315
- #
315
+ #
316
316
317
317
318
318
######################################################################
319
319
# Test on a random news
320
320
# ---------------------
321
- #
321
+ #
322
322
# Use the best model so far and test a golf news. The label information is
323
323
# available
324
324
# `here <https://pytorch.org/text/datasets.html?highlight=ag_news#torchtext.datasets.AG_NEWS>`__.
325
- #
325
+ #
326
326
327
327
import re
328
328
from torchtext .data .utils import ngrams_iterator
@@ -360,10 +360,10 @@ def predict(text, model, vocab, ngrams):
360
360
361
361
######################################################################
362
362
# This is a Sports news
363
- #
363
+ #
364
364
365
365
366
366
######################################################################
367
367
# You can find the code examples displayed in this note
368
368
# `here <https://github.com/pytorch/text/tree/master/examples/text_classification>`__.
369
- #
369
+ #
0 commit comments