Skip to content

Commit b09d465

Browse files
HDCharlessvekars
andauthored
Adding tutorial for gpu quantization using torchao (#2730)
* Adding tutorial for gpu quantization using torchao Summary: Optimizing and Quantizing the SAM model ---- Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent e96a8a9 commit b09d465

File tree

4 files changed

+334
-3
lines changed

4 files changed

+334
-3
lines changed

.jenkins/download_data.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ def download_lenet_mnist() -> None:
105105
sha256="cb5f8e578aef96d5c1a2cc5695e1aa9bbf4d0fe00d25760eeebaaac6ebc2edcb",
106106
)
107107

108+
def download_gpu_quantization_torchao() -> None:
109+
# Download SAM model checkpoint for prototype_source/gpu_quantization_torchao_tutorial.py
110+
download_url_to_file("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
111+
prefix=PROTOTYPE_DATA_DIR,
112+
dst="sam_vit_h_4b8939.pth",
113+
sha256="a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e",
114+
)
108115

109116
def main() -> None:
110117
DATA_DIR.mkdir(exist_ok=True)
@@ -122,7 +129,8 @@ def main() -> None:
122129
download_dcgan_data()
123130
if FILES_TO_RUN is None or "fgsm_tutorial" in FILES_TO_RUN:
124131
download_lenet_mnist()
125-
132+
if FILES_TO_RUN is None or "gpu_quantization_torchao_tutorial" in FILES_TO_RUN:
133+
download_gpu_quantization_torchao()
126134

127135
if __name__ == "__main__":
128136
main()

