|
| 1 | +(prototype) PyTorch BackendConfig Tutorial |
| 2 | +============================== |
| 3 | +**Author**: `Andrew Or <https://github.com/andrewor14>`_ |
| 4 | + |
| 5 | +The BackendConfig API enables developers to integrate their backends |
| 6 | +with PyTorch quantization. It is currently only supported in FX graph |
| 7 | +mode quantization, but support may be extended to other modes of |
| 8 | +quantization in the future. In this tutorial, we will demonstrate how to |
| 9 | +use this API to customize quantization support for specific backends. |
| 10 | +For more information on the motivation and implementation details behind |
| 11 | +BackendConfig, please refer to this |
| 12 | +`README <https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config>`__. |
| 13 | + |
| 14 | +BackendConfig API Specification |
| 15 | +------------------------------- |
| 16 | + |
| 17 | +On a high level, BackendConfig specifies the quantization behavior for |
| 18 | +each supported operator pattern (e.g. linear, conv-bn-relu). The API is |
| 19 | +broken down into the following class hierarchy: |
| 20 | + |
| 21 | +- `BackendConfig <https://pytorch.org/docs/stable/generated/torch.ao.quantization.backend_config.BackendConfig.html>`__: |
| 22 | + The main class to be passed to prepare and convert functions |
| 23 | +- `BackendPatternConfig <https://pytorch.org/docs/stable/generated/torch.ao.quantization.backend_config.BackendPatternConfig.html>`__: |
| 24 | + Config object that specifies quantization behavior for a given |
| 25 | + operator pattern. Each BackendConfig consists of many of these. |
| 26 | +- `DTypeConfig <https://pytorch.org/docs/stable/generated/torch.ao.quantization.backend_config.DTypeConfig.html>`__: |
| 27 | + Config object that specifies the supported data types and constraints |
| 28 | + (if any) for input and output activations, weights, and biases. Each |
| 29 | + BackendPatternConfig consists of one or more of these. |
| 30 | +- `DTypeWithConstraints <https://pytorch.org/docs/stable/generated/torch.ao.quantization.backend_config.DTypeWithConstraints.html>`__: |
| 31 | + Constraints imposed by the backend on the quantization parameters |
| 32 | + (scale and zero point) and ranges when quantizing to a given data |
| 33 | + type. Each DTypeConfig consists of many of these. |
| 34 | + |
| 35 | +The pattern specified in BackendPatternConfig follows the format |
| 36 | +described `here <https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md#pattern-specification>`__. |
| 37 | + |
| 38 | +BackendPatternConfig Specification |
| 39 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 40 | + |
| 41 | +set_observation_type |
| 42 | +^^^^^^^^^^^^^^^^^^^^ |
| 43 | + |
| 44 | +Observation type here refers to how observers (or QuantDeQuantStubs) are |
| 45 | +placed in the graph. There are two observation types: |
| 46 | + |
| 47 | +- ``OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT`` (default): the output |
| 48 | + observer instance will be different from the input. This is the most |
| 49 | + common observation type. |
| 50 | +- ``OUTPUT_SHARE_OBSERVER_WITH_INPUT``: the output observer instance |
| 51 | + will be the same as the input. This is useful for operators like ``cat``. |
| 52 | + |
| 53 | +Note: This will be renamed in the near future, since we will soon insert |
| 54 | +QuantDeQuantStubs with observers (and fake quantizes) attached instead |
| 55 | +of observers themselves. |
| 56 | + |
| 57 | +set_dtype_configs / add_type_config |
| 58 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 59 | + |
| 60 | +Each operator pattern may support one or more sets of |
| 61 | +input/output/weight/bias data types, and each set may have their own |
| 62 | +constraints. These requirements are captured in DTypeConfigs, which will |
| 63 | +be described in more detail in the next section. |
| 64 | + |
| 65 | +set_root_module / set_reference_quantized_module |
| 66 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 67 | + |
| 68 | +When we construct the reference quantized model during the convert |
| 69 | +phase, the root modules (e.g. ``torch.nn.Linear`` for |
| 70 | +``torch.ao.nn.intrinsic.LinearReLU``) will be swapped to the |
| 71 | +corresponding reference quantized modules (e.g. |
| 72 | +``torch.ao.nn.reference.quantized.Linear``). This allows custom backends |
| 73 | +to specify custom reference quantized module implementations to match |
| 74 | +the numerics of their lowered operators. Since this is a one-to-one |
| 75 | +mapping, both the root module and the reference quantized module must be |
| 76 | +specified in the same BackendPatternConfig in order for the conversion |
| 77 | +to take place. |
| 78 | + |
| 79 | +set_fuser_method |
| 80 | +^^^^^^^^^^^^^^^^ |
| 81 | + |
| 82 | +As an optimization, operator patterns such as (``torch.nn.Linear``, |
| 83 | +``torch.nn.ReLU``) may be fused into ``nni.LinearReLU``. |
| 84 | +``set_fuser_method`` specifies the function through which this is |
| 85 | +performed. The first argument of this function is ``is_qat``, and the |
| 86 | +rest of the arguments are the items in the tuple pattern, e.g. the fuser |
| 87 | +method for the above pattern will have three arguments, ``is_qat``, |
| 88 | +``linear``, and ``relu``. See `this |
| 89 | +example <https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6>`__ |
| 90 | +for a slightly more complicated usage. |
| 91 | + |
| 92 | +set_fused_module |
| 93 | +^^^^^^^^^^^^^^^^ |
| 94 | + |
| 95 | +This is used to identify fused weighted modules (e.g. |
| 96 | +``torch.ao.nn.intrinsic.LinearReLU``) that need to be converted to |
| 97 | +reference quantized modules. |
| 98 | + |
| 99 | +Data Type Restrictions |
| 100 | +~~~~~~~~~~~~~~~~~~~~~~ |
| 101 | + |
| 102 | +Each DTypeConfig attached to a BackendPatternConfig represents a set of |
| 103 | +supported data types for the input/output activations, weights, and |
| 104 | +biases. These data types are matched against the ones specified in the |
| 105 | +user’s QConfig. If there is a match, and the QConfig satisfies the |
| 106 | +constraints specified in the DTypeConfig (if any), then we will quantize |
| 107 | +the given pattern using this DTypeConfig. Otherwise, the QConfig is |
| 108 | +ignored and the pattern will not be quantized. |
| 109 | + |
| 110 | +There are two ways of specifying ``input_dtype``, ``output_dtype``, and |
| 111 | +``weight_dtype``, as simple ``torch.dtype`` or as |
| 112 | +``DTypeWithConstraints``. The constraints currently supported are: |
| 113 | + |
| 114 | +- **quant_min_lower_bound** and **quant_max_upper_bound**: Lower and upper |
| 115 | + bounds for the minimum and maximum quantized values respectively. If the |
| 116 | + QConfig’s ``quant_min`` and ``quant_max`` fall outside this range, then |
| 117 | + the QConfig will be ignored. |
| 118 | +- **scale_min_lower_bound** and **scale_max_upper_bound**: Lower and |
| 119 | + upper bounds for the minimum and aximum scale values respectively. If |
| 120 | + the QConfig’s minimum scale value (currently exposed as ``eps``) falls |
| 121 | + below the lower bound, then the QConfig will be ignored. Note that the |
| 122 | + upper bound is currently not enforced. |
| 123 | +- **scale_exact_match** and **zero_point_exact_match**: Exact match |
| 124 | + requirements for scale and zero point, to be used for operators with |
| 125 | + fixed quantization parameters such as sigmoid and tanh. If the observer |
| 126 | + specified in the QConfig is neither ``FixedQParamsObserver`` nor |
| 127 | + ``FixedQParamsFakeQuantize``, or if the quantization parameters don't |
| 128 | + match, then the QConfig will be ignored. |
| 129 | + |
| 130 | +End-to-End Example |
| 131 | +------------------ |
| 132 | + |
| 133 | +Here we define a simple example BackendConfig that supports 3 things: |
| 134 | + |
| 135 | +- Quantizing linear to int8 (unsigned) |
| 136 | +- Fusing conv2d followed by relu |
| 137 | +- Quantizing the fused conv2d-relu to int8 (unsigned) |
| 138 | + |
| 139 | +In this section, we will run through an example model with these |
| 140 | +patterns and show how to use this BackendConfig in ``prepare_fx`` and |
| 141 | +``convert_fx``. |
| 142 | + |
| 143 | +.. code:: ipython3 |
| 144 | +
|
| 145 | + import torch |
| 146 | + from torch.ao.quantization import ( |
| 147 | + default_weight_observer, |
| 148 | + get_default_qconfig_mapping, |
| 149 | + MinMaxObserver, |
| 150 | + QConfig, |
| 151 | + QConfigMapping, |
| 152 | + ) |
| 153 | + from torch.ao.quantization.backend_config import ( |
| 154 | + BackendConfig, |
| 155 | + BackendPatternConfig, |
| 156 | + DTypeConfig, |
| 157 | + DTypeWithConstraints, |
| 158 | + ObservationType, |
| 159 | + ) |
| 160 | + from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx |
| 161 | +
|
| 162 | +.. code:: ipython3 |
| 163 | +
|
| 164 | + # ====================== |
| 165 | + # Custom BackendConfig |
| 166 | + # ====================== |
| 167 | + |
| 168 | + quint8_with_constraints = DTypeWithConstraints( |
| 169 | + dtype=torch.quint8, |
| 170 | + quant_min_lower_bound=0, |
| 171 | + quant_max_upper_bound=255, |
| 172 | + scale_min_lower_bound=2 ** -12, |
| 173 | + ) |
| 174 | + |
| 175 | + weighted_int8_dtype_config = DTypeConfig( |
| 176 | + input_dtype=quint8_with_constraints, |
| 177 | + output_dtype=quint8_with_constraints, |
| 178 | + weight_dtype=torch.qint8, |
| 179 | + bias_dtype=torch.float) |
| 180 | +
|
| 181 | + def fuse_conv2d_relu(is_qat, conv, relu): |
| 182 | + """Return a fused ConvReLU2d from individual conv and relu modules.""" |
| 183 | + return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu) |
| 184 | + |
| 185 | + # For quantizing Linear |
| 186 | + linear_config = BackendPatternConfig(torch.nn.Linear) \ |
| 187 | + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ |
| 188 | + .add_dtype_config(weighted_int8_dtype_config) \ |
| 189 | + .set_root_module(torch.nn.Linear) \ |
| 190 | + .set_qat_module(torch.nn.qat.Linear) \ |
| 191 | + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear) |
| 192 | + |
| 193 | + # For fusing Conv2d + ReLU into ConvReLU2d |
| 194 | + conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \ |
| 195 | + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ |
| 196 | + .add_dtype_config(weighted_int8_dtype_config) \ |
| 197 | + .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \ |
| 198 | + .set_fuser_method(fuse_conv2d_relu) |
| 199 | + |
| 200 | + # For quantizing ConvReLU2d |
| 201 | + fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \ |
| 202 | + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ |
| 203 | + .add_dtype_config(weighted_int8_dtype_config) \ |
| 204 | + .set_root_module(torch.nn.Conv2d) \ |
| 205 | + .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \ |
| 206 | + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d) |
| 207 | + |
| 208 | + backend_config = BackendConfig("my_backend") \ |
| 209 | + .set_backend_pattern_config(linear_config) \ |
| 210 | + .set_backend_pattern_config(conv_relu_config) \ |
| 211 | + .set_backend_pattern_config(fused_conv_relu_config) |
| 212 | +
|
| 213 | +.. code:: ipython3 |
| 214 | +
|
| 215 | + # ==================== |
| 216 | + # Example user model |
| 217 | + # ==================== |
| 218 | + |
| 219 | + class MyModel(torch.nn.Module): |
| 220 | + def __init__(self, use_bn: bool): |
| 221 | + super().__init__() |
| 222 | + self.linear = torch.nn.Linear(10, 3) |
| 223 | + self.conv = torch.nn.Conv2d(3, 3, 3) |
| 224 | + self.bn = torch.nn.BatchNorm2d(3) |
| 225 | + self.relu = torch.nn.ReLU() |
| 226 | + self.sigmoid = torch.nn.Sigmoid() |
| 227 | + self.use_bn = use_bn |
| 228 | + |
| 229 | + def forward(self, x): |
| 230 | + x = self.linear(x) |
| 231 | + x = self.conv(x) |
| 232 | + if self.use_bn: |
| 233 | + x = self.bn(x) |
| 234 | + x = self.relu(x) |
| 235 | + x = self.sigmoid(x) |
| 236 | + return x |
| 237 | +
|
| 238 | +.. code:: ipython3 |
| 239 | +
|
| 240 | + # ======================= |
| 241 | + # Custom QConfigMapping |
| 242 | + # ======================= |
| 243 | + |
| 244 | + # Define a QConfig that satisfies the constraints specified in DTypeConfig |
| 245 | + # Note: Here we use a quant_max of 127, but this could be up to 255 (see `quint8_with_constraints`) |
| 246 | + activation_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127, eps=2 ** -12) |
| 247 | + qconfig = QConfig(activation=activation_observer, weight=default_weight_observer) |
| 248 | +
|
| 249 | + # Note: All individual items of a fused pattern, e.g. Conv2d and ReLU in |
| 250 | + # (Conv2d, ReLU), must have the same QConfig |
| 251 | + qconfig_mapping = QConfigMapping() \ |
| 252 | + .set_object_type(torch.nn.Linear, qconfig) \ |
| 253 | + .set_object_type(torch.nn.Conv2d, qconfig) \ |
| 254 | + .set_object_type(torch.nn.BatchNorm2d, qconfig) \ |
| 255 | + .set_object_type(torch.nn.ReLU, qconfig) |
| 256 | +
|
| 257 | +.. code:: ipython3 |
| 258 | +
|
| 259 | + # ===================== |
| 260 | + # Prepare and Convert |
| 261 | + # ===================== |
| 262 | + |
| 263 | + example_inputs = (torch.rand(1, 3, 10, 10, dtype=torch.float),) |
| 264 | + model = MyModel(use_bn=False) |
| 265 | + prepared = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config) |
| 266 | + prepared(*example_inputs) # calibrate |
| 267 | + converted = convert_fx(prepared, backend_config=backend_config) |
| 268 | +
|
| 269 | +.. parsed-literal:: |
| 270 | +
|
| 271 | + # Both linear and conv-relu are quantized |
| 272 | + >>> print(converted) |
| 273 | +
|
| 274 | + GraphModule( |
| 275 | + (linear): QuantizedLinear(in_features=10, out_features=3, scale=0.012136868201196194, zero_point=67, qscheme=torch.per_tensor_affine) |
| 276 | + (conv): QuantizedConvReLU2d(3, 3, kernel_size=(3, 3), stride=(1, 1), scale=0.0029353597201406956, zero_point=0) |
| 277 | + (sigmoid): Sigmoid() |
| 278 | + ) |
| 279 | + |
| 280 | + def forward(self, x): |
| 281 | + linear_input_scale_0 = self.linear_input_scale_0 |
| 282 | + linear_input_zero_point_0 = self.linear_input_zero_point_0 |
| 283 | + quantize_per_tensor = torch.quantize_per_tensor(x, linear_input_scale_0, linear_input_zero_point_0, torch.quint8); x = linear_input_scale_0 = linear_input_zero_point_0 = None |
| 284 | + linear = self.linear(quantize_per_tensor); quantize_per_tensor = None |
| 285 | + conv = self.conv(linear); linear = None |
| 286 | + dequantize_2 = conv.dequantize(); conv = None |
| 287 | + sigmoid = self.sigmoid(dequantize_2); dequantize_2 = None |
| 288 | + return sigmoid |
| 289 | +
|
| 290 | +.. code:: ipython3 |
| 291 | +
|
| 292 | + # ================================================ |
| 293 | + # Prepare and Convert (only linear is quantized) |
| 294 | + # ================================================ |
| 295 | + |
| 296 | + # As an experiment, here we modify the model to use conv-bn-relu instead of conv-relu, |
| 297 | + # but use the same BackendConfig, which doesn't know how to fuse or quantize conv-bn-relu. |
| 298 | + example_inputs = (torch.rand(1, 3, 10, 10, dtype=torch.float),) |
| 299 | + model = MyModel(use_bn=True) |
| 300 | + prepared = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config) |
| 301 | + prepared(*example_inputs) # calibrate |
| 302 | + converted = convert_fx(prepared, backend_config=backend_config) |
| 303 | +
|
| 304 | +.. parsed-literal:: |
| 305 | +
|
| 306 | + # Only linear is quantized |
| 307 | + # conv-bn-relu is neither fused nor quantized |
| 308 | + >>> print(converted) |
| 309 | +
|
| 310 | + GraphModule( |
| 311 | + (linear): QuantizedLinear(in_features=10, out_features=3, scale=0.015307803638279438, zero_point=95, qscheme=torch.per_tensor_affine) |
| 312 | + (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) |
| 313 | + (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
| 314 | + (relu): ReLU() |
| 315 | + (sigmoid): Sigmoid() |
| 316 | + ) |
| 317 | + |
| 318 | + def forward(self, x): |
| 319 | + linear_input_scale_0 = self.linear_input_scale_0 |
| 320 | + linear_input_zero_point_0 = self.linear_input_zero_point_0 |
| 321 | + quantize_per_tensor = torch.quantize_per_tensor(x, linear_input_scale_0, linear_input_zero_point_0, torch.quint8); x = linear_input_scale_0 = linear_input_zero_point_0 = None |
| 322 | + linear = self.linear(quantize_per_tensor); quantize_per_tensor = None |
| 323 | + dequantize_1 = linear.dequantize(); linear = None |
| 324 | + conv = self.conv(dequantize_1); dequantize_1 = None |
| 325 | + bn = self.bn(conv); conv = None |
| 326 | + relu = self.relu(bn); bn = None |
| 327 | + sigmoid = self.sigmoid(relu); relu = None |
| 328 | + return sigmoid |
| 329 | +
|
| 330 | +.. code:: ipython3 |
| 331 | +
|
| 332 | + # ============================================ |
| 333 | + # Prepare and Convert (nothing is quantized) |
| 334 | + # ============================================ |
| 335 | + |
| 336 | + # As an experiment, here we use the default QConfigMapping that doesn't satisfy the |
| 337 | + # dtype restrictions specified in the backend. As a result, nothing is quantized. |
| 338 | + example_inputs = (torch.rand(1, 3, 10, 10, dtype=torch.float),) |
| 339 | + model = MyModel(use_bn=True) |
| 340 | + prepared = prepare_fx(model, get_default_qconfig_mapping(), example_inputs, backend_config=backend_config) |
| 341 | + prepared(*example_inputs) # calibrate |
| 342 | + converted = convert_fx(prepared, backend_config=backend_config) |
| 343 | +
|
| 344 | +.. parsed-literal:: |
| 345 | +
|
| 346 | + # Nothing is quantized |
| 347 | + >>> print(converted) |
| 348 | +
|
| 349 | + GraphModule( |
| 350 | + (linear): Linear(in_features=10, out_features=3, bias=True) |
| 351 | + (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) |
| 352 | + (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
| 353 | + (relu): ReLU() |
| 354 | + (sigmoid): Sigmoid() |
| 355 | + ) |
| 356 | + |
| 357 | + def forward(self, x): |
| 358 | + linear = self.linear(x); x = None |
| 359 | + conv = self.conv(linear); linear = None |
| 360 | + bn = self.bn(conv); conv = None |
| 361 | + relu = self.relu(bn); bn = None |
| 362 | + sigmoid = self.sigmoid(relu); relu = None |
| 363 | + return sigmoid |
| 364 | +
|
| 365 | +
|
| 366 | +Built-in BackendConfigs |
| 367 | +----------------------- |
| 368 | + |
| 369 | +PyTorch quantization supports a few built-in native BackendConfigs under |
| 370 | +the ``torch.ao.quantization.backend_config`` namespace: |
| 371 | + |
| 372 | +- `get_fbgemm_backend_config <https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/fbgemm.py>`__: |
| 373 | + for server target settings |
| 374 | +- `get_qnnpack_backend_config <https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/qnnpack.py>`__: |
| 375 | + for mobile and edge device target settings, also supports XNNPACK |
| 376 | + quantized ops |
| 377 | +- `get_native_backend_config <https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/native.py>`__ |
| 378 | + (default): a BackendConfig that supports a union of the operator |
| 379 | + patterns supported in the FBGEMM and QNNPACK BackendConfigs |
| 380 | + |
| 381 | +There are also other BackendConfigs under development (e.g. for |
| 382 | +TensorRT and x86), but these are still mostly experimental at the |
| 383 | +moment. If the user wishes to integrate a new, custom backend with |
| 384 | +PyTorch’s quantization API, they may define their own BackendConfigs |
| 385 | +using the same set of APIs used to define the natively supported |
| 386 | +ones as in the example above. |
| 387 | + |
| 388 | +Further Reading |
| 389 | +--------------- |
| 390 | + |
| 391 | +How BackendConfig is used in FX graph mode quantization: |
| 392 | +https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/README.md |
| 393 | + |
| 394 | +Motivation and implementation details behind BackendConfig: |
| 395 | +https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md |
| 396 | + |
| 397 | +Early design of BackendConfig: |
| 398 | +https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md |
0 commit comments