Skip to content

Commit a28d763

Browse files
committed
[pt2e][quant] Update some docs for pt2 export quantization
Summary: . Test Plan: CI generated docs Reviewers: Subscribers: Tasks: Tags:
1 parent e7c86fd commit a28d763

File tree

2 files changed

+63
-16
lines changed

2 files changed

+63
-16
lines changed

prototype_source/pt2e_quant_ptq_static.rst

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -430,23 +430,42 @@ Convert the Calibrated Model to a Quantized Model
430430
print(quantized_model)
431431
432432
.. note::
433-
the model produced here also had some improvement upon the previous
434-
`representations <https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md>`_ in the FX graph mode quantizaiton, previously all quantized operators are represented as ``dequantize -> fp32_op -> qauntize``, in the new flow, we choose to represent some of the operators with integer computation so that it's closer to the computation happens in hardwares.
435-
For example, here is how we plan to represent a quantized linear operator:
433+
At this step, we currently have two representations that you can choose from, but what exact representation
434+
we offer in the long term might change based on feedbacks from users.
436435

437-
.. code-block:: python
436+
* Q/DQ Representation (default)
437+
Previous documentation for `representations <https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md>`_ all quantized operators are represented as ``dequantize -> fp32_op -> qauntize``.
438438

439-
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_int32, bias_scale, bias_zero_point, output_scale, output_zero_point):
440-
x_int16 = x_int8.to(torch.int16)
441-
weight_int16 = weight_int8.to(torch.int16)
442-
acc_int32 = torch.ops.out_dtype(torch.mm, torch.int32, (x_int16 - x_zero_point), (weight_int16 - weight_zero_point))
443-
acc_rescaled_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, torch.int32, acc_int32, x_scale * weight_scale / output_scale)
444-
bias_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, bias_int32 - bias_zero_point, bias_scale / output_scale))
445-
out_int8 = torch.ops.aten.clamp(acc_rescaled_int32 + bias_int32 + output_zero_point, qmin, qmax).to(torch.int8)
446-
return out_int8
447-
448-
For more details, please see:
449-
`Quantized Model Representation <https://docs.google.com/document/d/17h-OEtD4o_hoVuPqUFsdm5uo7psiNMY8ThN03F9ZZwg/edit>`_.
439+
.. code-block:: python
440+
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
441+
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
442+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
443+
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
444+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
445+
weight_permuted = torch.ops.aten.permute_copy.default(weight_fp32, [1, 0]);
446+
out_fp32 = torch.ops.aten.addmm.default(bias_fp32, x_fp32, weight_permuted)
447+
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
448+
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
449+
return out_i8
450+
451+
* Reference Quantized Model Representation (WIP, expected to be ready at end of August): we have special representation for selected ops (e.g. quantized linear), other ops are represented as (dq -> float32_op -> q), and q/dq are decomposed into more primitive operators.
452+
453+
You can get this representation by: convert_pt2e(..., use_reference_representation=True)
454+
455+
.. code-block:: python
456+
# Reference Quantized Pattern for quantized linear
457+
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
458+
x_int16 = x_int8.to(torch.int16)
459+
weight_int16 = weight_int8.to(torch.int16)
460+
acc_int32 = torch.ops.out_dtype(torch.mm, torch.int32, (x_int16 - x_zero_point), (weight_int16 - weight_zero_point))
461+
acc_rescaled_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, torch.int32, acc_int32, x_scale * weight_scale / output_scale)
462+
bias_scale = x_scale * weight_scale
463+
bias_int32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, bias_fp32, bias_scale / out_scale)
464+
out_int8 = torch.ops.aten.clamp(acc_rescaled_int32 + bias_int32 + output_zero_point, qmin, qmax).to(torch.int8)
465+
return out_int8
466+
467+
468+
Please see `<here https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/pt2e/representation/rewrite.py>`_ for the most up to date reference representations.
450469
451470
452471
Checking Model Size and Accuracy Evaluation

prototype_source/pt2e_quantizer.rst

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ Prerequisites:
99
^^^^^^^^^^^^^^^^
1010

1111
Required:
12+
1213
- `Torchdynamo concepts in PyTorch <https://pytorch.org/docs/stable/dynamo/index.html>`__
1314

1415
- `Quantization concepts in PyTorch <https://pytorch.org/docs/master/quantization.html#quantization-api-summary>`__
1516

1617
- `(prototype) PyTorch 2.0 Export Post Training Static Quantization <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html>`__
1718

1819
Optional:
20+
1921
- `FX Graph Mode post training static quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`__
2022

2123
- `BackendConfig in PyTorch Quantization FX Graph Mode <https://pytorch.org/tutorials/prototype/backend_config_tutorial.html?highlight=backend>`__
@@ -141,7 +143,33 @@ parameters can be shared among some tensors explicitly. Two typical use cases ar
141143

142144
``SharedQuantizationSpec`` is designed for this use case to annotate tensors whose quantization
143145
parameters are shared with other tensors. Input of ``SharedQuantizationSpec`` is an ``EdgeOrNode`` object which
144-
can be an input edge or an output value.
146+
can be an input edge or an output value.
147+
148+
.. note::
149+
* Sharing is Transitive
150+
Some Tensors might be effectively be using shared quantization spec due to (1) two nodes/edges are
151+
configured to use SharedQuantizationSpec (2) there is existing sharing of some of the nodes
152+
153+
For example, let's say we have two conv nodes conv1 and conv2, and both of them are fed into a cat
154+
node. `cat([conv1_out, conv2_out], ...)` Let's say output of conv1, conv2 and first input of cat are configured
155+
with the same configurations of QuantizationSpec, second input of cat is configured to use SharedQuantizationSpec
156+
with the first input.
157+
conv1_out: qspec1(dtype=torch.int8, ...)
158+
conv2_out: qspec1(dtype=torch.int8, ...)
159+
cat_input0: qspec1(dtype=torch.int8, ...)
160+
cat_input1: SharedQuantizationSpec((conv1, cat)) # conv1 node is the first input of cat
161+
162+
First of all, the output of conv1 are implicitly sharing quantization parameter (and observer object)
163+
with first input of cat, and same for output of conv2 and second input of cat.
164+
So since user configures the two input of cat to share quantization parameters, by transitivity,
165+
conv2_out and conv1_out will also be sharing quantization parameters. In the observed graph, you
166+
will see:
167+
```
168+
conv1 -> obs -> cat
169+
conv2 -> obs /
170+
```
171+
and both `obs` will be the same observer instance
172+
145173

146174
- Input edge is the connection between input node and the node consuming the input,
147175
so it's a ``Tuple[Node, Node]``.

0 commit comments

Comments
 (0)