@@ -223,21 +223,22 @@ def image_loader(image_name):
223
223
224
224
unloader = transforms .ToPILImage () # reconvert into PIL image
225
225
226
+ plt .ion ()
226
227
227
- def imshow (tensor ):
228
+ def imshow (tensor , title = None ):
228
229
image = tensor .clone ().cpu () # we clone the tensor to not do changes on it
229
230
image = image .view (3 , imsize , imsize ) # remove the fake batch dimension
230
231
image = unloader (image )
231
232
plt .imshow (image )
233
+ if title is not None :
234
+ plt .title (title )
232
235
233
236
234
- fig = plt .figure ()
237
+ plt .figure ()
238
+ imshow (style_img .data , title = 'Style Image' )
235
239
236
- plt .subplot (221 )
237
- imshow (style_img .data )
238
-
239
- plt .subplot (222 )
240
- imshow (content_img .data )
240
+ plt .figure ()
241
+ imshow (content_img .data , title = 'Content Image' )
241
242
242
243
243
244
######################################################################
@@ -497,8 +498,8 @@ def get_style_model_and_losses(cnn, style_img, content_img,
497
498
# input_img = Variable(torch.randn(content_img.data.size())).type(dtype)
498
499
499
500
# add the original input image to the figure:
500
- plt .subplot ( 223 )
501
- imshow (input_img .data )
501
+ plt .figure ( )
502
+ imshow (input_img .data , title = 'Input Image' )
502
503
503
504
504
505
######################################################################
@@ -547,10 +548,12 @@ def get_input_param_optimizer(input_img):
547
548
def run_style_transfer (cnn , content_img , style_img , input_img , num_steps = 300 ,
548
549
style_weight = 1000 , content_weight = 1 ):
549
550
"""Run the style transfer."""
551
+ print ('Building the style transfer model..' )
550
552
model , style_losses , content_losses = get_style_model_and_losses (cnn ,
551
553
style_img , content_img , style_weight , content_weight )
552
554
input_param , optimizer = get_input_param_optimizer (input_img )
553
555
556
+ print ('Optimizing..' )
554
557
run = [0 ]
555
558
while run [0 ] <= num_steps :
556
559
@@ -589,6 +592,8 @@ def closure():
589
592
590
593
output = run_style_transfer (cnn , content_img , style_img , input_img )
591
594
592
- plt .subplot (224 )
593
- imshow (output )
595
+ plt .figure ()
596
+ imshow (output , title = 'Output Image' )
597
+
598
+ plt .ioff ()
594
599
plt .show ()
0 commit comments