|
| 1 | +""" |
| 2 | +(prototype) GPU Quantization with TorchAO |
| 3 | +====================================================== |
| 4 | +
|
| 5 | +**Author**: `HDCharles <https://github.com/HDCharles>`_ |
| 6 | +
|
| 7 | +In this tutorial, we will walk you through the quantization and optimization |
| 8 | +of the popular `segment anything model <https://github.com/facebookresearch/segment-anything>`_. These |
| 9 | +steps will mimic some of those taken to develop the |
| 10 | +`segment-anything-fast <https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py#L15>`_ |
| 11 | +repo. This step-by-step guide demonstrates how you can |
| 12 | +apply these techniques to speed up your own models, especially those |
| 13 | +that use transformers. To that end, we will focus on widely applicable |
| 14 | +techniques, such as optimizing performance with ``torch.compile`` and |
| 15 | +quantization and measure their impact. |
| 16 | +
|
| 17 | +""" |
| 18 | + |
| 19 | + |
| 20 | +###################################################################### |
| 21 | +# Set up Your Environment |
| 22 | +# -------------------------------- |
| 23 | +# |
| 24 | +# First, let's configure your environment. This guide was written for CUDA 12.1. |
| 25 | +# We have run this tutorial on an A100-PG509-200 power limited to 330.00 W. If you |
| 26 | +# are using a different hardware, you might see different performance numbers. |
| 27 | +# |
| 28 | +# |
| 29 | +# .. code-block:: bash |
| 30 | +# |
| 31 | +# > conda create -n myenv python=3.10 |
| 32 | +# > pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 |
| 33 | +# > pip install git+https://github.com/facebookresearch/segment-anything.git |
| 34 | +# > pip install git+https://github.com/pytorch-labs/ao.git |
| 35 | +# |
| 36 | +# Segment Anything Model checkpoint setup: |
| 37 | +# |
| 38 | +# 1. Go to the `segment-anything repo <checkpoint https://github.com/facebookresearch/segment-anything/tree/main#model-checkpoints>`_ and download the ``vit_h`` checkpoint. Alternatively, you can just use ``wget``: `wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth --directory-prefix=<path> |
| 39 | +# 2. Pass in that directory by editing the code below to say: |
| 40 | +# |
| 41 | +# .. code-block:: |
| 42 | +# |
| 43 | +# {sam_checkpoint_base_path}=<path> |
| 44 | +# |
| 45 | +# This was run on an A100-PG509-200 power limited to 330.00 W |
| 46 | +# |
| 47 | + |
| 48 | +import torch |
| 49 | +from torchao.quantization import change_linear_weights_to_int8_dqtensors |
| 50 | +from segment_anything import sam_model_registry |
| 51 | +from torch.utils.benchmark import Timer |
| 52 | + |
| 53 | +sam_checkpoint_base_path = "data" |
| 54 | +model_type = 'vit_h' |
| 55 | +model_name = 'sam_vit_h_4b8939.pth' |
| 56 | +checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}" |
| 57 | +batchsize = 16 |
| 58 | +only_one_block = True |
| 59 | + |
| 60 | + |
| 61 | +@torch.no_grad() |
| 62 | +def benchmark(f, *args, **kwargs): |
| 63 | + for _ in range(3): |
| 64 | + f(*args, **kwargs) |
| 65 | + torch.cuda.synchronize() |
| 66 | + |
| 67 | + torch.cuda.reset_peak_memory_stats() |
| 68 | + t0 = Timer( |
| 69 | + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} |
| 70 | + ) |
| 71 | + res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20) |
| 72 | + return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9} |
| 73 | + |
| 74 | +def get_sam_model(only_one_block=False, batchsize=1): |
| 75 | + sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda() |
| 76 | + model = sam.image_encoder.eval() |
| 77 | + image = torch.randn(batchsize, 3, 1024, 1024, device='cuda') |
| 78 | + |
| 79 | + # code to use just a single block of the model |
| 80 | + if only_one_block: |
| 81 | + model = model.blocks[0] |
| 82 | + image = torch.randn(batchsize, 64, 64, 1280, device='cuda') |
| 83 | + return model, image |
| 84 | + |
| 85 | + |
| 86 | +###################################################################### |
| 87 | +# In this tutorial, we focus on quantizing the ``image_encoder`` because the |
| 88 | +# inputs to it are statically sized while the prompt encoder and mask |
| 89 | +# decoder have variable sizes which makes them harder to quantize. |
| 90 | +# |
| 91 | +# We’ll focus on just a single block at first to make the analysis easier. |
| 92 | +# |
| 93 | +# Let's start by measuring the baseline runtime. |
| 94 | + |
| 95 | +try: |
| 96 | + model, image = get_sam_model(only_one_block, batchsize) |
| 97 | + fp32_res = benchmark(model, image) |
| 98 | + print(f"base fp32 runtime of the model is {fp32_res['time']:0.2f}ms and peak memory {fp32_res['memory']:0.2f}GB") |
| 99 | + # base fp32 runtime of the model is 186.16ms and peak memory 6.33GB |
| 100 | +except Exception as e: |
| 101 | + print("unable to run fp32 model: ", e) |
| 102 | + |
| 103 | + |
| 104 | + |
| 105 | +###################################################################### |
| 106 | +# We can achieve an instant performance boost by converting the model to bfloat16. |
| 107 | +# The reason we opt for bfloat16 over fp16 is due to its dynamic range, which is comparable to |
| 108 | +# that of fp32. Both bfloat16 and fp32 possess 8 exponential bits, whereas fp16 only has 4. This |
| 109 | +# larger dynamic range helps protect us from overflow errors and other issues that can arise |
| 110 | +# when scaling and rescaling tensors due to quantization. |
| 111 | +# |
| 112 | + |
| 113 | +model, image = get_sam_model(only_one_block, batchsize) |
| 114 | +model = model.to(torch.bfloat16) |
| 115 | +image = image.to(torch.bfloat16) |
| 116 | +bf16_res = benchmark(model, image) |
| 117 | +print(f"bf16 runtime of the block is {bf16_res['time']:0.2f}ms and peak memory {bf16_res['memory']: 0.2f}GB") |
| 118 | +# bf16 runtime of the block is 25.43ms and peak memory 3.17GB |
| 119 | + |
| 120 | + |
| 121 | +###################################################################### |
| 122 | +# Just this quick change improves runtime by a factor of ~7x in the tests we have |
| 123 | +# conducted (186.16ms to 25.43ms). |
| 124 | +# |
| 125 | +# Next, let's use ``torch.compile`` with our model to see how much the performance |
| 126 | +# improves. |
| 127 | +# |
| 128 | + |
| 129 | +model_c = torch.compile(model, mode='max-autotune') |
| 130 | +comp_res = benchmark(model_c, image) |
| 131 | +print(f"bf16 compiled runtime of the block is {comp_res['time']:0.2f}ms and peak memory {comp_res['memory']: 0.2f}GB") |
| 132 | +# bf16 compiled runtime of the block is 19.95ms and peak memory 2.24GB |
| 133 | + |
| 134 | + |
| 135 | +###################################################################### |
| 136 | +# The first time this is run, you should see a sequence of ``AUTOTUNE`` |
| 137 | +# outputs which occurs when inductor compares the performance between |
| 138 | +# various kernel parameters for a kernel. This only happens once (unless |
| 139 | +# you delete your cache) so if you run the cell again you should just get |
| 140 | +# the benchmark output. |
| 141 | +# |
| 142 | +# ``torch.compile`` yields about another 27% improvement. This brings the |
| 143 | +# model to a reasonable baseline where we now have to work a bit harder |
| 144 | +# for improvements. |
| 145 | +# |
| 146 | +# Next, let's apply quantization. Quantization for GPUs comes in three main forms |
| 147 | +# in `torchao <https://github.com/pytorch-labs/ao>`_ which is just native |
| 148 | +# pytorch+python code. This includes: |
| 149 | +# |
| 150 | +# * int8 dynamic quantization |
| 151 | +# * int8 weight-only quantization |
| 152 | +# * int4 weight-only quantization |
| 153 | +# |
| 154 | +# Different models, or sometimes different layers in a model can require different techniques. |
| 155 | +# For models which are heavily compute bound, dynamic quantization tends |
| 156 | +# to work the best since it swaps the normal expensive floating point |
| 157 | +# matmul ops with integer versions. Weight-only quantization works better |
| 158 | +# in memory bound situations where the benefit comes from loading less |
| 159 | +# weight data, rather than doing less computation. The torchao APIs: |
| 160 | +# |
| 161 | +# ``change_linear_weights_to_int8_dqtensors``, |
| 162 | +# ``change_linear_weights_to_int8_woqtensors`` or |
| 163 | +# ``change_linear_weights_to_int4_woqtensors`` |
| 164 | +# |
| 165 | +# can be used to easily apply the desired quantization technique and then |
| 166 | +# once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is |
| 167 | +# complete and we can see our speedup. |
| 168 | +# |
| 169 | +# .. note:: |
| 170 | +# You might experience issues with these on older versions of PyTorch. If you run |
| 171 | +# into an issue, you can use ``apply_dynamic_quant`` and |
| 172 | +# ``apply_weight_only_int8_quant`` instead as drop in replacement for the two |
| 173 | +# above (no replacement for int4). |
| 174 | +# |
| 175 | +# The difference between the two APIs is that ``change_linear_weights`` API |
| 176 | +# alters the weight tensor of the linear module so instead of doing a |
| 177 | +# normal linear, it does a quantized operation. This is helpful when you |
| 178 | +# have non-standard linear ops that do more than one thing. The ``apply`` |
| 179 | +# APIs directly swap the linear modules for a quantized module which |
| 180 | +# works on older versions but doesn’t work with non-standard linear |
| 181 | +# modules. |
| 182 | +# |
| 183 | +# In this case Segment Anything is compute-bound so we’ll use dynamic quantization: |
| 184 | +# |
| 185 | + |
| 186 | +del model_c, model, image |
| 187 | +model, image = get_sam_model(only_one_block, batchsize) |
| 188 | +model = model.to(torch.bfloat16) |
| 189 | +image = image.to(torch.bfloat16) |
| 190 | +change_linear_weights_to_int8_dqtensors(model) |
| 191 | +model_c = torch.compile(model, mode='max-autotune') |
| 192 | +quant_res = benchmark(model_c, image) |
| 193 | +print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") |
| 194 | +# bf16 compiled runtime of the quantized block is 19.04ms and peak memory 3.58GB |
| 195 | + |
| 196 | + |
| 197 | +###################################################################### |
| 198 | +# With quantization, we have improved performance a bit more but memory usage increased |
| 199 | +# significantly. |
| 200 | +# |
| 201 | +# This is for two reasons: |
| 202 | +# |
| 203 | +# 1) Quantization adds overhead to the model |
| 204 | +# since we need to quantize and dequantize the input and output. For small |
| 205 | +# batch sizes this overhead can actually make the model go slower. |
| 206 | +# 2) Even though we are doing a quantized matmul, such as ``int8 x int8``, |
| 207 | +# the result of the multiplication gets stored in an int32 tensor |
| 208 | +# which is twice the size of the result from the non-quantized model. |
| 209 | +# If we can avoid creating this int32 tensor, our memory usage will improve a lot. |
| 210 | +# |
| 211 | +# We can fix #2 by fusing the integer matmul with the subsequent rescale |
| 212 | +# operation since the final output will be bf16, if we immediately convert |
| 213 | +# the int32 tensor to bf16 and instead store that we’ll get better |
| 214 | +# performance in terms of both runtime and memory. |
| 215 | +# |
| 216 | +# The way to do this, is to enable the option |
| 217 | +# ``force_fuse_int_mm_with_mul`` in the inductor config. |
| 218 | +# |
| 219 | + |
| 220 | +del model_c, model, image |
| 221 | +model, image = get_sam_model(only_one_block, batchsize) |
| 222 | +model = model.to(torch.bfloat16) |
| 223 | +image = image.to(torch.bfloat16) |
| 224 | +torch._inductor.config.force_fuse_int_mm_with_mul = True |
| 225 | +change_linear_weights_to_int8_dqtensors(model) |
| 226 | +model_c = torch.compile(model, mode='max-autotune') |
| 227 | +quant_res = benchmark(model_c, image) |
| 228 | +print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") |
| 229 | +# bf16 compiled runtime of the fused quantized block is 18.78ms and peak memory 2.37GB |
| 230 | + |
| 231 | + |
| 232 | +###################################################################### |
| 233 | +# The fusion improves performance by another small bit (about 6% over the |
| 234 | +# baseline in total) and removes almost all the memory increase, the |
| 235 | +# remaining amount (2.37GB quantized vs 2.24GB unquantized) is due to |
| 236 | +# quantization overhead which cannot be helped. |
| 237 | +# |
| 238 | +# We’re still not done though, we can apply a few general purpose |
| 239 | +# optimizations to get our final best-case performance. |
| 240 | +# |
| 241 | +# 1) We can sometimes improve performance by disabling epilogue fusion |
| 242 | +# since the autotuning process can be confused by fusions and choose |
| 243 | +# bad kernel parameters. |
| 244 | +# 2) We can apply coordinate descent tuning in all directions to enlarge |
| 245 | +# the search area for kernel parameters. |
| 246 | +# |
| 247 | + |
| 248 | +del model_c, model, image |
| 249 | +model, image = get_sam_model(only_one_block, batchsize) |
| 250 | +model = model.to(torch.bfloat16) |
| 251 | +image = image.to(torch.bfloat16) |
| 252 | +torch._inductor.config.epilogue_fusion = False |
| 253 | +torch._inductor.config.coordinate_descent_tuning = True |
| 254 | +torch._inductor.config.coordinate_descent_check_all_directions = True |
| 255 | +torch._inductor.config.force_fuse_int_mm_with_mul = True |
| 256 | +change_linear_weights_to_int8_dqtensors(model) |
| 257 | +model_c = torch.compile(model, mode='max-autotune') |
| 258 | +quant_res = benchmark(model_c, image) |
| 259 | +print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") |
| 260 | +# bf16 compiled runtime of the final quantized block is 18.16ms and peak memory 2.39GB |
| 261 | + |
| 262 | + |
| 263 | +###################################################################### |
| 264 | +# As you can see, we’ve squeezed another small improvement from the model, |
| 265 | +# taking our total improvement to over 10x compared to our original. To |
| 266 | +# get a final estimate of the impact of quantization lets do an apples to |
| 267 | +# apples comparison on the full model since the actual improvement will |
| 268 | +# differ block by block depending on the shapes involved. |
| 269 | +# |
| 270 | + |
| 271 | +try: |
| 272 | + del model_c, model, image |
| 273 | + model, image = get_sam_model(False, batchsize) |
| 274 | + model = model.to(torch.bfloat16) |
| 275 | + image = image.to(torch.bfloat16) |
| 276 | + model_c = torch.compile(model, mode='max-autotune') |
| 277 | + quant_res = benchmark(model_c, image) |
| 278 | + print(f"bf16 compiled runtime of the compiled full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") |
| 279 | + # bf16 compiled runtime of the compiled full model is 729.65ms and peak memory 23.96GB |
| 280 | + |
| 281 | + del model_c, model, image |
| 282 | + model, image = get_sam_model(False, batchsize) |
| 283 | + model = model.to(torch.bfloat16) |
| 284 | + image = image.to(torch.bfloat16) |
| 285 | + change_linear_weights_to_int8_dqtensors(model) |
| 286 | + model_c = torch.compile(model, mode='max-autotune') |
| 287 | + quant_res = benchmark(model_c, image) |
| 288 | + print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") |
| 289 | + # bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB |
| 290 | +except Exception as e: |
| 291 | + print("unable to run full model: ", e) |
| 292 | + |
| 293 | + |
| 294 | + |
| 295 | +###################################################################### |
| 296 | +# Conclusion |
| 297 | +# ----------------- |
| 298 | +# In this tutorial, we have learned about the quantization and optimization techniques |
| 299 | +# on the example of the segment anything model. |
| 300 | + |
| 301 | +# In the end, we achieved a full-model apples to apples quantization speedup |
| 302 | +# of about 7.7% on batch size 16 (677.28ms to 729.65ms). We can push this a |
| 303 | +# bit further by increasing the batch size and optimizing other parts of |
| 304 | +# the model. For example, this can be done with some form of flash attention. |
| 305 | +# |
| 306 | +# For more information visit |
| 307 | +# `torchao <https://github.com/pytorch-labs/ao>`_ and try it on your own |
| 308 | +# models. |
| 309 | +# |
0 commit comments