1
1
Visualizing Models, Data, and Training with TensorBoard
2
2
=======================================================
3
3
4
- In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html >`_,
4
+ In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html >`_,
5
5
we show you how to load in data,
6
6
feed it through a model we define as a subclass of ``nn.Module ``,
7
7
train this model on training data, and test it on test data.
@@ -348,7 +348,7 @@ In the prior tutorial, we looked at per-class accuracy once the model
348
348
had been trained; here, we'll use TensorBoard to plot precision-recall
349
349
curves (good explanation
350
350
`here <https://www.scikit-yb.org/en/latest/api/classifier/prcurve.html >`__)
351
- for each class.
351
+ for each class.
352
352
353
353
6. Assessing trained models with TensorBoard
354
354
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -359,38 +359,37 @@ for each class.
359
359
# 2. gets the preds in a test_size Tensor
360
360
# takes ~10 seconds to run
361
361
class_probs = []
362
- class_preds = []
362
+ class_label = []
363
363
with torch.no_grad():
364
364
for data in testloader:
365
365
images, labels = data
366
366
output = net(images)
367
367
class_probs_batch = [F.softmax(el, dim = 0 ) for el in output]
368
- _, class_preds_batch = torch.max(output, 1 )
369
368
370
369
class_probs.append(class_probs_batch)
371
- class_preds .append(class_preds_batch )
370
+ class_label .append(labels )
372
371
373
372
test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
374
- test_preds = torch.cat(class_preds )
373
+ test_label = torch.cat(class_label )
375
374
376
375
# helper function
377
- def add_pr_curve_tensorboard (class_index , test_probs , test_preds , global_step = 0 ):
376
+ def add_pr_curve_tensorboard (class_index , test_probs , test_label , global_step = 0 ):
378
377
'''
379
378
Takes in a "class_index" from 0 to 9 and plots the corresponding
380
379
precision-recall curve
381
380
'''
382
- tensorboard_preds = test_preds == class_index
381
+ tensorboard_truth = test_label == class_index
383
382
tensorboard_probs = test_probs[:, class_index]
384
383
385
384
writer.add_pr_curve(classes[class_index],
386
- tensorboard_preds ,
385
+ tensorboard_truth ,
387
386
tensorboard_probs,
388
387
global_step = global_step)
389
388
writer.close()
390
389
391
390
# plot all the pr curves
392
391
for i in range (len (classes)):
393
- add_pr_curve_tensorboard(i, test_probs, test_preds )
392
+ add_pr_curve_tensorboard(i, test_probs, test_label )
394
393
395
394
You will now see a "PR Curves" tab that contains the precision-recall
396
395
curves for each class. Go ahead and poke around; you'll see that on
0 commit comments