.jenkins/metadata.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,21 @@
2828
"intermediate_source/model_parallel_tutorial.py": {
2929
"needs": "linux.16xlarge.nvidia.gpu"
3030
},
31+
"intermediate_source/torchvision_tutorial.py": {
32+
"needs": "linux.g5.4xlarge.nvidia.gpu",
33+
"_comment": "does not require a5g but needs to run before gpu_quantization_torchao_tutorial.py."
34+
},
35+
"advanced_source/coding_ddpg.py": {
36+
"needs": "linux.g5.4xlarge.nvidia.gpu",
37+
"_comment": "does not require a5g but needs to run before gpu_quantization_torchao_tutorial.py."
38+
},
3139
"intermediate_source/torch_compile_tutorial.py": {
3240
"needs": "linux.g5.4xlarge.nvidia.gpu"
3341
},
3442
"intermediate_source/scaled_dot_product_attention_tutorial.py": {
3543
"needs": "linux.g5.4xlarge.nvidia.gpu"
44+
},
45+
"prototype_source/gpu_quantization_torchao_tutorial.py": {
46+
"needs": "linux.g5.4xlarge.nvidia.gpu"
3647
}
3748
}
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
"""
2+
(prototype) GPU Quantization with TorchAO
3+
======================================================
4+
5+
**Author**: `HDCharles <https://github.com/HDCharles>`_
6+
7+
In this tutorial, we will walk you through the quantization and optimization
8+
of the popular `segment anything model <https://github.com/facebookresearch/segment-anything>`_. These
9+
steps will mimic some of those taken to develop the
10+
`segment-anything-fast <https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py#L15>`_
11+
repo. This step-by-step guide demonstrates how you can
12+
apply these techniques to speed up your own models, especially those
13+
that use transformers. To that end, we will focus on widely applicable
14+
techniques, such as optimizing performance with ``torch.compile`` and
15+
quantization and measure their impact.
16+
17+
"""
18+
19+
20+
######################################################################
21+
# Set up Your Environment
22+
# --------------------------------
23+
#
24+
# First, let's configure your environment. This guide was written for CUDA 12.1.
25+
# We have run this tutorial on an A100-PG509-200 power limited to 330.00 W. If you
26+
# are using a different hardware, you might see different performance numbers.
27+
#
28+
#
29+
# .. code-block:: bash
30+
#
31+
# > conda create -n myenv python=3.10
32+
# > pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
33+
# > pip install git+https://github.com/facebookresearch/segment-anything.git
34+
# > pip install git+https://github.com/pytorch-labs/ao.git
35+
#
36+
# Segment Anything Model checkpoint setup:
37+
#
38+
# 1. Go to the `segment-anything repo <checkpoint https://github.com/facebookresearch/segment-anything/tree/main#model-checkpoints>`_ and download the ``vit_h`` checkpoint. Alternatively, you can just use ``wget``: `wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth --directory-prefix=<path>
39+
# 2. Pass in that directory by editing the code below to say:
40+
#
41+
# .. code-block::
42+
#
43+
# {sam_checkpoint_base_path}=<path>
44+
#
45+
# This was run on an A100-PG509-200 power limited to 330.00 W
46+
#
47+
48+
import torch
49+
from torchao.quantization import change_linear_weights_to_int8_dqtensors
50+
from segment_anything import sam_model_registry
51+
from torch.utils.benchmark import Timer
52+
53+
sam_checkpoint_base_path = "data"
54+
model_type = 'vit_h'
55+
model_name = 'sam_vit_h_4b8939.pth'
56+
checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}"
57+
batchsize = 16
58+
only_one_block = True
59+
60+
61+
@torch.no_grad()
62+
def benchmark(f, *args, **kwargs):
63+
for _ in range(3):
64+
f(*args, **kwargs)
65+
torch.cuda.synchronize()
66+
67+
torch.cuda.reset_peak_memory_stats()
68+
t0 = Timer(
69+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
70+
)
71+
res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20)
72+
return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}
73+
74+
def get_sam_model(only_one_block=False, batchsize=1):
75+
sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda()
76+
model = sam.image_encoder.eval()
77+
image = torch.randn(batchsize, 3, 1024, 1024, device='cuda')
78+
79+
# code to use just a single block of the model
80+
if only_one_block:
81+
model = model.blocks[0]
82+
image = torch.randn(batchsize, 64, 64, 1280, device='cuda')
83+
return model, image
84+
85+
86+
######################################################################
87+
# In this tutorial, we focus on quantizing the ``image_encoder`` because the
88+
# inputs to it are statically sized while the prompt encoder and mask
89+
# decoder have variable sizes which makes them harder to quantize.
90+
#
91+
# We’ll focus on just a single block at first to make the analysis easier.
92+
#
93+
# Let's start by measuring the baseline runtime.
94+
95+
try:
96+
model, image = get_sam_model(only_one_block, batchsize)
97+
fp32_res = benchmark(model, image)
98+
print(f"base fp32 runtime of the model is {fp32_res['time']:0.2f}ms and peak memory {fp32_res['memory']:0.2f}GB")
99+
# base fp32 runtime of the model is 186.16ms and peak memory 6.33GB
100+
except Exception as e:
101+
print("unable to run fp32 model: ", e)
102+
103+
104+
105+
######################################################################
106+
# We can achieve an instant performance boost by converting the model to bfloat16.
107+
# The reason we opt for bfloat16 over fp16 is due to its dynamic range, which is comparable to
108+
# that of fp32. Both bfloat16 and fp32 possess 8 exponential bits, whereas fp16 only has 4. This
109+
# larger dynamic range helps protect us from overflow errors and other issues that can arise
110+
# when scaling and rescaling tensors due to quantization.
111+
#
112+
113+
model, image = get_sam_model(only_one_block, batchsize)
114+
model = model.to(torch.bfloat16)
115+
image = image.to(torch.bfloat16)
116+
bf16_res = benchmark(model, image)
117+
print(f"bf16 runtime of the block is {bf16_res['time']:0.2f}ms and peak memory {bf16_res['memory']: 0.2f}GB")
118+
# bf16 runtime of the block is 25.43ms and peak memory 3.17GB
119+
120+
121+
######################################################################
122+
# Just this quick change improves runtime by a factor of ~7x in the tests we have
123+
# conducted (186.16ms to 25.43ms).
124+
#
125+
# Next, let's use ``torch.compile`` with our model to see how much the performance
126+
# improves.
127+
#
128+
129+
model_c = torch.compile(model, mode='max-autotune')
130+
comp_res = benchmark(model_c, image)
131+
print(f"bf16 compiled runtime of the block is {comp_res['time']:0.2f}ms and peak memory {comp_res['memory']: 0.2f}GB")
132+
# bf16 compiled runtime of the block is 19.95ms and peak memory 2.24GB
133+
134+
135+
######################################################################
136+
# The first time this is run, you should see a sequence of ``AUTOTUNE``
137+
# outputs which occurs when inductor compares the performance between
138+
# various kernel parameters for a kernel. This only happens once (unless
139+
# you delete your cache) so if you run the cell again you should just get
140+
# the benchmark output.
141+
#
142+
# ``torch.compile`` yields about another 27% improvement. This brings the
143+
# model to a reasonable baseline where we now have to work a bit harder
144+
# for improvements.
145+
#
146+
# Next, let's apply quantization. Quantization for GPUs comes in three main forms
147+
# in `torchao <https://github.com/pytorch-labs/ao>`_ which is just native
148+
# pytorch+python code. This includes:
149+
#
150+
# * int8 dynamic quantization
151+
# * int8 weight-only quantization
152+
# * int4 weight-only quantization
153+
#
154+
# Different models, or sometimes different layers in a model can require different techniques.
155+
# For models which are heavily compute bound, dynamic quantization tends
156+
# to work the best since it swaps the normal expensive floating point
157+
# matmul ops with integer versions. Weight-only quantization works better
158+
# in memory bound situations where the benefit comes from loading less
159+
# weight data, rather than doing less computation. The torchao APIs:
160+
#
161+
# ``change_linear_weights_to_int8_dqtensors``,
162+
# ``change_linear_weights_to_int8_woqtensors`` or
163+
# ``change_linear_weights_to_int4_woqtensors``
164+
#
165+
# can be used to easily apply the desired quantization technique and then
166+
# once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is
167+
# complete and we can see our speedup.
168+
#
169+
# .. note::
170+
# You might experience issues with these on older versions of PyTorch. If you run
171+
# into an issue, you can use ``apply_dynamic_quant`` and
172+
# ``apply_weight_only_int8_quant`` instead as drop in replacement for the two
173+
# above (no replacement for int4).
174+
#
175+
# The difference between the two APIs is that ``change_linear_weights`` API
176+
# alters the weight tensor of the linear module so instead of doing a
177+
# normal linear, it does a quantized operation. This is helpful when you
178+
# have non-standard linear ops that do more than one thing. The ``apply``
179+
# APIs directly swap the linear modules for a quantized module which
180+
# works on older versions but doesn’t work with non-standard linear
181+
# modules.
182+
#
183+
# In this case Segment Anything is compute-bound so we’ll use dynamic quantization:
184+
#
185+
186+
del model_c, model, image
187+
model, image = get_sam_model(only_one_block, batchsize)
188+
model = model.to(torch.bfloat16)
189+
image = image.to(torch.bfloat16)
190+
change_linear_weights_to_int8_dqtensors(model)
191+
model_c = torch.compile(model, mode='max-autotune')
192+
quant_res = benchmark(model_c, image)
193+
print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
194+
# bf16 compiled runtime of the quantized block is 19.04ms and peak memory 3.58GB
195+
196+
197+
######################################################################
198+
# With quantization, we have improved performance a bit more but memory usage increased
199+
# significantly.
200+
#
201+
# This is for two reasons:
202+
#
203+
# 1) Quantization adds overhead to the model
204+
# since we need to quantize and dequantize the input and output. For small
205+
# batch sizes this overhead can actually make the model go slower.
206+
# 2) Even though we are doing a quantized matmul, such as ``int8 x int8``,
207+
# the result of the multiplication gets stored in an int32 tensor
208+
# which is twice the size of the result from the non-quantized model.
209+
# If we can avoid creating this int32 tensor, our memory usage will improve a lot.
210+
#
211+
# We can fix #2 by fusing the integer matmul with the subsequent rescale
212+
# operation since the final output will be bf16, if we immediately convert
213+
# the int32 tensor to bf16 and instead store that we’ll get better
214+
# performance in terms of both runtime and memory.
215+
#
216+
# The way to do this, is to enable the option
217+
# ``force_fuse_int_mm_with_mul`` in the inductor config.
218+
#
219+
220+
del model_c, model, image
221+
model, image = get_sam_model(only_one_block, batchsize)
222+
model = model.to(torch.bfloat16)
223+
image = image.to(torch.bfloat16)
224+
torch._inductor.config.force_fuse_int_mm_with_mul = True
225+
change_linear_weights_to_int8_dqtensors(model)
226+
model_c = torch.compile(model, mode='max-autotune')
227+
quant_res = benchmark(model_c, image)
228+
print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
229+
# bf16 compiled runtime of the fused quantized block is 18.78ms and peak memory 2.37GB
230+
231+
232+
######################################################################
233+
# The fusion improves performance by another small bit (about 6% over the
234+
# baseline in total) and removes almost all the memory increase, the
235+
# remaining amount (2.37GB quantized vs 2.24GB unquantized) is due to
236+
# quantization overhead which cannot be helped.
237+
#
238+
# We’re still not done though, we can apply a few general purpose
239+
# optimizations to get our final best-case performance.
240+
#
241+
# 1) We can sometimes improve performance by disabling epilogue fusion
242+
# since the autotuning process can be confused by fusions and choose
243+
# bad kernel parameters.
244+
# 2) We can apply coordinate descent tuning in all directions to enlarge
245+
# the search area for kernel parameters.
246+
#
247+
248+
del model_c, model, image
249+
model, image = get_sam_model(only_one_block, batchsize)
250+
model = model.to(torch.bfloat16)
251+
image = image.to(torch.bfloat16)
252+
torch._inductor.config.epilogue_fusion = False
253+
torch._inductor.config.coordinate_descent_tuning = True
254+
torch._inductor.config.coordinate_descent_check_all_directions = True
255+
torch._inductor.config.force_fuse_int_mm_with_mul = True
256+
change_linear_weights_to_int8_dqtensors(model)
257+
model_c = torch.compile(model, mode='max-autotune')
258+
quant_res = benchmark(model_c, image)
259+
print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
260+
# bf16 compiled runtime of the final quantized block is 18.16ms and peak memory 2.39GB
261+
262+
263+
######################################################################
264+
# As you can see, we’ve squeezed another small improvement from the model,
265+
# taking our total improvement to over 10x compared to our original. To
266+
# get a final estimate of the impact of quantization lets do an apples to
267+
# apples comparison on the full model since the actual improvement will
268+
# differ block by block depending on the shapes involved.
269+
#
270+
271+
try:
272+
del model_c, model, image
273+
model, image = get_sam_model(False, batchsize)
274+
model = model.to(torch.bfloat16)
275+
image = image.to(torch.bfloat16)
276+
model_c = torch.compile(model, mode='max-autotune')
277+
quant_res = benchmark(model_c, image)
278+
print(f"bf16 compiled runtime of the compiled full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
279+
# bf16 compiled runtime of the compiled full model is 729.65ms and peak memory 23.96GB
280+
281+
del model_c, model, image
282+
model, image = get_sam_model(False, batchsize)
283+
model = model.to(torch.bfloat16)
284+
image = image.to(torch.bfloat16)
285+
change_linear_weights_to_int8_dqtensors(model)
286+
model_c = torch.compile(model, mode='max-autotune')
287+
quant_res = benchmark(model_c, image)
288+
print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
289+
# bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB
290+
except Exception as e:
291+
print("unable to run full model: ", e)
292+
293+
294+
295+
######################################################################
296+
# Conclusion
297+
# -----------------
298+
# In this tutorial, we have learned about the quantization and optimization techniques
299+
# on the example of the segment anything model.
300+
301+
# In the end, we achieved a full-model apples to apples quantization speedup
302+
# of about 7.7% on batch size 16 (677.28ms to 729.65ms). We can push this a
303+
# bit further by increasing the batch size and optimizing other parts of
304+
# the model. For example, this can be done with some form of flash attention.
305+
#
306+
# For more information visit
307+
# `torchao <https://github.com/pytorch-labs/ao>`_ and try it on your own
308+
# models.
309+
#

requirements.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ sphinx-gallery==0.11.1
66
sphinx_design
77
docutils==0.16
88
sphinx-copybutton
9-
tqdm
10-
numpy
9+
tqdm==4.66.1
10+
numpy==1.24.4
1111
matplotlib
1212
librosa
1313
torch
@@ -53,6 +53,7 @@ scipy==1.11.1
5353
numba==0.57.1
5454
pillow==10.2.0
5555
wget
56+
gym==0.26.2
5657
gym-super-mario-bros==7.4.0
5758
pyopengl
5859
gymnasium[mujoco]==0.27.0
@@ -61,3 +62,5 @@ iopath
6162
pygame==2.1.2
6263
pycocotools
6364
semilearn==0.3.2
65+
torchao==0.0.3
66+
segment_anything==1.0

0 commit comments

Comments
 (0)