Skip to content

Commit 6e8d0f7

Browse files
authored
Merge pull request #62 from chsasank/master
Update transfer learning tutorial
2 parents 8f60a6e + 238dcd4 commit 6e8d0f7

File tree

1 file changed

+57
-54
lines changed

1 file changed

+57
-54
lines changed

beginner_source/transfer_learning_tutorial.py

Lines changed: 57 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
# License: BSD
3434
# Author: Sasank Chilamkurthy
3535

36+
from __future__ import print_function, division
37+
3638
import torch
3739
import torch.nn as nn
3840
import torch.optim as optim
@@ -134,13 +136,11 @@ def imshow(inp, title=None):
134136
# - Scheduling the learning rate
135137
# - Saving (deep copying) the best model
136138
#
137-
# In the following, ``optim_scheduler`` is a function which returns an ``optim.SGD``
138-
# object when called as ``optim_scheduler(model, epoch)``. This is useful
139-
# when we want to change the learning rate or restrict the parameters we
140-
# want to optimize.
141-
#
139+
# In the following, parameter ``lr_scheduler(optimizer, epoch)``
140+
# is a function which modifies ``optimizer`` so that the learning
141+
# rate is changed according to desired schedule.
142142

143-
def train_model(model, criterion, optim_scheduler, num_epochs=25):
143+
def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=25):
144144
since = time.time()
145145

146146
best_model = model
@@ -153,7 +153,7 @@ def train_model(model, criterion, optim_scheduler, num_epochs=25):
153153
# Each epoch has a training and validation phase
154154
for phase in ['train', 'val']:
155155
if phase == 'train':
156-
optimizer = optim_scheduler(model, epoch)
156+
optimizer = lr_scheduler(optimizer, epoch)
157157
model.train(True) # Set model to training mode
158158
else:
159159
model.train(False) # Set model to evaluate mode
@@ -209,6 +209,24 @@ def train_model(model, criterion, optim_scheduler, num_epochs=25):
209209
print('Best val Acc: {:4f}'.format(best_acc))
210210
return best_model
211211

