@@ -242,10 +242,13 @@ def forward(self, x):
242
242
243
243
correct = 0
244
244
total = 0
245
+ # since we're not training, we don't need to calculate the gradients for our outputs
245
246
with torch .no_grad ():
246
247
for data in testloader :
247
248
images , labels = data
249
+ # calculate outputs by running images through the network
248
250
outputs = net (images )
251
+ # the class with the highest energy is what we choose as prediction
249
252
_ , predicted = torch .max (outputs .data , 1 )
250
253
total += labels .size (0 )
251
254
correct += (predicted == labels ).sum ().item ()
@@ -261,23 +264,28 @@ def forward(self, x):
261
264
# Hmmm, what are the classes that performed well, and the classes that did
262
265
# not perform well:
263
266
264
- class_correct = list (0. for i in range (10 ))
265
- class_total = list (0. for i in range (10 ))
267
+ # prepare to count predictions for each class
268
+ correct_pred = {classname : 0 for classname in classes }
269
+ total_pred = {classname : 0 for classname in classes }
270
+
271
+ # again no gradients needed
266
272
with torch .no_grad ():
267
273
for data in testloader :
268
- images , labels = data
269
- outputs = net (images )
270
- _ , predicted = torch .max (outputs , 1 )
271
- c = (predicted == labels ).squeeze ()
272
- for i in range (4 ):
273
- label = labels [i ]
274
- class_correct [label ] += c [i ].item ()
275
- class_total [label ] += 1
276
-
277
-
278
- for i in range (10 ):
279
- print ('Accuracy of %5s : %2d %%' % (
280
- classes [i ], 100 * class_correct [i ] / class_total [i ]))
274
+ images , labels = data
275
+ outputs = net (images )
276
+ _ , predictions = torch .max (outputs , 1 )
277
+ # collect the correct predictions for each class
278
+ for label , prediction in zip (labels , predictions ):
279
+ if label == prediction :
280
+ correct_pred [classes [label ]] += 1
281
+ total_pred [classes [label ]] += 1
282
+
283
+
284
+ # print accuracy for each class
285
+ for classname , correct_count in correct_pred .items ():
286
+ accuracy = 100 * float (correct_count ) / total_pred [classname ]
287
+ print ("Accuracy for class {:5s} is: {:.1f} %" .format (classname ,
288
+ accuracy ))
281
289
282
290
########################################################################
283
291
# Okay, so what next?
0 commit comments