|
| 1 | +Fuse Modules Recipe |
| 2 | +===================================== |
| 3 | + |
| 4 | +This recipe demonstrates how to fuse a list of PyTorch modules into a single module and how to do the performance test to compare the fused model with its non-fused version. |
| 5 | + |
| 6 | +Introduction |
| 7 | +------------ |
| 8 | + |
| 9 | +Before quantization is applied to a model to reduce its size and memory footprint (see `Quantization Recipe <quantization.html>`_ for details on quantization), the list of modules in the model may be fused first into a single module. Fusion is optional, but it may save on memory access, make the model run faster, and improve its accuracy. |
| 10 | + |
| 11 | + |
| 12 | +Pre-requisites |
| 13 | +-------------- |
| 14 | + |
| 15 | +PyTorch 1.6.0 or 1.7.0 |
| 16 | + |
| 17 | +Steps |
| 18 | +-------------- |
| 19 | + |
| 20 | +Follow the steps below to fuse an example model, quantize it, script it, optimize it for mobile, save it and test it with the Android benchmark tool. |
| 21 | + |
| 22 | +1. Define the Example Model |
| 23 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 24 | + |
| 25 | +Use the same example model defined in the `PyTorch Mobile Performance Recipes <https://pytorch.org/tutorials/recipes/mobile_perf.html>`_: |
| 26 | + |
| 27 | +:: |
| 28 | + |
| 29 | + import torch |
| 30 | + from torch.utils.mobile_optimizer import optimize_for_mobile |
| 31 | + |
| 32 | + class AnnotatedConvBnReLUModel(torch.nn.Module): |
| 33 | + def __init__(self): |
| 34 | + super(AnnotatedConvBnReLUModel, self).__init__() |
| 35 | + self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) |
| 36 | + self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) |
| 37 | + self.relu = torch.nn.ReLU(inplace=True) |
| 38 | + self.quant = torch.quantization.QuantStub() |
| 39 | + self.dequant = torch.quantization.DeQuantStub() |
| 40 | + |
| 41 | + def forward(self, x): |
| 42 | + x.contiguous(memory_format=torch.channels_last) |
| 43 | + x = self.quant(x) |
| 44 | + x = self.conv(x) |
| 45 | + x = self.bn(x) |
| 46 | + x = self.relu(x) |
| 47 | + x = self.dequant(x) |
| 48 | + return x |
| 49 | + |
| 50 | + |
| 51 | +2. Generate Two Models with and without `fuse_modules` |
| 52 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 53 | + |
| 54 | +Add the following code below the model definition above and run the script: |
| 55 | + |
| 56 | +:: |
| 57 | + |
| 58 | + model = AnnotatedConvBnReLUModel() |
| 59 | + |
| 60 | + def prepare_save(model, fused): |
| 61 | + model.qconfig = torch.quantization.get_default_qconfig('qnnpack') |
| 62 | + torch.quantization.prepare(model, inplace=True) |
| 63 | + torch.quantization.convert(model, inplace=True) |
| 64 | + torchscript_model = torch.jit.script(model) |
| 65 | + torchscript_model_optimized = optimize_for_mobile(torchscript_model) |
| 66 | + torch.jit.save(torchscript_model_optimized, "model.pt" if not fused else "model_fused.pt") |
| 67 | + |
| 68 | + prepare_save(model, False) |
| 69 | + model_fused = torch.quantization.fuse_modules(model, [['bn', 'relu']], inplace=False) |
| 70 | + prepare_save(model_fused, True) |
| 71 | + |
| 72 | + print(model) |
| 73 | + print(model_fused) |
| 74 | + |
| 75 | + |
| 76 | + |
| 77 | + |
| 78 | +The graphs of the original model and its fused version will be printed as follows: |
| 79 | + |
| 80 | +:: |
| 81 | + |
| 82 | + AnnotatedConvBnReLUModel( |
| 83 | + (conv): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), bias=False) |
| 84 | + (bn): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
| 85 | + (relu): ReLU(inplace=True) |
| 86 | + (quant): QuantStub() |
| 87 | + (dequant): DeQuantStub() |
| 88 | + ) |
| 89 | + |
| 90 | + AnnotatedConvBnReLUModel( |
| 91 | + (conv): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), bias=False) |
| 92 | + (bn): BNReLU2d( |
| 93 | + (0): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
| 94 | + (1): ReLU(inplace=True) |
| 95 | + ) |
| 96 | + (relu): Identity() |
| 97 | + (quant): QuantStub() |
| 98 | + (dequant): DeQuantStub() |
| 99 | + ) |
| 100 | + |
| 101 | +In the second fused model output, the first item `bn` in the list is replaced with the fused module, and the rest of the modules (`relu` in this example) is replaced with identity. In addition, the non-fused and fused versions of the model `model.pt` and `model_fused.pt` are generated. |
| 102 | + |
| 103 | +3. Build the Android benchmark Tool |
| 104 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 105 | + |
| 106 | +Get the PyTorch source and build the Android benchmark tool as follows: |
| 107 | + |
| 108 | +:: |
| 109 | + |
| 110 | + git clone --recursive https://github.com/pytorch/pytorch |
| 111 | + cd pytorch |
| 112 | + git submodule update --init --recursive |
| 113 | + BUILD_PYTORCH_MOBILE=1 ANDROID_ABI=arm64-v8a ./scripts/build_android.sh -DBUILD_BINARY=ON |
| 114 | + |
| 115 | + |
| 116 | +This will generate the Android benchmark binary `speed_benchmark_torch` in the `build_android/bin` folder. |
| 117 | + |
| 118 | +4. Test Compare the Fused and Non-Fused Models |
| 119 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 120 | + |
| 121 | +Connect your Android device, then copy `speed_benchmark_torch` and the model files and run the benchmark tool on them: |
| 122 | + |
| 123 | +:: |
| 124 | + |
| 125 | + adb push build_android/bin/speed_benchmark_torch /data/local/tmp |
| 126 | + adb push model.pt /data/local/tmp |
| 127 | + adb push model_fused.pt /data/local/tmp |
| 128 | + adb shell "/data/local/tmp/speed_benchmark_torch --model=/data/local/tmp/model.pt" --input_dims="1,3,224,224" --input_type="float" |
| 129 | + adb shell "/data/local/tmp/speed_benchmark_torch --model=/data/local/tmp/model_fused.pt" --input_dims="1,3,224,224" --input_type="float" |
| 130 | + |
| 131 | + |
| 132 | +The results from the last two commands should be like: |
| 133 | + |
| 134 | +:: |
| 135 | + |
| 136 | + Main run finished. Microseconds per iter: 6189.07. Iters per second: 161.575 |
| 137 | + |
| 138 | +and |
| 139 | + |
| 140 | +:: |
| 141 | + |
| 142 | + Main run finished. Microseconds per iter: 6216.65. Iters per second: 160.858 |
| 143 | + |
| 144 | +For this example model, there is no much performance difference between the fused and non-fused models. But the similar steps can be used to fuse and prepare a real deep model and test to see the performance improvement. Keep in mind that currently `torch.quantization.fuse_modules` only fuses the following sequence of modules: |
| 145 | + |
| 146 | +* conv, bn |
| 147 | +* conv, bn, relu |
| 148 | +* conv, relu |
| 149 | +* linear, relu |
| 150 | +* bn, relu |
| 151 | + |
| 152 | +If any other sequence list is provided to the `fuse_modules` call, it will simply be ignored. |
| 153 | + |
| 154 | +Learn More |
| 155 | +--------------- |
| 156 | + |
| 157 | +See `here <https://pytorch.org/docs/stable/quantization.html#preparing-model-for-quantization>`_ for the official documentation of `torch.quantization.fuse_modules`. |
0 commit comments