diff --git a/beginner_source/transfer_learning_tutorial.py b/beginner_source/transfer_learning_tutorial.py index 116de36587b..de15552282f 100644 --- a/beginner_source/transfer_learning_tutorial.py +++ b/beginner_source/transfer_learning_tutorial.py @@ -297,7 +297,8 @@ def optim_scheduler_ft(model, epoch, init_lr=0.001, lr_decay_epoch=7): param.requires_grad = False # Parameters of newly constructed modules have requires_grad=True by default -model.fc = nn.Linear(512, 100) +num_ftrs = model.fc.in_features +model.fc = nn.Linear(num_ftrs, 2) if use_gpu: model = model.cuda()