Skip to content

[pt2e][quant] Update some docs for pt2 export quantization #2530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 56 additions & 26 deletions prototype_source/pt2e_quant_ptq_static.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,27 @@ this:
\ /
\ /
—-------------------------------------------------------
| Dynamo Export |
| Export |
—-------------------------------------------------------
|
FX Graph in ATen XNNPACKQuantizer,
| or X86InductorQuantizer,
| or <Other Backend Quantizer>
| /
—--------------------------------------------------------
| prepare_pt2e |
| prepare_pt2e |
—--------------------------------------------------------
|
Calibrate/Train
|
—--------------------------------------------------------
| convert_pt2e |
| convert_pt2e |
—--------------------------------------------------------
|
Reference Quantized Model
|
—--------------------------------------------------------
| Lowering |
| Lowering |
—--------------------------------------------------------
|
Executorch, or Inductor, or <Other Backends>
Expand All @@ -53,6 +53,7 @@ The PyTorch 2.0 export quantization API looks like this:
.. code:: python

import torch
from torch._export import capture_pre_autograd_graph
class M(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -66,7 +67,9 @@ The PyTorch 2.0 export quantization API looks like this:
m = M().eval()

# Step 1. program capture
m = torch._dynamo.export(m, *example_inputs, aten_graph=True)
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result shoud mostly stay the same
m = capture_pre_autograd_graph(m, *example_inputs)
# we get a model with aten ops


Expand Down Expand Up @@ -186,8 +189,6 @@ and rename it to ``data/resnet18_pretrained_float.pth``.
import numpy as np

import torch
from torch.ao.quantization import get_default_qconfig, QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, fuse_fx
import torch.nn as nn
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -352,10 +353,16 @@ Here is how you can use ``torch.export`` to export the model:

.. code-block:: python

import torch._dynamo as torchdynamo
from torch._export import capture_pre_autograd_graph

example_inputs = (torch.rand(2, 3, 224, 224),)
exported_model, _ = torchdynamo.export(model_to_quantize, *example_inputs, aten_graph=True, tracing_mode="symbolic")
exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
# or capture with dynamic dimensions
# from torch._export import dynamic_dim
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs, constraints=[dynamic_dim(example_inputs[0], 0)])


``capture_pre_autograd_graph`` is a short term API, it will be updated to use the offical ``torch.export`` API when that is ready.


Import the Backend Specific Quantizer and Configure how to Quantize the Model
Expand Down Expand Up @@ -429,24 +436,47 @@ Convert the Calibrated Model to a Quantized Model
quantized_model = convert_pt2e(prepared_model)
print(quantized_model)

.. note::
the model produced here also had some improvement upon the previous
`representations <https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md>`_ in the FX graph mode quantizaiton, previously all quantized operators are represented as ``dequantize -> fp32_op -> qauntize``, in the new flow, we choose to represent some of the operators with integer computation so that it's closer to the computation happens in hardwares.
For example, here is how we plan to represent a quantized linear operator:
At this step, we currently have two representations that you can choose from, but exact representation
we offer in the long term might change based on feedback from PyTorch users.

* Q/DQ Representation (default)

Previous documentation for `representations <https://github.com/pytorch/rfcs/blob/master/RFC-0019-
Extending-PyTorch-Quantization-to-Custom-Backends.md>`_ all quantized operators are represented as ``dequantize -> fp32_op -> qauntize``.

.. code-block:: python
.. code-block:: python

def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
weight_permuted = torch.ops.aten.permute_copy.default(weight_fp32, [1, 0]);
out_fp32 = torch.ops.aten.addmm.default(bias_fp32, x_fp32, weight_permuted)
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
return out_i8

* Reference Quantized Model Representation (WIP, expected to be ready at end of August): we have special representation for selected ops (for example, quantized linear), other ops are represented as (``dq -> float32_op -> q``), and ``q/dq`` are decomposed into more primitive operators.

You can get this representation by using ``convert_pt2e(..., use_reference_representation=True)``.

.. code-block:: python

# Reference Quantized Pattern for quantized linear
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
x_int16 = x_int8.to(torch.int16)
weight_int16 = weight_int8.to(torch.int16)
acc_int32 = torch.ops.out_dtype(torch.mm, torch.int32, (x_int16 - x_zero_point), (weight_int16 - weight_zero_point))
bias_scale = x_scale * weight_scale
bias_int32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
acc_int32 = acc_int32 + bias_int32
acc_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, torch.int32, acc_int32, x_scale * weight_scale / output_scale) + output_zero_point
out_int8 = torch.ops.aten.clamp(acc_int32, qmin, qmax).to(torch.int8)
return out_int8

