Skip to content

Commit 250c671

Browse files
authored
Merge pull request #58 from darkstar112358/neural-style-patch
Use __call__ instead of forward
2 parents 1c780e8 + 7d624d2 commit 250c671

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

advanced_source/neural_style_tutorial.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def __init__(self, target, weight):
286286
self.criterion = nn.MSELoss()
287287

288288
def forward(self, input):
289-
self.loss = self.criterion.forward(input * self.weight, self.target)
289+
self.loss = self.criterion(input * self.weight, self.target)
290290
self.output = input
291291
return self.output
292292

@@ -357,9 +357,9 @@ def __init__(self, target, weight):
357357

358358
def forward(self, input):
359359
self.output = input.clone()
360-
self.G = self.gram.forward(input)
360+
self.G = self.gram(input)
361361
self.G.mul_(self.weight)
362-
self.loss = self.criterion.forward(self.G, self.target)
362+
self.loss = self.criterion(self.G, self.target)
363363
return self.output
364364

365365
def backward(self, retain_variables=True):
@@ -430,15 +430,15 @@ def get_style_model_and_losses(cnn, style_img, content_img,
430430

431431
if name in content_layers:
432432
# add content loss:
433-
target = model.forward(content_img).clone()
433+
target = model(content_img).clone()
434434
content_loss = ContentLoss(target, content_weight)
435435
model.add_module("content_loss_" + str(i), content_loss)
436436
content_losses.append(content_loss)
437437

438438
if name in style_layers:
439439
# add style loss:
440-
target_feature = model.forward(style_img).clone()
441-
target_feature_gram = gram.forward(target_feature)
440+
target_feature = model(style_img).clone()
441+
target_feature_gram = gram(target_feature)
442442
style_loss = StyleLoss(target_feature_gram, style_weight)
443443
model.add_module("style_loss_" + str(i), style_loss)
444444
style_losses.append(style_loss)
@@ -449,15 +449,15 @@ def get_style_model_and_losses(cnn, style_img, content_img,
449449

450450
if name in content_layers:
451451
# add content loss:
452-
target = model.forward(content_img).clone()
452+
target = model(content_img).clone()
453453
content_loss = ContentLoss(target, content_weight)
454454
model.add_module("content_loss_" + str(i), content_loss)
455455
content_losses.append(content_loss)
456456

457457
if name in style_layers:
458458
# add style loss:
459-
target_feature = model.forward(style_img).clone()
460-
target_feature_gram = gram.forward(target_feature)
459+
target_feature = model(style_img).clone()
460+
target_feature_gram = gram(target_feature)
461461
style_loss = StyleLoss(target_feature_gram, style_weight)
462462
model.add_module("style_loss_" + str(i), style_loss)
463463
style_losses.append(style_loss)
@@ -564,7 +564,7 @@ def closure():
564564
input_param.data.clamp_(0, 1)
565565

566566
optimizer.zero_grad()
567-
model.forward(input_param)
567+
model(input_param)
568568
style_score = 0
569569
content_score = 0
570570

0 commit comments

Comments
 (0)