Skip to content

Commit f2b930e

Browse files
jerryzh168svekars
andauthored
pt2e quantization tutorial related updates (#3106)
* pt2e quantization tutorial related updates Summary: * prototype_source/pt2e_quant_ptq_x86_inductor.rst and prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst are redirects to the actual doc, these are there because the tutorials are renamed by the pages files are still there * pt2e_quant_ptq.rst and pt2e_quant_qat.rst Updates to export API Updates to import path of xnnapck quantizer Updates to print_model_size --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 90525b7 commit f2b930e

File tree

4 files changed

+70
-33
lines changed

4 files changed

+70
-33
lines changed

prototype_source/pt2e_quant_ptq.rst

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ The PyTorch 2 export quantization API looks like this:
5151
.. code:: python
5252
5353
import torch
54-
from torch._export import capture_pre_autograd_graph
5554
class M(torch.nn.Module):
5655
def __init__(self):
5756
super().__init__()
@@ -65,9 +64,9 @@ The PyTorch 2 export quantization API looks like this:
6564
m = M().eval()
6665
6766
# Step 1. program capture
68-
# NOTE: this API will be updated to torch.export API in the future, but the captured
69-
# result shoud mostly stay the same
70-
m = capture_pre_autograd_graph(m, *example_inputs)
67+
# This is available for pytorch 2.5+, for more details on lower pytorch versions
68+
# please check `Export the model with torch.export` section
69+
m = torch.export.export_for_training(m, example_inputs).module()
7170
# we get a model with aten ops
7271
7372
@@ -77,7 +76,7 @@ The PyTorch 2 export quantization API looks like this:
7776
convert_pt2e,
7877
)
7978
80-
from torch.ao.quantization.quantizer import (
79+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
8180
XNNPACKQuantizer,
8281
get_symmetric_quantization_config,
8382
)
@@ -280,10 +279,7 @@ and rename it to ``data/resnet18_pretrained_float.pth``.
280279
return model
281280
282281
def print_size_of_model(model):
283-
if isinstance(model, torch.jit.RecursiveScriptModule):
284-
torch.jit.save(model, "temp.p")
285-
else:
286-
torch.jit.save(torch.jit.script(model), "temp.p")
282+
torch.save(model.state_dict(), "temp.p")
287283
print("Size (MB):", os.path.getsize("temp.p")/1e6)
288284
os.remove("temp.p")
289285
@@ -351,18 +347,28 @@ Here is how you can use ``torch.export`` to export the model:
351347

352348
.. code-block:: python
353349
354-
from torch._export import capture_pre_autograd_graph
355-
356350
example_inputs = (torch.rand(2, 3, 224, 224),)
357-
exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
351+
# for pytorch 2.5+
352+
exported_model = torch.export.export_for_training(model_to_quantize, example_inputs).module()
353+
354+
# for pytorch 2.4 and before
355+
# from torch._export import capture_pre_autograd_graph
356+
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
357+
358358
# or capture with dynamic dimensions
359+
# for pytorch 2.5+
360+
dynamic_shapes = tuple(
361+
{0: torch.export.Dim("dim")} if i == 0 else None
362+
for i in range(len(example_inputs))
363+
)
364+
exported_model = torch.export.export_for_training(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()
365+
366+
# for pytorch 2.4 and before
367+
# dynamic_shape API may vary as well
359368
# from torch._export import dynamic_dim
360369
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs, constraints=[dynamic_dim(example_inputs[0], 0)])
361370
362371
363-
``capture_pre_autograd_graph`` is a short term API, it will be updated to use the offical ``torch.export`` API when that is ready.
364-
365-
366372
Import the Backend Specific Quantizer and Configure how to Quantize the Model
367373
-----------------------------------------------------------------------------
368374

@@ -454,7 +460,7 @@ we offer in the long term might change based on feedback from PyTorch users.
454460
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
455461
return out_i8
456462
457-
* Reference Quantized Model Representation (available in the nightly build)
463+
* Reference Quantized Model Representation
458464

