Skip to content

pt2e quantization tutorial related updates #3106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions prototype_source/pt2e_quant_ptq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ The PyTorch 2 export quantization API looks like this:
.. code:: python

import torch
from torch._export import capture_pre_autograd_graph
class M(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -65,9 +64,9 @@ The PyTorch 2 export quantization API looks like this:
m = M().eval()

# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result shoud mostly stay the same
m = capture_pre_autograd_graph(m, *example_inputs)
# This is available for pytorch 2.5+, for more details on lower pytorch versions
# please check `Export the model with torch.export` section
m = torch.export.export_for_training(m, example_inputs).module()
# we get a model with aten ops


Expand All @@ -77,7 +76,7 @@ The PyTorch 2 export quantization API looks like this:
convert_pt2e,
)

from torch.ao.quantization.quantizer import (
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
Expand Down Expand Up @@ -280,10 +279,7 @@ and rename it to ``data/resnet18_pretrained_float.pth``.
return model

def print_size_of_model(model):
if isinstance(model, torch.jit.RecursiveScriptModule):
torch.jit.save(model, "temp.p")
else:
torch.jit.save(torch.jit.script(model), "temp.p")
torch.save(model.state_dict(), "temp.p")
print("Size (MB):", os.path.getsize("temp.p")/1e6)
os.remove("temp.p")

Expand Down Expand Up @@ -351,18 +347,28 @@ Here is how you can use ``torch.export`` to export the model:

.. code-block:: python

from torch._export import capture_pre_autograd_graph

example_inputs = (torch.rand(2, 3, 224, 224),)
exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
# for pytorch 2.5+
exported_model = torch.export.export_for_training(model_to_quantize, example_inputs).module()

# for pytorch 2.4 and before
# from torch._export import capture_pre_autograd_graph
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)

# or capture with dynamic dimensions
# for pytorch 2.5+
dynamic_shapes = tuple(
{0: torch.export.Dim("dim")} if i == 0 else None
for i in range(len(example_inputs))
)
exported_model = torch.export.export_for_training(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()

# for pytorch 2.4 and before
# dynamic_shape API may vary as well
# from torch._export import dynamic_dim
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs, constraints=[dynamic_dim(example_inputs[0], 0)])


``capture_pre_autograd_graph`` is a short term API, it will be updated to use the offical ``torch.export`` API when that is ready.


Import the Backend Specific Quantizer and Configure how to Quantize the Model
-----------------------------------------------------------------------------

Expand Down Expand Up @@ -454,7 +460,7 @@ we offer in the long term might change based on feedback from PyTorch users.
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
return out_i8

* Reference Quantized Model Representation (available in the nightly build)
* Reference Quantized Model Representation

We will have a special representation for selected ops, for example, quantized linear. Other ops are represented as ``dq -> float32_op -> q`` and ``q/dq`` are decomposed into more primitive operators.
You can get this representation by using ``convert_pt2e(..., use_reference_representation=True)``.
Expand Down Expand Up @@ -485,8 +491,6 @@ Now we can compare the size and model accuracy with baseline model.
.. code-block:: python

# Baseline model size and accuracy
scripted_float_model_file = "resnet18_scripted.pth"

print("Size of baseline model")
print_size_of_model(float_model)

Expand All @@ -495,6 +499,8 @@ Now we can compare the size and model accuracy with baseline model.

# Quantized model size and accuracy
print("Size of model after quantization")
# export again to remove unused weights
quantized_model = torch.export.export_for_training(quantized_model, example_inputs).module()
print_size_of_model(quantized_model)

top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
Expand Down
10 changes: 10 additions & 0 deletions prototype_source/pt2e_quant_ptq_x86_inductor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Quantization in PyTorch 2.0 Export Tutorial
===========================================

This tutorial has been moved.

Redirecting in 3 seconds...

.. raw:: html

<meta http-equiv="Refresh" content="3; url='https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html'" />
41 changes: 26 additions & 15 deletions prototype_source/pt2e_quant_qat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ to the post training quantization (PTQ) flow for the most part:
prepare_qat_pt2e,
convert_pt2e,
)
from torch.ao.quantization.quantizer import (
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
Expand All @@ -36,9 +36,9 @@ to the post training quantization (PTQ) flow for the most part:
m = M()

# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result shoud mostly stay the same
m = capture_pre_autograd_graph(m, *example_inputs)
# This is available for pytorch 2.5+, for more details on lower pytorch versions
# please check `Export the model with torch.export` section
m = torch.export.export_for_training(m, example_inputs).module()
# we get a model with aten ops

# Step 2. quantization-aware training
Expand Down Expand Up @@ -272,24 +272,35 @@ Here is how you can use ``torch.export`` to export the model:
from torch._export import capture_pre_autograd_graph

example_inputs = (torch.rand(2, 3, 224, 224),)
exported_model = capture_pre_autograd_graph(float_model, example_inputs)
# for pytorch 2.5+
exported_model = torch.export.export_for_training(float_model, example_inputs).module()
# for pytorch 2.4 and before
# from torch._export import capture_pre_autograd_graph
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)


.. code:: python

# or, to capture with dynamic dimensions:
from torch._export import dynamic_dim

example_inputs = (torch.rand(2, 3, 224, 224),)
exported_model = capture_pre_autograd_graph(
float_model,
example_inputs,
constraints=[dynamic_dim(example_inputs[0], 0)],
# for pytorch 2.5+
dynamic_shapes = tuple(
{0: torch.export.Dim("dim")} if i == 0 else None
for i in range(len(example_inputs))
)
.. note::

``capture_pre_autograd_graph`` is a short term API, it will be updated to use the offical ``torch.export`` API when that is ready.

exported_model = torch.export.export_for_training(float_model, example_inputs, dynamic_shapes=dynamic_shapes).module()

# for pytorch 2.4 and before
# dynamic_shape API may vary as well
# from torch._export import dynamic_dim

# example_inputs = (torch.rand(2, 3, 224, 224),)
# exported_model = capture_pre_autograd_graph(
# float_model,
# example_inputs,
# constraints=[dynamic_dim(example_inputs[0], 0)],
# )


Import the Backend Specific Quantizer and Configure how to Quantize the Model
-----------------------------------------------------------------------------
Expand Down
10 changes: 10 additions & 0 deletions prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Quantization in PyTorch 2.0 Export Tutorial
===========================================

This tutorial has been moved.

Redirecting in 3 seconds...

.. raw:: html

<meta http-equiv="Refresh" content="3; url='https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html'" />
Loading