@@ -508,6 +508,10 @@ Now we can compare the size and model accuracy with baseline model.
508
508
target device, it's just a representation of quantized computation in ATen
509
509
operators.
510
510
511
+ .. note ::
512
+ The weights are still in fp32 right now, we may do constant propagation for quantize op to
513
+ get integer weights in the future
514
+
511
515
If you want to get better accuracy or performance, try configuring
512
516
``quantizer `` in different ways, and each ``quantizer `` will have its own way
513
517
of configuration, so please consult the documentation for the
@@ -520,45 +524,48 @@ Save and Load Quantized Model
520
524
We'll show how to save and load the quantized model.
521
525
522
526
.. code-block :: python
523
-
524
- # 1. Save state_dict
525
- pt2e_quantized_model_file_path = saved_model_dir + " resnet18_pt2e_quantized.pth"
526
- torch.save(quantized_model.state_dict(), pt2e_quantized_model_file_path)
527
-
528
- # Get a reference output
527
+ # 0. Store reference output for example inputs and check evaluation accuracy
529
528
example_inputs = (next (iter (data_loader))[0 ],)
530
529
ref = quantized_model(* example_inputs)
530
+ top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
531
+ print (" [before serialization] Evaluation accuracy on test dataset: %2.2f , %2.2f " % (top1.avg, top5.avg))
531
532
532
- # 2. Initialize the quantized model and Load state_dict
533
- # Rerun all steps to get a quantized model
534
- model_to_quantize = load_model(saved_model_dir + float_model_file).to(" cpu" )
535
- model_to_quantize.eval()
536
- from torch._export import capture_pre_autograd_graph
537
-
538
- exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
539
- from torch.ao.quantization.quantizer.xnnpack_quantizer import (
540
- XNNPACKQuantizer,
541
- get_symmetric_quantization_config,
542
- )
533
+ # 1. Export the model and Save ExportedProgram
534
+ pt2e_quantized_model_file_path = saved_model_dir + " resnet18_pt2e_quantized.pth"
535
+ # capture the model to get an ExportedProgram
536
+ quantized_ep = torch.export.export(quantized_model, example_inputs)
537
+ # use torch.export.save to save an ExportedProgram
538
+ torch.export.save(quantized_ep, pt2e_quantized_model_file_path)
543
539
544
- quantizer = XNNPACKQuantizer()
545
- quantizer.set_global(get_symmetric_quantization_config())
546
- prepared_model = prepare_pt2e(exported_model, quantizer)
547
- prepared_model(* example_inputs)
548
- loaded_quantized_model = convert_pt2e(prepared_model)
549
540
550
- # load the state_dict from saved file to intialized model
551
- loaded_quantized_model.load_state_dict(torch.load(pt2e_quantized_model_file_path))
541
+ # 2. Load the saved ExportedProgram
542
+ loaded_quantized_ep = torch.export.load(pt2e_quantized_model_file_path)
543
+ loaded_quantized_model = loaded_quantized_ep.module()
552
544
553
- # Sanity check with sample data
545
+ # 3. Check results for example inputs and checke evaluation accuracy again
554
546
res = loaded_quantized_model(* example_inputs)
555
-
556
- # 3. Evaluate the loaded quantized model
547
+ print ( " diff: " , ref - res)
548
+
557
549
top1, top5 = evaluate(loaded_quantized_model, criterion, data_loader_test)
558
550
print (" [after serialization/deserialization] Evaluation accuracy on test dataset: %2.2f , %2.2f " % (top1.avg, top5.avg))
559
551
552
+
553
+ Output:
554
+ .. code-block :: python
555
+ [before serialization] Evaluation accuracy on test dataset: 79.82 , 94.55
556
+ diff: tensor([[0 ., 0 ., 0 ., ... , 0 ., 0 ., 0 .],
557
+ [0 ., 0 ., 0 ., ... , 0 ., 0 ., 0 .],
558
+ [0 ., 0 ., 0 ., ... , 0 ., 0 ., 0 .],
559
+ ... ,
560
+ [0 ., 0 ., 0 ., ... , 0 ., 0 ., 0 .],
561
+ [0 ., 0 ., 0 ., ... , 0 ., 0 ., 0 .],
562
+ [0 ., 0 ., 0 ., ... , 0 ., 0 ., 0 .]])
563
+
564
+ [after serialization/ deserialization] Evaluation accuracy on test dataset: 79.82 , 94.55
565
+
566
+
560
567
Debugging the Quantized Model
561
- ----------------------------
568
+ ------------------------------
562
569
563
570
You can use `Numeric Suite <https://pytorch.org/docs/stable/quantization-accuracy-debugging.html#numerical-debugging-tooling-prototype >`_
564
571
that can help with debugging in eager mode and FX graph mode. The new version of
0 commit comments