Skip to content

Commit 5cbaeb0

Browse files
committed
Update Dynamic Quant for BERT tutorial
1 parent ab73472 commit 5cbaeb0

File tree

1 file changed

+15
-50
lines changed

1 file changed

+15
-50
lines changed

intermediate_source/dynamic_quantization_bert_tutorial.py

Lines changed: 15 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@
8686
#
8787
# .. code:: shell
8888
#
89-
# !pip install sklearn
90-
# !pip install transformers
89+
# pip install sklearn
90+
# pip install transformers
91+
#
9192

9293

9394
######################################################################
@@ -98,8 +99,10 @@
9899
# Mac:
99100
#
100101
# .. code:: shell
101-
# !yes y | pip uninstall torch tochvision
102-
# !yes y | pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html
102+
#
103+
# yes y | pip uninstall torch tochvision
104+
# yes y | pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html
105+
#
103106

104107

105108
######################################################################
@@ -155,13 +158,13 @@
155158
# https://github.com/nyu-mll/GLUE-baselines/blob/master/download_glue_data.py)
156159
# and unpack it to some directory “glue_data/MRPC”.
157160
#
158-
159-
# !python download_glue_data.py --data_dir='glue_data' --tasks='MRPC' --test_labels=True
160-
!pwd
161-
!ls
162-
!wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
163-
!python download_glue_data.py --data_dir='glue_data' --tasks='MRPC'
164-
!ls glue_data/MRPC
161+
#
162+
# .. code:: shell
163+
#
164+
# wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
165+
# python download_glue_data.py --data_dir='glue_data' --tasks='MRPC'
166+
# ls glue_data/MRPC
167+
#
165168

166169

167170
######################################################################
@@ -255,9 +258,6 @@
255258
from google.colab import drive
256259
drive.mount('/content/drive')
257260

258-
!ls
259-
!pwd
260-
261261

262262
######################################################################
263263
# Set global configurations
@@ -273,13 +273,9 @@
273273
configs = Namespace()
274274

275275
# The output directory for the fine-tuned model.
276-
# configs.output_dir = "/mnt/homedir/jianyuhuang/public/bert/MRPC/"
277276
configs.output_dir = "/content/drive/My Drive/BERT_Quant_Tutorial/MRPC/"
278-
# configs.output_dir = "./MRPC/"
279277

280278
# The data directory for the MRPC task in the GLUE benchmark.
281-
# configs.data_dir = "/mnt/homedir/jianyuhuang/public/bert/glue_data/MRPC"
282-
# configs.data_dir = "./glue_data/MRPC"
283279
configs.data_dir = "/content/glue_data/MRPC"
284280

285281
# The model name or path for the pre-trained model.
@@ -493,30 +489,6 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
493489
print(quantized_model)
494490

495491

496-
497-
######################################################################
498-
# In PyTorch 1.4 release, we further add the per-channel quantization
499-
# support for dynamic quantization.
500-
#
501-
# .. figure:: https://drive.google.com/open?id=1N6P70MR6jJ2tcFnFJ2lROLSFqmiOY--g
502-
# :alt: Per Tensor Quantization for Weight
503-
#
504-
# Per Tensor Quantization for Weight
505-
#
506-
# .. figure:: https://drive.google.com/open?id=1nyjUKP5qtkRCJPKtUaXXwhglLMQQ0Dfs
507-
# :alt: Per Channel Quantization for Weight
508-
#
509-
# Per Channel Quantization for Weight
510-
#
511-
512-
qconfig_dict = {
513-
torch.nn.Linear: torch.quantization.per_channel_dynamic_qconfig
514-
}
515-
per_channel_quantized_model = torch.quantization.quantize_dynamic(
516-
model, qconfig_dict, dtype=torch.qint8
517-
)
518-
519-
520492
######################################################################
521493
# Check the model size
522494
# --------------------
@@ -532,9 +504,6 @@ def print_size_of_model(model):
532504

533505
print_size_of_model(model)
534506
print_size_of_model(quantized_model)
535-
# print_size_of_model(per_channel_quantized_model)
536-
537-
538507

539508

540509
######################################################################
@@ -606,10 +575,6 @@ def time_model_evaluation(model, configs, tokenizer):
606575
# processing the evaluation of MRPC dataset.
607576
#
608577

609-
# Evaluate the INT8 BERT model after the per-channel dynamic quantization
610-
time_model_evaluation(per_channel_quantized_model, configs, tokenizer)
611-
612-
613578

614579
######################################################################
615580
# Serialize the quantized model
@@ -621,7 +586,7 @@ def time_model_evaluation(model, configs, tokenizer):
621586
quantized_output_dir = configs.output_dir + "quantized/"
622587
if not os.path.exists(quantized_output_dir):
623588
os.makedirs(quantized_output_dir)
624-
quantized_model.save_pretrained(quantized_output_dir)
589+
quantized_model.save_pretrained(quantized_output_dir)
625590

626591

627592
######################################################################

0 commit comments

Comments
 (0)