Skip to content

Commit f9ee24b

Browse files
add more explain of limitations
1 parent 4b86060 commit f9ee24b

File tree

1 file changed

+35
-15
lines changed

1 file changed

+35
-15
lines changed

prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
(prototype) Quantization in PyTorch 2.0 Export Tutorial (Work in Progress)
1+
(Work in Progress) Quantization in PyTorch 2.0 Export Tutorial
22
==============================================================
33

4-
**Author**: `Leslie Fang <https://github.com/leslie-fang-intel>`_, `Weiwen Xia <https://github.com/Xia-Weiwen>`__, `Jiong Gong <https://github.com/jgong5>`__
4+
**Author**: `Leslie Fang <https://github.com/leslie-fang-intel>`_, `Weiwen Xia <https://github.com/Xia-Weiwen>`__, `Jiong Gong <https://github.com/jgong5>`__, `Jerry Zhang <https://github.com/jerryzh168>`__
55

66
Today we have `FX Graph Mode
77
Quantization <https://pytorch.org/docs/stable/quantization.html#prototype-fx-graph-mode-quantization>`__
@@ -20,12 +20,12 @@ Prerequisites:
2020
- `Understanding of the quantization concepts in PyTorch <https://pytorch.org/docs/master/quantization.html#quantization-api-summary>`__
2121
- `Understanding of FX Graph Mode post training static quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`__
2222
- `Understanding of BackendConfig in PyTorch Quantization FX Graph Mode <https://pytorch.org/tutorials/prototype/backend_config_tutorial.html?highlight=backend>`__
23-
- `Understanding of QConfigMapping in PyTorch Quantization FX Graph Mode <https://pytorch.org/tutorials/prototype/backend_config_tutorial.html#set-up-qconfigmapping-that-satisfies-the-backend-constraints>`__
23+
- `Understanding of QConfig and QConfigMapping in PyTorch Quantization FX Graph Mode <https://pytorch.org/tutorials/prototype/backend_config_tutorial.html#set-up-qconfigmapping-that-satisfies-the-backend-constraints>`__
2424

2525
Previously in ``FX Graph Mode Quantization`` we were using ``QConfigMapping`` for users to specify how the model to be quantized
2626
and ``BackendConfig`` to specify the supported ways of quantization in their backend.
2727
This API covers most use cases relatively well, but the main problem is that this API is not fully extensible
28-
with two main limitations:
28+
without involvement of the quantization team:
2929

3030
- Limitation around expressing quantization intentions for complicated operator patterns such as in the discussion of
3131
`issue-96288 <https://github.com/pytorch/pytorch/issues/96288>`__ to support ``conv add`` fusion with oneDNN library.
@@ -34,6 +34,15 @@ with two main limitations:
3434
- Limitation around supporting user's advanced quantization intention to quantize their model. For example, if backend
3535
developer only wants to quantize inputs and outputs when the ``linear`` has a third input, it requires co-work from quantization
3636
team and backend developer.
37+
- Currently we use ``QConfigMapping`` and ``BackendConfig`` as separate object. ``QConfigMapping`` describes user's
38+
intention of how they want their model to be quantized. ``BackendConfig`` describes what kind of quantization a backend support.
39+
Currently ``BackendConfig`` is backend specific, but ``QConfigMapping`` is not. And user can provide a ``QConfigMapping``
40+
that is incompatible with a specific BackendConfig. This is not a great UX. Ideally we can structure this better
41+
by making both configuration (``QConfigMapping``) and quantization capability (``BackendConfig``) backend
42+
specific, so there will be less confusion about incompatibilities.
43+
- Currently in ``QConfig`` we are exposing observer/fake_quant classes as an object for user to configure quantization.
44+
This increases the things that user may need to care about, e.g. not only the ``dtype`` but also how the observation should
45+
happen. These could potentially be hidden from user so that the user interface is simpler.
3746

3847
To address these scalability issues,
3948
`Quantizer <https://github.com/pytorch/pytorch/blob/3e988316b5976df560c51c998303f56a234a6a1f/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L160>`__
@@ -127,24 +136,34 @@ Taking QNNPackQuantizer as an example, the overall Quantization 2.0 flow could b
127136

