|
| 1 | +Ease-of-use quantization for PyTorch with Intel® Neural Compressor |
| 2 | +================================================================== |
| 3 | + |
| 4 | +Overview |
| 5 | +-------- |
| 6 | + |
| 7 | +Most deep learning applications are using 32-bits of floating-point precision |
| 8 | +for inference. But low precision data types, especially int8, are getting more |
| 9 | +focus due to significant performance boost. One of the essential concerns on |
| 10 | +adopting low precision is how to easily mitigate the possible accuracy loss |
| 11 | +and reach predefined accuracy requirement. |
| 12 | + |
| 13 | +Intel® Neural Compressor aims to address the aforementioned concern by extending |
| 14 | +PyTorch with accuracy-driven automatic tuning strategies to help user quickly find |
| 15 | +out the best quantized model on Intel hardware, including Intel Deep Learning |
| 16 | +Boost (`Intel DL Boost <https://www.intel.com/content/www/us/en/artificial-intelligence/deep-learning-boost.html>`_) |
| 17 | +and Intel Advanced Matrix Extensions (`Intel AMX <https://www.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-amx-instructions/intrinsics-for-amx-tile-instructions.html>`_). |
| 18 | + |
| 19 | +Intel® Neural Compressor has been released as an open-source project |
| 20 | +at `Github <https://github.com/intel/neural-compressor>`_. |
| 21 | + |
| 22 | +Features |
| 23 | +-------- |
| 24 | + |
| 25 | +- **Ease-of-use Python API:** Intel® Neural Compressor provides simple frontend |
| 26 | + Python APIs and utilities for users to do neural network compression with few |
| 27 | + line code changes. |
| 28 | + Typically, only 5 to 6 clauses are required to be added to the original code. |
| 29 | + |
| 30 | +- **Quantization:** Intel® Neural Compressor supports accuracy-driven automatic |
| 31 | + tuning process on post-training static quantization, post-training dynamic |
| 32 | + quantization, and quantization-aware training on PyTorch fx graph mode and |
| 33 | + eager model. |
| 34 | + |
| 35 | +*This tutorial mainly focuses on the quantization part. As for how to use Intel® |
| 36 | +Neural Compressor to do pruning and distillation, please refer to corresponding |
| 37 | +documents in the Intel® Neural Compressor github repo.* |
| 38 | + |
| 39 | +Getting Started |
| 40 | +--------------- |
| 41 | + |
| 42 | +Installation |
| 43 | +~~~~~~~~~~~~ |
| 44 | + |
| 45 | +.. code:: bash |
| 46 | +
|
| 47 | + # install stable version from pip |
| 48 | + pip install neural-compressor |
| 49 | +
|
| 50 | + # install nightly version from pip |
| 51 | + pip install -i https://test.pypi.org/simple/ neural-compressor |
| 52 | +
|
| 53 | + # install stable version from from conda |
| 54 | + conda install neural-compressor -c conda-forge -c intel |
| 55 | +
|
| 56 | +*Supported python versions are 3.6 or 3.7 or 3.8 or 3.9* |
| 57 | + |
| 58 | +Usages |
| 59 | +~~~~~~ |
| 60 | + |
| 61 | +Minor code changes are required for users to get started with Intel® Neural Compressor |
| 62 | +quantization API. Both PyTorch fx graph mode and eager mode are supported. |
| 63 | + |
| 64 | +Intel® Neural Compressor takes a FP32 model and a yaml configuration file as inputs. |
| 65 | +To construct the quantization process, users can either specify the below settings via |
| 66 | +the yaml configuration file or python APIs: |
| 67 | + |
| 68 | +1. Calibration Dataloader (Needed for static quantization) |
| 69 | +2. Evaluation Dataloader |
| 70 | +3. Evaluation Metric |
| 71 | + |
| 72 | +Intel® Neural Compressor supports some popular dataloaders and evaluation metrics. For |
| 73 | +how to configure them in yaml configuration file, user could refer to `Built-in Datasets |
| 74 | +<https://github.com/intel/neural-compressor/blob/master/docs/dataset.md>`_. |
| 75 | + |
| 76 | +If users want to use a self-developed dataloader or evaluation metric, Intel® Neural |
| 77 | +Compressor supports this by the registration of customized dataloader/metric using python code. |
| 78 | + |
| 79 | +For the yaml configuration file format please refer to `yaml template |
| 80 | +<https://github.com/intel/neural-compressor/blob/master/neural_compressor/template/ptq.yaml>`_. |
| 81 | + |
| 82 | +The code changes that are required for *Intel® Neural Compressor* are highlighted with |
| 83 | +comments in the line above. |
| 84 | + |
| 85 | +Model |
| 86 | +^^^^^ |
| 87 | + |
| 88 | +In this tutorial, the LeNet model is used to demonstrate how to deal with *Intel® Neural Compressor*. |
| 89 | + |
| 90 | +.. code-block:: python3 |
| 91 | +
|
| 92 | + # main.py |
| 93 | + import torch |
| 94 | + import torch.nn as nn |
| 95 | + import torch.nn.functional as F |
| 96 | +
|
| 97 | + # LeNet Model definition |
| 98 | + class Net(nn.Module): |
| 99 | + def __init__(self): |
| 100 | + super(Net, self).__init__() |
| 101 | + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) |
| 102 | + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) |
| 103 | + self.conv2_drop = nn.Dropout2d() |
| 104 | + self.fc1 = nn.Linear(320, 50) |
| 105 | + self.fc1_drop = nn.Dropout() |
| 106 | + self.fc2 = nn.Linear(50, 10) |
| 107 | +
|
| 108 | + def forward(self, x): |
| 109 | + x = F.relu(F.max_pool2d(self.conv1(x), 2)) |
| 110 | + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) |
| 111 | + x = x.reshape(-1, 320) |
| 112 | + x = F.relu(self.fc1(x)) |
| 113 | + x = self.fc1_drop(x) |
| 114 | + x = self.fc2(x) |
| 115 | + return F.log_softmax(x, dim=1) |
| 116 | +
|
| 117 | + model = Net() |
| 118 | + model.load_state_dict(torch.load('./lenet_mnist_model.pth')) |
| 119 | +
|
| 120 | +The pretrained model weight `lenet_mnist_model.pth` comes from |
| 121 | +`here <https://drive.google.com/drive/folders/1fn83DF14tWmit0RTKWRhPq5uVXt73e0h?usp=sharing>`_. |
| 122 | + |
| 123 | +Accuracy driven quantization |
| 124 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 125 | + |
| 126 | +Intel® Neural Compressor supports accuracy-driven automatic tuning to generate the optimal |
| 127 | +int8 model which meets a predefined accuracy goal. |
| 128 | + |
| 129 | +Below is an example of how to quantize a simple network on PyTorch |
| 130 | +`FX graph mode <https://pytorch.org/docs/stable/fx.html>`_ by auto-tuning. |
| 131 | + |
| 132 | +.. code-block:: yaml |
| 133 | +
|
| 134 | + # conf.yaml |
| 135 | + model: |
| 136 | + name: LeNet |
| 137 | + framework: pytorch_fx |
| 138 | +
|
| 139 | + evaluation: |
| 140 | + accuracy: |
| 141 | + metric: |
| 142 | + topk: 1 |
| 143 | +
|
| 144 | + tuning: |
| 145 | + accuracy_criterion: |
| 146 | + relative: 0.01 |
| 147 | +
|
| 148 | +.. code-block:: python3 |
| 149 | +
|
| 150 | + # main.py |
| 151 | + model.eval() |
| 152 | +
|
| 153 | + from torchvision import datasets, transforms |
| 154 | + test_loader = torch.utils.data.DataLoader( |
| 155 | + datasets.MNIST('./data', train=False, download=True, |
| 156 | + transform=transforms.Compose([ |
| 157 | + transforms.ToTensor(), |
| 158 | + ])), |
| 159 | + batch_size=1) |
| 160 | +
|
| 161 | + # launch code for Intel® Neural Compressor |
| 162 | + from neural_compressor.experimental import Quantization |
| 163 | + quantizer = Quantization("./conf.yaml") |
| 164 | + quantizer.model = model |
| 165 | + quantizer.calib_dataloader = test_loader |
| 166 | + quantizer.eval_dataloader = test_loader |
| 167 | + q_model = quantizer() |
| 168 | + q_model.save('./output') |
| 169 | +
|
| 170 | +In the `conf.yaml` file, the built-in metric `top1` of Intel® Neural Compressor is specified as |
| 171 | +the evaluation method, and `1%` relative accuracy loss is set as the accuracy target for auto-tuning. |
| 172 | +Intel® Neural Compressor will traverse all possible quantization config combinations on per-op level |
| 173 | +to find out the optimal int8 model that reaches the predefined accuracy target. |
| 174 | + |
| 175 | +Besides those built-in metrics, Intel® Neural Compressor also supports customized metric through |
| 176 | +python code: |
| 177 | + |
| 178 | +.. code-block:: yaml |
| 179 | +
|
| 180 | + # conf.yaml |
| 181 | + model: |
| 182 | + name: LeNet |
| 183 | + framework: pytorch_fx |
| 184 | +
|
| 185 | + tuning: |
| 186 | + accuracy_criterion: |
| 187 | + relative: 0.01 |
| 188 | +
|
| 189 | +.. code-block:: python3 |
| 190 | +
|
| 191 | + # main.py |
| 192 | + model.eval() |
| 193 | +
|
| 194 | + from torchvision import datasets, transforms |
| 195 | + test_loader = torch.utils.data.DataLoader( |
| 196 | + datasets.MNIST('./data', train=False, download=True, |
| 197 | + transform=transforms.Compose([ |
| 198 | + transforms.ToTensor(), |
| 199 | + ])), |
| 200 | + batch_size=1) |
| 201 | +
|
| 202 | + # define a customized metric |
| 203 | + class Top1Metric(object): |
| 204 | + def __init__(self): |
| 205 | + self.correct = 0 |
| 206 | + def update(self, output, label): |
| 207 | + pred = output.argmax(dim=1, keepdim=True) |
| 208 | + self.correct += pred.eq(label.view_as(pred)).sum().item() |
| 209 | + def reset(self): |
| 210 | + self.correct = 0 |
| 211 | + def result(self): |
| 212 | + return 100. * self.correct / len(test_loader.dataset) |
| 213 | +
|
| 214 | + # launch code for Intel® Neural Compressor |
| 215 | + from neural_compressor.experimental import Quantization |
| 216 | + quantizer = Quantization("./conf.yaml") |
| 217 | + quantizer.model = model |
| 218 | + quantizer.calib_dataloader = test_loader |
| 219 | + quantizer.eval_dataloader = test_loader |
| 220 | + quantizer.metric = Top1Metric() |
| 221 | + q_model = quantizer() |
| 222 | + q_model.save('./output') |
| 223 | +
|
| 224 | +In the above example, a `class` which contains `update()` and `result()` function is implemented |
| 225 | +to record per mini-batch result and calculate final accuracy at the end. |
| 226 | + |
| 227 | +Quantization aware training |
| 228 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 229 | + |
| 230 | +Besides post-training static quantization and post-training dynamic quantization, Intel® Neural |
| 231 | +Compressor supports quantization-aware training with an accuracy-driven automatic tuning mechanism. |
| 232 | + |
| 233 | +Below is an example of how to do quantization aware training on a simple network on PyTorch |
| 234 | +`FX graph mode <https://pytorch.org/docs/stable/fx.html>`_. |
| 235 | + |
| 236 | +.. code-block:: yaml |
| 237 | +
|
| 238 | + # conf.yaml |
| 239 | + model: |
| 240 | + name: LeNet |
| 241 | + framework: pytorch_fx |
| 242 | +
|
| 243 | + quantization: |
| 244 | + approach: quant_aware_training |
| 245 | +
|
| 246 | + evaluation: |
| 247 | + accuracy: |
| 248 | + metric: |
| 249 | + topk: 1 |
| 250 | +
|
| 251 | + tuning: |
| 252 | + accuracy_criterion: |
| 253 | + relative: 0.01 |
| 254 | +
|
| 255 | +.. code-block:: python3 |
| 256 | +
|
| 257 | + # main.py |
| 258 | + model.eval() |
| 259 | +
|
| 260 | + from torchvision import datasets, transforms |
| 261 | + train_loader = torch.utils.data.DataLoader( |
| 262 | + datasets.MNIST('./data', train=True, download=True, |
| 263 | + transform=transforms.Compose([ |
| 264 | + transforms.ToTensor(), |
| 265 | + transforms.Normalize((0.1307,), (0.3081,)) |
| 266 | + ])), |
| 267 | + batch_size=64, shuffle=True) |
| 268 | + test_loader = torch.utils.data.DataLoader( |
| 269 | + datasets.MNIST('./data', train=False, download=True, |
| 270 | + transform=transforms.Compose([ |
| 271 | + transforms.ToTensor(), |
| 272 | + transforms.Normalize((0.1307,), (0.3081,)) |
| 273 | + ])), |
| 274 | + batch_size=1) |
| 275 | +
|
| 276 | + import torch.optim as optim |
| 277 | + optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.1) |
| 278 | +
|
| 279 | + def training_func(model): |
| 280 | + model.train() |
| 281 | + for epoch in range(1, 3): |
| 282 | + for batch_idx, (data, target) in enumerate(train_loader): |
| 283 | + optimizer.zero_grad() |
| 284 | + output = model(data) |
| 285 | + loss = F.nll_loss(output, target) |
| 286 | + loss.backward() |
| 287 | + optimizer.step() |
| 288 | + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( |
| 289 | + epoch, batch_idx * len(data), len(train_loader.dataset), |
| 290 | + 100. * batch_idx / len(train_loader), loss.item())) |
| 291 | +
|
| 292 | + # launch code for Intel® Neural Compressor |
| 293 | + from neural_compressor.experimental import Quantization |
| 294 | + quantizer = Quantization("./conf.yaml") |
| 295 | + quantizer.model = model |
| 296 | + quantizer.q_func = training_func |
| 297 | + quantizer.eval_dataloader = test_loader |
| 298 | + q_model = quantizer() |
| 299 | + q_model.save('./output') |
| 300 | +
|
| 301 | +Performance only quantization |
| 302 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 303 | + |
| 304 | +Intel® Neural Compressor supports directly yielding int8 model with dummy dataset for the |
| 305 | +performance benchmarking purpose. |
| 306 | + |
| 307 | +Below is an example of how to quantize a simple network on PyTorch |
| 308 | +`FX graph mode <https://pytorch.org/docs/stable/fx.html>`_ with a dummy dataset. |
| 309 | + |
| 310 | +.. code-block:: yaml |
| 311 | +
|
| 312 | + # conf.yaml |
| 313 | + model: |
| 314 | + name: lenet |
| 315 | + framework: pytorch_fx |
| 316 | +
|
| 317 | +.. code-block:: python3 |
| 318 | +
|
| 319 | + # main.py |
| 320 | + model.eval() |
| 321 | +
|
| 322 | + # launch code for Intel® Neural Compressor |
| 323 | + from neural_compressor.experimental import Quantization, common |
| 324 | + from neural_compressor.experimental.data.datasets.dummy_dataset import DummyDataset |
| 325 | + quantizer = Quantization("./conf.yaml") |
| 326 | + quantizer.model = model |
| 327 | + quantizer.calib_dataloader = common.DataLoader(DummyDataset([(1, 1, 28, 28)])) |
| 328 | + q_model = quantizer() |
| 329 | + q_model.save('./output') |
| 330 | +
|
| 331 | +Quantization outputs |
| 332 | +~~~~~~~~~~~~~~~~~~~~ |
| 333 | + |
| 334 | +Users could know how many ops get quantized from log printed by Intel® Neural Compressor |
| 335 | +like below: |
| 336 | + |
| 337 | +:: |
| 338 | + |
| 339 | + 2021-12-08 14:58:35 [INFO] |********Mixed Precision Statistics*******| |
| 340 | + 2021-12-08 14:58:35 [INFO] +------------------------+--------+-------+ |
| 341 | + 2021-12-08 14:58:35 [INFO] | Op Type | Total | INT8 | |
| 342 | + 2021-12-08 14:58:35 [INFO] +------------------------+--------+-------+ |
| 343 | + 2021-12-08 14:58:35 [INFO] | quantize_per_tensor | 2 | 2 | |
| 344 | + 2021-12-08 14:58:35 [INFO] | Conv2d | 2 | 2 | |
| 345 | + 2021-12-08 14:58:35 [INFO] | max_pool2d | 1 | 1 | |
| 346 | + 2021-12-08 14:58:35 [INFO] | relu | 1 | 1 | |
| 347 | + 2021-12-08 14:58:35 [INFO] | dequantize | 2 | 2 | |
| 348 | + 2021-12-08 14:58:35 [INFO] | LinearReLU | 1 | 1 | |
| 349 | + 2021-12-08 14:58:35 [INFO] | Linear | 1 | 1 | |
| 350 | + 2021-12-08 14:58:35 [INFO] +------------------------+--------+-------+ |
| 351 | + |
| 352 | +The quantized model will be generated under `./output` directory, in which there are two files: |
| 353 | +1. best_configure.yaml |
| 354 | +2. best_model_weights.pt |
| 355 | + |
| 356 | +The first file contains the quantization configurations of each op, the second file contains |
| 357 | +int8 weights and zero point and scale info of activations. |
| 358 | + |
| 359 | +Deployment |
| 360 | +~~~~~~~~~~ |
| 361 | + |
| 362 | +Users could use the below code to load quantized model and then do inference or performance benchmark. |
| 363 | + |
| 364 | +.. code-block:: python3 |
| 365 | +
|
| 366 | + from neural_compressor.utils.pytorch import load |
| 367 | + int8_model = load('./output', model) |
| 368 | +
|
| 369 | +Tutorials |
| 370 | +--------- |
| 371 | + |
| 372 | +Please visit `Intel® Neural Compressor Github repo <https://github.com/intel/neural-compressor>`_ |
| 373 | +for more tutorials. |
0 commit comments