diff --git a/_static/img/bert.png b/_static/img/bert.png
new file mode 100644
index 00000000000..6e23a8acfd3
Binary files /dev/null and b/_static/img/bert.png differ
diff --git a/_static/img/quantized_transfer_learning.png b/_static/img/quantized_transfer_learning.png
new file mode 100644
index 00000000000..c138cbdb0c1
Binary files /dev/null and b/_static/img/quantized_transfer_learning.png differ
diff --git a/index.rst b/index.rst
index 4656b66976a..5817cadae94 100644
--- a/index.rst
+++ b/index.rst
@@ -247,6 +247,18 @@ Quantization (experimental)
:figure: /_static/img/qat.png
:description: :doc:`advanced/static_quantization_tutorial`
+.. customgalleryitem::
+ :tooltip: Perform quantized transfer learning with feature extractor
+ :description: :doc:`/intermediate/quantized_transfer_learning_tutorial`
+ :figure: /_static/img/quantized_transfer_learning.png
+
+.. customgalleryitem::
+ :tooltip: Convert a well-known state-of-the-art model like BERT into dynamic quantized model
+ :description: :doc:`/intermediate/dynamic_quantization_bert_tutorial`
+ :figure: /_static/img/bert.png
+
+
+
.. raw:: html
@@ -328,7 +340,7 @@ PyTorch Fundamentals In-Depth
beginner/text_sentiment_ngrams_tutorial
beginner/torchtext_translation_tutorial
beginner/transformer_tutorial
-
+
.. toctree::
:maxdepth: 2
:includehidden:
@@ -385,6 +397,8 @@ PyTorch Fundamentals In-Depth
advanced/dynamic_quantization_tutorial
advanced/static_quantization_tutorial
+ intermediate/quantized_transfer_learning_tutorial
+ intermediate/dynamic_quantization_bert_tutorial
.. toctree::
:maxdepth: 2
diff --git a/intermediate_source/dynamic_quantization_bert_tutorial.py b/intermediate_source/dynamic_quantization_bert_tutorial.py
new file mode 100644
index 00000000000..0bdc91ce14c
--- /dev/null
+++ b/intermediate_source/dynamic_quantization_bert_tutorial.py
@@ -0,0 +1,661 @@
+# -*- coding: utf-8 -*-
+"""
+(Experimental) Dynamic Quantization on HuggingFace BERT model
+==============================================================
+**Author**: `Jianyu Huang `_
+
+**Reviewed by**: `Raghuraman Krishnamoorthi `_
+
+**Edited by**: `Jessica Lin `_
+
+"""
+
+
+######################################################################
+# Introduction
+# ============
+#
+# In this tutorial, we will apply the dynamic quantization on a BERT
+# model, closely following the BERT model from the HuggingFace
+# Transformers examples (https://github.com/huggingface/transformers).
+# With this step-by-step journey, we would like to demonstrate how to
+# convert a well-known state-of-the-art model like BERT into dynamic
+# quantized model.
+#
+# - BERT, or Bidirectional Embedding Representations from Transformers,
+# is a new method of pre-training language representations which
+# achieves the state-of-the-art accuracy results on many popular
+# Natural Language Processing (NLP) tasks, such as question answering,
+# text classification, and others. The original paper can be found
+# here: https://arxiv.org/pdf/1810.04805.pdf.
+#
+# - Dynamic quantization support in PyTorch converts a float model to a
+# quantized model with static int8 or float16 data types for the
+# weights and dynamic quantization for the activations. The activations
+# are quantized dynamically (per batch) to int8 when the weights are
+# quantized to int8.
+#
+# In PyTorch, we have ``torch.quantization.quantize_dynamic`` API support
+# (https://pytorch.org/docs/stable/quantization.html#torch.quantization.quantize_dynamic),
+# which replaces specified modules with dynamic weight-only quantized
+# versions and output the quantized model.
+#
+# - We demonstrate the accuracy and inference performance results on the
+# Microsoft Research Paraphrase Corpus (MRPC) task
+# (https://www.microsoft.com/en-us/download/details.aspx?id=52398) in
+# the General Language Understanding Evaluation benchmark (GLUE)
+# (https://gluebenchmark.com/). The MRPC (Dolan and Brockett, 2005) is
+# a corpus of sentence pairs automatically extracted from online news
+# sources, with human annotations of whether the sentences in the pair
+# are semantically equivalent. Because the classes are imbalanced (68%
+# positive, 32% negative), we follow common practice and report both
+# accuracy and F1 score
+# (https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html).
+# MRPC is a common NLP task for language pair classification, as shown
+# below.
+#
+# .. raw:: html
+#
+#
+#
+# .. figure:: https://gluon-nlp.mxnet.io/_images/bert-sentence-pair.png
+# :alt: BERT for setence pair classification
+#
+# BERT for setence pair classification
+#
+# .. raw:: html
+#
+#
+#
+
+
+######################################################################
+# Setup
+# =====
+#
+# Install PyTorch and HuggingFace Transformers
+# --------------------------------------------
+#
+# To start this tutorial, let’s first follow the installation instructions
+# in PyTorch and HuggingFace Github Repo: -
+# https://github.com/pytorch/pytorch/#installation -
+# https://github.com/huggingface/transformers#installation
+#
+# In addition, we also install ``sklearn`` package, as we will reuse its
+# built-in F1 score calculation helper function.
+#
+# .. code:: shell
+#
+# !pip install sklearn
+# !pip install transformers
+
+
+######################################################################
+# Because we will be using the experimental parts of the PyTorch, it is
+# recommended to install the latest version of torch and torchvision. You
+# can find the most recent instructions on local installation here
+# https://pytorch.org/get-started/locally/. For example, to install on
+# Mac:
+#
+# .. code:: shell
+# !yes y | pip uninstall torch tochvision
+# !yes y | pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html
+
+
+######################################################################
+# Import the necessary modules
+# ----------------------------
+#
+# In this step we import the necessary Python modules for the tutorial.
+#
+
+from __future__ import absolute_import, division, print_function
+
+import logging
+import numpy as np
+import os
+import random
+import sys
+import time
+import torch
+
+from argparse import Namespace
+from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
+ TensorDataset)
+from tqdm import tqdm
+from transformers import (BertConfig, BertForSequenceClassification, BertTokenizer,)
+from transformers import glue_compute_metrics as compute_metrics
+from transformers import glue_output_modes as output_modes
+from transformers import glue_processors as processors
+from transformers import glue_convert_examples_to_features as convert_examples_to_features
+
+# Setup logging
+logger = logging.getLogger(__name__)
+logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
+ datefmt = '%m/%d/%Y %H:%M:%S',
+ level = logging.WARN)
+
+logging.getLogger("transformers.modeling_utils").setLevel(
+ logging.WARN) # Reduce logging
+
+print(torch.__version__)
+# We set the number of threads to compare the single thread performance between FP32 and INT8 performance.
+# In the end of the tutorial, the user can set other number of threads by building PyTorch with right parallel backend.
+torch.set_num_threads(1)
+print(torch.__config__.parallel_info())
+
+
+######################################################################
+# Download the dataset
+# --------------------
+#
+# Before running MRPC tasks we download the GLUE data
+# (https://gluebenchmark.com/tasks) by running this script
+# (https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e,
+# https://github.com/nyu-mll/GLUE-baselines/blob/master/download_glue_data.py)
+# and unpack it to some directory “glue_data/MRPC”.
+#
+
+# !python download_glue_data.py --data_dir='glue_data' --tasks='MRPC' --test_labels=True
+!pwd
+!ls
+!wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
+!python download_glue_data.py --data_dir='glue_data' --tasks='MRPC'
+!ls glue_data/MRPC
+
+
+######################################################################
+# Helper functions
+# ----------------
+#
+# The helper functions are built-in in transformers library. We mainly use
+# the following helper functions: one for converting the text examples
+# into the feature vectors; The other one for measuring the F1 score of
+# the predicted result.
+#
+# Convert the texts into features
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# glue_convert_examples_to_features (
+# https://github.com/huggingface/transformers/blob/master/transformers/data/processors/glue.py)
+# load a data file into a list of ``InputFeatures``.
+#
+# - Tokenize the input sequences;
+# - Insert [CLS] at the beginning;
+# - Insert [SEP] between the first sentence and the second sentence, and
+# at the end;
+# - Generate token type ids to indicate whether a token belongs to the
+# first sequence or the second sequence;
+#
+# F1 metric
+# ~~~~~~~~~
+#
+# The F1 score
+# (https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html)
+# can be interpreted as a weighted average of the precision and recall,
+# where an F1 score reaches its best value at 1 and worst score at 0. The
+# relative contribution of precision and recall to the F1 score are equal.
+# The formula for the F1 score is:
+#
+# F1 = 2 \* (precision \* recall) / (precision + recall)
+#
+
+
+######################################################################
+# Fine-tune the BERT model
+# ========================
+#
+
+
+######################################################################
+# The spirit of BERT is to pre-train the language representations and then
+# to fine-tune the deep bi-directional representations on a wide range of
+# tasks with minimal task-dependent parameters, and achieves
+# state-of-the-art results. In this tutorial, we will focus on fine-tuning
+# with the pre-trained BERT model to classify semantically equivalent
+# sentence pairs on MRPC task.
+#
+# To fine-tune the pre-trained BERT model (“bert-base-uncased” model in
+# HuggingFace transformers) for the MRPC task, you can follow the command
+# in (https://github.com/huggingface/transformers/tree/master/examples):
+#
+# ::
+#
+# export GLUE_DIR=./glue_data
+# export TASK_NAME=MRPC
+# export OUT_DIR=/mnt/homedir/jianyuhuang/public/bert/$TASK_NAME/
+# python ./run_glue.py \
+# --model_type bert \
+# --model_name_or_path bert-base-uncased \
+# --task_name $TASK_NAME \
+# --do_train \
+# --do_eval \
+# --do_lower_case \
+# --data_dir $GLUE_DIR/$TASK_NAME \
+# --max_seq_length 128 \
+# --per_gpu_eval_batch_size=8 \
+# --per_gpu_train_batch_size=8 \
+# --learning_rate 2e-5 \
+# --num_train_epochs 3.0 \
+# --save_steps 100000 \
+# --output_dir $OUT_DIR
+#
+# We provide the fined-tuned BERT model for MRPC task here (We did the
+# fine-tuning on CPUs with a total train batch size of 8):
+#
+# https://drive.google.com/drive/folders/1mGBx0t-YJAWXHbgab2f_IimaMiVHlKh-
+#
+# To save time, you can manually copy the fined-tuned BERT model for MRPC
+# task in your Google Drive (Create the same “BERT_Quant_Tutorial/MRPC”
+# folder in the Google Drive directory), and then mount your Google Drive
+# on your runtime using an authorization code, so that we can directly
+# read and write the models into Google Drive in the following steps.
+#
+
+from google.colab import drive
+drive.mount('/content/drive')
+
+!ls
+!pwd
+
+
+######################################################################
+# Set global configurations
+# -------------------------
+#
+
+
+######################################################################
+# Here we set the global configurations for evaluating the fine-tuned BERT
+# model before and after the dynamic quantization.
+#
+
+configs = Namespace()
+
+# The output directory for the fine-tuned model.
+# configs.output_dir = "/mnt/homedir/jianyuhuang/public/bert/MRPC/"
+configs.output_dir = "/content/drive/My Drive/BERT_Quant_Tutorial/MRPC/"
+# configs.output_dir = "./MRPC/"
+
+# The data directory for the MRPC task in the GLUE benchmark.
+# configs.data_dir = "/mnt/homedir/jianyuhuang/public/bert/glue_data/MRPC"
+# configs.data_dir = "./glue_data/MRPC"
+configs.data_dir = "/content/glue_data/MRPC"
+
+# The model name or path for the pre-trained model.
+configs.model_name_or_path = "bert-base-uncased"
+# The maximum length of an input sequence
+configs.max_seq_length = 128
+
+# Prepare GLUE task.
+configs.task_name = "MRPC".lower()
+configs.processor = processors[configs.task_name]()
+configs.output_mode = output_modes[configs.task_name]
+configs.label_list = configs.processor.get_labels()
+configs.model_type = "bert".lower()
+configs.do_lower_case = True
+
+# Set the device, batch size, topology, and caching flags.
+configs.device = "cpu"
+configs.per_gpu_eval_batch_size = 8
+configs.n_gpu = 0
+configs.local_rank = -1
+configs.overwrite_cache = False
+
+
+# Set random seed for reproducibility.
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+set_seed(42)
+
+
+######################################################################
+# Load the fine-tuned BERT model
+# ------------------------------
+#
+
+
+######################################################################
+# We load the tokenizer and fine-tuned BERT sequence classifier model
+# (FP32) from the ``configs.output_dir``.
+#
+
+tokenizer = BertTokenizer.from_pretrained(
+ configs.output_dir, do_lower_case=configs.do_lower_case)
+
+model = BertForSequenceClassification.from_pretrained(configs.output_dir)
+model.to(configs.device)
+
+
+######################################################################
+# Define the tokenize and evaluation function
+# -------------------------------------------
+#
+# We reuse the tokenize and evaluation function from
+# https://github.com/huggingface/transformers/blob/master/examples/run_glue.py.
+#
+
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+def evaluate(args, model, tokenizer, prefix=""):
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
+ eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
+ eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)
+
+ results = {}
+ for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
+ eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
+
+ if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
+ os.makedirs(eval_output_dir)
+
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
+ # Note that DistributedSampler samples randomly
+ eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
+ eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
+
+ # multi-gpu eval
+ if args.n_gpu > 1:
+ model = torch.nn.DataParallel(model)
+
+ # Eval!
+ logger.info("***** Running evaluation {} *****".format(prefix))
+ logger.info(" Num examples = %d", len(eval_dataset))
+ logger.info(" Batch size = %d", args.eval_batch_size)
+ eval_loss = 0.0
+ nb_eval_steps = 0
+ preds = None
+ out_label_ids = None
+ for batch in tqdm(eval_dataloader, desc="Evaluating"):
+ model.eval()
+ batch = tuple(t.to(args.device) for t in batch)
+
+ with torch.no_grad():
+ inputs = {'input_ids': batch[0],
+ 'attention_mask': batch[1],
+ 'labels': batch[3]}
+ if args.model_type != 'distilbert':
+ inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
+ outputs = model(**inputs)
+ tmp_eval_loss, logits = outputs[:2]
+
+ eval_loss += tmp_eval_loss.mean().item()
+ nb_eval_steps += 1
+ if preds is None:
+ preds = logits.detach().cpu().numpy()
+ out_label_ids = inputs['labels'].detach().cpu().numpy()
+ else:
+ preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
+ out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
+
+ eval_loss = eval_loss / nb_eval_steps
+ if args.output_mode == "classification":
+ preds = np.argmax(preds, axis=1)
+ elif args.output_mode == "regression":
+ preds = np.squeeze(preds)
+ result = compute_metrics(eval_task, preds, out_label_ids)
+ results.update(result)
+
+ output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
+ with open(output_eval_file, "w") as writer:
+ logger.info("***** Eval results {} *****".format(prefix))
+ for key in sorted(result.keys()):
+ logger.info(" %s = %s", key, str(result[key]))
+ writer.write("%s = %s\n" % (key, str(result[key])))
+
+ return results
+
+
+def load_and_cache_examples(args, task, tokenizer, evaluate=False):
+ if args.local_rank not in [-1, 0] and not evaluate:
+ torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
+
+ processor = processors[task]()
+ output_mode = output_modes[task]
+ # Load data features from cache or dataset file
+ cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
+ 'dev' if evaluate else 'train',
+ list(filter(None, args.model_name_or_path.split('/'))).pop(),
+ str(args.max_seq_length),
+ str(task)))
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
+ logger.info("Loading features from cached file %s", cached_features_file)
+ features = torch.load(cached_features_file)
+ else:
+ logger.info("Creating features from dataset file at %s", args.data_dir)
+ label_list = processor.get_labels()
+ if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']:
+ # HACK(label indices are swapped in RoBERTa pretrained model)
+ label_list[1], label_list[2] = label_list[2], label_list[1]
+ examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
+ features = convert_examples_to_features(examples,
+ tokenizer,
+ label_list=label_list,
+ max_length=args.max_seq_length,
+ output_mode=output_mode,
+ pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
+ pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
+ pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
+ )
+ if args.local_rank in [-1, 0]:
+ logger.info("Saving features into cached file %s", cached_features_file)
+ torch.save(features, cached_features_file)
+
+ if args.local_rank == 0 and not evaluate:
+ torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
+
+ # Convert to Tensors and build dataset
+ all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
+ all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
+ all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
+ if output_mode == "classification":
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
+ elif output_mode == "regression":
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
+
+ dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
+ return dataset
+
+
+
+######################################################################
+# Apply the dynamic quantization
+# ==============================
+#
+# We call ``torch.quantization.quantize_dynamic`` on the model to apply
+# the dynamic quantization on the HuggingFace BERT model. Specifically,
+#
+# - We specify that we want the torch.nn.Linear modules in our model to
+# be quantized;
+# - We specify that we want weights to be converted to quantized int8
+# values.
+#
+
+quantized_model = torch.quantization.quantize_dynamic(
+ model, {torch.nn.Linear}, dtype=torch.qint8
+)
+print(quantized_model)
+
+
+
+######################################################################
+# In PyTorch 1.4 release, we further add the per-channel quantization
+# support for dynamic quantization.
+#
+# .. figure:: https://drive.google.com/open?id=1N6P70MR6jJ2tcFnFJ2lROLSFqmiOY--g
+# :alt: Per Tensor Quantization for Weight
+#
+# Per Tensor Quantization for Weight
+#
+# .. figure:: https://drive.google.com/open?id=1nyjUKP5qtkRCJPKtUaXXwhglLMQQ0Dfs
+# :alt: Per Channel Quantization for Weight
+#
+# Per Channel Quantization for Weight
+#
+
+qconfig_dict = {
+ torch.nn.Linear: torch.quantization.per_channel_dynamic_qconfig
+}
+per_channel_quantized_model = torch.quantization.quantize_dynamic(
+ model, qconfig_dict, dtype=torch.qint8
+)
+
+
+######################################################################
+# Check the model size
+# --------------------
+#
+# Let’s first check the model size. We can observe a significant reduction
+# in model size:
+#
+
+def print_size_of_model(model):
+ torch.save(model.state_dict(), "temp.p")
+ print('Size (MB):', os.path.getsize("temp.p")/1e6)
+ os.remove('temp.p')
+
+print_size_of_model(model)
+print_size_of_model(quantized_model)
+# print_size_of_model(per_channel_quantized_model)
+
+
+
+
+######################################################################
+# The BERT model used in this tutorial (bert-base-uncased) has a
+# vocabulary size V of 30522. With the embedding size of 768, the total
+# size of the word embedding table is ~ 4 (Bytes/FP32) \* 30522 \* 768 =
+# 90 MB. So with the help of quantization, the model size of the
+# non-embedding table part is reduced from 350 MB (FP32 model) to 90 MB
+# (INT8 model).
+#
+
+
+######################################################################
+# Evaluate the inference accuracy and time
+# ----------------------------------------
+#
+# Next, let’s compare the inference time as well as the evaluation
+# accuracy between the original FP32 model and the INT8 model after the
+# dynamic quantization.
+#
+
+# Evaluate the original FP32 BERT model
+def time_model_evaluation(model, configs, tokenizer):
+ eval_start_time = time.time()
+ result = evaluate(configs, model, tokenizer, prefix="")
+ eval_end_time = time.time()
+ eval_duration_time = eval_end_time - eval_start_time
+ print(result)
+ print("Evaluate total time (seconds): {0:.1f}".format(eval_duration_time))
+
+time_model_evaluation(model, configs, tokenizer)
+
+# Evaluate the INT8 BERT model after the dynamic quantization
+time_model_evaluation(quantized_model, configs, tokenizer)
+
+
+######################################################################
+# Running this locally on a MacBook Pro, without quantization, inference
+# (for all 408 examples in MRPC dataset) takes about 160 seconds, and with
+# quantization it takes just about 90 seconds. We summarize the results
+# for running the quantized BERT model inference on a Macbook Pro as the
+# follows:
+#
+# ::
+#
+# | Prec | F1 score | Model Size | 1 thread | 4 threads |
+# | FP32 | 0.9019 | 438 MB | 160 sec | 85 sec |
+# | INT8 | 0.8953 | 181 MB | 90 sec | 46 sec |
+#
+# We have 0.6% F1 score accuracy after applying the post-training dynamic
+# quantization on the fine-tuned BERT model on the MRPC task. As a
+# comparison, in the recent paper [3] (Table 1), it achieved 0.8788 by
+# applying the post-training dynamic quantization and 0.8956 by applying
+# the quantization-aware training. The main reason is that we support the
+# asymmetric quantization in PyTorch while that paper supports the
+# symmetric quantization only.
+#
+# Note that we set the number of threads to 1 for the single-thread
+# comparison in this tutorial. We also support the intra-op
+# parallelization for these quantized INT8 operators. The users can now
+# set multi-thread by ``torch.set_num_threads(N)`` (``N`` is the number of
+# intra-op parallelization threads). One preliminary requirement to enable
+# the intra-op parallelization support is to build PyTorch with the right
+# backend such as OpenMP, Native, or TBB
+# (https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html#build-options).
+# You can use ``torch.__config__.parallel_info()`` to check the
+# parallelization settings. On the same MacBook Pro using PyTorch with
+# Native backend for parallelization, we can get about 46 seconds for
+# processing the evaluation of MRPC dataset.
+#
+
+# Evaluate the INT8 BERT model after the per-channel dynamic quantization
+time_model_evaluation(per_channel_quantized_model, configs, tokenizer)
+
+
+
+######################################################################
+# Serialize the quantized model
+# -----------------------------
+#
+# We can serialize and save the quantized model for the future use.
+#
+
+quantized_output_dir = configs.output_dir + "quantized/"
+if not os.path.exists(quantized_output_dir):
+ os.makedirs(quantized_output_dir)
+quantized_model.save_pretrained(quantized_output_dir)
+
+
+######################################################################
+# Conclusion
+# ==========
+#
+# In this tutorial, we demonstrated how to demonstrate how to convert a
+# well-known state-of-the-art NLP model like BERT into dynamic quantized
+# model. Dynamic quantization can reduce the size of the model while only
+# having a limited implication on accuracy.
+#
+# Thanks for reading! As always, we welcome any feedback, so please create
+# an issue here (https://github.com/pytorch/pytorch/issues) if you have
+# any.
+#
+
+
+######################################################################
+# References
+# ==========
+#
+# [1] J.Devlin, M. Chang, K. Lee and K. Toutanova, BERT: Pre-training of
+# Deep Bidirectional Transformers for Language Understanding (2018)
+#
+# [2] HuggingFace Transformers.
+# https://github.com/huggingface/transformers
+#
+# [3] O. Zafrir, G. Boudoukh, P. Izsak, & M. Wasserblat (2019). Q8BERT:
+# Quantized 8bit BERT. arXiv preprint arXiv:1910.06188.
+#
+
+
+######################################################################
+#
+#
+
+
diff --git a/intermediate_source/quantized_transfer_learning_tutorial.py b/intermediate_source/quantized_transfer_learning_tutorial.py
new file mode 100644
index 00000000000..750d2c9ff29
--- /dev/null
+++ b/intermediate_source/quantized_transfer_learning_tutorial.py
@@ -0,0 +1,530 @@
+"""
+Quantized Transfer Learning for Computer Vision Tutorial
+========================================================
+
+**Author**: `Zafar Takhirov `_
+
+**Reviewed by**: `Raghuraman Krishnamoorthi `_
+
+**Edited by**: `Jessica Lin `_
+
+This tutorial builds on the original `PyTorch Transfer Learning `_
+tutorial, written by
+`Sasank Chilamkurthy `_.
+
+Transfer learning refers to techniques to use a pretrained model for
+application on a different data-set. Typical scenarios look as follows:
+
+1. **ConvNet as fixed feature extractor**: Here, you “freeze”[#1]\_ the
+ weights for all of the network parameters except that of the final
+ several layers (aka “the head”, usually fully connected layers).
+ These last layers are replaced with new ones initialized with random
+ weights and only these layers are trained.
+2. **Finetuning the convnet**: Instead of random initializaion, you
+ initialize the network with a pretrained network, like the one that
+ is trained on imagenet 1000 dataset. Rest of the training looks as
+ usual. It is common to set the learning rate to a smaller number, as
+ the network is already considered to be trained.
+
+You can also combine the above two scenarios, and execute them both:
+First you can freeze the feature extractor, and train the head. After
+that, you can unfreeze the feature extractor (or part of it), set the
+learning rate to something smaller, and continue training.
+
+In this part you will use the first scenario – extracting the features
+using a quantized model.
+
+.. rubric:: Footnotes
+
+.. [#1] “Freezing” the model/layer means running it only in inference
+mode, and not allowing its parameters to be updated during the training.
+
+We will start by doing the necessary imports:
+"""
+
+# imports
+import matplotlib.pyplot as plt
+import numpy as np
+import time
+import copy
+
+plt.rc('axes', labelsize=18, titlesize=18)
+plt.rc('figure', titlesize=18)
+plt.rc('font', family='DejaVu Sans', serif='Times', size=18)
+plt.rc('legend', fontsize=18)
+plt.rc('lines', linewidth=3)
+plt.rc('text', usetex=False) # TeX might not be supported
+plt.rc('xtick', labelsize=18)
+plt.rc('ytick', labelsize=18)
+
+######################################################################
+# Installing the Nightly Build
+# ----------------------------
+#
+# Because you will be using the experimental parts of the PyTorch, it is
+# recommended to install the latest version of ``torch`` and
+# ``torchvision``. You can find the most recent instructions on local
+# installation `here `_.
+# For example, to install on Mac:
+#
+# .. code:: shell
+#
+# pip install numpy
+# pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
+#
+
+
+
+
+######################################################################
+# Load Data (section not needed as it is covered in the original tutorial)
+# ------------------------------------------------------------------------
+#
+# We will use ``torchvision`` and ``torch.utils.data`` packages to load
+# the data.
+#
+# The problem you are going to solve today is classifying **ants** and
+# **bees** from images. The dataset contains about 120 training images
+# each for ants and bees. There are 75 validation images for each class.
+# This is considered a very small dataset to generalize on. However, since
+# we are using transfer learning, we should be able to generalize
+# reasonably well.
+#
+# *This dataset is a very small subset of imagenet.*
+#
+# .. Note :: Download the data from
+# `here `_
+# and extract it to the ``data`` directory.
+#
+
+import requests
+import os
+import zipfile
+
+DATA_URL = 'https://download.pytorch.org/tutorial/hymenoptera_data.zip'
+DATA_PATH = os.path.join('.', 'data')
+FILE_NAME = os.path.join(DATA_PATH, 'hymenoptera_data.zip')
+
+if not os.path.isfile(FILE_NAME):
+ print("Downloading the data...")
+ os.makedirs('data', exist_ok=True)
+ with requests.get(DATA_URL) as req:
+ with open(FILE_NAME, 'wb') as f:
+ f.write(req.content)
+ if 200 <= req.status_code < 300:
+ print("Download complete!")
+ else:
+ print("Download failed!")
+else:
+ print(FILE_NAME, "already exists, skipping download...")
+
+with zipfile.ZipFile(FILE_NAME, 'r') as zip_ref:
+ print("Unzipping...")
+ zip_ref.extractall('data')
+
+DATA_PATH = os.path.join(DATA_PATH, 'hymenoptera_data')
+
+import torch
+from torchvision import transforms, datasets
+
+# Data augmentation and normalization for training
+# Just normalization for validation
+data_transforms = {
+ 'train': transforms.Compose([
+ transforms.Resize(224),
+ transforms.RandomCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+ ]),
+ 'val': transforms.Compose([
+ transforms.Resize(224),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+ ]),
+}
+
+image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_PATH, x),
+ data_transforms[x])
+ for x in ['train', 'val']}
+dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=16,
+ shuffle=True, num_workers=8)
+ for x in ['train', 'val']}
+dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
+class_names = image_datasets['train'].classes
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+
+######################################################################
+# Visualize a few images
+# ^^^^^^^^^^^^^^^^^^^^^^
+#
+# Let’s visualize a few training images so as to understand the data
+# augmentations.
+#
+
+import torchvision
+
+def imshow(inp, title=None, ax=None, figsize=(5, 5)):
+ """Imshow for Tensor."""
+ inp = inp.numpy().transpose((1, 2, 0))
+ mean = np.array([0.485, 0.456, 0.406])
+ std = np.array([0.229, 0.224, 0.225])
+ inp = std * inp + mean
+ inp = np.clip(inp, 0, 1)
+ if ax is None:
+ fig, ax = plt.subplots(1, figsize=figsize)
+ ax.imshow(inp)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ if title is not None:
+ ax.set_title(title)
+
+# Get a batch of training data
+inputs, classes = next(iter(dataloaders['train']))
+
+# Make a grid from batch
+out = torchvision.utils.make_grid(inputs, nrow=4)
+
+fig, ax = plt.subplots(1, figsize=(10, 10))
+imshow(out, title=[class_names[x] for x in classes], ax=ax)
+
+
+######################################################################
+# Training the model
+# ------------------
+#
+# Now, let’s write a general function to train a model. Here, we will
+# illustrate:
+#
+# - Scheduling the learning rate
+# - Saving the best model
+#
+# In the following, parameter ``scheduler`` is an LR scheduler object from
+# ``torch.optim.lr_scheduler``.
+#
+
+def train_model(model, criterion, optimizer, scheduler, num_epochs=25, device='cpu'):
+ since = time.time()
+
+ best_model_wts = copy.deepcopy(model.state_dict())
+ best_acc = 0.0
+
+ for epoch in range(num_epochs):
+ print('Epoch {}/{}'.format(epoch, num_epochs - 1))
+ print('-' * 10)
+
+ # Each epoch has a training and validation phase
+ for phase in ['train', 'val']:
+ if phase == 'train':
+ model.train() # Set model to training mode
+ else:
+ model.eval() # Set model to evaluate mode
+
+ running_loss = 0.0
+ running_corrects = 0
+
+ # Iterate over data.
+ for inputs, labels in dataloaders[phase]:
+ inputs = inputs.to(device)
+ labels = labels.to(device)
+
+ # zero the parameter gradients
+ optimizer.zero_grad()
+
+ # forward
+ # track history if only in train
+ with torch.set_grad_enabled(phase == 'train'):
+ outputs = model(inputs)
+ _, preds = torch.max(outputs, 1)
+ loss = criterion(outputs, labels)
+
+ # backward + optimize only if in training phase
+ if phase == 'train':
+ loss.backward()
+ optimizer.step()
+
+ # statistics
+ running_loss += loss.item() * inputs.size(0)
+ running_corrects += torch.sum(preds == labels.data)
+ if phase == 'train':
+ scheduler.step()
+
+ epoch_loss = running_loss / dataset_sizes[phase]
+ epoch_acc = running_corrects.double() / dataset_sizes[phase]
+
+ print('{} Loss: {:.4f} Acc: {:.4f}'.format(
+ phase, epoch_loss, epoch_acc))
+
+ # deep copy the model
+ if phase == 'val' and epoch_acc > best_acc:
+ best_acc = epoch_acc
+ best_model_wts = copy.deepcopy(model.state_dict())
+
+ print()
+
+ time_elapsed = time.time() - since
+ print('Training complete in {:.0f}m {:.0f}s'.format(
+ time_elapsed // 60, time_elapsed % 60))
+ print('Best val Acc: {:4f}'.format(best_acc))
+
+ # load best model weights
+ model.load_state_dict(best_model_wts)
+ return model
+
+
+######################################################################
+# Visualizing the model predictions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+#
+# Generic function to display predictions for a few images
+#
+
+def visualize_model(model, rows=3, cols=3):
+ was_training = model.training
+ model.eval()
+ current_row = current_col = 0
+ fig, ax = plt.subplots(rows, cols, figsize=(cols*2, rows*2))
+
+ with torch.no_grad():
+ for idx, (imgs, lbls) in enumerate(dataloaders['val']):
+ imgs = imgs.cpu()
+ lbls = lbls.cpu()
+
+ outputs = model(imgs)
+ _, preds = torch.max(outputs, 1)
+
+ for jdx in range(imgs.size()[0]):
+ imshow(imgs.data[jdx], ax=ax[current_row, current_col])
+ ax[current_row, current_col].axis('off')
+ ax[current_row, current_col].set_title('predicted: {}'.format(class_names[preds[jdx]]))
+
+ current_col += 1
+ if current_col >= cols:
+ current_row += 1
+ current_col = 0
+ if current_row >= rows:
+ model.train(mode=was_training)
+ return
+ model.train(mode=was_training)
+
+
+######################################################################
+# Part 1. Training a Custom Classifier based on a Quantized Feature Extractor
+# ---------------------------------------------------------------------------
+#
+# In this section you will use a “frozen” quantized feature extractor, and
+# train a custom classifier head on top of it. Unlike floating point
+# models, you don’t need to set requires_grad=False for the quantized
+# model, as it has no trainable parameters. Please, refer to the
+# documentation https://pytorch.org/docs/stable/quantization.html\ \_ for
+# more details.
+#
+# Load a pretrained model: for this exercise you will be using ResNet-18
+# https://pytorch.org/hub/pytorch_vision_resnet/\ \_.
+#
+
+import torchvision.models.quantization as models
+
+# We will need the number of filters in the `fc` for future use.
+# Here the size of each output sample is set to 2.
+# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
+model_fe = models.resnet18(pretrained=True, progress=True, quantize=True)
+num_ftrs = model_fe.fc.in_features
+
+
+######################################################################
+# At this point you need to mofify the pretrained model: Because the model
+# has the quantize/dequantize blocks in the beginning and the end, butt we
+# will only uuse the feature extractor, the dequantizatioin layer has to
+# move right before the linear layer (the head). The easiest way of doing
+# it is to wrap the model under the ``nn.Sequential``.
+#
+# The first step to do, is to isolate the feature extractor in the ResNet
+# model. Although in this example you are tasked to use all layers except
+# ``fc`` as the feature extractor, in reality, you can take as many parts
+# as you need. This would be useful in case you would like to replace some
+# of the convolutional layers as well.
+#
+
+
+######################################################################
+# **Notice that when isolating the feature extractor from a quantized
+# model, you have to place the quantizer in the beginning and in the end
+# of it.**
+#
+
+from torch import nn
+
+def create_combined_model(model_fe):
+ # Step 1. Isolate the feature extractor.
+ model_fe_features = nn.Sequential(
+ model_fe.quant, # Quantize the input
+ model_fe.conv1,
+ model_fe.bn1,
+ model_fe.relu,
+ model_fe.maxpool,
+ model_fe.layer1,
+ model_fe.layer2,
+ model_fe.layer3,
+ model_fe.layer4,
+ model_fe.avgpool,
+ model_fe.dequant, # Dequantize the output
+ )
+
+ # Step 2. Create a new "head"
+ new_head = nn.Sequential(
+ nn.Dropout(p=0.5),
+ nn.Linear(num_ftrs, 2),
+ )
+
+ # Step 3. Combine, and don't forget the quant stubs.
+ new_model = nn.Sequential(
+ model_fe_features,
+ nn.Flatten(1),
+ new_head,
+ )
+ return new_model
+
+new_model = create_combined_model(model_fe)
+
+
+######################################################################
+# .. warning:: Currently the quantized models can only be run on CPU.
+# However, it is possible to send the non-quantized parts of the model to
+# a GPU.
+#
+
+import torch.optim as optim
+new_model = new_model.to('cpu')
+
+criterion = nn.CrossEntropyLoss()
+
+# Note that we are only training the head.
+optimizer_ft = optim.SGD(new_model.parameters(), lr=0.01, momentum=0.9)
+
+# Decay LR by a factor of 0.1 every 7 epochs
+exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
+
+
+######################################################################
+# Train and evaluate
+# ------------------
+#
+# This step takes around 15-25 min on CPU. Because the quantized model can
+# only run on the CPU, you cannot run the training on GPU.
+#
+
+new_model = train_model(new_model, criterion, optimizer_ft, exp_lr_scheduler,
+ num_epochs=25, device='cpu')
+
+visualize_model(new_model)
+plt.tight_layout()
+
+
+######################################################################
+# **Part 2. Finetuning the quantizable model**
+#
+# In this part, we fine tune the feature extractor used for transfer
+# learning, and quantize the feature extractor. Note that in both part 1
+# and 2, the feature extractor is quantized. The difference is that in
+# part 1, we use a pretrained quantized model. In this part, we create a
+# quantized feature extractor after fine tuning on the data-set of
+# interest, so this is a way to get better accuracy with transfer learning
+# while having the benefits of quantization. Note that in our specific
+# example, the training set is really small (120 images) so the benefits
+# of fine tuning the entire model is not apparent. However, the procedure
+# shown here will improve accuracy for transfer learning with larger
+# datasets.
+#
+# The pretrained feature extractor must be quantizable, i.e we need to do
+# the following: 1. Fuse (Conv, BN, ReLU), (Conv, BN) and (Conv, ReLU)
+# using torch.quantization.fuse_modules. 2. Connect the feature extractor
+# with a custom head. This requires dequantizing the output of the feature
+# extractor. 3. Insert fake-quantization modules at appropriate locations
+# in the feature extractor to mimic quantization during training.
+#
+# For step (1), we use models from torchvision/models/quantization, which
+# support a member method fuse_model, which fuses all the conv, bn, and
+# relu modules. In general, this would require calling the
+# torch.quantization.fuse_modules API with the list of modules to fuse.
+#
+# Step (2) is done by the function create_custom_model function that we
+# used in the previous section.
+#
+# Step (3) is achieved by using torch.quantization.prepare_qat, which
+# inserts fake-quantization modules.
+#
+# Step (4) Fine tune the model with the desired custom head.
+#
+# Step (5) We convert the fine tuned model into a quantized model (only
+# the feature extractor is quantized) by calling
+# torch.quantization.convert
+#
+# .. note:: Because of the random initialization your results might differ
+# from the results shown here.
+#
+
+model = models.resnet18(pretrained=True, progress=True, quantize=False) # notice `quantize=False`
+num_ftrs = model.fc.in_features
+
+# Step 1
+model.train()
+model.fuse_model()
+# Step 2
+model_ft = create_combined_model(model)
+model_ft[0].qconfig = torch.quantization.default_qat_qconfig # Use default QAT configuration
+# Step 3
+model_ft = torch.quantization.prepare_qat(model_ft, inplace=True)
+
+
+
+
+######################################################################
+# Finetuning the model
+# --------------------
+#
+# We fine tune the entire model including the feature extractor. In
+# general, this will lead to higher accuracy. However, due to the small
+# training set used here, we end up overfitting to the training set.
+#
+
+# Step 4. Fine tune the model
+
+for param in model_ft.parameters():
+ param.requires_grad = True
+
+model_ft.cuda() # We can fine-tune on GPU
+
+criterion = nn.CrossEntropyLoss()
+
+# Note that we are training everything, so the learning rate is lower
+# Notice the smaller learning rate
+optimizer_ft = optim.SGD(model_ft.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.1)
+
+# Decay LR by a factor of 0.3 every several epochs
+exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.3)
+
+model_ft_tuned = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
+ num_epochs=25, device='cuda')
+
+# Step 5. Convert to quantized model
+
+from torch.quantization import convert
+model_ft_tuned.cpu()
+
+model_quantized_and_trained = convert(model_ft_tuned, inplace=False)
+
+
+
+######################################################################
+# Lets see how the quantized model performs on a few images
+#
+
+visualize_model(model_quantized_and_trained)
+
+plt.ioff()
+plt.tight_layout()
+plt.show()
+