|
| 1 | +""" |
| 2 | +(prototype) FX Graph Mode Post Training Dynamic Quantization |
| 3 | +=========================================================== |
| 4 | +
|
| 5 | +**Author**: `Jerry Zhang <https://github.com/jerryzh168>`_ |
| 6 | +
|
| 7 | +This tutorial introduces the steps to do post training dynamic quantization in graph mode based on ``torch.fx``. |
| 8 | +We have a separate tutorial for FX Graph Mode Post Training Static Quantization(TODO: link), |
| 9 | +comparison between FX Graph Mode Quantization and Eager Mode Quantization can be found in the `quantization docs <https://pytorch.org/docs/stable/quantization.html>`_ (TODO: update link to section) |
| 10 | +
|
| 11 | +tldr; The FX Graph Mode API for dynamic quantization looks like the following: |
| 12 | +
|
| 13 | +.. code:: python |
| 14 | +
|
| 15 | + import torch |
| 16 | + from torch.quantization import default_dynamic_qconfig |
| 17 | + # Note that this is temporary, we'll expose these functions to torch.quantization after official releasee |
| 18 | + from torch.quantization.quantize_fx import prepare_fx, convert_fx |
| 19 | +
|
| 20 | + float_model.eval() |
| 21 | + qconfig = get_default_qconfig("fbgemm") |
| 22 | + qconfig_dict = {"": qconfig} |
| 23 | + prepared_model = prepare_fx(float_model, qconfig_dict) # fuse modules and insert observers |
| 24 | + # no calibration is required for dynamic quantization |
| 25 | + quantized_model = convert_fx(prepared_model) # convert the model to a dynamically quantized model |
| 26 | +
|
| 27 | +In this tutorial, we’ll apply dynamic quantization to an LSTM-based next word-prediction model, |
| 28 | +closely following the word language model from the PyTorch examples. |
| 29 | +We will copy the code from `Dynamic Quantization on an LSTM Word Language Model <https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html>`_ |
| 30 | +and omit the descriptions. |
| 31 | +
|
| 32 | +""" |
| 33 | + |
| 34 | + |
| 35 | +################################################### |
| 36 | +# 1. Define the Model, Download Data and Model |
| 37 | +# -------------------------------------------- |
| 38 | +# |
| 39 | +# Download the `data <https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip>`_ |
| 40 | +# and unzip to data folder |
| 41 | +# |
| 42 | +# .. code:: |
| 43 | +# |
| 44 | +# mkdir data |
| 45 | +# cd data |
| 46 | +# wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip |
| 47 | +# unzip wikitext-2-v1.zip |
| 48 | +# |
| 49 | +# Download model to the data folder: |
| 50 | +# |
| 51 | +# .. code:: |
| 52 | +# |
| 53 | +# wget https://s3.amazonaws.com/pytorch-tutorial-assets/word_language_model_quantize.pth |
| 54 | +# |
| 55 | +# Define the model: |
| 56 | + |
| 57 | +# imports |
| 58 | +import os |
| 59 | +from io import open |
| 60 | +import time |
| 61 | +import copy |
| 62 | + |
| 63 | +import torch |
| 64 | +import torch.nn as nn |
| 65 | +import torch.nn.functional as F |
| 66 | + |
| 67 | +# Model Definition |
| 68 | +class LSTMModel(nn.Module): |
| 69 | + """Container module with an encoder, a recurrent module, and a decoder.""" |
| 70 | + |
| 71 | + def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5): |
| 72 | + super(LSTMModel, self).__init__() |
| 73 | + self.drop = nn.Dropout(dropout) |
| 74 | + self.encoder = nn.Embedding(ntoken, ninp) |
| 75 | + self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout) |
| 76 | + self.decoder = nn.Linear(nhid, ntoken) |
| 77 | + |
| 78 | + self.init_weights() |
| 79 | + |
| 80 | + self.nhid = nhid |
| 81 | + self.nlayers = nlayers |
| 82 | + |
| 83 | + def init_weights(self): |
| 84 | + initrange = 0.1 |
| 85 | + self.encoder.weight.data.uniform_(-initrange, initrange) |
| 86 | + self.decoder.bias.data.zero_() |
| 87 | + self.decoder.weight.data.uniform_(-initrange, initrange) |
| 88 | + |
| 89 | + def forward(self, input, hidden): |
| 90 | + emb = self.drop(self.encoder(input)) |
| 91 | + output, hidden = self.rnn(emb, hidden) |
| 92 | + output = self.drop(output) |
| 93 | + decoded = self.decoder(output) |
| 94 | + return decoded, hidden |
| 95 | + |
| 96 | + |
| 97 | +def init_hidden(lstm_model, bsz): |
| 98 | + # get the weight tensor and create hidden layer in the same device |
| 99 | + weight = lstm_model.encoder.weight |
| 100 | + # get weight from quantized model |
| 101 | + if not isinstance(weight, torch.Tensor): |
| 102 | + weight = weight() |
| 103 | + device = weight.device |
| 104 | + nlayers = lstm_model.rnn.num_layers |
| 105 | + nhid = lstm_model.rnn.hidden_size |
| 106 | + return (torch.zeros(nlayers, bsz, nhid, device=device), |
| 107 | + torch.zeros(nlayers, bsz, nhid, device=device)) |
| 108 | + |
| 109 | + |
| 110 | +# Load Text Data |
| 111 | +class Dictionary(object): |
| 112 | + def __init__(self): |
| 113 | + self.word2idx = {} |
| 114 | + self.idx2word = [] |
| 115 | + |
| 116 | + def add_word(self, word): |
| 117 | + if word not in self.word2idx: |
| 118 | + self.idx2word.append(word) |
| 119 | + self.word2idx[word] = len(self.idx2word) - 1 |
| 120 | + return self.word2idx[word] |
| 121 | + |
| 122 | + def __len__(self): |
| 123 | + return len(self.idx2word) |
| 124 | + |
| 125 | + |
| 126 | +class Corpus(object): |
| 127 | + def __init__(self, path): |
| 128 | + self.dictionary = Dictionary() |
| 129 | + self.train = self.tokenize(os.path.join(path, 'wiki.train.tokens')) |
| 130 | + self.valid = self.tokenize(os.path.join(path, 'wiki.valid.tokens')) |
| 131 | + self.test = self.tokenize(os.path.join(path, 'wiki.test.tokens')) |
| 132 | + |
| 133 | + def tokenize(self, path): |
| 134 | + """Tokenizes a text file.""" |
| 135 | + assert os.path.exists(path) |
| 136 | + # Add words to the dictionary |
| 137 | + with open(path, 'r', encoding="utf8") as f: |
| 138 | + for line in f: |
| 139 | + words = line.split() + ['<eos>'] |
| 140 | + for word in words: |
| 141 | + self.dictionary.add_word(word) |
| 142 | + |
| 143 | + # Tokenize file content |
| 144 | + with open(path, 'r', encoding="utf8") as f: |
| 145 | + idss = [] |
| 146 | + for line in f: |
| 147 | + words = line.split() + ['<eos>'] |
| 148 | + ids = [] |
| 149 | + for word in words: |
| 150 | + ids.append(self.dictionary.word2idx[word]) |
| 151 | + idss.append(torch.tensor(ids).type(torch.int64)) |
| 152 | + ids = torch.cat(idss) |
| 153 | + |
| 154 | + return ids |
| 155 | + |
| 156 | +model_data_filepath = 'data/' |
| 157 | + |
| 158 | +corpus = Corpus(model_data_filepath + 'wikitext-2') |
| 159 | + |
| 160 | +ntokens = len(corpus.dictionary) |
| 161 | + |
| 162 | +# Load Pretrained Model |
| 163 | +model = LSTMModel( |
| 164 | + ntoken = ntokens, |
| 165 | + ninp = 512, |
| 166 | + nhid = 256, |
| 167 | + nlayers = 5, |
| 168 | +) |
| 169 | + |
| 170 | +model.load_state_dict( |
| 171 | + torch.load( |
| 172 | + model_data_filepath + 'word_language_model_quantize.pth', |
| 173 | + map_location=torch.device('cpu') |
| 174 | + ) |
| 175 | + ) |
| 176 | + |
| 177 | +model.eval() |
| 178 | +print(model) |
| 179 | + |
| 180 | +bptt = 25 |
| 181 | +criterion = nn.CrossEntropyLoss() |
| 182 | +eval_batch_size = 1 |
| 183 | + |
| 184 | +# create test data set |
| 185 | +def batchify(data, bsz): |
| 186 | + # Work out how cleanly we can divide the dataset into bsz parts. |
| 187 | + nbatch = data.size(0) // bsz |
| 188 | + # Trim off any extra elements that wouldn't cleanly fit (remainders). |
| 189 | + data = data.narrow(0, 0, nbatch * bsz) |
| 190 | + # Evenly divide the data across the bsz batches. |
| 191 | + return data.view(bsz, -1).t().contiguous() |
| 192 | + |
| 193 | +test_data = batchify(corpus.test, eval_batch_size) |
| 194 | + |
| 195 | +# Evaluation functions |
| 196 | +def get_batch(source, i): |
| 197 | + seq_len = min(bptt, len(source) - 1 - i) |
| 198 | + data = source[i:i+seq_len] |
| 199 | + target = source[i+1:i+1+seq_len].reshape(-1) |
| 200 | + return data, target |
| 201 | + |
| 202 | +def repackage_hidden(h): |
| 203 | + """Wraps hidden states in new Tensors, to detach them from their history.""" |
| 204 | + |
| 205 | + if isinstance(h, torch.Tensor): |
| 206 | + return h.detach() |
| 207 | + else: |
| 208 | + return tuple(repackage_hidden(v) for v in h) |
| 209 | + |
| 210 | +def evaluate(model_, data_source): |
| 211 | + # Turn on evaluation mode which disables dropout. |
| 212 | + model_.eval() |
| 213 | + total_loss = 0. |
| 214 | + hidden = init_hidden(model_, eval_batch_size) |
| 215 | + with torch.no_grad(): |
| 216 | + for i in range(0, data_source.size(0) - 1, bptt): |
| 217 | + data, targets = get_batch(data_source, i) |
| 218 | + output, hidden = model_(data, hidden) |
| 219 | + hidden = repackage_hidden(hidden) |
| 220 | + output_flat = output.view(-1, ntokens) |
| 221 | + total_loss += len(data) * criterion(output_flat, targets).item() |
| 222 | + return total_loss / (len(data_source) - 1) |
| 223 | + |
| 224 | +###################################################################### |
| 225 | +# 2. Post Training Dynamic Quantization |
| 226 | +# ------------------------------------- |
| 227 | +# Now we can dynamically quantize the model. |
| 228 | +# We can use the same function as post training static quantization but with a dynamic qconfig. |
| 229 | + |
| 230 | +from torch.quantization.quantize_fx import prepare_fx, convert_fx |
| 231 | +from torch.quantization import default_dynamic_qconfig, float_qparams_weight_only_qconfig |
| 232 | + |
| 233 | +# Full docs for supported qconfig for floating point modules/ops can be found in docs for quantization (TODO: link) |
| 234 | +# Full docs for qconfig_dict can be found in the documents of prepare_fx (TODO: link) |
| 235 | +qconfig_dict = { |
| 236 | + "object_type": [ |
| 237 | + (nn.Embedding, float_qparams_weight_only_qconfig), |
| 238 | + (nn.LSTM, default_dynamic_qconfig), |
| 239 | + (nn.Linear, default_dynamic_qconfig) |
| 240 | + ] |
| 241 | +} |
| 242 | +# Deepcopying the original model because quantization api changes the model inplace and we want |
| 243 | +# to keep the original model for future comparison |
| 244 | +model_to_quantize = copy.deepcopy(model) |
| 245 | +prepared_model = prepare_fx(model_to_quantize, qconfig_dict) |
| 246 | +print("prepared model:", prepared_model) |
| 247 | +quantized_model = convert_fx(prepared_model) |
| 248 | +print("quantized model", quantized_model) |
| 249 | + |
| 250 | + |
| 251 | +###################################################################### |
| 252 | +# For dynamically quantized objects, we didn't do anything in ``prepare_fx`` for modules, |
| 253 | +# but will insert observers for weight for dynamically quantizable forunctionals and torch ops. |
| 254 | +# We also fuse the modules like Conv + Bn, Linear + ReLU. |
| 255 | +# |
| 256 | +# In convert we'll convert the float modules to dynamically quantized modules and |
| 257 | +# convert float ops to dynamically quantized ops. We can see in the example model, |
| 258 | +# ``nn.Embedding``, ``nn.Linear`` and ``nn.LSTM`` are dynamically quantized. |
| 259 | +# |
| 260 | +# Now we can compare the size and runtime of the quantized model. |
| 261 | + |
| 262 | +def print_size_of_model(model): |
| 263 | + torch.save(model.state_dict(), "temp.p") |
| 264 | + print('Size (MB):', os.path.getsize("temp.p")/1e6) |
| 265 | + os.remove('temp.p') |
| 266 | + |
| 267 | +print_size_of_model(model) |
| 268 | +print_size_of_model(quantized_model) |
| 269 | + |
| 270 | +###################################################################### |
| 271 | +# There is a 4x size reduction because we quantized all the weights |
| 272 | +# in the model (nn.Embedding, nn.Linear and nn.LSTM) from float (4 bytes) to quantized int (1 byte). |
| 273 | + |
| 274 | +torch.set_num_threads(1) |
| 275 | + |
| 276 | +def time_model_evaluation(model, test_data): |
| 277 | + s = time.time() |
| 278 | + loss = evaluate(model, test_data) |
| 279 | + elapsed = time.time() - s |
| 280 | + print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed)) |
| 281 | + |
| 282 | +time_model_evaluation(model, test_data) |
| 283 | +time_model_evaluation(quantized_model, test_data) |
| 284 | + |
| 285 | +##################################################################### |
| 286 | +# There is a roughly 2x speedup for this model. Also note that the speedup |
| 287 | +# may vary depending on model, device, build, input batch sizes, threading etc. |
| 288 | +# |
| 289 | +# 3. Conclusion |
| 290 | +# ------------- |
| 291 | +# This tutorial introduces the api for post training dynamic quantization in FX Graph Mode, |
| 292 | +# which dynamically quantizes the same modules as Eager Mode Quantization. |
0 commit comments