212+
######################################################################
213+
# Learning rate scheduler
214+
# ^^^^^^^^^^^^^^^^^^^^^^^
215+
# Let's create our learning rate scheduler. We will exponentially
216+
# decrease the learning rate once every few epochs.
217+
218+
def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=7):
219+
"""Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
220+
lr = init_lr * (0.1**(epoch // lr_decay_epoch))
221+
222+
if epoch % lr_decay_epoch == 0:
223+
print('LR is set to {}'.format(lr))
224+
225+
for param_group in optimizer.param_groups:
226+
param_group['lr'] = lr
227+
228+
return optimizer
229+
212230

213231
######################################################################
214232
# Visualizing the model predictions
@@ -217,7 +235,10 @@ def train_model(model, criterion, optim_scheduler, num_epochs=25):
217235
# Generic function to display predictions for a few images
218236
#
219237

220-
def visualize_model(model, num_images=5):
238+
def visualize_model(model, num_images=6):
239+
images_so_far = 0
240+
fig = plt.figure()
241+
221242
for i, data in enumerate(dset_loaders['val']):
222243
inputs, labels = data
223244
if use_gpu:
@@ -228,45 +249,34 @@ def visualize_model(model, num_images=5):
228249
outputs = model(inputs)
229250
_, preds = torch.max(outputs.data, 1)
230251

231-
plt.figure()
232-
imshow(inputs.cpu().data[0],
233-
title='pred: {}'.format(dset_classes[labels.data[0]]))
234-
235-
if i == num_images - 1:
236-
break
252+
for j in range(inputs.size()[0]):
253+
images_so_far += 1
254+
ax = plt.subplot(num_images//2, 2, images_so_far)
255+
ax.axis('off')
256+
ax.set_title('predicted: {}'.format(dset_classes[labels.data[j]]))
257+
imshow(inputs.cpu().data[j])
237258

259+
if images_so_far == num_images:
260+
return
238261

239262
######################################################################
240263
# Finetuning the convnet
241264
# ----------------------
242265
#
243-
# First, let's create our learning rate scheduler. We will exponentially
244-
# decrease the learning rate once every few epochs.
245-
#
246-
247-
def optim_scheduler_ft(model, epoch, init_lr=0.001, lr_decay_epoch=7):
248-
lr = init_lr * (0.1**(epoch // lr_decay_epoch))
249-
250-
if epoch % lr_decay_epoch == 0:
251-
print('LR is set to {}'.format(lr))
252-
253-
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
254-
return optimizer
255-
256-
257-
######################################################################
258266
# Load a pretrained model and reset final fully connected layer.
259267
#
260268

261-
model = models.resnet18(pretrained=True)
262-
num_ftrs = model.fc.in_features
263-
model.fc = nn.Linear(num_ftrs, 2)
269+
model_ft = models.resnet18(pretrained=True)
270+
num_ftrs = model_ft.fc.in_features
271+
model_ft.fc = nn.Linear(num_ftrs, 2)
264272

265273
if use_gpu:
266-
model = model.cuda()
274+
model_ft = model_ft.cuda()
267275

268276
criterion = nn.CrossEntropyLoss()
269277

278+
# Observe that all parameters are being optimized
279+
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
270280

271281
######################################################################
272282
# Train and evaluate
@@ -276,12 +286,13 @@ def optim_scheduler_ft(model, epoch, init_lr=0.001, lr_decay_epoch=7):
276286
# minute.
277287
#
278288

279-
model = train_model(model, criterion, optim_scheduler_ft, num_epochs=25)
289+
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
290+
num_epochs=25)
280291

281292
######################################################################
282293
#
283294

284-
visualize_model(model)
295+
visualize_model(model_ft)
285296

286297

287298
######################################################################
@@ -296,31 +307,22 @@ def optim_scheduler_ft(model, epoch, init_lr=0.001, lr_decay_epoch=7):
296307
# `here <http://pytorch.org/docs/notes/autograd.html#excluding-subgraphs-from-backward>`__.
297308
#
298309

299-
model = torchvision.models.resnet18(pretrained=True)
300-
for param in model.parameters():
310+
model_conv = torchvision.models.resnet18(pretrained=True)
311+
for param in model_conv.parameters():
301312
param.requires_grad = False
302313

303314
# Parameters of newly constructed modules have requires_grad=True by default
304-
num_ftrs = model.fc.in_features
305-
model.fc = nn.Linear(num_ftrs, 2)
315+
num_ftrs = model_conv.fc.in_features
316+
model_conv.fc = nn.Linear(num_ftrs, 2)
306317

307318
if use_gpu:
308-
model = model.cuda()
319+
model_conv = model_conv.cuda()
309320

310321
criterion = nn.CrossEntropyLoss()
311-
######################################################################
312-
# Let's write ``optim_scheduler``. We will use previous lr scheduler. Also
313-
# we need to optimize only the parameters of final FC layer.
314-
#
315-
316-
def optim_scheduler_conv(model, epoch, init_lr=0.001, lr_decay_epoch=7):
317-
lr = init_lr * (0.1**(epoch // lr_decay_epoch))
318322

319-
if epoch % lr_decay_epoch == 0:
320-
print('LR is set to {}'.format(lr))
321-
322-
optimizer = optim.SGD(model.fc.parameters(), lr=lr, momentum=0.9)
323-
return optimizer
323+
# Observe that only parameters of final layer are being optimized as
324+
# opoosed to before.
325+
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
324326

325327

326328
######################################################################
@@ -332,12 +334,13 @@ def optim_scheduler_conv(model, epoch, init_lr=0.001, lr_decay_epoch=7):
332334
# network. However, forward does need to be computed.
333335
#
334336

335-
model = train_model(model, criterion, optim_scheduler_conv)
337+
model_conv = train_model(model_conv, criterion, optimizer_conv,
338+
exp_lr_scheduler, num_epochs=25)
336339

337340
######################################################################
338341
#
339342

340-
visualize_model(model)
343+
visualize_model(model_conv)
341344

342345
plt.ioff()
343346
plt.show()

0 commit comments

Comments
 (0)