diff --git a/prototype_source/fx_graph_mode_ptq_static.rst b/prototype_source/fx_graph_mode_ptq_static.rst index e2d03f9d298..410f5a116bd 100644 --- a/prototype_source/fx_graph_mode_ptq_static.rst +++ b/prototype_source/fx_graph_mode_ptq_static.rst @@ -311,6 +311,7 @@ The purpose for calibration is to run through some sample examples that is repre the statistics of the Tensors and we can later use this information to calculate quantization parameters. .. code:: python + def calibrate(model, data_loader): model.eval() with torch.no_grad(): @@ -320,17 +321,19 @@ the statistics of the Tensors and we can later use this information to calculate 7. Convert the Model to a Quantized Model ----------------------------------------- -``convert_fx`` takes a calibrated model and produces a quantized model. +``convert_fx`` takes a calibrated model and produces a quantized model. .. code:: python - quantized_model = convert_fx(prepared_model) + + quantized_model = convert_fx(prepared_model) print(quantized_model) - + 8. Evaluation ------------- We can now print the size and accuracy of the quantized model. .. code:: python + print("Size of model before quantization") print_size_of_model(float_model) print("Size of model after quantization") @@ -372,6 +375,7 @@ we'll first call fuse explicitly to fuse the conv and bn in the model: Note that ``fuse_fx`` only works in eval mode. .. code:: python + fused = fuse_fx(float_model) conv1_weight_after_fuse = fused.conv1[0].weight[0] @@ -383,6 +387,7 @@ Note that ``fuse_fx`` only works in eval mode. -------------------------------------------------------------------- .. code:: python + scripted_float_model_file = "resnet18_scripted.pth" print("Size of baseline model") @@ -397,6 +402,7 @@ quantized in eager mode. FX graph mode and eager mode produce very similar quant so the expectation is that the accuracy and speedup are similar as well. .. code:: python + print("Size of Fx graph mode quantized model") print_size_of_model(quantized_model) top1, top5 = evaluate(quantized_model, criterion, data_loader_test)