459465
We will have a special representation for selected ops, for example, quantized linear. Other ops are represented as ``dq -> float32_op -> q`` and ``q/dq`` are decomposed into more primitive operators.
460466
You can get this representation by using ``convert_pt2e(..., use_reference_representation=True)``.
@@ -485,8 +491,6 @@ Now we can compare the size and model accuracy with baseline model.
485491
.. code-block:: python
486492
487493
# Baseline model size and accuracy
488-
scripted_float_model_file = "resnet18_scripted.pth"
489-
490494
print("Size of baseline model")
491495
print_size_of_model(float_model)
492496
@@ -495,6 +499,8 @@ Now we can compare the size and model accuracy with baseline model.
495499
496500
# Quantized model size and accuracy
497501
print("Size of model after quantization")
502+
# export again to remove unused weights
503+
quantized_model = torch.export.export_for_training(quantized_model, example_inputs).module()
498504
print_size_of_model(quantized_model)
499505
500506
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Quantization in PyTorch 2.0 Export Tutorial
2+
===========================================
3+
4+
This tutorial has been moved.
5+
6+
Redirecting in 3 seconds...
7+
8+
.. raw:: html
9+
10+
<meta http-equiv="Refresh" content="3; url='https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html'" />

prototype_source/pt2e_quant_qat.rst

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ to the post training quantization (PTQ) flow for the most part:
1818
prepare_qat_pt2e,
1919
convert_pt2e,
2020
)
21-
from torch.ao.quantization.quantizer import (
21+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
2222
XNNPACKQuantizer,
2323
get_symmetric_quantization_config,
2424
)
@@ -36,9 +36,9 @@ to the post training quantization (PTQ) flow for the most part:
3636
m = M()
3737
3838
# Step 1. program capture
39-
# NOTE: this API will be updated to torch.export API in the future, but the captured
40-
# result shoud mostly stay the same
41-
m = capture_pre_autograd_graph(m, *example_inputs)
39+
# This is available for pytorch 2.5+, for more details on lower pytorch versions
40+
# please check `Export the model with torch.export` section
41+
m = torch.export.export_for_training(m, example_inputs).module()
4242
# we get a model with aten ops
4343
4444
# Step 2. quantization-aware training
@@ -272,24 +272,35 @@ Here is how you can use ``torch.export`` to export the model:
272272
from torch._export import capture_pre_autograd_graph
273273
274274
example_inputs = (torch.rand(2, 3, 224, 224),)
275-
exported_model = capture_pre_autograd_graph(float_model, example_inputs)
275+
# for pytorch 2.5+
276+
exported_model = torch.export.export_for_training(float_model, example_inputs).module()
277+
# for pytorch 2.4 and before
278+
# from torch._export import capture_pre_autograd_graph
279+
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
276280
277281
278282
.. code:: python
279283
280284
# or, to capture with dynamic dimensions:
281-
from torch._export import dynamic_dim
282285
283-
example_inputs = (torch.rand(2, 3, 224, 224),)
284-
exported_model = capture_pre_autograd_graph(
285-
float_model,
286-
example_inputs,
287-
constraints=[dynamic_dim(example_inputs[0], 0)],
286+
# for pytorch 2.5+
287+
dynamic_shapes = tuple(
288+
{0: torch.export.Dim("dim")} if i == 0 else None
289+
for i in range(len(example_inputs))
288290
)
289-
.. note::
290-
291-
``capture_pre_autograd_graph`` is a short term API, it will be updated to use the offical ``torch.export`` API when that is ready.
292-
291+
exported_model = torch.export.export_for_training(float_model, example_inputs, dynamic_shapes=dynamic_shapes).module()
292+
293+
# for pytorch 2.4 and before
294+
# dynamic_shape API may vary as well
295+
# from torch._export import dynamic_dim
296+
297+
# example_inputs = (torch.rand(2, 3, 224, 224),)
298+
# exported_model = capture_pre_autograd_graph(
299+
# float_model,
300+
# example_inputs,
301+
# constraints=[dynamic_dim(example_inputs[0], 0)],
302+
# )
303+
293304
294305
Import the Backend Specific Quantizer and Configure how to Quantize the Model
295306
-----------------------------------------------------------------------------
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Quantization in PyTorch 2.0 Export Tutorial
2+
===========================================
3+
4+
This tutorial has been moved.
5+
6+
Redirecting in 3 seconds...
7+
8+
.. raw:: html
9+
10+
<meta http-equiv="Refresh" content="3; url='https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html'" />

0 commit comments

Comments
 (0)