From 5cbaeb0e8fa5b732f293375364b6ad6f8d5a77a9 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Wed, 4 Dec 2019 23:27:06 -0800 Subject: [PATCH] Update Dynamic Quant for BERT tutorial --- .../dynamic_quantization_bert_tutorial.py | 65 +++++-------------- 1 file changed, 15 insertions(+), 50 deletions(-) diff --git a/intermediate_source/dynamic_quantization_bert_tutorial.py b/intermediate_source/dynamic_quantization_bert_tutorial.py index 02ee84b36ea..341c7d05d29 100644 --- a/intermediate_source/dynamic_quantization_bert_tutorial.py +++ b/intermediate_source/dynamic_quantization_bert_tutorial.py @@ -86,8 +86,9 @@ # # .. code:: shell # -# !pip install sklearn -# !pip install transformers +# pip install sklearn +# pip install transformers +# ###################################################################### @@ -98,8 +99,10 @@ # Mac: # # .. code:: shell -# !yes y | pip uninstall torch tochvision -# !yes y | pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html +# +# yes y | pip uninstall torch tochvision +# yes y | pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html +# ###################################################################### @@ -155,13 +158,13 @@ # https://github.com/nyu-mll/GLUE-baselines/blob/master/download_glue_data.py) # and unpack it to some directory “glue_data/MRPC”. # - -# !python download_glue_data.py --data_dir='glue_data' --tasks='MRPC' --test_labels=True -!pwd -!ls -!wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py -!python download_glue_data.py --data_dir='glue_data' --tasks='MRPC' -!ls glue_data/MRPC +# +# .. code:: shell +# +# wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py +# python download_glue_data.py --data_dir='glue_data' --tasks='MRPC' +# ls glue_data/MRPC +# ###################################################################### @@ -255,9 +258,6 @@ from google.colab import drive drive.mount('/content/drive') -!ls -!pwd - ###################################################################### # Set global configurations @@ -273,13 +273,9 @@ configs = Namespace() # The output directory for the fine-tuned model. -# configs.output_dir = "/mnt/homedir/jianyuhuang/public/bert/MRPC/" configs.output_dir = "/content/drive/My Drive/BERT_Quant_Tutorial/MRPC/" -# configs.output_dir = "./MRPC/" # The data directory for the MRPC task in the GLUE benchmark. -# configs.data_dir = "/mnt/homedir/jianyuhuang/public/bert/glue_data/MRPC" -# configs.data_dir = "./glue_data/MRPC" configs.data_dir = "/content/glue_data/MRPC" # The model name or path for the pre-trained model. @@ -493,30 +489,6 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): print(quantized_model) - -###################################################################### -# In PyTorch 1.4 release, we further add the per-channel quantization -# support for dynamic quantization. -# -# .. figure:: https://drive.google.com/open?id=1N6P70MR6jJ2tcFnFJ2lROLSFqmiOY--g -# :alt: Per Tensor Quantization for Weight -# -# Per Tensor Quantization for Weight -# -# .. figure:: https://drive.google.com/open?id=1nyjUKP5qtkRCJPKtUaXXwhglLMQQ0Dfs -# :alt: Per Channel Quantization for Weight -# -# Per Channel Quantization for Weight -# - -qconfig_dict = { - torch.nn.Linear: torch.quantization.per_channel_dynamic_qconfig -} -per_channel_quantized_model = torch.quantization.quantize_dynamic( - model, qconfig_dict, dtype=torch.qint8 -) - - ###################################################################### # Check the model size # -------------------- @@ -532,9 +504,6 @@ def print_size_of_model(model): print_size_of_model(model) print_size_of_model(quantized_model) -# print_size_of_model(per_channel_quantized_model) - - ###################################################################### @@ -606,10 +575,6 @@ def time_model_evaluation(model, configs, tokenizer): # processing the evaluation of MRPC dataset. # -# Evaluate the INT8 BERT model after the per-channel dynamic quantization -time_model_evaluation(per_channel_quantized_model, configs, tokenizer) - - ###################################################################### # Serialize the quantized model @@ -621,7 +586,7 @@ def time_model_evaluation(model, configs, tokenizer): quantized_output_dir = configs.output_dir + "quantized/" if not os.path.exists(quantized_output_dir): os.makedirs(quantized_output_dir) -quantized_model.save_pretrained(quantized_output_dir) + quantized_model.save_pretrained(quantized_output_dir) ######################################################################