Skip to content

Commit 9c4b2c1

Browse files
committed
matplotlib ion in nerual style
1 parent 653f902 commit 9c4b2c1

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

advanced_source/neural_style_tutorial.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,21 +223,22 @@ def image_loader(image_name):
223223

224224
unloader = transforms.ToPILImage() # reconvert into PIL image
225225

226+
plt.ion()
226227

227-
def imshow(tensor):
228+
def imshow(tensor, title=None):
228229
image = tensor.clone().cpu() # we clone the tensor to not do changes on it
229230
image = image.view(3, imsize, imsize) # remove the fake batch dimension
230231
image = unloader(image)
231232
plt.imshow(image)
233+
if title is not None:
234+
plt.title(title)
232235

233236

234-
fig = plt.figure()
237+
plt.figure()
238+
imshow(style_img.data, title='Style Image')
235239

236-
plt.subplot(221)
237-
imshow(style_img.data)
238-
239-
plt.subplot(222)
240-
imshow(content_img.data)
240+
plt.figure()
241+
imshow(content_img.data, title='Content Image')
241242

242243

243244
######################################################################
@@ -497,8 +498,8 @@ def get_style_model_and_losses(cnn, style_img, content_img,
497498
# input_img = Variable(torch.randn(content_img.data.size())).type(dtype)
498499

499500
# add the original input image to the figure:
500-
plt.subplot(223)
501-
imshow(input_img.data)
501+
plt.figure()
502+
imshow(input_img.data, title='Input Image')
502503

503504

504505
######################################################################
@@ -547,10 +548,12 @@ def get_input_param_optimizer(input_img):
547548
def run_style_transfer(cnn, content_img, style_img, input_img, num_steps=300,
548549
style_weight=1000, content_weight=1):
549550
"""Run the style transfer."""
551+
print('Building the style transfer model..')
550552
model, style_losses, content_losses = get_style_model_and_losses(cnn,
551553
style_img, content_img, style_weight, content_weight)
552554
input_param, optimizer = get_input_param_optimizer(input_img)
553555

556+
print('Optimizing..')
554557
run = [0]
555558
while run[0] <= num_steps:
556559

@@ -589,6 +592,8 @@ def closure():
589592

590593
output = run_style_transfer(cnn, content_img, style_img, input_img)
591594

592-
plt.subplot(224)
593-
imshow(output)
595+
plt.figure()
596+
imshow(output, title='Output Image')
597+
598+
plt.ioff()
594599
plt.show()

0 commit comments

Comments
 (0)