84
84
85
85
86
86
######################################################################
87
- # Load Data (section not needed as it is covered in the original tutorial)
87
+ # Load Data
88
88
# ------------------------------------------------------------------------
89
89
#
90
+ # ..Note :: This section is identical to the original transfer learning tutorial.
91
+ #
90
92
# We will use ``torchvision`` and ``torch.utils.data`` packages to load
91
93
# the data.
92
94
#
@@ -360,7 +362,7 @@ def visualize_model(model, rows=3, cols=3):
360
362
# **Notice that when isolating the feature extractor from a quantized
361
363
# model, you have to place the quantizer in the beginning and in the end
362
364
# of it.**
363
- #
365
+ # We write a helper function to create a model with a custom head.
364
366
365
367
from torch import nn
366
368
@@ -394,8 +396,6 @@ def create_combined_model(model_fe):
394
396
)
395
397
return new_model
396
398
397
- new_model = create_combined_model (model_fe )
398
-
399
399
400
400
######################################################################
401
401
# .. warning:: Currently the quantized models can only be run on CPU.
@@ -404,6 +404,7 @@ def create_combined_model(model_fe):
404
404
#
405
405
406
406
import torch .optim as optim
407
+ new_model = create_combined_model (model_fe )
407
408
new_model = new_model .to ('cpu' )
408
409
409
410
criterion = nn .CrossEntropyLoss ()
@@ -431,7 +432,7 @@ def create_combined_model(model_fe):
431
432
432
433
433
434
######################################################################
434
- # ** Part 2. Finetuning the quantizable model**
435
+ # Part 2. Finetuning the quantizable model
435
436
#
436
437
# In this part, we fine tune the feature extractor used for transfer
437
438
# learning, and quantize the feature extractor. Note that in both part 1
@@ -446,18 +447,21 @@ def create_combined_model(model_fe):
446
447
# datasets.
447
448
#
448
449
# The pretrained feature extractor must be quantizable, i.e we need to do
449
- # the following: 1. Fuse (Conv, BN, ReLU), (Conv, BN) and (Conv, ReLU)
450
- # using torch.quantization.fuse_modules. 2. Connect the feature extractor
451
- # with a custom head. This requires dequantizing the output of the feature
452
- # extractor. 3. Insert fake-quantization modules at appropriate locations
453
- # in the feature extractor to mimic quantization during training.
450
+ # the following:
451
+ # 1. Fuse (Conv, BN, ReLU), (Conv, BN) and (Conv, ReLU)
452
+ # using torch.quantization.fuse_modules.
453
+ # 2. Connect the feature extractor
454
+ # with a custom head. This requires dequantizing the output of the feature
455
+ # extractor.
456
+ # 3. Insert fake-quantization modules at appropriate locations
457
+ # in the feature extractor to mimic quantization during training.
454
458
#
455
459
# For step (1), we use models from torchvision/models/quantization, which
456
460
# support a member method fuse_model, which fuses all the conv, bn, and
457
461
# relu modules. In general, this would require calling the
458
462
# torch.quantization.fuse_modules API with the list of modules to fuse.
459
463
#
460
- # Step (2) is done by the function create_custom_model function that we
464
+ # Step (2) is done by the function create_combined_model function that we
461
465
# used in the previous section.
462
466
#
463
467
# Step (3) is achieved by using torch.quantization.prepare_qat, which
@@ -534,4 +538,3 @@ def create_combined_model(model_fe):
534
538
plt .ioff ()
535
539
plt .tight_layout ()
536
540
plt .show ()
537
-
0 commit comments