Skip to content

Commit 2408739

Browse files
committed
[Quant] Add tutorial for BackendConfig
Summary: This commit adds the tutorial for the BackendConfig, the integration point for backend developers to specify the quantization behavior on a given target backend. Note to reviewers: To see the rendered version, just go to "Files changed" > "..." > "View file". Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo
1 parent db34a77 commit 2408739

File tree

1 file changed

+398
-0
lines changed

1 file changed

+398
-0
lines changed
Lines changed: 398 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,398 @@
1+
(prototype) PyTorch BackendConfig Tutorial
2+
==============================
3+
**Author**: `Andrew Or <https://github.com/andrewor14>`_
4+
5+
The BackendConfig API enables developers to integrate their backends
6+
with PyTorch quantization. It is currently only supported in FX graph
7+
mode quantization, but support may be extended to other modes of
8+
quantization in the future. In this tutorial, we will demonstrate how to
9+
use this API to customize quantization support for specific backends.
10+
For more information on the motivation and implementation details behind
11+
BackendConfig, please refer to this
12+
`README <https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config>`__.
13+
14+
BackendConfig API Specification
15+
-------------------------------
16+
17+
On a high level, BackendConfig specifies the quantization behavior for
18+
each supported operator pattern (e.g. linear, conv-bn-relu). The API is
19+
broken down into the following class hierarchy:
20+
21+
- `BackendConfig <https://pytorch.org/docs/stable/generated/torch.ao.quantization.backend_config.BackendConfig.html>`__:
22+
The main class to be passed to prepare and convert functions
23+
- `BackendPatternConfig <https://pytorch.org/docs/stable/generated/torch.ao.quantization.backend_config.BackendPatternConfig.html>`__:
24+
Config object that specifies quantization behavior for a given
25+
operator pattern. Each BackendConfig consists of many of these.
26+
- `DTypeConfig <https://pytorch.org/docs/stable/generated/torch.ao.quantization.backend_config.DTypeConfig.html>`__:
27+
Config object that specifies the supported data types and constraints
28+
(if any) for input and output activations, weights, and biases. Each
29+
BackendPatternConfig consists of one or more of these.
30+
- `DTypeWithConstraints <https://pytorch.org/docs/stable/generated/torch.ao.quantization.backend_config.DTypeWithConstraints.html>`__:
31+
Constraints imposed by the backend on the quantization parameters
32+
(scale and zero point) and ranges when quantizing to a given data
33+
type. Each DTypeConfig consists of many of these.
34+
35+
The pattern specified in BackendPatternConfig follows the format
36+
described `here <https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md#pattern-specification>`__.
37+
38+
BackendPatternConfig Specification
39+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
40+
41+
set_observation_type
42+
^^^^^^^^^^^^^^^^^^^^
43+
44+
Observation type here refers to how observers (or QuantDeQuantStubs) are
45+
placed in the graph. There are two observation types:
46+
47+
- ``OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT`` (default): the output
48+
observer instance will be different from the input. This is the most
49+
common observation type.
50+
- ``OUTPUT_SHARE_OBSERVER_WITH_INPUT``: the output observer instance
51+
will be the same as the input. This is useful for operators like ``cat``.
52+
53+
Note: This will be renamed in the near future, since we will soon insert
54+
QuantDeQuantStubs with observers (and fake quantizes) attached instead
55+
of observers themselves.
56+
57+
set_dtype_configs / add_type_config
58+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
59+
60+
Each operator pattern may support one or more sets of
61+
input/output/weight/bias data types, and each set may have their own
62+
constraints. These requirements are captured in DTypeConfigs, which will
63+
be described in more detail in the next section.
64+
65+
set_root_module / set_reference_quantized_module
66+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
67+
68+
When we construct the reference quantized model during the convert
69+
phase, the root modules (e.g. ``torch.nn.Linear`` for
70+
``torch.ao.nn.intrinsic.LinearReLU``) will be swapped to the
71+
corresponding reference quantized modules (e.g.
72+
``torch.ao.nn.reference.quantized.Linear``). This allows custom backends
73+
to specify custom reference quantized module implementations to match
74+
the numerics of their lowered operators. Since this is a one-to-one
75+
mapping, both the root module and the reference quantized module must be
76+
specified in the same BackendPatternConfig in order for the conversion
77+
to take place.
78+
79+
set_fuser_method
80+
^^^^^^^^^^^^^^^^
81+
82+
As an optimization, operator patterns such as (``torch.nn.Linear``,
83+
``torch.nn.ReLU``) may be fused into ``nni.LinearReLU``.
84+
``set_fuser_method`` specifies the function through which this is
85+
performed. The first argument of this function is ``is_qat``, and the
86+
rest of the arguments are the items in the tuple pattern, e.g. the fuser
87+
method for the above pattern will have three arguments, ``is_qat``,
88+
``linear``, and ``relu``. See `this
89+
example <https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6>`__
90+
for a slightly more complicated usage.
91+
92+
set_fused_module
93+
^^^^^^^^^^^^^^^^
94+
95+
This is used to identify fused weighted modules (e.g.
96+
``torch.ao.nn.intrinsic.LinearReLU``) that need to be converted to
97+
reference quantized modules.
98+
99+
Data Type Restrictions
100+
~~~~~~~~~~~~~~~~~~~~~~
101+
102+
Each DTypeConfig attached to a BackendPatternConfig represents a set of
103+
supported data types for the input/output activations, weights, and
104+
biases. These data types are matched against the ones specified in the
105+
user’s QConfig. If there is a match, and the QConfig satisfies the
106+
constraints specified in the DTypeConfig (if any), then we will quantize
107+
the given pattern using this DTypeConfig. Otherwise, the QConfig is
108+
ignored and the pattern will not be quantized.
109+
110+
There are two ways of specifying ``input_dtype``, ``output_dtype``, and
111+
``weight_dtype``, as simple ``torch.dtype`` or as
112+
``DTypeWithConstraints``. The constraints currently supported are:
113+
114+
- **quant_min_lower_bound** and **quant_max_upper_bound**: Lower and upper
115+
bounds for the minimum and maximum quantized values respectively. If the
116+
QConfig’s ``quant_min`` and ``quant_max`` fall outside this range, then
117+
the QConfig will be ignored.
118+
- **scale_min_lower_bound** and **scale_max_upper_bound**: Lower and
119+
upper bounds for the minimum and aximum scale values respectively. If
120+
the QConfig’s minimum scale value (currently exposed as ``eps``) falls
121+
below the lower bound, then the QConfig will be ignored. Note that the
122+
upper bound is currently not enforced.
123+
- **scale_exact_match** and **zero_point_exact_match**: Exact match
124+
requirements for scale and zero point, to be used for operators with
125+
fixed quantization parameters such as sigmoid and tanh. If the observer
126+
specified in the QConfig is neither ``FixedQParamsObserver`` nor
127+
``FixedQParamsFakeQuantize``, or if the quantization parameters don't
128+
match, then the QConfig will be ignored.
129+
130+
End-to-End Example
131+
------------------
132+
133+
Here we define a simple example BackendConfig that supports 3 things:
134+
135+
- Quantizing linear to int8 (unsigned)
136+
- Fusing conv2d followed by relu
137+
- Quantizing the fused conv2d-relu to int8 (unsigned)
138+
139+
In this section, we will run through an example model with these
140+
patterns and show how to use this BackendConfig in ``prepare_fx`` and
141+
``convert_fx``.
142+
143+
.. code:: ipython3
144+
145+
import torch
146+
from torch.ao.quantization import (
147+
default_weight_observer,
148+
get_default_qconfig_mapping,
149+
MinMaxObserver,
150+
QConfig,
151+
QConfigMapping,
152+
)
153+
from torch.ao.quantization.backend_config import (
154+
BackendConfig,
155+
BackendPatternConfig,
156+
DTypeConfig,
157+
DTypeWithConstraints,
158+
ObservationType,
159+
)
160+
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
161+
162+
.. code:: ipython3
163+
164+
# ======================
165+
# Custom BackendConfig
166+
# ======================
167+
168+
quint8_with_constraints = DTypeWithConstraints(
169+
dtype=torch.quint8,
170+
quant_min_lower_bound=0,
171+
quant_max_upper_bound=255,
172+
scale_min_lower_bound=2 ** -12,
173+
)
174+
175+
weighted_int8_dtype_config = DTypeConfig(
176+
input_dtype=quint8_with_constraints,
177+
output_dtype=quint8_with_constraints,
178+
weight_dtype=torch.qint8,
179+
bias_dtype=torch.float)
180+
181+
def fuse_conv2d_relu(is_qat, conv, relu):
182+
"""Return a fused ConvReLU2d from individual conv and relu modules."""
183+
return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu)
184+
185+
# For quantizing Linear
186+
linear_config = BackendPatternConfig(torch.nn.Linear) \
187+
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
188+
.add_dtype_config(weighted_int8_dtype_config) \
189+
.set_root_module(torch.nn.Linear) \
190+
.set_qat_module(torch.nn.qat.Linear) \
191+
.set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear)
192+
193+
# For fusing Conv2d + ReLU into ConvReLU2d
194+
conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \
195+
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
196+
.add_dtype_config(weighted_int8_dtype_config) \
197+
.set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \
198+
.set_fuser_method(fuse_conv2d_relu)
199+
200+
# For quantizing ConvReLU2d
201+
fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \
202+
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
203+
.add_dtype_config(weighted_int8_dtype_config) \
204+
.set_root_module(torch.nn.Conv2d) \
205+
.set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \
206+
.set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d)
207+
208+
backend_config = BackendConfig("my_backend") \
209+
.set_backend_pattern_config(linear_config) \
210+
.set_backend_pattern_config(conv_relu_config) \
211+
.set_backend_pattern_config(fused_conv_relu_config)
212+
213+
.. code:: ipython3
214+
215+
# ====================
216+
# Example user model
217+
# ====================
218+
219+
class MyModel(torch.nn.Module):
220+
def __init__(self, use_bn: bool):
221+
super().__init__()
222+
self.linear = torch.nn.Linear(10, 3)
223+
self.conv = torch.nn.Conv2d(3, 3, 3)
224+
self.bn = torch.nn.BatchNorm2d(3)
225+
self.relu = torch.nn.ReLU()
226+
self.sigmoid = torch.nn.Sigmoid()
227+
self.use_bn = use_bn
228+
229+
def forward(self, x):
230+
x = self.linear(x)
231+
x = self.conv(x)
232+
if self.use_bn:
233+
x = self.bn(x)
234+
x = self.relu(x)
235+
x = self.sigmoid(x)
236+
return x
237+
238+
.. code:: ipython3
239+
240+
# =======================
241+
# Custom QConfigMapping
242+
# =======================
243+
244+
# Define a QConfig that satisfies the constraints specified in DTypeConfig
245+
# Note: Here we use a quant_max of 127, but this could be up to 255 (see `quint8_with_constraints`)
246+
activation_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127, eps=2 ** -12)
247+
qconfig = QConfig(activation=activation_observer, weight=default_weight_observer)
248+
249+
# Note: All individual items of a fused pattern, e.g. Conv2d and ReLU in
250+
# (Conv2d, ReLU), must have the same QConfig
251+
qconfig_mapping = QConfigMapping() \
252+
.set_object_type(torch.nn.Linear, qconfig) \
253+
.set_object_type(torch.nn.Conv2d, qconfig) \
254+
.set_object_type(torch.nn.BatchNorm2d, qconfig) \
255+
.set_object_type(torch.nn.ReLU, qconfig)
256+
257+
.. code:: ipython3
258+
259+
# =====================
260+
# Prepare and Convert
261+
# =====================
262+
263+
example_inputs = (torch.rand(1, 3, 10, 10, dtype=torch.float),)
264+
model = MyModel(use_bn=False)
265+
prepared = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config)
266+
prepared(*example_inputs) # calibrate
267+
converted = convert_fx(prepared, backend_config=backend_config)
268+
269+
.. parsed-literal::
270+
271+
# Both linear and conv-relu are quantized
272+
>>> print(converted)
273+
274+
GraphModule(
275+
(linear): QuantizedLinear(in_features=10, out_features=3, scale=0.012136868201196194, zero_point=67, qscheme=torch.per_tensor_affine)
276+
(conv): QuantizedConvReLU2d(3, 3, kernel_size=(3, 3), stride=(1, 1), scale=0.0029353597201406956, zero_point=0)
277+
(sigmoid): Sigmoid()
278+
)
279+
280+
def forward(self, x):
281+
linear_input_scale_0 = self.linear_input_scale_0
282+
linear_input_zero_point_0 = self.linear_input_zero_point_0
283+
quantize_per_tensor = torch.quantize_per_tensor(x, linear_input_scale_0, linear_input_zero_point_0, torch.quint8); x = linear_input_scale_0 = linear_input_zero_point_0 = None
284+
linear = self.linear(quantize_per_tensor); quantize_per_tensor = None
285+
conv = self.conv(linear); linear = None
286+
dequantize_2 = conv.dequantize(); conv = None
287+
sigmoid = self.sigmoid(dequantize_2); dequantize_2 = None
288+
return sigmoid
289+
290+
.. code:: ipython3
291+
292+
# ================================================
293+
# Prepare and Convert (only linear is quantized)
294+
# ================================================
295+
296+
# As an experiment, here we modify the model to use conv-bn-relu instead of conv-relu,
297+
# but use the same BackendConfig, which doesn't know how to fuse or quantize conv-bn-relu.
298+
example_inputs = (torch.rand(1, 3, 10, 10, dtype=torch.float),)
299+
model = MyModel(use_bn=True)
300+
prepared = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config)
301+
prepared(*example_inputs) # calibrate
302+
converted = convert_fx(prepared, backend_config=backend_config)
303+
304+
.. parsed-literal::
305+
306+
# Only linear is quantized
307+
# conv-bn-relu is neither fused nor quantized
308+
>>> print(converted)
309+
310+
GraphModule(
311+
(linear): QuantizedLinear(in_features=10, out_features=3, scale=0.015307803638279438, zero_point=95, qscheme=torch.per_tensor_affine)
312+
(conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
313+
(bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
314+
(relu): ReLU()
315+
(sigmoid): Sigmoid()
316+
)
317+
318+
def forward(self, x):
319+
linear_input_scale_0 = self.linear_input_scale_0
320+
linear_input_zero_point_0 = self.linear_input_zero_point_0
321+
quantize_per_tensor = torch.quantize_per_tensor(x, linear_input_scale_0, linear_input_zero_point_0, torch.quint8); x = linear_input_scale_0 = linear_input_zero_point_0 = None
322+
linear = self.linear(quantize_per_tensor); quantize_per_tensor = None
323+
dequantize_1 = linear.dequantize(); linear = None
324+
conv = self.conv(dequantize_1); dequantize_1 = None
325+
bn = self.bn(conv); conv = None
326+
relu = self.relu(bn); bn = None
327+
sigmoid = self.sigmoid(relu); relu = None
328+
return sigmoid
329+
330+
.. code:: ipython3
331+
332+
# ============================================
333+
# Prepare and Convert (nothing is quantized)
334+
# ============================================
335+
336+
# As an experiment, here we use the default QConfigMapping that doesn't satisfy the
337+
# dtype restrictions specified in the backend. As a result, nothing is quantized.
338+
example_inputs = (torch.rand(1, 3, 10, 10, dtype=torch.float),)
339+
model = MyModel(use_bn=True)
340+
prepared = prepare_fx(model, get_default_qconfig_mapping(), example_inputs, backend_config=backend_config)
341+
prepared(*example_inputs) # calibrate
342+
converted = convert_fx(prepared, backend_config=backend_config)
343+
344+
.. parsed-literal::
345+
346+
# Nothing is quantized
347+
>>> print(converted)
348+
349+
GraphModule(
350+
(linear): Linear(in_features=10, out_features=3, bias=True)
351+
(conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
352+
(bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
353+
(relu): ReLU()
354+
(sigmoid): Sigmoid()
355+
)
356+
357+
def forward(self, x):
358+
linear = self.linear(x); x = None
359+
conv = self.conv(linear); linear = None
360+
bn = self.bn(conv); conv = None
361+
relu = self.relu(bn); bn = None
362+
sigmoid = self.sigmoid(relu); relu = None
363+
return sigmoid
364+
365+
366+
Built-in BackendConfigs
367+
-----------------------
368+
369+
PyTorch quantization supports a few built-in native BackendConfigs under
370+
the ``torch.ao.quantization.backend_config`` namespace:
371+
372+
- `get_fbgemm_backend_config <https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/fbgemm.py>`__:
373+
for server target settings
374+
- `get_qnnpack_backend_config <https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/qnnpack.py>`__:
375+
for mobile and edge device target settings, also supports XNNPACK
376+
quantized ops
377+
- `get_native_backend_config <https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/native.py>`__
378+
(default): a BackendConfig that supports a union of the operator
379+
patterns supported in the FBGEMM and QNNPACK BackendConfigs
380+
381+
There are also other BackendConfigs under development (e.g. for
382+
TensorRT and x86), but these are still mostly experimental at the
383+
moment. If the user wishes to integrate a new, custom backend with
384+
PyTorch’s quantization API, they may define their own BackendConfigs
385+
using the same set of APIs used to define the natively supported
386+
ones as in the example above.
387+
388+
Further Reading
389+
---------------
390+
391+
How BackendConfig is used in FX graph mode quantization:
392+
https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/README.md
393+
394+
Motivation and implementation details behind BackendConfig:
395+
https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md
396+
397+
Early design of BackendConfig:
398+
https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md

0 commit comments

Comments
 (0)