45
45
import copy
46
46
import os
47
47
48
+ plt .ion () # interactive mode
49
+
48
50
######################################################################
49
51
# Load Data
50
52
# ---------
101
103
# Let's visualize a few training images so as to understand the data
102
104
# augmentations.
103
105
104
- def imshow (inp ):
106
+ def imshow (inp , title = None ):
105
107
"""Imshow for Tensor."""
106
108
inp = inp .numpy ().transpose ((1 , 2 , 0 ))
107
109
mean = np .array ([0.485 , 0.456 , 0.406 ])
108
110
std = np .array ([0.229 , 0.224 , 0.225 ])
109
111
inp = std * inp + mean
110
112
plt .imshow (inp )
113
+ if title is not None :
114
+ plt .title (title )
115
+ plt .pause (0.001 ) # pause a bit so that plots are updated
111
116
112
117
113
118
# Get a batch of training data
@@ -116,9 +121,7 @@ def imshow(inp):
116
121
# Make a grid from batch
117
122
out = torchvision .utils .make_grid (inputs )
118
123
119
- imshow (out )
120
- plt .title ([dset_classes [x ] for x in classes ])
121
- plt .show ()
124
+ imshow (out , title = [dset_classes [x ] for x in classes ])
122
125
123
126
124
127
######################################################################
@@ -222,14 +225,12 @@ def visualize_model(model, num_images=5):
222
225
else :
223
226
inputs , labels = Variable (inputs ), Variable (labels )
224
227
225
-
226
228
outputs = model (inputs )
227
229
_ , preds = torch .max (outputs .data , 1 )
228
-
230
+
229
231
plt .figure ()
230
- imshow (inputs .cpu ().data [0 ])
231
- plt .title ('pred: {}' .format (dset_classes [labels .data [0 ]]))
232
- plt .show ()
232
+ imshow (inputs .cpu ().data [0 ],
233
+ title = 'pred: {}' .format (dset_classes [labels .data [0 ]]))
233
234
234
235
if i == num_images - 1 :
235
236
break
@@ -337,3 +338,6 @@ def optim_scheduler_conv(model, epoch, init_lr=0.001, lr_decay_epoch=7):
337
338
#
338
339
339
340
visualize_model (model )
341
+
342
+ plt .ioff ()
343
+ plt .show ()
0 commit comments