@@ -41,8 +41,8 @@ without involvement of the quantization team:
41
41
by making both configuration (``QConfigMapping ``) and quantization capability (``BackendConfig ``) backend
42
42
specific. So there will be less confusion about incompatibilities.
43
43
- Currently, in ``QConfig `` we are exposing observer/fake_quant classes as an object for user to configure quantization.
44
- This increases the things that user needs to care about, e.g. not only the ``dtype `` but also how the observation should
45
- happen. These could potentially be hidden from user to make user interface simpler.
44
+ This increases the things that user needs to care about, e.g. not only the ``dtype `` but also how the
45
+ observation should happen. These could potentially be hidden from user to make user interface simpler.
46
46
47
47
To address these scalability issues,
48
48
`Quantizer <https://github.com/pytorch/pytorch/blob/3e988316b5976df560c51c998303f56a234a6a1f/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L160 >`__
@@ -136,22 +136,30 @@ Taking QNNPackQuantizer as an example, the overall Quantization 2.0 flow could b
136
136
137
137
# Step 4: Lower Reference Quantized Model into the backend
138
138
139
- Quantizer uses annotation API to convey quantization intent for different operators/patterns.
140
- Annotation API uses
139
+ `` Quantizer `` uses annotation API to convey quantization intent for different operators/patterns.
140
+ Annotation API mainly consists of
141
141
`QuantizationSpec <https://github.com/pytorch/pytorch/blob/1ca2e993af6fa6934fca35da6970308ce227ddc7/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L38 >`__
142
- to convey intent of how a tensor will be quantized,
142
+ and
143
+ `QuantizationAnnotation <https://github.com/pytorch/pytorch/blob/07104ca99c9d297975270fb58fda786e60b49b38/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L144 >`__.
144
+
145
+ ``QuantizationSpec `` is used to convey intent of how a tensor will be quantized,
143
146
e.g. dtype, bitwidth, min, max values, symmetric vs. asymmetric etc.
144
- Furthermore, annotation API also allows quantizer to specify how a
147
+ Furthermore, `` QuantizationSpec `` also allows quantizer to specify how a
145
148
tensor value should be observed, e.g. ``MinMaxObserver ``, or ``HistogramObserver ``
146
149
, or some customized observer.
147
150
148
- ``QuantizationSpec `` is used to annotate nodes' input tensors or output tensor. Annotating
149
- input tensors is equivalent of annotating edge of the graph, while annotating output tensor is
150
- equivalent of annotating node. Thus annotation API requires quantizer to annotate
151
- edges (input tensors) or nodes (output tensor) of the graph.
151
+ ``QuantizationAnnotation `` composed of ``QuantizationSpec `` objects is used to annotate input tensors
152
+ and output tensor of a ``FX Node ``. Annotating input tensors is equivalent of annotating input edges,
153
+ while annotating output tensor is equivalent of annotating node. ``QuantizationAnnotation `` is a ``dataclass ``
154
+ with several fields:
155
+
156
+ - ``input_qspec_map `` field is of class ``Dict `` to map each input tensor (as input edge) to a ``QuantizationSpec ``.
157
+ - ``output_qspec `` field expresses the ``QuantizationSpec `` used to annotate the output tensor;
158
+ - ``_annotated `` field indicates if this node has already been annotated by quantizer.
152
159
153
- Now, we will have a step-by-step tutorial for how to use the annotation API with different types of
154
- ``QuantizationSpec ``.
160
+ Thus annotation API requires quantizer to annotate edges (input tensors) or
161
+ nodes (output tensor) of the graph. Now, we will have a step-by-step tutorial for
162
+ how to use the annotation API with different types of ``QuantizationSpec ``.
155
163
156
164
1. Annotate common operator patterns
157
165
--------------------------------------------------------
@@ -193,13 +201,8 @@ of how this intent is conveyed in the quantization workflow with annotation API.
193
201
input_act_qspec = act_quantization_spec
194
202
output_act_qspec = act_quantization_spec
195
203
196
- - Step 3: Annotate the inputs and output of the pattern with
197
- `QuantizationAnnotation <https://github.com/pytorch/pytorch/blob/07104ca99c9d297975270fb58fda786e60b49b38/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L144 >`__.
198
- ``QuantizationAnnotation `` is a ``dataclass `` with several fields as: ``input_qspec_map `` field is of class ``Dict ``
199
- to map each input ``Node `` to a ``QuantizationSpec ``. It means to annotate each input edge with this ``QuantizationSpec ``;
200
- ``output_qspec `` field expresses the ``QuantizationSpec `` used to
201
- annotate the output node; ``_annotated `` field indicates if this node has already been annotated by quantizer.
202
- In this example, we will create the ``QuantizationAnnotation `` object with the ``QuantizationSpec `` objects
204
+ - Step 3: Annotate the inputs and output of the pattern with ``QuantizationAnnotation ``.
205
+ In this example, we will create the ``QuantizationAnnotation `` object with the ``QuantizationSpec ``
203
206
created in above step 2 for two inputs and one output of the ``add `` node.
204
207
205
208
::
0 commit comments