@@ -20,19 +20,20 @@ We'll start by doing the necessary imports:
20
20
21
21
.. code :: python
22
22
23
- import numpy as np
24
- import torch
25
- import torch.nn as nn
26
- import torchvision
27
- from torch.utils.data import DataLoader
28
- from torchvision import datasets
29
- import torchvision.transforms as transforms
30
- import os
31
- import time
32
- import sys
33
- import torch.quantization
34
-
35
- # # Setup warnings
23
+ import os
24
+ import sys
25
+ import time
26
+ import numpy as np
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ from torch.utils.data import DataLoader
31
+
32
+ import torchvision
33
+ from torchvision import datasets
34
+ import torchvision.transforms as transforms
35
+
36
+ # Set up warnings
36
37
import warnings
37
38
warnings.filterwarnings(
38
39
action = ' ignore' ,
@@ -41,7 +42,7 @@ We'll start by doing the necessary imports:
41
42
)
42
43
warnings.filterwarnings(
43
44
action = ' default' ,
44
- module = r ' torch. quantization'
45
+ module = r ' torch. ao . quantization'
45
46
)
46
47
47
48
# Specify random seed for repeatable results
@@ -62,7 +63,7 @@ Note: this code is taken from
62
63
63
64
.. code :: python
64
65
65
- from torch.quantization import QuantStub, DeQuantStub
66
+ from torch.ao. quantization import QuantStub, DeQuantStub
66
67
67
68
def _make_divisible (v , divisor , min_value = None ):
68
69
"""
@@ -196,9 +197,7 @@ Note: this code is taken from
196
197
nn.init.zeros_(m.bias)
197
198
198
199
def forward (self , x ):
199
-
200
200
x = self .quant(x)
201
-
202
201
x = self .features(x)
203
202
x = x.mean([2 , 3 ])
204
203
x = self .classifier(x)
@@ -210,11 +209,11 @@ Note: this code is taken from
210
209
def fuse_model (self ):
211
210
for m in self .modules():
212
211
if type (m) == ConvBNReLU:
213
- torch.quantization.fuse_modules(m, [' 0' , ' 1' , ' 2' ], inplace = True )
212
+ torch.ao. quantization.fuse_modules(m, [' 0' , ' 1' , ' 2' ], inplace = True )
214
213
if type (m) == InvertedResidual:
215
214
for idx in range (len (m.conv)):
216
215
if type (m.conv[idx]) == nn.Conv2d:
217
- torch.quantization.fuse_modules(m.conv, [str (idx), str (idx + 1 )], inplace = True )
216
+ torch.ao. quantization.fuse_modules(m.conv, [str (idx), str (idx + 1 )], inplace = True )
218
217
219
218
2. Helper functions
220
219
-------------------
@@ -314,25 +313,22 @@ in this data. These functions mostly come from
314
313
.. code :: python
315
314
316
315
def prepare_data_loaders (data_path ):
317
-
318
316
normalize = transforms.Normalize(mean = [0.485 , 0.456 , 0.406 ],
319
317
std = [0.229 , 0.224 , 0.225 ])
320
318
dataset = torchvision.datasets.ImageNet(
321
- data_path, split = " train" ,
322
- transforms.Compose([
323
- transforms.RandomResizedCrop(224 ),
324
- transforms.RandomHorizontalFlip(),
325
- transforms.ToTensor(),
326
- normalize,
327
- ]))
319
+ data_path, split = " train" , transform = transforms.Compose([
320
+ transforms.RandomResizedCrop(224 ),
321
+ transforms.RandomHorizontalFlip(),
322
+ transforms.ToTensor(),
323
+ normalize,
324
+ ]))
328
325
dataset_test = torchvision.datasets.ImageNet(
329
- data_path, split = " val" ,
330
- transforms.Compose([
331
- transforms.Resize(256 ),
332
- transforms.CenterCrop(224 ),
333
- transforms.ToTensor(),
334
- normalize,
335
- ]))
326
+ data_path, split = " val" , transform = transforms.Compose([
327
+ transforms.Resize(256 ),
328
+ transforms.CenterCrop(224 ),
329
+ transforms.ToTensor(),
330
+ normalize,
331
+ ]))
336
332
337
333
train_sampler = torch.utils.data.RandomSampler(dataset)
338
334
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
@@ -424,9 +420,9 @@ values to floats - and then back to ints - between every operation, resulting in
424
420
425
421
# Specify quantization configuration
426
422
# Start with simple min/max range estimation and per-tensor quantization of weights
427
- myModel.qconfig = torch.quantization.default_qconfig
423
+ myModel.qconfig = torch.ao. quantization.default_qconfig
428
424
print (myModel.qconfig)
429
- torch.quantization.prepare(myModel, inplace = True )
425
+ torch.ao. quantization.prepare(myModel, inplace = True )
430
426
431
427
# Calibrate first
432
428
print (' Post Training Quantization Prepare: Inserting Observers' )
@@ -437,7 +433,7 @@ values to floats - and then back to ints - between every operation, resulting in
437
433
print (' Post Training Quantization: Calibration done' )
438
434
439
435
# Convert to quantized model
440
- torch.quantization.convert(myModel, inplace = True )
436
+ torch.ao. quantization.convert(myModel, inplace = True )
441
437
print (' Post Training Quantization: Convert done' )
442
438
print (' \n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n ' ,myModel.features[1 ].conv)
443
439
@@ -462,12 +458,12 @@ quantizing for x86 architectures. This configuration does the following:
462
458
per_channel_quantized_model = load_model(saved_model_dir + float_model_file)
463
459
per_channel_quantized_model.eval()
464
460
per_channel_quantized_model.fuse_model()
465
- per_channel_quantized_model.qconfig = torch.quantization.get_default_qconfig(' fbgemm' )
461
+ per_channel_quantized_model.qconfig = torch.ao. quantization.get_default_qconfig(' fbgemm' )
466
462
print (per_channel_quantized_model.qconfig)
467
463
468
- torch.quantization.prepare(per_channel_quantized_model, inplace = True )
464
+ torch.ao. quantization.prepare(per_channel_quantized_model, inplace = True )
469
465
evaluate(per_channel_quantized_model,criterion, data_loader, num_calibration_batches)
470
- torch.quantization.convert(per_channel_quantized_model, inplace = True )
466
+ torch.ao. quantization.convert(per_channel_quantized_model, inplace = True )
471
467
top1, top5 = evaluate(per_channel_quantized_model, criterion, data_loader_test, neval_batches = num_eval_batches)
472
468
print (' Evaluation accuracy on %d images, %2.2f ' % (num_eval_batches * eval_batch_size, top1.avg))
473
469
torch.jit.save(torch.jit.script(per_channel_quantized_model), saved_model_dir + scripted_quantized_model_file)
@@ -539,13 +535,13 @@ We fuse modules as before
539
535
qat_model.fuse_model()
540
536
541
537
optimizer = torch.optim.SGD(qat_model.parameters(), lr = 0.0001 )
542
- qat_model.qconfig = torch.quantization.get_default_qat_qconfig(' fbgemm' )
538
+ qat_model.qconfig = torch.ao. quantization.get_default_qat_qconfig(' fbgemm' )
543
539
544
540
Finally, ``prepare_qat `` performs the "fake quantization", preparing the model for quantization-aware training
545
541
546
542
.. code :: python
547
543
548
- torch.quantization.prepare_qat(qat_model, inplace = True )
544
+ torch.ao. quantization.prepare_qat(qat_model, inplace = True )
549
545
print (' Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n ' ,qat_model.features[1 ].conv)
550
546
551
547
Training a quantized model with high accuracy requires accurate modeling of numerics at
@@ -565,13 +561,13 @@ inference. For quantization aware training, therefore, we modify the training lo
565
561
train_one_epoch(qat_model, criterion, optimizer, data_loader, torch.device(' cpu' ), num_train_batches)
566
562
if nepoch > 3 :
567
563
# Freeze quantizer parameters
568
- qat_model.apply(torch.quantization.disable_observer)
564
+ qat_model.apply(torch.ao. quantization.disable_observer)
569
565
if nepoch > 2 :
570
566
# Freeze batch norm mean and variance estimates
571
567
qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
572
568
573
569
# Check the accuracy after each epoch
574
- quantized_model = torch.quantization.convert(qat_model.eval(), inplace = False )
570
+ quantized_model = torch.ao. quantization.convert(qat_model.eval(), inplace = False )
575
571
quantized_model.eval()
576
572
top1, top5 = evaluate(quantized_model,criterion, data_loader_test, neval_batches = num_eval_batches)
577
573
print (' Epoch %d :Evaluation accuracy on %d images, %2.2f ' % (nepoch, num_eval_batches * eval_batch_size, top1.avg))
@@ -630,4 +626,4 @@ and quantization-aware training - describing what they do "under the hood" and h
630
626
them in PyTorch.
631
627
632
628
Thanks for reading! As always, we welcome any feedback, so please create an issue
633
- `here <https://github.com/pytorch/pytorch/issues >`_ if you have any.
629
+ `here <https://github.com/pytorch/pytorch/issues >`_ if you have any.
0 commit comments