From 97f537759767d0db52f6ba3db29edbc071c19297 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 13 Dec 2021 11:05:17 -0800 Subject: [PATCH] Fix quantization tutorials (imports, syntax, and style) Summary: This commit fixes the quantization tutorials such that they can be run smoothly by the user. Test Plan: Ran the updated tutorials without problem. Reviewers: jerryzh168 Subscribers: jerryzh168, supriyar [ghstack-poisoned] --- .../dynamic_quantization_tutorial.py | 6 ++-- .../static_quantization_tutorial.rst | 33 ++++++++---------- prototype_source/fx_graph_mode_ptq_static.rst | 34 +++++++++---------- 3 files changed, 34 insertions(+), 39 deletions(-) diff --git a/advanced_source/dynamic_quantization_tutorial.py b/advanced_source/dynamic_quantization_tutorial.py index 07609eec853..9e09d792c1a 100644 --- a/advanced_source/dynamic_quantization_tutorial.py +++ b/advanced_source/dynamic_quantization_tutorial.py @@ -98,9 +98,9 @@ def __len__(self): class Corpus(object): def __init__(self, path): self.dictionary = Dictionary() - self.train = self.tokenize(os.path.join(path, 'train.txt')) - self.valid = self.tokenize(os.path.join(path, 'valid.txt')) - self.test = self.tokenize(os.path.join(path, 'test.txt')) + self.train = self.tokenize(os.path.join(path, 'wiki.train.token')) + self.valid = self.tokenize(os.path.join(path, 'wiki.valid.token')) + self.test = self.tokenize(os.path.join(path, 'wiki.test.token')) def tokenize(self, path): """Tokenizes a text file.""" diff --git a/advanced_source/static_quantization_tutorial.rst b/advanced_source/static_quantization_tutorial.rst index 79f76b805e1..9652b28783c 100644 --- a/advanced_source/static_quantization_tutorial.rst +++ b/advanced_source/static_quantization_tutorial.rst @@ -32,7 +32,7 @@ We'll start by doing the necessary imports: import sys import torch.quantization - # # Setup warnings + # Set up warnings import warnings warnings.filterwarnings( action='ignore', @@ -196,9 +196,7 @@ Note: this code is taken from nn.init.zeros_(m.bias) def forward(self, x): - x = self.quant(x) - x = self.features(x) x = x.mean([2, 3]) x = self.classifier(x) @@ -314,25 +312,22 @@ in this data. These functions mostly come from .. code:: python def prepare_data_loaders(data_path): - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) dataset = torchvision.datasets.ImageNet( - data_path, split="train", - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) + data_path, split="train", transform=transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) dataset_test = torchvision.datasets.ImageNet( - data_path, split="val", - transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])) + data_path, split="val", transform=transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) @@ -630,4 +625,4 @@ and quantization-aware training - describing what they do "under the hood" and h them in PyTorch. Thanks for reading! As always, we welcome any feedback, so please create an issue -`here `_ if you have any. \ No newline at end of file +`here `_ if you have any. diff --git a/prototype_source/fx_graph_mode_ptq_static.rst b/prototype_source/fx_graph_mode_ptq_static.rst index 2fc872b7d98..e40d39a5517 100644 --- a/prototype_source/fx_graph_mode_ptq_static.rst +++ b/prototype_source/fx_graph_mode_ptq_static.rst @@ -75,7 +75,7 @@ Download the `torchvision resnet18 model