128137
# Step 4: Lower Reference Quantized Model into the backend
129138

130-
Inside the Quantizer, we will use the ``QuantizationAnnotation API``
131-
to convey user's intent for what quantization spec to use and how to
132-
observe certain tensor values in the prepare step. Now, we will have a step-by-step
133-
tutorial for how to use the ``QuantizationAnnotation API`` with different types of
139+
Quantizer uses annotation API to convey quantization intent for different operators/patterns.
140+
Annotation API uses ``QuantizationSpec`` (
141+
`definition is here <https://github.com/pytorch/pytorch/blob/1ca2e993af6fa6934fca35da6970308ce227ddc7/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L38>`__
142+
) to convey intent of how a tensor will be quantized,
143+
e.g. dtype, bitwidth, min, max values, symmetric vs. asymmetric etc.
144+
Furthermore, annotation API also allows quantizer to specify how a
145+
tensor value should be observed, e.g. ``MinMaxObserver``, or ``HistogramObserver``
146+
, or some customized observer.
147+
148+
``QuantizationSpec`` is used to annotate nodes' output tensor or input tensors. Annotating
149+
input tensors is equivalent of annotating edge of the graph, while annotating output tensor is
150+
equivalent of annotating node. Thus annotation API requires quantizer to annotate nodes (output tensor)
151+
or edges (input tensors) of the graph.
152+
153+
Now, we will have a step-by-step tutorial for how to use the annotation API with different types of
134154
``QuantizationSpec``.
135155

136156
1. Annotate common operator patterns
137157
--------------------------------------------------------
138158

139159
In order to use the quantized pattern/operators, e.g. ``quantized add``,
140-
backend developers will have intent to quantize (as expressed by
141-
`QuantizationSpec <https://github.com/pytorch/pytorch/blob/1ca2e993af6fa6934fca35da6970308ce227ddc7/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L38>`__
142-
) inputs, output of the pattern. Following is an example flow (take ``add`` operator as example)
160+
backend developers will have intent to quantize (as expressed by ``QuantizationSpec``)
161+
inputs, output of the pattern. Following is an example flow (take ``add`` operator as example)
143162
of how this intent is conveyed in the quantization workflow with annotation API.
144163

145164
- Step 1: Identify the original floating point pattern in the FX graph. There are
146-
several ways to identify this pattern: User may use a pattern matcher (e.g. SubgraphMatcher)
147-
to match the operator pattern; User may go through the nodes from start to the end and compare
165+
several ways to identify this pattern: Quantizer may use a pattern matcher (e.g. SubgraphMatcher)
166+
to match the operator pattern; Quantizer may go through the nodes from start to the end and compare
148167
the node's target type to match the operator pattern. In this example, we can use the
149168
`get_source_partitions <https://github.com/pytorch/pytorch/blob/07104ca99c9d297975270fb58fda786e60b49b38/torch/fx/passes/utils/source_matcher_utils.py#L51>`__
150169
to match this pattern. The original floating point ``add`` pattern only contain a single ``add`` node.
@@ -177,8 +196,9 @@ of how this intent is conveyed in the quantization workflow with annotation API.
177196
- Step 3: Annotate the inputs and output of the pattern with
178197
`QuantizationAnnotation <https://github.com/pytorch/pytorch/blob/07104ca99c9d297975270fb58fda786e60b49b38/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L144>`__.
179198
``QuantizationAnnotation`` is a ``dataclass`` with several fields as: ``input_qspec_map`` field is of class ``Dict``
180-
to map each input ``Node`` to a ``QuantizationSpec``; ``output_qspec`` field expresses the ``QuantizationSpec`` used for
181-
output node; ``_annotated`` field indicates if this node has already been annotated by quantizer.
199+
to map each input ``Node`` to a ``QuantizationSpec``. It means to annotate each input edge with this ``QuantizationSpec``;
200+
``output_qspec`` field expresses the ``QuantizationSpec`` used to
201+
annotate the output node; ``_annotated`` field indicates if this node has already been annotated by quantizer.
182202
In this example, we will create the ``QuantizationAnnotation`` object with the ``QuantizationSpec`` objects
183203
created in above step 2 for two inputs and one output of ``add`` node.
184204

0 commit comments

Comments
 (0)