diff --git a/prototype_source/pt2e_quant_ptq.rst b/prototype_source/pt2e_quant_ptq.rst index 0fe713f8abe..4873bce7d55 100644 --- a/prototype_source/pt2e_quant_ptq.rst +++ b/prototype_source/pt2e_quant_ptq.rst @@ -51,7 +51,6 @@ The PyTorch 2 export quantization API looks like this: .. code:: python import torch - from torch._export import capture_pre_autograd_graph class M(torch.nn.Module): def __init__(self): super().__init__() @@ -65,9 +64,9 @@ The PyTorch 2 export quantization API looks like this: m = M().eval() # Step 1. program capture - # NOTE: this API will be updated to torch.export API in the future, but the captured - # result shoud mostly stay the same - m = capture_pre_autograd_graph(m, *example_inputs) + # This is available for pytorch 2.5+, for more details on lower pytorch versions + # please check `Export the model with torch.export` section + m = torch.export.export_for_training(m, example_inputs).module() # we get a model with aten ops @@ -77,7 +76,7 @@ The PyTorch 2 export quantization API looks like this: convert_pt2e, ) - from torch.ao.quantization.quantizer import ( + from torch.ao.quantization.quantizer.xnnpack_quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, ) @@ -280,10 +279,7 @@ and rename it to ``data/resnet18_pretrained_float.pth``. return model def print_size_of_model(model): - if isinstance(model, torch.jit.RecursiveScriptModule): - torch.jit.save(model, "temp.p") - else: - torch.jit.save(torch.jit.script(model), "temp.p") + torch.save(model.state_dict(), "temp.p") print("Size (MB):", os.path.getsize("temp.p")/1e6) os.remove("temp.p") @@ -351,18 +347,28 @@ Here is how you can use ``torch.export`` to export the model: .. code-block:: python - from torch._export import capture_pre_autograd_graph - example_inputs = (torch.rand(2, 3, 224, 224),) - exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs) + # for pytorch 2.5+ + exported_model = torch.export.export_for_training(model_to_quantize, example_inputs).module() + + # for pytorch 2.4 and before + # from torch._export import capture_pre_autograd_graph + # exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs) + # or capture with dynamic dimensions + # for pytorch 2.5+ + dynamic_shapes = tuple( + {0: torch.export.Dim("dim")} if i == 0 else None + for i in range(len(example_inputs)) + ) + exported_model = torch.export.export_for_training(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module() + + # for pytorch 2.4 and before + # dynamic_shape API may vary as well # from torch._export import dynamic_dim # exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs, constraints=[dynamic_dim(example_inputs[0], 0)]) -``capture_pre_autograd_graph`` is a short term API, it will be updated to use the offical ``torch.export`` API when that is ready. - - Import the Backend Specific Quantizer and Configure how to Quantize the Model ----------------------------------------------------------------------------- @@ -454,7 +460,7 @@ we offer in the long term might change based on feedback from PyTorch users. out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8) return out_i8 -* Reference Quantized Model Representation (available in the nightly build) +* Reference Quantized Model Representation We will have a special representation for selected ops, for example, quantized linear. Other ops are represented as ``dq -> float32_op -> q`` and ``q/dq`` are decomposed into more primitive operators. You can get this representation by using ``convert_pt2e(..., use_reference_representation=True)``. @@ -485,8 +491,6 @@ Now we can compare the size and model accuracy with baseline model. .. code-block:: python # Baseline model size and accuracy - scripted_float_model_file = "resnet18_scripted.pth" - print("Size of baseline model") print_size_of_model(float_model) @@ -495,6 +499,8 @@ Now we can compare the size and model accuracy with baseline model. # Quantized model size and accuracy print("Size of model after quantization") + # export again to remove unused weights + quantized_model = torch.export.export_for_training(quantized_model, example_inputs).module() print_size_of_model(quantized_model) top1, top5 = evaluate(quantized_model, criterion, data_loader_test) diff --git a/prototype_source/pt2e_quant_ptq_x86_inductor.rst b/prototype_source/pt2e_quant_ptq_x86_inductor.rst new file mode 100644 index 00000000000..39214a51749 --- /dev/null +++ b/prototype_source/pt2e_quant_ptq_x86_inductor.rst @@ -0,0 +1,10 @@ +Quantization in PyTorch 2.0 Export Tutorial +=========================================== + +This tutorial has been moved. + +Redirecting in 3 seconds... + +.. raw:: html + + diff --git a/prototype_source/pt2e_quant_qat.rst b/prototype_source/pt2e_quant_qat.rst index d716af5fec8..8f11b0730c5 100644 --- a/prototype_source/pt2e_quant_qat.rst +++ b/prototype_source/pt2e_quant_qat.rst @@ -18,7 +18,7 @@ to the post training quantization (PTQ) flow for the most part: prepare_qat_pt2e, convert_pt2e, ) - from torch.ao.quantization.quantizer import ( + from torch.ao.quantization.quantizer.xnnpack_quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, ) @@ -36,9 +36,9 @@ to the post training quantization (PTQ) flow for the most part: m = M() # Step 1. program capture - # NOTE: this API will be updated to torch.export API in the future, but the captured - # result shoud mostly stay the same - m = capture_pre_autograd_graph(m, *example_inputs) + # This is available for pytorch 2.5+, for more details on lower pytorch versions + # please check `Export the model with torch.export` section + m = torch.export.export_for_training(m, example_inputs).module() # we get a model with aten ops # Step 2. quantization-aware training @@ -272,24 +272,35 @@ Here is how you can use ``torch.export`` to export the model: from torch._export import capture_pre_autograd_graph example_inputs = (torch.rand(2, 3, 224, 224),) - exported_model = capture_pre_autograd_graph(float_model, example_inputs) + # for pytorch 2.5+ + exported_model = torch.export.export_for_training(float_model, example_inputs).module() + # for pytorch 2.4 and before + # from torch._export import capture_pre_autograd_graph + # exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs) .. code:: python # or, to capture with dynamic dimensions: - from torch._export import dynamic_dim - example_inputs = (torch.rand(2, 3, 224, 224),) - exported_model = capture_pre_autograd_graph( - float_model, - example_inputs, - constraints=[dynamic_dim(example_inputs[0], 0)], + # for pytorch 2.5+ + dynamic_shapes = tuple( + {0: torch.export.Dim("dim")} if i == 0 else None + for i in range(len(example_inputs)) ) -.. note:: - - ``capture_pre_autograd_graph`` is a short term API, it will be updated to use the offical ``torch.export`` API when that is ready. - + exported_model = torch.export.export_for_training(float_model, example_inputs, dynamic_shapes=dynamic_shapes).module() + + # for pytorch 2.4 and before + # dynamic_shape API may vary as well + # from torch._export import dynamic_dim + + # example_inputs = (torch.rand(2, 3, 224, 224),) + # exported_model = capture_pre_autograd_graph( + # float_model, + # example_inputs, + # constraints=[dynamic_dim(example_inputs[0], 0)], + # ) + Import the Backend Specific Quantizer and Configure how to Quantize the Model ----------------------------------------------------------------------------- diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst new file mode 100644 index 00000000000..43fd190e995 --- /dev/null +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -0,0 +1,10 @@ +Quantization in PyTorch 2.0 Export Tutorial +=========================================== + +This tutorial has been moved. + +Redirecting in 3 seconds... + +.. raw:: html + +