Skip to content

Commit ee282de

Browse files
committed
interactive plt in TL tutorial
1 parent 11259a8 commit ee282de

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

beginner_source/transfer_learning_tutorial.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
import copy
4646
import os
4747

48+
plt.ion() # interactive mode
49+
4850
######################################################################
4951
# Load Data
5052
# ---------
@@ -101,13 +103,16 @@
101103
# Let's visualize a few training images so as to understand the data
102104
# augmentations.
103105

104-
def imshow(inp):
106+
def imshow(inp, title=None):
105107
"""Imshow for Tensor."""
106108
inp = inp.numpy().transpose((1, 2, 0))
107109
mean = np.array([0.485, 0.456, 0.406])
108110
std = np.array([0.229, 0.224, 0.225])
109111
inp = std * inp + mean
110112
plt.imshow(inp)
113+
if title is not None:
114+
plt.title(title)
115+
plt.pause(0.001) # pause a bit so that plots are updated
111116

112117

113118
# Get a batch of training data
@@ -116,9 +121,7 @@ def imshow(inp):
116121
# Make a grid from batch
117122
out = torchvision.utils.make_grid(inputs)
118123

119-
imshow(out)
120-
plt.title([dset_classes[x] for x in classes])
121-
plt.show()
124+
imshow(out, title=[dset_classes[x] for x in classes])
122125

123126

124127
######################################################################
@@ -222,14 +225,12 @@ def visualize_model(model, num_images=5):
222225
else:
223226
inputs, labels = Variable(inputs), Variable(labels)
224227

225-
226228
outputs = model(inputs)
227229
_, preds = torch.max(outputs.data, 1)
228-
230+
229231
plt.figure()
230-
imshow(inputs.cpu().data[0])
231-
plt.title('pred: {}'.format(dset_classes[labels.data[0]]))
232-
plt.show()
232+
imshow(inputs.cpu().data[0],
233+
title='pred: {}'.format(dset_classes[labels.data[0]]))
233234

234235
if i == num_images - 1:
235236
break
@@ -337,3 +338,6 @@ def optim_scheduler_conv(model, epoch, init_lr=0.001, lr_decay_epoch=7):
337338
#
338339

339340
visualize_model(model)
341+
342+
plt.ioff()
343+
plt.show()

0 commit comments

Comments
 (0)