diff --git a/advanced_source/static_quantization_tutorial.rst b/advanced_source/static_quantization_tutorial.rst index aaceb9f9a40..967dc8ed769 100644 --- a/advanced_source/static_quantization_tutorial.rst +++ b/advanced_source/static_quantization_tutorial.rst @@ -20,19 +20,20 @@ We'll start by doing the necessary imports: .. code:: python - import numpy as np - import torch - import torch.nn as nn - import torchvision - from torch.utils.data import DataLoader - from torchvision import datasets - import torchvision.transforms as transforms - import os - import time - import sys - import torch.quantization - - # # Setup warnings + import os + import sys + import time + import numpy as np + + import torch + import torch.nn as nn + from torch.utils.data import DataLoader + + import torchvision + from torchvision import datasets + import torchvision.transforms as transforms + + # Set up warnings import warnings warnings.filterwarnings( action='ignore', @@ -41,7 +42,7 @@ We'll start by doing the necessary imports: ) warnings.filterwarnings( action='default', - module=r'torch.quantization' + module=r'torch.ao.quantization' ) # Specify random seed for repeatable results @@ -62,7 +63,7 @@ Note: this code is taken from .. code:: python - from torch.quantization import QuantStub, DeQuantStub + from torch.ao.quantization import QuantStub, DeQuantStub def _make_divisible(v, divisor, min_value=None): """ @@ -196,9 +197,7 @@ Note: this code is taken from nn.init.zeros_(m.bias) def forward(self, x): - x = self.quant(x) - x = self.features(x) x = x.mean([2, 3]) x = self.classifier(x) @@ -210,11 +209,11 @@ Note: this code is taken from def fuse_model(self): for m in self.modules(): if type(m) == ConvBNReLU: - torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True) + torch.ao.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True) if type(m) == InvertedResidual: for idx in range(len(m.conv)): if type(m.conv[idx]) == nn.Conv2d: - torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True) + torch.ao.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True) 2. Helper functions ------------------- @@ -314,25 +313,22 @@ in this data. These functions mostly come from .. code:: python def prepare_data_loaders(data_path): - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) dataset = torchvision.datasets.ImageNet( - data_path, split="train", - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) + data_path, split="train", transform=transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) dataset_test = torchvision.datasets.ImageNet( - data_path, split="val", - transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])) + data_path, split="val", transform=transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) @@ -348,8 +344,8 @@ in this data. These functions mostly come from return data_loader, data_loader_test -Next, we'll load in the pre-trained MobileNetV2 model. We provide the URL to download the data from in ``torchvision`` -`here `_. +Next, we'll load in the pre-trained MobileNetV2 model. We provide the URL to download the model +`here `_. .. code:: python @@ -424,9 +420,9 @@ values to floats - and then back to ints - between every operation, resulting in # Specify quantization configuration # Start with simple min/max range estimation and per-tensor quantization of weights - myModel.qconfig = torch.quantization.default_qconfig + myModel.qconfig = torch.ao.quantization.default_qconfig print(myModel.qconfig) - torch.quantization.prepare(myModel, inplace=True) + torch.ao.quantization.prepare(myModel, inplace=True) # Calibrate first print('Post Training Quantization Prepare: Inserting Observers') @@ -437,7 +433,7 @@ values to floats - and then back to ints - between every operation, resulting in print('Post Training Quantization: Calibration done') # Convert to quantized model - torch.quantization.convert(myModel, inplace=True) + torch.ao.quantization.convert(myModel, inplace=True) print('Post Training Quantization: Convert done') print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n',myModel.features[1].conv) @@ -462,12 +458,12 @@ quantizing for x86 architectures. This configuration does the following: per_channel_quantized_model = load_model(saved_model_dir + float_model_file) per_channel_quantized_model.eval() per_channel_quantized_model.fuse_model() - per_channel_quantized_model.qconfig = torch.quantization.get_default_qconfig('fbgemm') + per_channel_quantized_model.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm') print(per_channel_quantized_model.qconfig) - torch.quantization.prepare(per_channel_quantized_model, inplace=True) + torch.ao.quantization.prepare(per_channel_quantized_model, inplace=True) evaluate(per_channel_quantized_model,criterion, data_loader, num_calibration_batches) - torch.quantization.convert(per_channel_quantized_model, inplace=True) + torch.ao.quantization.convert(per_channel_quantized_model, inplace=True) top1, top5 = evaluate(per_channel_quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches) print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg)) torch.jit.save(torch.jit.script(per_channel_quantized_model), saved_model_dir + scripted_quantized_model_file) @@ -539,13 +535,13 @@ We fuse modules as before qat_model.fuse_model() optimizer = torch.optim.SGD(qat_model.parameters(), lr = 0.0001) - qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') + qat_model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm') Finally, ``prepare_qat`` performs the "fake quantization", preparing the model for quantization-aware training .. code:: python - torch.quantization.prepare_qat(qat_model, inplace=True) + torch.ao.quantization.prepare_qat(qat_model, inplace=True) print('Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n',qat_model.features[1].conv) Training a quantized model with high accuracy requires accurate modeling of numerics at @@ -565,13 +561,13 @@ inference. For quantization aware training, therefore, we modify the training lo train_one_epoch(qat_model, criterion, optimizer, data_loader, torch.device('cpu'), num_train_batches) if nepoch > 3: # Freeze quantizer parameters - qat_model.apply(torch.quantization.disable_observer) + qat_model.apply(torch.ao.quantization.disable_observer) if nepoch > 2: # Freeze batch norm mean and variance estimates qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) # Check the accuracy after each epoch - quantized_model = torch.quantization.convert(qat_model.eval(), inplace=False) + quantized_model = torch.ao.quantization.convert(qat_model.eval(), inplace=False) quantized_model.eval() top1, top5 = evaluate(quantized_model,criterion, data_loader_test, neval_batches=num_eval_batches) print('Epoch %d :Evaluation accuracy on %d images, %2.2f'%(nepoch, num_eval_batches * eval_batch_size, top1.avg)) diff --git a/prototype_source/fx_graph_mode_ptq_static.rst b/prototype_source/fx_graph_mode_ptq_static.rst index 2fc872b7d98..812c9d23f4d 100644 --- a/prototype_source/fx_graph_mode_ptq_static.rst +++ b/prototype_source/fx_graph_mode_ptq_static.rst @@ -13,9 +13,8 @@ tldr; The FX Graph Mode API looks like the following: .. code:: python import torch - from torch.quantization import get_default_qconfig - # Note that this is temporary, we'll expose these functions to torch.quantization after official releasee - from torch.quantization.quantize_fx import prepare_fx, convert_fx + from torch.ao.quantization import get_default_qconfig + from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx float_model.eval() qconfig = get_default_qconfig("fbgemm") qconfig_dict = {"": qconfig} @@ -58,24 +57,28 @@ These steps are identitcal to `Static Quantization with Eager Mode in PyTorch `_. Unzip the downloaded file into the 'data_path' folder. -Download the `torchvision resnet18 model `_ and rename it to +Download the `torchvision resnet18 model `_ and rename it to ``data/resnet18_pretrained_float.pth``. .. code:: python - import numpy as np - import torch - import torch.nn as nn - import torchvision - from torch.utils.data import DataLoader - from torchvision import datasets - import torchvision.transforms as transforms - import os - import time - import sys - import torch.quantization - - # Setup warnings + import os + import sys + import time + import numpy as np + + import torch + from torch.ao.quantization import get_default_qconfig + from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, fuse_fx + import torch.nn as nn + from torch.utils.data import DataLoader + + import torchvision + from torchvision import datasets + from torchvision.models.resnet import resnet18 + import torchvision.transforms as transforms + + # Set up warnings import warnings warnings.filterwarnings( action='ignore', @@ -84,16 +87,13 @@ Download the `torchvision resnet18 model `_ file. +Utility functions related to ``qconfig`` can be found in the `qconfig `_ file. .. code:: python