Skip to content

Fix quantization tutorials (imports, syntax, and style) #1763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions advanced_source/dynamic_quantization_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def __len__(self):
class Corpus(object):
def __init__(self, path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.txt'))
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
self.test = self.tokenize(os.path.join(path, 'test.txt'))
self.train = self.tokenize(os.path.join(path, 'wiki.train.token'))
self.valid = self.tokenize(os.path.join(path, 'wiki.valid.token'))
self.test = self.tokenize(os.path.join(path, 'wiki.test.token'))

def tokenize(self, path):
"""Tokenizes a text file."""
Expand Down
33 changes: 14 additions & 19 deletions advanced_source/static_quantization_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ We'll start by doing the necessary imports:
import sys
import torch.quantization

# # Setup warnings
# Set up warnings
import warnings
warnings.filterwarnings(
action='ignore',
Expand Down Expand Up @@ -196,9 +196,7 @@ Note: this code is taken from
nn.init.zeros_(m.bias)

def forward(self, x):

x = self.quant(x)

x = self.features(x)
x = x.mean([2, 3])
x = self.classifier(x)
Expand Down Expand Up @@ -314,25 +312,22 @@ in this data. These functions mostly come from
.. code:: python

def prepare_data_loaders(data_path):

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = torchvision.datasets.ImageNet(
data_path, split="train",
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
data_path, split="train", transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
dataset_test = torchvision.datasets.ImageNet(
data_path, split="val",
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
data_path, split="val", transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))

train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
Expand Down Expand Up @@ -630,4 +625,4 @@ and quantization-aware training - describing what they do "under the hood" and h
them in PyTorch.

Thanks for reading! As always, we welcome any feedback, so please create an issue
`here <https://github.com/pytorch/pytorch/issues>`_ if you have any.
`here <https://github.com/pytorch/pytorch/issues>`_ if you have any.
34 changes: 17 additions & 17 deletions prototype_source/fx_graph_mode_ptq_static.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Download the `torchvision resnet18 model <https://github.com/pytorch/vision/blob
import sys
import torch.quantization

# Setup warnings
# Set up warnings
import warnings
warnings.filterwarnings(
action='ignore',
Expand Down Expand Up @@ -168,25 +168,22 @@ Download the `torchvision resnet18 model <https://github.com/pytorch/vision/blob
os.remove("temp.p")

def prepare_data_loaders(data_path):

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = torchvision.datasets.ImageNet(
data_path, split="train",
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
data_path, split="train", transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
dataset_test = torchvision.datasets.ImageNet(
data_path, split="val",
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
data_path, split="val", transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))

train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
Expand Down Expand Up @@ -239,7 +236,7 @@ of the observers for activation and weight. ``qconfig_dict`` is a dictionary wit
.. code:: python

qconfig = {
" : qconfig_global,
"" : qconfig_global,
"sub" : qconfig_sub,
"sub.fc" : qconfig_fc,
"sub.conv": None
Expand Down Expand Up @@ -294,6 +291,7 @@ Utility functions related to ``qconfig`` can be found in the `qconfig <https://g

.. code:: python

from torch.ao.quantization.quantize_fx import prepare_fx
prepared_model = prepare_fx(model_to_quantize, qconfig_dict)

prepare_fx folds BatchNorm modules into previous Conv2d modules, and insert observers
Expand Down Expand Up @@ -326,6 +324,7 @@ the statistics of the Tensors and we can later use this information to calculate

.. code:: python

from torch.ao.quantization.quantize_fx import convert_fx
quantized_model = convert_fx(prepared_model)
print(quantized_model)

Expand Down Expand Up @@ -377,6 +376,7 @@ Note that ``fuse_fx`` only works in eval mode.

.. code:: python

from torch.ao.quantization.quantize_fx import fuse_fx
fused = fuse_fx(float_model)

conv1_weight_after_fuse = fused.conv1[0].weight[0]
Expand Down