Skip to content

Commit 282c441

Browse files
committed
Fix quantization tutorials (imports, syntax, and style)
Summary: This commit fixes the quantization tutorials such that they can be run smoothly by the user. Test Plan: Ran the updated tutorials without problem. Reviewers: jerryzh168 Subscribers: jerryzh168, supriyar ghstack-source-id: 196719d Pull Request resolved: #1763
1 parent 6d0f524 commit 282c441

File tree

3 files changed

+34
-39
lines changed

3 files changed

+34
-39
lines changed

advanced_source/dynamic_quantization_tutorial.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ def __len__(self):
9898
class Corpus(object):
9999
def __init__(self, path):
100100
self.dictionary = Dictionary()
101-
self.train = self.tokenize(os.path.join(path, 'train.txt'))
102-
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
103-
self.test = self.tokenize(os.path.join(path, 'test.txt'))
101+
self.train = self.tokenize(os.path.join(path, 'wiki.train.token'))
102+
self.valid = self.tokenize(os.path.join(path, 'wiki.valid.token'))
103+
self.test = self.tokenize(os.path.join(path, 'wiki.test.token'))
104104

105105
def tokenize(self, path):
106106
"""Tokenizes a text file."""

advanced_source/static_quantization_tutorial.rst

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ We'll start by doing the necessary imports:
3232
import sys
3333
import torch.quantization
3434
35-
# # Setup warnings
35+
# Set up warnings
3636
import warnings
3737
warnings.filterwarnings(
3838
action='ignore',
@@ -196,9 +196,7 @@ Note: this code is taken from
196196
nn.init.zeros_(m.bias)
197197
198198
def forward(self, x):
199-
200199
x = self.quant(x)
201-
202200
x = self.features(x)
203201
x = x.mean([2, 3])
204202
x = self.classifier(x)
@@ -314,25 +312,22 @@ in this data. These functions mostly come from
314312
.. code:: python
315313
316314
def prepare_data_loaders(data_path):
317-
318315
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
319316
std=[0.229, 0.224, 0.225])
320317
dataset = torchvision.datasets.ImageNet(
321-
data_path, split="train",
322-
transforms.Compose([
323-
transforms.RandomResizedCrop(224),
324-
transforms.RandomHorizontalFlip(),
325-
transforms.ToTensor(),
326-
normalize,
327-
]))
318+
data_path, split="train", transform=transforms.Compose([
319+
transforms.RandomResizedCrop(224),
320+
transforms.RandomHorizontalFlip(),
321+
transforms.ToTensor(),
322+
normalize,
323+
]))
328324
dataset_test = torchvision.datasets.ImageNet(
329-
data_path, split="val",
330-
transforms.Compose([
331-
transforms.Resize(256),
332-
transforms.CenterCrop(224),
333-
transforms.ToTensor(),
334-
normalize,
335-
]))
325+
data_path, split="val", transform=transforms.Compose([
326+
transforms.Resize(256),
327+
transforms.CenterCrop(224),
328+
transforms.ToTensor(),
329+
normalize,
330+
]))
336331
337332
train_sampler = torch.utils.data.RandomSampler(dataset)
338333
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
@@ -630,4 +625,4 @@ and quantization-aware training - describing what they do "under the hood" and h
630625
them in PyTorch.
631626

632627
Thanks for reading! As always, we welcome any feedback, so please create an issue
633-
`here <https://github.com/pytorch/pytorch/issues>`_ if you have any.
628+
`here <https://github.com/pytorch/pytorch/issues>`_ if you have any.

prototype_source/fx_graph_mode_ptq_static.rst

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ Download the `torchvision resnet18 model <https://github.com/pytorch/vision/blob
7575
import sys
7676
import torch.quantization
7777
78-
# Setup warnings
78+
# Set up warnings
7979
import warnings
8080
warnings.filterwarnings(
8181
action='ignore',
@@ -168,25 +168,22 @@ Download the `torchvision resnet18 model <https://github.com/pytorch/vision/blob
168168
os.remove("temp.p")
169169
170170
def prepare_data_loaders(data_path):
171-
172171
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
173172
std=[0.229, 0.224, 0.225])
174173
dataset = torchvision.datasets.ImageNet(
175-
data_path, split="train",
176-
transforms.Compose([
177-
transforms.RandomResizedCrop(224),
178-
transforms.RandomHorizontalFlip(),
179-
transforms.ToTensor(),
180-
normalize,
181-
]))
174+
data_path, split="train", transform=transforms.Compose([
175+
transforms.RandomResizedCrop(224),
176+
transforms.RandomHorizontalFlip(),
177+
transforms.ToTensor(),
178+
normalize,
179+
]))
182180
dataset_test = torchvision.datasets.ImageNet(
183-
data_path, split="val",
184-
transforms.Compose([
185-
transforms.Resize(256),
186-
transforms.CenterCrop(224),
187-
transforms.ToTensor(),
188-
normalize,
189-
]))
181+
data_path, split="val", transform=transforms.Compose([
182+
transforms.Resize(256),
183+
transforms.CenterCrop(224),
184+
transforms.ToTensor(),
185+
normalize,
186+
]))
190187
191188
train_sampler = torch.utils.data.RandomSampler(dataset)
192189
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
@@ -239,7 +236,7 @@ of the observers for activation and weight. ``qconfig_dict`` is a dictionary wit
239236
.. code:: python
240237
241238
qconfig = {
242-
" : qconfig_global,
239+
"" : qconfig_global,
243240
"sub" : qconfig_sub,
244241
"sub.fc" : qconfig_fc,
245242
"sub.conv": None
@@ -294,6 +291,7 @@ Utility functions related to ``qconfig`` can be found in the `qconfig <https://g
294291

295292
.. code:: python
296293
294+
from torch.ao.quantization.quantize_fx import prepare_fx
297295
prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
298296
299297
prepare_fx folds BatchNorm modules into previous Conv2d modules, and insert observers
@@ -326,6 +324,7 @@ the statistics of the Tensors and we can later use this information to calculate
326324

327325
.. code:: python
328326
327+
from torch.ao.quantization.quantize_fx import convert_fx
329328
quantized_model = convert_fx(prepared_model)
330329
print(quantized_model)
331330
@@ -377,6 +376,7 @@ Note that ``fuse_fx`` only works in eval mode.
377376

378377
.. code:: python
379378
379+
from torch.ao.quantization.quantize_fx import fuse_fx
380380
fused = fuse_fx(float_model)
381381
382382
conv1_weight_after_fuse = fused.conv1[0].weight[0]

0 commit comments

Comments
 (0)