Skip to content

Commit 11259a8

Browse files
committed
fix training mode bug in TL tutorial
1 parent 250c671 commit 11259a8

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

beginner_source/transfer_learning_tutorial.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def train_model(model, criterion, optim_scheduler, num_epochs=25):
151151
for phase in ['train', 'val']:
152152
if phase == 'train':
153153
optimizer = optim_scheduler(model, epoch)
154+
model.train(True) # Set model to training mode
155+
else:
156+
model.train(False) # Set model to evaluate mode
154157

155158
running_loss = 0.0
156159
running_corrects = 0

0 commit comments

Comments
 (0)