Skip to content

Commit d0ef61b

Browse files
author
Jessica Lin
authored
Merge pull request #749 from raghuramank100/jlin27-quant-tutorials
Fix formatting and clean up tutorial on quantized transfer learning
2 parents f58558a + a800c77 commit d0ef61b

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

intermediate_source/quantized_transfer_learning_tutorial.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,11 @@
8484

8585

8686
######################################################################
87-
# Load Data (section not needed as it is covered in the original tutorial)
87+
# Load Data
8888
# ------------------------------------------------------------------------
8989
#
90+
# ..Note :: This section is identical to the original transfer learning tutorial.
91+
#
9092
# We will use ``torchvision`` and ``torch.utils.data`` packages to load
9193
# the data.
9294
#
@@ -360,7 +362,7 @@ def visualize_model(model, rows=3, cols=3):
360362
# **Notice that when isolating the feature extractor from a quantized
361363
# model, you have to place the quantizer in the beginning and in the end
362364
# of it.**
363-
#
365+
# We write a helper function to create a model with a custom head.
364366

365367
from torch import nn
366368

@@ -394,8 +396,6 @@ def create_combined_model(model_fe):
394396
)
395397
return new_model
396398

397-
new_model = create_combined_model(model_fe)
398-
399399

400400
######################################################################
401401
# .. warning:: Currently the quantized models can only be run on CPU.
@@ -404,6 +404,7 @@ def create_combined_model(model_fe):
404404
#
405405

406406
import torch.optim as optim
407+
new_model = create_combined_model(model_fe)
407408
new_model = new_model.to('cpu')
408409

409410
criterion = nn.CrossEntropyLoss()
@@ -431,7 +432,7 @@ def create_combined_model(model_fe):
431432

432433

433434
######################################################################
434-
# **Part 2. Finetuning the quantizable model**
435+
# Part 2. Finetuning the quantizable model
435436
#
436437
# In this part, we fine tune the feature extractor used for transfer
437438
# learning, and quantize the feature extractor. Note that in both part 1
@@ -446,18 +447,21 @@ def create_combined_model(model_fe):
446447
# datasets.
447448
#
448449
# The pretrained feature extractor must be quantizable, i.e we need to do
449-
# the following: 1. Fuse (Conv, BN, ReLU), (Conv, BN) and (Conv, ReLU)
450-
# using torch.quantization.fuse_modules. 2. Connect the feature extractor
451-
# with a custom head. This requires dequantizing the output of the feature
452-
# extractor. 3. Insert fake-quantization modules at appropriate locations
453-
# in the feature extractor to mimic quantization during training.
450+
# the following:
451+
# 1. Fuse (Conv, BN, ReLU), (Conv, BN) and (Conv, ReLU)
452+
# using torch.quantization.fuse_modules.
453+
# 2. Connect the feature extractor
454+
# with a custom head. This requires dequantizing the output of the feature
455+
# extractor.
456+
# 3. Insert fake-quantization modules at appropriate locations
457+
# in the feature extractor to mimic quantization during training.
454458
#
455459
# For step (1), we use models from torchvision/models/quantization, which
456460
# support a member method fuse_model, which fuses all the conv, bn, and
457461
# relu modules. In general, this would require calling the
458462
# torch.quantization.fuse_modules API with the list of modules to fuse.
459463
#
460-
# Step (2) is done by the function create_custom_model function that we
464+
# Step (2) is done by the function create_combined_model function that we
461465
# used in the previous section.
462466
#
463467
# Step (3) is achieved by using torch.quantization.prepare_qat, which
@@ -534,4 +538,3 @@ def create_combined_model(model_fe):
534538
plt.ioff()
535539
plt.tight_layout()
536540
plt.show()
537-

0 commit comments

Comments
 (0)