diff --git a/beginner_source/blitz/cifar10_tutorial.py b/beginner_source/blitz/cifar10_tutorial.py index c0c3ec7bded..39a43ba3f56 100644 --- a/beginner_source/blitz/cifar10_tutorial.py +++ b/beginner_source/blitz/cifar10_tutorial.py @@ -246,10 +246,13 @@ def forward(self, x): correct = 0 total = 0 +# since we're not training, we don't need to calculate the gradients for our outputs with torch.no_grad(): for data in testloader: images, labels = data + # calculate outputs by running images through the network outputs = net(images) + # the class with the highest energy is what we choose as prediction _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() @@ -265,23 +268,28 @@ def forward(self, x): # Hmmm, what are the classes that performed well, and the classes that did # not perform well: -class_correct = list(0. for i in range(10)) -class_total = list(0. for i in range(10)) +# prepare to count predictions for each class +correct_pred = {classname: 0 for classname in classes} +total_pred = {classname: 0 for classname in classes} + +# again no gradients needed with torch.no_grad(): for data in testloader: - images, labels = data - outputs = net(images) - _, predicted = torch.max(outputs, 1) - c = (predicted == labels).squeeze() - for i in range(4): - label = labels[i] - class_correct[label] += c[i].item() - class_total[label] += 1 - - -for i in range(10): - print('Accuracy of %5s : %2d %%' % ( - classes[i], 100 * class_correct[i] / class_total[i])) + images, labels = data + outputs = net(images) + _, predictions = torch.max(outputs, 1) + # collect the correct predictions for each class + for label, prediction in zip(labels, predictions): + if label == prediction: + correct_pred[classes[label]] += 1 + total_pred[classes[label]] += 1 + + +# print accuracy for each class +for classname, correct_count in correct_pred.items(): + accuracy = 100 * float(correct_count) / total_pred[classname] + print("Accuracy for class {:5s} is: {:.1f} %".format(classname, + accuracy)) ######################################################################## # Okay, so what next?