Skip to content

Update torchao to 0.5.0 and fix GPU quantization tutorial #3069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ iopath
pygame==2.6.0
pycocotools
semilearn==0.3.2
torchao==0.0.3
torchao==0.5.0
segment_anything==1.0
31 changes: 22 additions & 9 deletions prototype_source/gpu_quantization_torchao_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
#

import torch
from torchao.quantization import change_linear_weights_to_int8_dqtensors
from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass, TORCH_VERSION_AT_LEAST_2_5
from segment_anything import sam_model_registry
from torch.utils.benchmark import Timer

Expand Down Expand Up @@ -156,9 +157,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
# in memory bound situations where the benefit comes from loading less
# weight data, rather than doing less computation. The torchao APIs:
#
# ``change_linear_weights_to_int8_dqtensors``,
# ``change_linear_weights_to_int8_woqtensors`` or
# ``change_linear_weights_to_int4_woqtensors``
# ``int8_dynamic_activation_int8_weight()``,
# ``int8_weight_only()`` or
# ``int4_weight_only()``
#
# can be used to easily apply the desired quantization technique and then
# once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is
Expand All @@ -170,7 +171,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
# ``apply_weight_only_int8_quant`` instead as drop in replacement for the two
# above (no replacement for int4).
#
# The difference between the two APIs is that ``change_linear_weights`` API
# The difference between the two APIs is that ``int8_dynamic_activation`` API
# alters the weight tensor of the linear module so instead of doing a
# normal linear, it does a quantized operation. This is helpful when you
# have non-standard linear ops that do more than one thing. The ``apply``
Expand All @@ -185,7 +186,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
change_linear_weights_to_int8_dqtensors(model)
quantize_(model, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
Expand Down Expand Up @@ -220,7 +224,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(model)
quantize_(model, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
Expand Down Expand Up @@ -251,7 +258,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(model)
quantize_(model, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
Expand Down Expand Up @@ -280,7 +290,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
model, image = get_sam_model(False, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
change_linear_weights_to_int8_dqtensors(model)
quantize_(model, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
Expand Down
Loading