Skip to content

Commit d763c40

Browse files
hwong557holly1238
andauthored
Fix tensorboard pr curve example. (#1469)
add_pr_curve should accept 1) model probabilities, and 2) ground truth values, namely if a sample is or isn't the specified class index. Previously, the add_pr_curve was coded to accept model probabilities and model predictions. Co-authored-by: holly1238 <77758406+holly1238@users.noreply.github.com>
1 parent faad264 commit d763c40

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

_static/img/tensorboard_pr_curves.png

-190 KB
Loading

intermediate_source/tensorboard_tutorial.rst

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Visualizing Models, Data, and Training with TensorBoard
22
=======================================================
33

4-
In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_,
4+
In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_,
55
we show you how to load in data,
66
feed it through a model we define as a subclass of ``nn.Module``,
77
train this model on training data, and test it on test data.
@@ -348,7 +348,7 @@ In the prior tutorial, we looked at per-class accuracy once the model
348348
had been trained; here, we'll use TensorBoard to plot precision-recall
349349
curves (good explanation
350350
`here <https://www.scikit-yb.org/en/latest/api/classifier/prcurve.html>`__)
351-
for each class.
351+
for each class.
352352

353353
6. Assessing trained models with TensorBoard
354354
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -359,38 +359,37 @@ for each class.
359359
# 2. gets the preds in a test_size Tensor
360360
# takes ~10 seconds to run
361361
class_probs = []
362-
class_preds = []
362+
class_label = []
363363
with torch.no_grad():
364364
for data in testloader:
365365
images, labels = data
366366
output = net(images)
367367
class_probs_batch = [F.softmax(el, dim=0) for el in output]
368-
_, class_preds_batch = torch.max(output, 1)
369368
370369
class_probs.append(class_probs_batch)
371-
class_preds.append(class_preds_batch)
370+
class_label.append(labels)
372371
373372
test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
374-
test_preds = torch.cat(class_preds)
373+
test_label = torch.cat(class_label)
375374
376375
# helper function
377-
def add_pr_curve_tensorboard(class_index, test_probs, test_preds, global_step=0):
376+
def add_pr_curve_tensorboard(class_index, test_probs, test_label, global_step=0):
378377
'''
379378
Takes in a "class_index" from 0 to 9 and plots the corresponding
380379
precision-recall curve
381380
'''
382-
tensorboard_preds = test_preds == class_index
381+
tensorboard_truth = test_label == class_index
383382
tensorboard_probs = test_probs[:, class_index]
384383
385384
writer.add_pr_curve(classes[class_index],
386-
tensorboard_preds,
385+
tensorboard_truth,
387386
tensorboard_probs,
388387
global_step=global_step)
389388
writer.close()
390389
391390
# plot all the pr curves
392391
for i in range(len(classes)):
393-
add_pr_curve_tensorboard(i, test_probs, test_preds)
392+
add_pr_curve_tensorboard(i, test_probs, test_label)
394393
395394
You will now see a "PR Curves" tab that contains the precision-recall
396395
curves for each class. Go ahead and poke around; you'll see that on

0 commit comments

Comments
 (0)