33
33
# License: BSD
34
34
# Author: Sasank Chilamkurthy
35
35
36
+ from __future__ import print_function , division
37
+
36
38
import torch
37
39
import torch .nn as nn
38
40
import torch .optim as optim
@@ -134,13 +136,11 @@ def imshow(inp, title=None):
134
136
# - Scheduling the learning rate
135
137
# - Saving (deep copying) the best model
136
138
#
137
- # In the following, ``optim_scheduler`` is a function which returns an ``optim.SGD``
138
- # object when called as ``optim_scheduler(model, epoch)``. This is useful
139
- # when we want to change the learning rate or restrict the parameters we
140
- # want to optimize.
141
- #
139
+ # In the following, parameter ``lr_scheduler(optimizer, epoch)``
140
+ # is a function which modifies ``optimizer`` so that the learning
141
+ # rate is changed according to desired schedule.
142
142
143
- def train_model (model , criterion , optim_scheduler , num_epochs = 25 ):
143
+ def train_model (model , criterion , optimizer , lr_scheduler , num_epochs = 25 ):
144
144
since = time .time ()
145
145
146
146
best_model = model
@@ -153,7 +153,7 @@ def train_model(model, criterion, optim_scheduler, num_epochs=25):
153
153
# Each epoch has a training and validation phase
154
154
for phase in ['train' , 'val' ]:
155
155
if phase == 'train' :
156
- optimizer = optim_scheduler ( model , epoch )
156
+ optimizer = lr_scheduler ( optimizer , epoch )
157
157
model .train (True ) # Set model to training mode
158
158
else :
159
159
model .train (False ) # Set model to evaluate mode
@@ -209,6 +209,24 @@ def train_model(model, criterion, optim_scheduler, num_epochs=25):
209
209
print ('Best val Acc: {:4f}' .format (best_acc ))
210
210
return best_model
211
211
212
+ ######################################################################
213
+ # Learning rate scheduler
214
+ # ^^^^^^^^^^^^^^^^^^^^^^^
215
+ # Let's create our learning rate scheduler. We will exponentially
216
+ # decrease the learning rate once every few epochs.
217
+
218
+ def exp_lr_scheduler (optimizer , epoch , init_lr = 0.001 , lr_decay_epoch = 7 ):
219
+ """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
220
+ lr = init_lr * (0.1 ** (epoch // lr_decay_epoch ))
221
+
222
+ if epoch % lr_decay_epoch == 0 :
223
+ print ('LR is set to {}' .format (lr ))
224
+
225
+ for param_group in optimizer .param_groups :
226
+ param_group ['lr' ] = lr
227
+
228
+ return optimizer
229
+
212
230
213
231
######################################################################
214
232
# Visualizing the model predictions
@@ -217,7 +235,10 @@ def train_model(model, criterion, optim_scheduler, num_epochs=25):
217
235
# Generic function to display predictions for a few images
218
236
#
219
237
220
- def visualize_model (model , num_images = 5 ):
238
+ def visualize_model (model , num_images = 6 ):
239
+ images_so_far = 0
240
+ fig = plt .figure ()
241
+
221
242
for i , data in enumerate (dset_loaders ['val' ]):
222
243
inputs , labels = data
223
244
if use_gpu :
@@ -228,45 +249,34 @@ def visualize_model(model, num_images=5):
228
249
outputs = model (inputs )
229
250
_ , preds = torch .max (outputs .data , 1 )
230
251
231
- plt . figure ()
232
- imshow ( inputs . cpu (). data [ 0 ],
233
- title = 'pred: {}' . format ( dset_classes [ labels . data [ 0 ]]) )
234
-
235
- if i == num_images - 1 :
236
- break
252
+ for j in range ( inputs . size ()[ 0 ]):
253
+ images_so_far += 1
254
+ ax = plt . subplot ( num_images // 2 , 2 , images_so_far )
255
+ ax . axis ( 'off' )
256
+ ax . set_title ( 'predicted: {}' . format ( dset_classes [ labels . data [ j ]]))
257
+ imshow ( inputs . cpu (). data [ j ])
237
258
259
+ if images_so_far == num_images :
260
+ return
238
261
239
262
######################################################################
240
263
# Finetuning the convnet
241
264
# ----------------------
242
265
#
243
- # First, let's create our learning rate scheduler. We will exponentially
244
- # decrease the learning rate once every few epochs.
245
- #
246
-
247
- def optim_scheduler_ft (model , epoch , init_lr = 0.001 , lr_decay_epoch = 7 ):
248
- lr = init_lr * (0.1 ** (epoch // lr_decay_epoch ))
249
-
250
- if epoch % lr_decay_epoch == 0 :
251
- print ('LR is set to {}' .format (lr ))
252
-
253
- optimizer = optim .SGD (model .parameters (), lr = lr , momentum = 0.9 )
254
- return optimizer
255
-
256
-
257
- ######################################################################
258
266
# Load a pretrained model and reset final fully connected layer.
259
267
#
260
268
261
- model = models .resnet18 (pretrained = True )
262
- num_ftrs = model .fc .in_features
263
- model .fc = nn .Linear (num_ftrs , 2 )
269
+ model_ft = models .resnet18 (pretrained = True )
270
+ num_ftrs = model_ft .fc .in_features
271
+ model_ft .fc = nn .Linear (num_ftrs , 2 )
264
272
265
273
if use_gpu :
266
- model = model .cuda ()
274
+ model_ft = model_ft .cuda ()
267
275
268
276
criterion = nn .CrossEntropyLoss ()
269
277
278
+ # Observe that all parameters are being optimized
279
+ optimizer_ft = optim .SGD (model_ft .parameters (), lr = 0.001 , momentum = 0.9 )
270
280
271
281
######################################################################
272
282
# Train and evaluate
@@ -276,12 +286,13 @@ def optim_scheduler_ft(model, epoch, init_lr=0.001, lr_decay_epoch=7):
276
286
# minute.
277
287
#
278
288
279
- model = train_model (model , criterion , optim_scheduler_ft , num_epochs = 25 )
289
+ model_ft = train_model (model_ft , criterion , optimizer_ft , exp_lr_scheduler ,
290
+ num_epochs = 25 )
280
291
281
292
######################################################################
282
293
#
283
294
284
- visualize_model (model )
295
+ visualize_model (model_ft )
285
296
286
297
287
298
######################################################################
@@ -296,31 +307,22 @@ def optim_scheduler_ft(model, epoch, init_lr=0.001, lr_decay_epoch=7):
296
307
# `here <http://pytorch.org/docs/notes/autograd.html#excluding-subgraphs-from-backward>`__.
297
308
#
298
309
299
- model = torchvision .models .resnet18 (pretrained = True )
300
- for param in model .parameters ():
310
+ model_conv = torchvision .models .resnet18 (pretrained = True )
311
+ for param in model_conv .parameters ():
301
312
param .requires_grad = False
302
313
303
314
# Parameters of newly constructed modules have requires_grad=True by default
304
- num_ftrs = model .fc .in_features
305
- model .fc = nn .Linear (num_ftrs , 2 )
315
+ num_ftrs = model_conv .fc .in_features
316
+ model_conv .fc = nn .Linear (num_ftrs , 2 )
306
317
307
318
if use_gpu :
308
- model = model .cuda ()
319
+ model_conv = model_conv .cuda ()
309
320
310
321
criterion = nn .CrossEntropyLoss ()
311
- ######################################################################
312
- # Let's write ``optim_scheduler``. We will use previous lr scheduler. Also
313
- # we need to optimize only the parameters of final FC layer.
314
- #
315
-
316
- def optim_scheduler_conv (model , epoch , init_lr = 0.001 , lr_decay_epoch = 7 ):
317
- lr = init_lr * (0.1 ** (epoch // lr_decay_epoch ))
318
322
319
- if epoch % lr_decay_epoch == 0 :
320
- print ('LR is set to {}' .format (lr ))
321
-
322
- optimizer = optim .SGD (model .fc .parameters (), lr = lr , momentum = 0.9 )
323
- return optimizer
323
+ # Observe that only parameters of final layer are being optimized as
324
+ # opoosed to before.
325
+ optimizer_conv = optim .SGD (model_conv .fc .parameters (), lr = 0.001 , momentum = 0.9 )
324
326
325
327
326
328
######################################################################
@@ -332,12 +334,13 @@ def optim_scheduler_conv(model, epoch, init_lr=0.001, lr_decay_epoch=7):
332
334
# network. However, forward does need to be computed.
333
335
#
334
336
335
- model = train_model (model , criterion , optim_scheduler_conv )
337
+ model_conv = train_model (model_conv , criterion , optimizer_conv ,
338
+ exp_lr_scheduler , num_epochs = 25 )
336
339
337
340
######################################################################
338
341
#
339
342
340
- visualize_model (model )
343
+ visualize_model (model_conv )
341
344
342
345
plt .ioff ()
343
346
plt .show ()
0 commit comments