Skip to content

Fix quantization tutorials (imports, syntax, and style) #1772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 42 additions & 46 deletions advanced_source/static_quantization_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,20 @@ We'll start by doing the necessary imports:

.. code:: python

import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import os
import time
import sys
import torch.quantization

# # Setup warnings
import os
import sys
import time
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms

# Set up warnings
import warnings
warnings.filterwarnings(
action='ignore',
Expand All @@ -41,7 +42,7 @@ We'll start by doing the necessary imports:
)
warnings.filterwarnings(
action='default',
module=r'torch.quantization'
module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
Expand All @@ -62,7 +63,7 @@ Note: this code is taken from

.. code:: python

from torch.quantization import QuantStub, DeQuantStub
from torch.ao.quantization import QuantStub, DeQuantStub

def _make_divisible(v, divisor, min_value=None):
"""
Expand Down Expand Up @@ -196,9 +197,7 @@ Note: this code is taken from
nn.init.zeros_(m.bias)

def forward(self, x):

x = self.quant(x)

x = self.features(x)
x = x.mean([2, 3])
x = self.classifier(x)
Expand All @@ -210,11 +209,11 @@ Note: this code is taken from
def fuse_model(self):
for m in self.modules():
if type(m) == ConvBNReLU:
torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)
torch.ao.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)
if type(m) == InvertedResidual:
for idx in range(len(m.conv)):
if type(m.conv[idx]) == nn.Conv2d:
torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)
torch.ao.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)

2. Helper functions
-------------------
Expand Down Expand Up @@ -314,25 +313,22 @@ in this data. These functions mostly come from
.. code:: python

def prepare_data_loaders(data_path):

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = torchvision.datasets.ImageNet(
data_path, split="train",
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
data_path, split="train", transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
dataset_test = torchvision.datasets.ImageNet(
data_path, split="val",
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
data_path, split="val", transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))

train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
Expand All @@ -348,8 +344,8 @@ in this data. These functions mostly come from
return data_loader, data_loader_test


Next, we'll load in the pre-trained MobileNetV2 model. We provide the URL to download the data from in ``torchvision``
`here <https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenet.py#L9>`_.
Next, we'll load in the pre-trained MobileNetV2 model. We provide the URL to download the model
`here <https://download.pytorch.org/models/mobilenet_v2-b0353104.pth>`_.

.. code:: python

Expand Down Expand Up @@ -424,9 +420,9 @@ values to floats - and then back to ints - between every operation, resulting in

# Specify quantization configuration
# Start with simple min/max range estimation and per-tensor quantization of weights
myModel.qconfig = torch.quantization.default_qconfig
myModel.qconfig = torch.ao.quantization.default_qconfig
print(myModel.qconfig)
torch.quantization.prepare(myModel, inplace=True)
torch.ao.quantization.prepare(myModel, inplace=True)

# Calibrate first
print('Post Training Quantization Prepare: Inserting Observers')
Expand All @@ -437,7 +433,7 @@ values to floats - and then back to ints - between every operation, resulting in
print('Post Training Quantization: Calibration done')

# Convert to quantized model
torch.quantization.convert(myModel, inplace=True)
torch.ao.quantization.convert(myModel, inplace=True)
print('Post Training Quantization: Convert done')
print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n',myModel.features[1].conv)

Expand All @@ -462,12 +458,12 @@ quantizing for x86 architectures. This configuration does the following:
per_channel_quantized_model = load_model(saved_model_dir + float_model_file)
per_channel_quantized_model.eval()
per_channel_quantized_model.fuse_model()
per_channel_quantized_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
per_channel_quantized_model.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
print(per_channel_quantized_model.qconfig)

torch.quantization.prepare(per_channel_quantized_model, inplace=True)
torch.ao.quantization.prepare(per_channel_quantized_model, inplace=True)
evaluate(per_channel_quantized_model,criterion, data_loader, num_calibration_batches)
torch.quantization.convert(per_channel_quantized_model, inplace=True)
torch.ao.quantization.convert(per_channel_quantized_model, inplace=True)
top1, top5 = evaluate(per_channel_quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(per_channel_quantized_model), saved_model_dir + scripted_quantized_model_file)
Expand Down Expand Up @@ -539,13 +535,13 @@ We fuse modules as before
qat_model.fuse_model()

optimizer = torch.optim.SGD(qat_model.parameters(), lr = 0.0001)
qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
qat_model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')

Finally, ``prepare_qat`` performs the "fake quantization", preparing the model for quantization-aware training

.. code:: python

torch.quantization.prepare_qat(qat_model, inplace=True)
torch.ao.quantization.prepare_qat(qat_model, inplace=True)
print('Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n',qat_model.features[1].conv)

Training a quantized model with high accuracy requires accurate modeling of numerics at
Expand All @@ -565,13 +561,13 @@ inference. For quantization aware training, therefore, we modify the training lo
train_one_epoch(qat_model, criterion, optimizer, data_loader, torch.device('cpu'), num_train_batches)
if nepoch > 3:
# Freeze quantizer parameters
qat_model.apply(torch.quantization.disable_observer)
qat_model.apply(torch.ao.quantization.disable_observer)
if nepoch > 2:
# Freeze batch norm mean and variance estimates
qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

# Check the accuracy after each epoch
quantized_model = torch.quantization.convert(qat_model.eval(), inplace=False)
quantized_model = torch.ao.quantization.convert(qat_model.eval(), inplace=False)
quantized_model.eval()
top1, top5 = evaluate(quantized_model,criterion, data_loader_test, neval_batches=num_eval_batches)
print('Epoch %d :Evaluation accuracy on %d images, %2.2f'%(nepoch, num_eval_batches * eval_batch_size, top1.avg))
Expand Down
73 changes: 35 additions & 38 deletions prototype_source/fx_graph_mode_ptq_static.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ tldr; The FX Graph Mode API looks like the following:
.. code:: python

import torch
from torch.quantization import get_default_qconfig
# Note that this is temporary, we'll expose these functions to torch.quantization after official releasee
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
float_model.eval()
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
Expand Down Expand Up @@ -58,24 +57,28 @@ These steps are identitcal to `Static Quantization with Eager Mode in PyTorch <h

To run the code in this tutorial using the entire ImageNet dataset, first download imagenet by following the instructions at here `ImageNet Data <http://www.image-net.org/download>`_. Unzip the downloaded file into the 'data_path' folder.

Download the `torchvision resnet18 model <https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L12>`_ and rename it to
Download the `torchvision resnet18 model <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_ and rename it to
``data/resnet18_pretrained_float.pth``.

.. code:: python

import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import os
import time
import sys
import torch.quantization

# Setup warnings
import os
import sys
import time
import numpy as np

import torch
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, fuse_fx
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
from torchvision.models.resnet import resnet18
import torchvision.transforms as transforms

# Set up warnings
import warnings
warnings.filterwarnings(
action='ignore',
Expand All @@ -84,16 +87,13 @@ Download the `torchvision resnet18 model <https://github.com/pytorch/vision/blob
)
warnings.filterwarnings(
action='default',
module=r'torch.quantization'
module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
_ = torch.manual_seed(191009)


from torchvision.models.resnet import resnet18
from torch.quantization import get_default_qconfig, quantize_jit

class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
Expand Down Expand Up @@ -168,25 +168,22 @@ Download the `torchvision resnet18 model <https://github.com/pytorch/vision/blob
os.remove("temp.p")

def prepare_data_loaders(data_path):

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = torchvision.datasets.ImageNet(
data_path, split="train",
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
data_path, split="train", transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
dataset_test = torchvision.datasets.ImageNet(
data_path, split="val",
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
data_path, split="val", transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))

train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
Expand Down Expand Up @@ -239,7 +236,7 @@ of the observers for activation and weight. ``qconfig_dict`` is a dictionary wit
.. code:: python

qconfig = {
" : qconfig_global,
"" : qconfig_global,
"sub" : qconfig_sub,
"sub.fc" : qconfig_fc,
"sub.conv": None
Expand Down Expand Up @@ -282,7 +279,7 @@ of the observers for activation and weight. ``qconfig_dict`` is a dictionary wit
]
}

Utility functions related to ``qconfig`` can be found in the `qconfig <https://github.com/pytorch/pytorch/blob/master/torch/quantization/qconfig.py>`_ file.
Utility functions related to ``qconfig`` can be found in the `qconfig <https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/qconfig.py>`_ file.

.. code:: python

Expand Down