From 84ef49820c682fc2f1bab911561ab5777eeb288d Mon Sep 17 00:00:00 2001 From: wAuner <18383180+wAuner@users.noreply.github.com> Date: Thu, 13 Feb 2020 10:22:23 +0100 Subject: [PATCH] add code comments and test code readability++ --- beginner_source/blitz/cifar10_tutorial.py | 38 ++++++++++++++--------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/beginner_source/blitz/cifar10_tutorial.py b/beginner_source/blitz/cifar10_tutorial.py index 730bf6ac986..4a430d738ca 100644 --- a/beginner_source/blitz/cifar10_tutorial.py +++ b/beginner_source/blitz/cifar10_tutorial.py @@ -242,10 +242,13 @@ def forward(self, x): correct = 0 total = 0 +# since we're not training, we don't need to calculate the gradients for our outputs with torch.no_grad(): for data in testloader: images, labels = data + # calculate outputs by running images through the network outputs = net(images) + # the class with the highest energy is what we choose as prediction _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() @@ -261,23 +264,28 @@ def forward(self, x): # Hmmm, what are the classes that performed well, and the classes that did # not perform well: -class_correct = list(0. for i in range(10)) -class_total = list(0. for i in range(10)) +# prepare to count predictions for each class +correct_pred = {classname: 0 for classname in classes} +total_pred = {classname: 0 for classname in classes} + +# again no gradients needed with torch.no_grad(): for data in testloader: - images, labels = data - outputs = net(images) - _, predicted = torch.max(outputs, 1) - c = (predicted == labels).squeeze() - for i in range(4): - label = labels[i] - class_correct[label] += c[i].item() - class_total[label] += 1 - - -for i in range(10): - print('Accuracy of %5s : %2d %%' % ( - classes[i], 100 * class_correct[i] / class_total[i])) + images, labels = data + outputs = net(images) + _, predictions = torch.max(outputs, 1) + # collect the correct predictions for each class + for label, prediction in zip(labels, predictions): + if label == prediction: + correct_pred[classes[label]] += 1 + total_pred[classes[label]] += 1 + + +# print accuracy for each class +for classname, correct_count in correct_pred.items(): + accuracy = 100 * float(correct_count) / total_pred[classname] + print("Accuracy for class {:5s} is: {:.1f} %".format(classname, + accuracy)) ######################################################################## # Okay, so what next?