Skip to content

Commit c5fe766

Browse files
authored
Add FX Graph Mode Post Training Dynamic Quantization Tutorial (#1283)
* Add FX Graph Mode Post Training Dynamic Quantization Tutorial * address comments * Add FX Graph Mode Post Training Dynamic Quantization Tutorial * address comments
1 parent b3dd09e commit c5fe766

File tree

2 files changed

+297
-1
lines changed

2 files changed

+297
-1
lines changed

prototype_source/README.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,8 @@ Prototype Tutorials
2626

2727
7. fx_graph_mode_static_quantization.py
2828
FX Graph Mode Post Training Static Quantization
29-
https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static_tutorial.html
29+
https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static_tutorial.html
30+
31+
8. fx_graph_mode_dynamic_quantization.py
32+
FX Graph Mode Post Training Dynamic Quantization
33+
https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_dynamic_tutorial.html
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
"""
2+
(prototype) FX Graph Mode Post Training Dynamic Quantization
3+
===========================================================
4+
5+
**Author**: `Jerry Zhang <https://github.com/jerryzh168>`_
6+
7+
This tutorial introduces the steps to do post training dynamic quantization in graph mode based on ``torch.fx``.
8+
We have a separate tutorial for FX Graph Mode Post Training Static Quantization(TODO: link),
9+
comparison between FX Graph Mode Quantization and Eager Mode Quantization can be found in the `quantization docs <https://pytorch.org/docs/stable/quantization.html>`_ (TODO: update link to section)
10+
11+
tldr; The FX Graph Mode API for dynamic quantization looks like the following:
12+
13+
.. code:: python
14+
15+
import torch
16+
from torch.quantization import default_dynamic_qconfig
17+
# Note that this is temporary, we'll expose these functions to torch.quantization after official releasee
18+
from torch.quantization.quantize_fx import prepare_fx, convert_fx
19+
20+
float_model.eval()
21+
qconfig = get_default_qconfig("fbgemm")
22+
qconfig_dict = {"": qconfig}
23+
prepared_model = prepare_fx(float_model, qconfig_dict) # fuse modules and insert observers
24+
# no calibration is required for dynamic quantization
25+
quantized_model = convert_fx(prepared_model) # convert the model to a dynamically quantized model
26+
27+
In this tutorial, we’ll apply dynamic quantization to an LSTM-based next word-prediction model,
28+
closely following the word language model from the PyTorch examples.
29+
We will copy the code from `Dynamic Quantization on an LSTM Word Language Model <https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html>`_
30+
and omit the descriptions.
31+
32+
"""
33+
34+
35+
###################################################
36+
# 1. Define the Model, Download Data and Model
37+
# --------------------------------------------
38+
#
39+
# Download the `data <https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip>`_
40+
# and unzip to data folder
41+
#
42+
# .. code::
43+
#
44+
# mkdir data
45+
# cd data
46+
# wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
47+
# unzip wikitext-2-v1.zip
48+
#
49+
# Download model to the data folder:
50+
#
51+
# .. code::
52+
#
53+
# wget https://s3.amazonaws.com/pytorch-tutorial-assets/word_language_model_quantize.pth
54+
#
55+
# Define the model:
56+
57+
# imports
58+
import os
59+
from io import open
60+
import time
61+
import copy
62+
63+
import torch
64+
import torch.nn as nn
65+
import torch.nn.functional as F
66+
67+
# Model Definition
68+
class LSTMModel(nn.Module):
69+
"""Container module with an encoder, a recurrent module, and a decoder."""
70+
71+
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
72+
super(LSTMModel, self).__init__()
73+
self.drop = nn.Dropout(dropout)
74+
self.encoder = nn.Embedding(ntoken, ninp)
75+
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
76+
self.decoder = nn.Linear(nhid, ntoken)
77+
78+
self.init_weights()
79+
80+
self.nhid = nhid
81+
self.nlayers = nlayers
82+
83+
def init_weights(self):
84+
initrange = 0.1
85+
self.encoder.weight.data.uniform_(-initrange, initrange)
86+
self.decoder.bias.data.zero_()
87+
self.decoder.weight.data.uniform_(-initrange, initrange)
88+
89+
def forward(self, input, hidden):
90+
emb = self.drop(self.encoder(input))
91+
output, hidden = self.rnn(emb, hidden)
92+
output = self.drop(output)
93+
decoded = self.decoder(output)
94+
return decoded, hidden
95+
96+
97+
def init_hidden(lstm_model, bsz):
98+
# get the weight tensor and create hidden layer in the same device
99+
weight = lstm_model.encoder.weight
100+
# get weight from quantized model
101+
if not isinstance(weight, torch.Tensor):
102+
weight = weight()
103+
device = weight.device
104+
nlayers = lstm_model.rnn.num_layers
105+
nhid = lstm_model.rnn.hidden_size
106+
return (torch.zeros(nlayers, bsz, nhid, device=device),
107+
torch.zeros(nlayers, bsz, nhid, device=device))
108+
109+
110+
# Load Text Data
111+
class Dictionary(object):
112+
def __init__(self):
113+
self.word2idx = {}
114+
self.idx2word = []
115+
116+
def add_word(self, word):
117+
if word not in self.word2idx:
118+
self.idx2word.append(word)
119+
self.word2idx[word] = len(self.idx2word) - 1
120+
return self.word2idx[word]
121+
122+
def __len__(self):
123+
return len(self.idx2word)
124+
125+
126+
class Corpus(object):
127+
def __init__(self, path):
128+
self.dictionary = Dictionary()
129+
self.train = self.tokenize(os.path.join(path, 'wiki.train.tokens'))
130+
self.valid = self.tokenize(os.path.join(path, 'wiki.valid.tokens'))
131+
self.test = self.tokenize(os.path.join(path, 'wiki.test.tokens'))
132+
133+
def tokenize(self, path):
134+
"""Tokenizes a text file."""
135+
assert os.path.exists(path)
136+
# Add words to the dictionary
137+
with open(path, 'r', encoding="utf8") as f:
138+
for line in f:
139+
words = line.split() + ['<eos>']
140+
for word in words:
141+
self.dictionary.add_word(word)
142+
143+
# Tokenize file content
144+
with open(path, 'r', encoding="utf8") as f:
145+
idss = []
146+
for line in f:
147+
words = line.split() + ['<eos>']
148+
ids = []
149+
for word in words:
150+
ids.append(self.dictionary.word2idx[word])
151+
idss.append(torch.tensor(ids).type(torch.int64))
152+
ids = torch.cat(idss)
153+
154+
return ids
155+
156+
model_data_filepath = 'data/'
157+
158+
corpus = Corpus(model_data_filepath + 'wikitext-2')
159+
160+
ntokens = len(corpus.dictionary)
161+
162+
# Load Pretrained Model
163+
model = LSTMModel(
164+
ntoken = ntokens,
165+
ninp = 512,
166+
nhid = 256,
167+
nlayers = 5,
168+
)
169+
170+
model.load_state_dict(
171+
torch.load(
172+
model_data_filepath + 'word_language_model_quantize.pth',
173+
map_location=torch.device('cpu')
174+
)
175+
)
176+
177+
model.eval()
178+
print(model)
179+
180+
bptt = 25
181+
criterion = nn.CrossEntropyLoss()
182+
eval_batch_size = 1
183+
184+
# create test data set
185+
def batchify(data, bsz):
186+
# Work out how cleanly we can divide the dataset into bsz parts.
187+
nbatch = data.size(0) // bsz
188+
# Trim off any extra elements that wouldn't cleanly fit (remainders).
189+
data = data.narrow(0, 0, nbatch * bsz)
190+
# Evenly divide the data across the bsz batches.
191+
return data.view(bsz, -1).t().contiguous()
192+
193+
test_data = batchify(corpus.test, eval_batch_size)
194+
195+
# Evaluation functions
196+
def get_batch(source, i):
197+
seq_len = min(bptt, len(source) - 1 - i)
198+
data = source[i:i+seq_len]
199+
target = source[i+1:i+1+seq_len].reshape(-1)
200+
return data, target
201+
202+
def repackage_hidden(h):
203+
"""Wraps hidden states in new Tensors, to detach them from their history."""
204+
205+
if isinstance(h, torch.Tensor):
206+
return h.detach()
207+
else:
208+
return tuple(repackage_hidden(v) for v in h)
209+
210+
def evaluate(model_, data_source):
211+
# Turn on evaluation mode which disables dropout.
212+
model_.eval()
213+
total_loss = 0.
214+
hidden = init_hidden(model_, eval_batch_size)
215+
with torch.no_grad():
216+
for i in range(0, data_source.size(0) - 1, bptt):
217+
data, targets = get_batch(data_source, i)
218+
output, hidden = model_(data, hidden)
219+
hidden = repackage_hidden(hidden)
220+
output_flat = output.view(-1, ntokens)
221+
total_loss += len(data) * criterion(output_flat, targets).item()
222+
return total_loss / (len(data_source) - 1)
223+
224+
######################################################################
225+
# 2. Post Training Dynamic Quantization
226+
# -------------------------------------
227+
# Now we can dynamically quantize the model.
228+
# We can use the same function as post training static quantization but with a dynamic qconfig.
229+
230+
from torch.quantization.quantize_fx import prepare_fx, convert_fx
231+
from torch.quantization import default_dynamic_qconfig, float_qparams_weight_only_qconfig
232+
233+
# Full docs for supported qconfig for floating point modules/ops can be found in docs for quantization (TODO: link)
234+
# Full docs for qconfig_dict can be found in the documents of prepare_fx (TODO: link)
235+
qconfig_dict = {
236+
"object_type": [
237+
(nn.Embedding, float_qparams_weight_only_qconfig),
238+
(nn.LSTM, default_dynamic_qconfig),
239+
(nn.Linear, default_dynamic_qconfig)
240+
]
241+
}
242+
# Deepcopying the original model because quantization api changes the model inplace and we want
243+
# to keep the original model for future comparison
244+
model_to_quantize = copy.deepcopy(model)
245+
prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
246+
print("prepared model:", prepared_model)
247+
quantized_model = convert_fx(prepared_model)
248+
print("quantized model", quantized_model)
249+
250+
251+
######################################################################
252+
# For dynamically quantized objects, we didn't do anything in ``prepare_fx`` for modules,
253+
# but will insert observers for weight for dynamically quantizable forunctionals and torch ops.
254+
# We also fuse the modules like Conv + Bn, Linear + ReLU.
255+
#
256+
# In convert we'll convert the float modules to dynamically quantized modules and
257+
# convert float ops to dynamically quantized ops. We can see in the example model,
258+
# ``nn.Embedding``, ``nn.Linear`` and ``nn.LSTM`` are dynamically quantized.
259+
#
260+
# Now we can compare the size and runtime of the quantized model.
261+
262+
def print_size_of_model(model):
263+
torch.save(model.state_dict(), "temp.p")
264+
print('Size (MB):', os.path.getsize("temp.p")/1e6)
265+
os.remove('temp.p')
266+
267+
print_size_of_model(model)
268+
print_size_of_model(quantized_model)
269+
270+
######################################################################
271+
# There is a 4x size reduction because we quantized all the weights
272+
# in the model (nn.Embedding, nn.Linear and nn.LSTM) from float (4 bytes) to quantized int (1 byte).
273+
274+
torch.set_num_threads(1)
275+
276+
def time_model_evaluation(model, test_data):
277+
s = time.time()
278+
loss = evaluate(model, test_data)
279+
elapsed = time.time() - s
280+
print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))
281+
282+
time_model_evaluation(model, test_data)
283+
time_model_evaluation(quantized_model, test_data)
284+
285+
#####################################################################
286+
# There is a roughly 2x speedup for this model. Also note that the speedup
287+
# may vary depending on model, device, build, input batch sizes, threading etc.
288+
#
289+
# 3. Conclusion
290+
# -------------
291+
# This tutorial introduces the api for post training dynamic quantization in FX Graph Mode,
292+
# which dynamically quantizes the same modules as Eager Mode Quantization.

0 commit comments

Comments
 (0)