@@ -499,7 +499,7 @@ <h1>Source code for torch.ao.quantization.quantize</h1><div class="highlight"><p
499
499
< span class ="n "> module</ span > < span class ="o "> .</ span > < span class ="n "> _forward_hooks</ span > < span class ="o "> .</ span > < span class ="n "> move_to_end</ span > < span class ="p "> (</ span > < span class ="n "> handle</ span > < span class ="o "> .</ span > < span class ="n "> id</ span > < span class ="p "> ,</ span > < span class ="n "> last</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> )</ span >
500
500
501
501
502
- < div class =" viewcode-block " id =" add_observer_ " > < a class =" viewcode-back " href =" ../../../../generated/torch.quantization.add_observer_.html#torch.quantization.add_observer_ " > [docs] </ a > < span class ="k "> def</ span > < span class ="nf "> add_observer_</ span > < span class ="p "> (</ span > < span class ="n "> module</ span > < span class ="p "> ,</ span > < span class ="n "> qconfig_propagation_list</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="n "> non_leaf_module_list</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="n "> custom_module_class_mapping</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ):</ span >
502
+ < span class ="k "> def</ span > < span class ="nf "> add_observer_</ span > < span class ="p "> (</ span > < span class ="n "> module</ span > < span class ="p "> ,</ span > < span class ="n "> qconfig_propagation_list</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="n "> non_leaf_module_list</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="n "> custom_module_class_mapping</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ):</ span >
503
503
< span class ="sa "> r</ span > < span class ="sd "> """Add observer for the leaf child of the module.</ span >
504
504
505
505
< span class ="sd "> This function insert observer module to all leaf child module that</ span >
@@ -583,7 +583,7 @@ <h1>Source code for torch.ao.quantization.quantize</h1><div class="highlight"><p
583
583
< span class ="c1 "> # the output of the module, for input QuantStub will observe them</ span >
584
584
< span class ="k "> if</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> module</ span > < span class ="o "> .</ span > < span class ="n "> _modules</ span > < span class ="p "> )</ span > < span class ="o "> ==</ span > < span class ="mi "> 0</ span > < span class ="ow "> and</ span > < span class ="ow "> not</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> module</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Sequential</ span > < span class ="p "> )</ span > \
585
585
< span class ="ow "> and</ span > < span class ="nb "> type</ span > < span class ="p "> (</ span > < span class ="n "> module</ span > < span class ="p "> )</ span > < span class ="ow "> in</ span > < span class ="n "> qconfig_propagation_list</ span > < span class ="p "> :</ span >
586
- < span class ="n "> insert_activation_post_process</ span > < span class ="p "> (</ span > < span class ="n "> module</ span > < span class ="p "> )</ span > </ div >
586
+ < span class ="n "> insert_activation_post_process</ span > < span class ="p "> (</ span > < span class ="n "> module</ span > < span class ="p "> )</ span >
587
587
588
588
< span class ="k "> def</ span > < span class ="nf "> get_unique_devices_</ span > < span class ="p "> (</ span > < span class ="n "> module</ span > < span class ="p "> ):</ span >
589
589
< span class ="k "> return</ span > < span class ="p "> {</ span > < span class ="n "> p</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="k "> for</ span > < span class ="n "> p</ span > < span class ="ow "> in</ span > < span class ="n "> module</ span > < span class ="o "> .</ span > < span class ="n "> parameters</ span > < span class ="p "> ()}</ span > < span class ="o "> |</ span > \
@@ -704,7 +704,7 @@ <h1>Source code for torch.ao.quantization.quantize</h1><div class="highlight"><p
704
704
705
705
< span class ="n "> _remove_activation_post_process</ span > < span class ="p "> (</ span > < span class ="n "> module</ span > < span class ="p "> )</ span >
706
706
707
- < span class ="k "> def</ span > < span class ="nf "> quantize</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> run_fn</ span > < span class ="p "> ,</ span > < span class ="n "> run_args</ span > < span class ="p "> ,</ span > < span class ="n "> mapping</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="n "> inplace</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ):</ span >
707
+ < div class =" viewcode-block " id =" quantize " > < a class =" viewcode-back " href =" ../../../../generated/torch.quantization.quantize.html#torch.quantization.quantize " > [docs] </ a > < span class ="k "> def</ span > < span class ="nf "> quantize</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> run_fn</ span > < span class ="p "> ,</ span > < span class ="n "> run_args</ span > < span class ="p "> ,</ span > < span class ="n "> mapping</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="n "> inplace</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ):</ span >
708
708
< span class ="sa "> r</ span > < span class ="sd "> """Quantize the input float model with post training static quantization.</ span >
709
709
710
710
< span class ="sd "> First it will prepare the model for calibration, then it calls</ span >
@@ -730,9 +730,9 @@ <h1>Source code for torch.ao.quantization.quantize</h1><div class="highlight"><p
730
730
< span class ="n "> prepare</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> inplace</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
731
731
< span class ="n "> run_fn</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="o "> *</ span > < span class ="n "> run_args</ span > < span class ="p "> )</ span >
732
732
< span class ="n "> convert</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> mapping</ span > < span class ="p "> ,</ span > < span class ="n "> inplace</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
733
- < span class ="k "> return</ span > < span class ="n "> model</ span >
733
+ < span class ="k "> return</ span > < span class ="n "> model</ span > </ div >
734
734
735
- < span class ="k "> def</ span > < span class ="nf "> quantize_dynamic</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> qconfig_spec</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> qint8</ span > < span class ="p "> ,</ span >
735
+ < div class =" viewcode-block " id =" quantize_dynamic " > < a class =" viewcode-back " href =" ../../../../generated/torch.quantization.quantize_dynamic.html#torch.quantization.quantize_dynamic " > [docs] </ a > < span class ="k "> def</ span > < span class ="nf "> quantize_dynamic</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> qconfig_spec</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> qint8</ span > < span class ="p "> ,</ span >
736
736
< span class ="n "> mapping</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="n "> inplace</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ):</ span >
737
737
< span class ="sa "> r</ span > < span class ="sd "> """Converts a float model to dynamic (i.e. weights-only) quantized model.</ span >
738
738
@@ -815,7 +815,7 @@ <h1>Source code for torch.ao.quantization.quantize</h1><div class="highlight"><p
815
815
< span class ="n "> model</ span > < span class ="o "> .</ span > < span class ="n "> eval</ span > < span class ="p "> ()</ span >
816
816
< span class ="n "> propagate_qconfig_</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> qconfig_spec</ span > < span class ="p "> )</ span >
817
817
< span class ="n "> convert</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> mapping</ span > < span class ="p "> ,</ span > < span class ="n "> inplace</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
818
- < span class ="k "> return</ span > < span class ="n "> model</ span >
818
+ < span class ="k "> return</ span > < span class ="n "> model</ span > </ div >
819
819
820
820
< span class ="k "> def</ span > < span class ="nf "> prepare_qat</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> mapping</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="n "> inplace</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ):</ span >
821
821
< span class ="sa "> r</ span > < span class ="sd "> """</ span >
@@ -845,7 +845,7 @@ <h1>Source code for torch.ao.quantization.quantize</h1><div class="highlight"><p
845
845
< span class ="n "> prepare</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> observer_non_leaf_module_list</ span > < span class ="o "> =</ span > < span class ="nb "> set</ span > < span class ="p "> (</ span > < span class ="n "> mapping</ span > < span class ="o "> .</ span > < span class ="n "> values</ span > < span class ="p "> ()),</ span > < span class ="n "> inplace</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
846
846
< span class ="k "> return</ span > < span class ="n "> model</ span >
847
847
848
- < span class ="k "> def</ span > < span class ="nf "> quantize_qat</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> run_fn</ span > < span class ="p "> ,</ span > < span class ="n "> run_args</ span > < span class ="p "> ,</ span > < span class ="n "> inplace</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ):</ span >
848
+ < div class =" viewcode-block " id =" quantize_qat " > < a class =" viewcode-back " href =" ../../../../generated/torch.quantization.quantize_qat.html#torch.quantization.quantize_qat " > [docs] </ a > < span class ="k "> def</ span > < span class ="nf "> quantize_qat</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> run_fn</ span > < span class ="p "> ,</ span > < span class ="n "> run_args</ span > < span class ="p "> ,</ span > < span class ="n "> inplace</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ):</ span >
849
849
< span class ="sa "> r</ span > < span class ="sd "> """Do quantization aware training and output a quantized model</ span >
850
850
851
851
< span class ="sd "> Args:</ span >
@@ -865,7 +865,7 @@ <h1>Source code for torch.ao.quantization.quantize</h1><div class="highlight"><p
865
865
< span class ="n "> prepare_qat</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> inplace</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
866
866
< span class ="n "> run_fn</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="o "> *</ span > < span class ="n "> run_args</ span > < span class ="p "> )</ span >
867
867
< span class ="n "> convert</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> inplace</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span >
868
- < span class ="k "> return</ span > < span class ="n "> model</ span >
868
+ < span class ="k "> return</ span > < span class ="n "> model</ span > </ div >
869
869
870
870
< span class ="k "> def</ span > < span class ="nf "> convert</ span > < span class ="p "> (</ span >
871
871
< span class ="n "> module</ span > < span class ="p "> ,</ span > < span class ="n "> mapping</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="n "> inplace</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span > < span class ="n "> remove_qconfig</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span >
@@ -944,7 +944,7 @@ <h1>Source code for torch.ao.quantization.quantize</h1><div class="highlight"><p
944
944
945
945
< span class ="k "> return</ span > < span class ="n "> module</ span >
946
946
947
- < span class ="k "> def</ span > < span class ="nf "> swap_module</ span > < span class ="p "> (</ span > < span class ="n "> mod</ span > < span class ="p "> ,</ span > < span class ="n "> mapping</ span > < span class ="p "> ,</ span > < span class ="n "> custom_module_class_mapping</ span > < span class ="p "> ):</ span >
947
+ < div class =" viewcode-block " id =" swap_module " > < a class =" viewcode-back " href =" ../../../../generated/torch.quantization.swap_module.html#torch.quantization.swap_module " > [docs] </ a > < span class ="k "> def</ span > < span class ="nf "> swap_module</ span > < span class ="p "> (</ span > < span class ="n "> mod</ span > < span class ="p "> ,</ span > < span class ="n "> mapping</ span > < span class ="p "> ,</ span > < span class ="n "> custom_module_class_mapping</ span > < span class ="p "> ):</ span >
948
948
< span class ="sa "> r</ span > < span class ="sd "> """Swaps the module if it has a quantized counterpart and it has an</ span >
949
949
< span class ="sd "> `observer` attached.</ span >
950
950
@@ -984,7 +984,7 @@ <h1>Source code for torch.ao.quantization.quantize</h1><div class="highlight"><p
984
984
< span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="nb "> next</ span > < span class ="p "> (</ span > < span class ="nb "> iter</ span > < span class ="p "> (</ span > < span class ="n "> devices</ span > < span class ="p "> ))</ span > < span class ="k "> if</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> devices</ span > < span class ="p "> )</ span > < span class ="o "> ></ span > < span class ="mi "> 0</ span > < span class ="k "> else</ span > < span class ="kc "> None</ span >
985
985
< span class ="k "> if</ span > < span class ="n "> device</ span > < span class ="p "> :</ span >
986
986
< span class ="n "> new_mod</ span > < span class ="o "> .</ span > < span class ="n "> to</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> )</ span >
987
- < span class ="k "> return</ span > < span class ="n "> new_mod</ span >
987
+ < span class ="k "> return</ span > < span class ="n "> new_mod</ span > </ div >
988
988
989
989
< span class ="k "> def</ span > < span class ="nf "> get_observer_dict</ span > < span class ="p "> (</ span > < span class ="n "> mod</ span > < span class ="p "> ,</ span > < span class ="n "> target_dict</ span > < span class ="p "> ,</ span > < span class ="n "> prefix</ span > < span class ="o "> =</ span > < span class ="s2 "> ""</ span > < span class ="p "> ):</ span >
990
990
< span class ="sa "> r</ span > < span class ="sd "> """Traverse the modules and save all observers into dict.</ span >
0 commit comments