Skip to content

Commit 54b62c1

Browse files
svekarsHDCharles
andauthored
Update torchao to 0.5.0 and fix GPU quantization tutorial (#3069)
* Update torchao to 0.5.0 and fix GPU quantization tutorial --------- Co-authored-by: HDCharles <charlesdavidhernandez@gmail.com>
1 parent 8d959ca commit 54b62c1

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

.ci/docker/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,5 @@ iopath
6868
pygame==2.6.0
6969
pycocotools
7070
semilearn==0.3.2
71-
torchao==0.0.3
71+
torchao==0.5.0
7272
segment_anything==1.0

prototype_source/gpu_quantization_torchao_tutorial.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
#
4545

4646
import torch
47-
from torchao.quantization import change_linear_weights_to_int8_dqtensors
47+
from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight
48+
from torchao.utils import unwrap_tensor_subclass, TORCH_VERSION_AT_LEAST_2_5
4849
from segment_anything import sam_model_registry
4950
from torch.utils.benchmark import Timer
5051

@@ -156,9 +157,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
156157
# in memory bound situations where the benefit comes from loading less
157158
# weight data, rather than doing less computation. The torchao APIs:
158159
#
159-
# ``change_linear_weights_to_int8_dqtensors``,
160-
# ``change_linear_weights_to_int8_woqtensors`` or
161-
# ``change_linear_weights_to_int4_woqtensors``
160+
# ``int8_dynamic_activation_int8_weight()``,
161+
# ``int8_weight_only()`` or
162+
# ``int4_weight_only()``
162163
#
163164
# can be used to easily apply the desired quantization technique and then
164165
# once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is
@@ -170,7 +171,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
170171
# ``apply_weight_only_int8_quant`` instead as drop in replacement for the two
171172
# above (no replacement for int4).
172173
#
173-
# The difference between the two APIs is that ``change_linear_weights`` API
174+
# The difference between the two APIs is that ``int8_dynamic_activation`` API
174175
# alters the weight tensor of the linear module so instead of doing a
175176
# normal linear, it does a quantized operation. This is helpful when you
176177
# have non-standard linear ops that do more than one thing. The ``apply``
@@ -185,7 +186,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
185186
model, image = get_sam_model(only_one_block, batchsize)
186187
model = model.to(torch.bfloat16)
187188
image = image.to(torch.bfloat16)
188-
change_linear_weights_to_int8_dqtensors(model)
189+
quantize_(model, int8_dynamic_activation_int8_weight())
190+
if not TORCH_VERSION_AT_LEAST_2_5:
191+
# needed for subclass + compile to work on older versions of pytorch
192+
unwrap_tensor_subclass(model)
189193
model_c = torch.compile(model, mode='max-autotune')
190194
quant_res = benchmark(model_c, image)
191195
print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
@@ -220,7 +224,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
220224
model = model.to(torch.bfloat16)
221225
image = image.to(torch.bfloat16)
222226
torch._inductor.config.force_fuse_int_mm_with_mul = True
223-
change_linear_weights_to_int8_dqtensors(model)
227+
quantize_(model, int8_dynamic_activation_int8_weight())
228+
if not TORCH_VERSION_AT_LEAST_2_5:
229+
# needed for subclass + compile to work on older versions of pytorch
230+
unwrap_tensor_subclass(model)
224231
model_c = torch.compile(model, mode='max-autotune')
225232
quant_res = benchmark(model_c, image)
226233
print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
@@ -251,7 +258,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
251258
torch._inductor.config.coordinate_descent_tuning = True
252259
torch._inductor.config.coordinate_descent_check_all_directions = True
253260
torch._inductor.config.force_fuse_int_mm_with_mul = True
254-
change_linear_weights_to_int8_dqtensors(model)
261+
quantize_(model, int8_dynamic_activation_int8_weight())
262+
if not TORCH_VERSION_AT_LEAST_2_5:
263+
# needed for subclass + compile to work on older versions of pytorch
264+
unwrap_tensor_subclass(model)
255265
model_c = torch.compile(model, mode='max-autotune')
256266
quant_res = benchmark(model_c, image)
257267
print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
@@ -280,7 +290,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
280290
model, image = get_sam_model(False, batchsize)
281291
model = model.to(torch.bfloat16)
282292
image = image.to(torch.bfloat16)
283-
change_linear_weights_to_int8_dqtensors(model)
293+
quantize_(model, int8_dynamic_activation_int8_weight())
294+
if not TORCH_VERSION_AT_LEAST_2_5:
295+
# needed for subclass + compile to work on older versions of pytorch
296+
unwrap_tensor_subclass(model)
284297
model_c = torch.compile(model, mode='max-autotune')
285298
quant_res = benchmark(model_c, image)
286299
print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")

0 commit comments

Comments
 (0)