Skip to content

Commit e18d233

Browse files
wAunerholly1238
andauthored
add code comments and test code readability++ (#849)
Co-authored-by: holly1238 <77758406+holly1238@users.noreply.github.com>
1 parent 8dfdda4 commit e18d233

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

beginner_source/blitz/cifar10_tutorial.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,13 @@ def forward(self, x):
246246

247247
correct = 0
248248
total = 0
249+
# since we're not training, we don't need to calculate the gradients for our outputs
249250
with torch.no_grad():
250251
for data in testloader:
251252
images, labels = data
253+
# calculate outputs by running images through the network
252254
outputs = net(images)
255+
# the class with the highest energy is what we choose as prediction
253256
_, predicted = torch.max(outputs.data, 1)
254257
total += labels.size(0)
255258
correct += (predicted == labels).sum().item()
@@ -265,23 +268,28 @@ def forward(self, x):
265268
# Hmmm, what are the classes that performed well, and the classes that did
266269
# not perform well:
267270

268-
class_correct = list(0. for i in range(10))
269-
class_total = list(0. for i in range(10))
271+
# prepare to count predictions for each class
272+
correct_pred = {classname: 0 for classname in classes}
273+
total_pred = {classname: 0 for classname in classes}
274+
275+
# again no gradients needed
270276
with torch.no_grad():
271277
for data in testloader:
272-
images, labels = data
273-
outputs = net(images)
274-
_, predicted = torch.max(outputs, 1)
275-
c = (predicted == labels).squeeze()
276-
for i in range(4):
277-
label = labels[i]
278-
class_correct[label] += c[i].item()
279-
class_total[label] += 1
280-
281-
282-
for i in range(10):
283-
print('Accuracy of %5s : %2d %%' % (
284-
classes[i], 100 * class_correct[i] / class_total[i]))
278+
images, labels = data
279+
outputs = net(images)
280+
_, predictions = torch.max(outputs, 1)
281+
# collect the correct predictions for each class
282+
for label, prediction in zip(labels, predictions):
283+
if label == prediction:
284+
correct_pred[classes[label]] += 1
285+
total_pred[classes[label]] += 1
286+
287+
288+
# print accuracy for each class
289+
for classname, correct_count in correct_pred.items():
290+
accuracy = 100 * float(correct_count) / total_pred[classname]
291+
print("Accuracy for class {:5s} is: {:.1f} %".format(classname,
292+
accuracy))
285293

286294
########################################################################
287295
# Okay, so what next?

0 commit comments

Comments
 (0)