7
7
- Access to the raw data as an iterator
8
8
- Build data processing pipeline to convert the raw text strings into ``torch.Tensor`` that can be used to train the model
9
9
- Shuffle and iterate the data with `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>`__
10
+
11
+
12
+ Prerequisites:
13
+ - | Recent 2.x version of portalocker package needs to be installed prior to running the tutorial.
14
+ | E.g., in Colab environment this can be done by adding following line at the top of the script:
15
+ | `!pip install -U portalocker>=2.0.0`
16
+ | (More details https://github.com/pytorch/tutorials/issues/1993)
17
+
10
18
"""
11
19
12
20
16
24
#
17
25
# The torchtext library provides a few raw dataset iterators, which yield the raw text strings. For example, the ``AG_NEWS`` dataset iterators yield the raw data as a tuple of label and text.
18
26
#
19
- # To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.
27
+ # To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.
20
28
#
21
29
22
30
import torch
23
31
from torchtext .datasets import AG_NEWS
24
- train_iter = iter (AG_NEWS (split = 'train' ))
32
+
33
+ train_iter = iter (AG_NEWS (split = "train" ))
25
34
26
35
######################################################################
27
36
# ::
60
69
from torchtext .data .utils import get_tokenizer
61
70
from torchtext .vocab import build_vocab_from_iterator
62
71
63
- tokenizer = get_tokenizer ('basic_english' )
64
- train_iter = AG_NEWS (split = 'train' )
72
+ tokenizer = get_tokenizer ("basic_english" )
73
+ train_iter = AG_NEWS (split = "train" )
74
+
65
75
66
76
def yield_tokens (data_iter ):
67
77
for _ , text in data_iter :
68
78
yield tokenizer (text )
69
79
80
+
70
81
vocab = build_vocab_from_iterator (yield_tokens (train_iter ), specials = ["<unk>" ])
71
82
vocab .set_default_index (vocab ["<unk>" ])
72
83
@@ -96,7 +107,6 @@ def yield_tokens(data_iter):
96
107
#
97
108
98
109
99
-
100
110
######################################################################
101
111
# Generate data batch and iterator
102
112
# --------------------------------
@@ -111,22 +121,27 @@ def yield_tokens(data_iter):
111
121
112
122
113
123
from torch .utils .data import DataLoader
124
+
114
125
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
115
126
127
+
116
128
def collate_batch (batch ):
117
129
label_list , text_list , offsets = [], [], [0 ]
118
- for ( _label , _text ) in batch :
119
- label_list .append (label_pipeline (_label ))
120
- processed_text = torch .tensor (text_pipeline (_text ), dtype = torch .int64 )
121
- text_list .append (processed_text )
122
- offsets .append (processed_text .size (0 ))
130
+ for _label , _text in batch :
131
+ label_list .append (label_pipeline (_label ))
132
+ processed_text = torch .tensor (text_pipeline (_text ), dtype = torch .int64 )
133
+ text_list .append (processed_text )
134
+ offsets .append (processed_text .size (0 ))
123
135
label_list = torch .tensor (label_list , dtype = torch .int64 )
124
136
offsets = torch .tensor (offsets [:- 1 ]).cumsum (dim = 0 )
125
137
text_list = torch .cat (text_list )
126
138
return label_list .to (device ), text_list .to (device ), offsets .to (device )
127
139
128
- train_iter = AG_NEWS (split = 'train' )
129
- dataloader = DataLoader (train_iter , batch_size = 8 , shuffle = False , collate_fn = collate_batch )
140
+
141
+ train_iter = AG_NEWS (split = "train" )
142
+ dataloader = DataLoader (
143
+ train_iter , batch_size = 8 , shuffle = False , collate_fn = collate_batch
144
+ )
130
145
131
146
132
147
######################################################################
@@ -144,8 +159,8 @@ def collate_batch(batch):
144
159
145
160
from torch import nn
146
161
147
- class TextClassificationModel (nn .Module ):
148
162
163
+ class TextClassificationModel (nn .Module ):
149
164
def __init__ (self , vocab_size , embed_dim , num_class ):
150
165
super (TextClassificationModel , self ).__init__ ()
151
166
self .embedding = nn .EmbeddingBag (vocab_size , embed_dim , sparse = False )
@@ -179,7 +194,7 @@ def forward(self, text, offsets):
179
194
# We build a model with the embedding dimension of 64. The vocab size is equal to the length of the vocabulary instance. The number of classes is equal to the number of labels,
180
195
#
181
196
182
- train_iter = AG_NEWS (split = ' train' )
197
+ train_iter = AG_NEWS (split = " train" )
183
198
num_class = len (set ([label for (label , text ) in train_iter ]))
184
199
vocab_size = len (vocab )
185
200
emsize = 64
@@ -194,6 +209,7 @@ def forward(self, text, offsets):
194
209
195
210
import time
196
211
212
+
197
213
def train (dataloader ):
198
214
model .train ()
199
215
total_acc , total_count = 0 , 0
@@ -211,12 +227,16 @@ def train(dataloader):
211
227
total_count += label .size (0 )
212
228
if idx % log_interval == 0 and idx > 0 :
213
229
elapsed = time .time () - start_time
214
- print ('| epoch {:3d} | {:5d}/{:5d} batches '
215
- '| accuracy {:8.3f}' .format (epoch , idx , len (dataloader ),
216
- total_acc / total_count ))
230
+ print (
231
+ "| epoch {:3d} | {:5d}/{:5d} batches "
232
+ "| accuracy {:8.3f}" .format (
233
+ epoch , idx , len (dataloader ), total_acc / total_count
234
+ )
235
+ )
217
236
total_acc , total_count = 0 , 0
218
237
start_time = time .time ()
219
238
239
+
220
240
def evaluate (dataloader ):
221
241
model .eval ()
222
242
total_acc , total_count = 0 , 0
@@ -227,7 +247,7 @@ def evaluate(dataloader):
227
247
loss = criterion (predicted_label , label )
228
248
total_acc += (predicted_label .argmax (1 ) == label ).sum ().item ()
229
249
total_count += label .size (0 )
230
- return total_acc / total_count
250
+ return total_acc / total_count
231
251
232
252
233
253
######################################################################
@@ -253,10 +273,11 @@ def evaluate(dataloader):
253
273
254
274
from torch .utils .data .dataset import random_split
255
275
from torchtext .data .functional import to_map_style_dataset
276
+
256
277
# Hyperparameters
257
- EPOCHS = 10 # epoch
278
+ EPOCHS = 10 # epoch
258
279
LR = 5 # learning rate
259
- BATCH_SIZE = 64 # batch size for training
280
+ BATCH_SIZE = 64 # batch size for training
260
281
261
282
criterion = torch .nn .CrossEntropyLoss ()
262
283
optimizer = torch .optim .SGD (model .parameters (), lr = LR )
@@ -266,31 +287,36 @@ def evaluate(dataloader):
266
287
train_dataset = to_map_style_dataset (train_iter )
267
288
test_dataset = to_map_style_dataset (test_iter )
268
289
num_train = int (len (train_dataset ) * 0.95 )
269
- split_train_ , split_valid_ = \
270
- random_split (train_dataset , [num_train , len (train_dataset ) - num_train ])
271
-
272
- train_dataloader = DataLoader (split_train_ , batch_size = BATCH_SIZE ,
273
- shuffle = True , collate_fn = collate_batch )
274
- valid_dataloader = DataLoader (split_valid_ , batch_size = BATCH_SIZE ,
275
- shuffle = True , collate_fn = collate_batch )
276
- test_dataloader = DataLoader (test_dataset , batch_size = BATCH_SIZE ,
277
- shuffle = True , collate_fn = collate_batch )
290
+ split_train_ , split_valid_ = random_split (
291
+ train_dataset , [num_train , len (train_dataset ) - num_train ]
292
+ )
293
+
294
+ train_dataloader = DataLoader (
295
+ split_train_ , batch_size = BATCH_SIZE , shuffle = True , collate_fn = collate_batch
296
+ )
297
+ valid_dataloader = DataLoader (
298
+ split_valid_ , batch_size = BATCH_SIZE , shuffle = True , collate_fn = collate_batch
299
+ )
300
+ test_dataloader = DataLoader (
301
+ test_dataset , batch_size = BATCH_SIZE , shuffle = True , collate_fn = collate_batch
302
+ )
278
303
279
304
for epoch in range (1 , EPOCHS + 1 ):
280
305
epoch_start_time = time .time ()
281
306
train (train_dataloader )
282
307
accu_val = evaluate (valid_dataloader )
283
308
if total_accu is not None and total_accu > accu_val :
284
- scheduler .step ()
309
+ scheduler .step ()
285
310
else :
286
- total_accu = accu_val
287
- print ('-' * 59 )
288
- print ('| end of epoch {:3d} | time: {:5.2f}s | '
289
- 'valid accuracy {:8.3f} ' .format (epoch ,
290
- time .time () - epoch_start_time ,
291
- accu_val ))
292
- print ('-' * 59 )
293
-
311
+ total_accu = accu_val
312
+ print ("-" * 59 )
313
+ print (
314
+ "| end of epoch {:3d} | time: {:5.2f}s | "
315
+ "valid accuracy {:8.3f} " .format (
316
+ epoch , time .time () - epoch_start_time , accu_val
317
+ )
318
+ )
319
+ print ("-" * 59 )
294
320
295
321
296
322
######################################################################
@@ -299,15 +325,12 @@ def evaluate(dataloader):
299
325
#
300
326
301
327
302
-
303
328
######################################################################
304
329
# Checking the results of the test dataset…
305
330
306
- print (' Checking the results of test dataset.' )
331
+ print (" Checking the results of test dataset." )
307
332
accu_test = evaluate (test_dataloader )
308
- print ('test accuracy {:8.3f}' .format (accu_test ))
309
-
310
-
333
+ print ("test accuracy {:8.3f}" .format (accu_test ))
311
334
312
335
313
336
######################################################################
@@ -318,17 +341,16 @@ def evaluate(dataloader):
318
341
#
319
342
320
343
321
- ag_news_label = {1 : "World" ,
322
- 2 : "Sports" ,
323
- 3 : "Business" ,
324
- 4 : "Sci/Tec" }
344
+ ag_news_label = {1 : "World" , 2 : "Sports" , 3 : "Business" , 4 : "Sci/Tec" }
345
+
325
346
326
347
def predict (text , text_pipeline ):
327
348
with torch .no_grad ():
328
349
text = torch .tensor (text_pipeline (text ))
329
350
output = model (text , torch .tensor ([0 ]))
330
351
return output .argmax (1 ).item () + 1
331
352
353
+
332
354
ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
333
355
enduring the season’s worst weather conditions on Sunday at The \
334
356
Open on his way to a closing 75 at Royal Portrush, which \
@@ -343,4 +365,4 @@ def predict(text, text_pipeline):
343
365
344
366
model = model .to ("cpu" )
345
367
346
- print ("This is a %s news" % ag_news_label [predict (ex_text_str , text_pipeline )])
368
+ print ("This is a %s news" % ag_news_label [predict (ex_text_str , text_pipeline )])
0 commit comments