def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_int32, bias_scale, bias_zero_point, output_scale, output_zero_point):
x_int16 = x_int8.to(torch.int16)
weight_int16 = weight_int8.to(torch.int16)
acc_int32 = torch.ops.out_dtype(torch.mm, torch.int32, (x_int16 - x_zero_point), (weight_int16 - weight_zero_point))
acc_rescaled_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, torch.int32, acc_int32, x_scale * weight_scale / output_scale)
bias_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, bias_int32 - bias_zero_point, bias_scale / output_scale))
out_int8 = torch.ops.aten.clamp(acc_rescaled_int32 + bias_int32 + output_zero_point, qmin, qmax).to(torch.int8)
return out_int8

For more details, please see:
`Quantized Model Representation <https://docs.google.com/document/d/17h-OEtD4o_hoVuPqUFsdm5uo7psiNMY8ThN03F9ZZwg/edit>`_.
See `here <https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/pt2e/representation/rewrite.py>`_ for the most up-to-date reference representations.


Checking Model Size and Accuracy Evaluation
Expand Down Expand Up @@ -503,9 +533,9 @@ We'll show how to save and load the quantized model.
# Rerun all steps to get a quantized model
model_to_quantize = load_model(saved_model_dir + float_model_file).to("cpu")
model_to_quantize.eval()
import torch._dynamo as torchdynamo
from torch._export import capture_pre_autograd_graph

exported_model, _ = torchdynamo.export(model_to_quantize, *copy.deepcopy(example_inputs), aten_graph=True, tracing_mode="symbolic")
exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
Expand Down
39 changes: 38 additions & 1 deletion prototype_source/pt2e_quantizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ Prerequisites:
^^^^^^^^^^^^^^^^

Required:

- `Torchdynamo concepts in PyTorch <https://pytorch.org/docs/stable/dynamo/index.html>`__

- `Quantization concepts in PyTorch <https://pytorch.org/docs/master/quantization.html#quantization-api-summary>`__

- `(prototype) PyTorch 2.0 Export Post Training Static Quantization <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html>`__

Optional:

- `FX Graph Mode post training static quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`__

- `BackendConfig in PyTorch Quantization FX Graph Mode <https://pytorch.org/tutorials/prototype/backend_config_tutorial.html?highlight=backend>`__
Expand Down Expand Up @@ -141,7 +143,42 @@ parameters can be shared among some tensors explicitly. Two typical use cases ar

``SharedQuantizationSpec`` is designed for this use case to annotate tensors whose quantization
parameters are shared with other tensors. Input of ``SharedQuantizationSpec`` is an ``EdgeOrNode`` object which
can be an input edge or an output value.
can be an input edge or an output value.

.. note::

* Sharing is transitive

Some tensors might be effectively using shared quantization spec due to:

* Two nodes/edges are configured to use ``SharedQuantizationSpec``.
* There is existing sharing of some nodes.

For example, let's say we have two ``conv`` nodes ``conv1`` and ``conv2``, and both of them are fed into a ``cat``
node: ``cat([conv1_out, conv2_out], ...)``. Let's say the output of ``conv1``, ``conv2``, and the first input of ``cat`` are configured
with the same configurations of ``QuantizationSpec``. The second input of ``cat`` is configured to use ``SharedQuantizationSpec``
with the first input.

.. code-block::

conv1_out: qspec1(dtype=torch.int8, ...)
conv2_out: qspec1(dtype=torch.int8, ...)
cat_input0: qspec1(dtype=torch.int8, ...)
cat_input1: SharedQuantizationSpec((conv1, cat)) # conv1 node is the first input of cat

First of all, the output of ``conv1`` is implicitly sharing quantization parameters (and observer object)
with the first input of ``cat``, and the same is true for the output of ``conv2`` and the second input of ``cat``.
Therefore, since the user configures the two inputs of ``cat`` to share quantization parameters, by transitivity,
``conv2_out`` and ``conv1_out`` will also be sharing quantization parameters. In the observed graph, you
will see the following:

.. code-block::

conv1 -> obs -> cat
conv2 -> obs /

and both ``obs`` will be the same observer instance.


- Input edge is the connection between input node and the node consuming the input,
so it's a ``Tuple[Node, Node]``.
Expand Down