Skip to content

Commit acee6c7

Browse files
committed
[quant][pt2e] Update save/load model for pt2 export quant tutorial
Summary: att Test Plan: visual inspection of generated pages Reviewers: Subscribers: Tasks: Tags:
1 parent 32d8341 commit acee6c7

File tree

1 file changed

+35
-28
lines changed

1 file changed

+35
-28
lines changed

prototype_source/pt2e_quant_ptq_static.rst

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,10 @@ Now we can compare the size and model accuracy with baseline model.
508508
target device, it's just a representation of quantized computation in ATen
509509
operators.
510510

511+
.. note::
512+
The weights are still in fp32 right now, we may do constant propagation for quantize op to
513+
get integer weights in the future
514+
511515
If you want to get better accuracy or performance, try configuring
512516
``quantizer`` in different ways, and each ``quantizer`` will have its own way
513517
of configuration, so please consult the documentation for the
@@ -520,45 +524,48 @@ Save and Load Quantized Model
520524
We'll show how to save and load the quantized model.
521525

522526
.. code-block:: python
523-
524-
# 1. Save state_dict
525-
pt2e_quantized_model_file_path = saved_model_dir + "resnet18_pt2e_quantized.pth"
526-
torch.save(quantized_model.state_dict(), pt2e_quantized_model_file_path)
527-
528-
# Get a reference output
527+
# 0. Store reference output for example inputs and check evaluation accuracy
529528
example_inputs = (next(iter(data_loader))[0],)
530529
ref = quantized_model(*example_inputs)
530+
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
531+
print("[before serialization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
531532
532-
# 2. Initialize the quantized model and Load state_dict
533-
# Rerun all steps to get a quantized model
534-
model_to_quantize = load_model(saved_model_dir + float_model_file).to("cpu")
535-
model_to_quantize.eval()
536-
from torch._export import capture_pre_autograd_graph
537-
538-
exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
539-
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
540-
XNNPACKQuantizer,
541-
get_symmetric_quantization_config,
542-
)
533+
# 1. Export the model and Save ExportedProgram
534+
pt2e_quantized_model_file_path = saved_model_dir + "resnet18_pt2e_quantized.pth"
535+
# capture the model to get an ExportedProgram
536+
quantized_ep = torch.export.export(quantized_model, example_inputs)
537+
# use torch.export.save to save an ExportedProgram
538+
torch.export.save(quantized_ep, pt2e_quantized_model_file_path)
543539
544-
quantizer = XNNPACKQuantizer()
545-
quantizer.set_global(get_symmetric_quantization_config())
546-
prepared_model = prepare_pt2e(exported_model, quantizer)
547-
prepared_model(*example_inputs)
548-
loaded_quantized_model = convert_pt2e(prepared_model)
549540
550-
# load the state_dict from saved file to intialized model
551-
loaded_quantized_model.load_state_dict(torch.load(pt2e_quantized_model_file_path))
541+
# 2. Load the saved ExportedProgram
542+
loaded_quantized_ep = torch.export.load(pt2e_quantized_model_file_path)
543+
loaded_quantized_model = loaded_quantized_ep.module()
552544
553-
# Sanity check with sample data
545+
# 3. Check results for example inputs and checke evaluation accuracy again
554546
res = loaded_quantized_model(*example_inputs)
555-
556-
# 3. Evaluate the loaded quantized model
547+
print("diff:", ref - res)
548+
557549
top1, top5 = evaluate(loaded_quantized_model, criterion, data_loader_test)
558550
print("[after serialization/deserialization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
559551
552+
553+
Output:
554+
.. code-block:: python
555+
[before serialization] Evaluation accuracy on test dataset: 79.82, 94.55
556+
diff: tensor([[0., 0., 0., ..., 0., 0., 0.],
557+
[0., 0., 0., ..., 0., 0., 0.],
558+
[0., 0., 0., ..., 0., 0., 0.],
559+
...,
560+
[0., 0., 0., ..., 0., 0., 0.],
561+
[0., 0., 0., ..., 0., 0., 0.],
562+
[0., 0., 0., ..., 0., 0., 0.]])
563+
564+
[after serialization/deserialization] Evaluation accuracy on test dataset: 79.82, 94.55
565+
566+
560567
Debugging the Quantized Model
561-
----------------------------
568+
------------------------------
562569

563570
You can use `Numeric Suite <https://pytorch.org/docs/stable/quantization-accuracy-debugging.html#numerical-debugging-tooling-prototype>`_
564571
that can help with debugging in eager mode and FX graph mode. The new version of

0 commit comments

Comments
 (0)