@@ -246,10 +246,13 @@ def forward(self, x):
246
246
247
247
correct = 0
248
248
total = 0
249
+ # since we're not training, we don't need to calculate the gradients for our outputs
249
250
with torch .no_grad ():
250
251
for data in testloader :
251
252
images , labels = data
253
+ # calculate outputs by running images through the network
252
254
outputs = net (images )
255
+ # the class with the highest energy is what we choose as prediction
253
256
_ , predicted = torch .max (outputs .data , 1 )
254
257
total += labels .size (0 )
255
258
correct += (predicted == labels ).sum ().item ()
@@ -265,23 +268,28 @@ def forward(self, x):
265
268
# Hmmm, what are the classes that performed well, and the classes that did
266
269
# not perform well:
267
270
268
- class_correct = list (0. for i in range (10 ))
269
- class_total = list (0. for i in range (10 ))
271
+ # prepare to count predictions for each class
272
+ correct_pred = {classname : 0 for classname in classes }
273
+ total_pred = {classname : 0 for classname in classes }
274
+
275
+ # again no gradients needed
270
276
with torch .no_grad ():
271
277
for data in testloader :
272
- images , labels = data
273
- outputs = net (images )
274
- _ , predicted = torch .max (outputs , 1 )
275
- c = (predicted == labels ).squeeze ()
276
- for i in range (4 ):
277
- label = labels [i ]
278
- class_correct [label ] += c [i ].item ()
279
- class_total [label ] += 1
280
-
281
-
282
- for i in range (10 ):
283
- print ('Accuracy of %5s : %2d %%' % (
284
- classes [i ], 100 * class_correct [i ] / class_total [i ]))
278
+ images , labels = data
279
+ outputs = net (images )
280
+ _ , predictions = torch .max (outputs , 1 )
281
+ # collect the correct predictions for each class
282
+ for label , prediction in zip (labels , predictions ):
283
+ if label == prediction :
284
+ correct_pred [classes [label ]] += 1
285
+ total_pred [classes [label ]] += 1
286
+
287
+
288
+ # print accuracy for each class
289
+ for classname , correct_count in correct_pred .items ():
290
+ accuracy = 100 * float (correct_count ) / total_pred [classname ]
291
+ print ("Accuracy for class {:5s} is: {:.1f} %" .format (classname ,
292
+ accuracy ))
285
293
286
294
########################################################################
287
295
# Okay, so what next?
0 commit comments