Skip to content

Commit 84ef498

Browse files
committed
add code comments and test code readability++
1 parent e8055d7 commit 84ef498

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

beginner_source/blitz/cifar10_tutorial.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,13 @@ def forward(self, x):
242242

243243
correct = 0
244244
total = 0
245+
# since we're not training, we don't need to calculate the gradients for our outputs
245246
with torch.no_grad():
246247
for data in testloader:
247248
images, labels = data
249+
# calculate outputs by running images through the network
248250
outputs = net(images)
251+
# the class with the highest energy is what we choose as prediction
249252
_, predicted = torch.max(outputs.data, 1)
250253
total += labels.size(0)
251254
correct += (predicted == labels).sum().item()
@@ -261,23 +264,28 @@ def forward(self, x):
261264
# Hmmm, what are the classes that performed well, and the classes that did
262265
# not perform well:
263266

264-
class_correct = list(0. for i in range(10))
265-
class_total = list(0. for i in range(10))
267+
# prepare to count predictions for each class
268+
correct_pred = {classname: 0 for classname in classes}
269+
total_pred = {classname: 0 for classname in classes}
270+
271+
# again no gradients needed
266272
with torch.no_grad():
267273
for data in testloader:
268-
images, labels = data
269-
outputs = net(images)
270-
_, predicted = torch.max(outputs, 1)
271-
c = (predicted == labels).squeeze()
272-
for i in range(4):
273-
label = labels[i]
274-
class_correct[label] += c[i].item()
275-
class_total[label] += 1
276-
277-
278-
for i in range(10):
279-
print('Accuracy of %5s : %2d %%' % (
280-
classes[i], 100 * class_correct[i] / class_total[i]))
274+
images, labels = data
275+
outputs = net(images)
276+
_, predictions = torch.max(outputs, 1)
277+
# collect the correct predictions for each class
278+
for label, prediction in zip(labels, predictions):
279+
if label == prediction:
280+
correct_pred[classes[label]] += 1
281+
total_pred[classes[label]] += 1
282+
283+
284+
# print accuracy for each class
285+
for classname, correct_count in correct_pred.items():
286+
accuracy = 100 * float(correct_count) / total_pred[classname]
287+
print("Accuracy for class {:5s} is: {:.1f} %".format(classname,
288+
accuracy))
281289

282290
########################################################################
283291
# Okay, so what next?

0 commit comments

Comments
 (0)