From 0899f3494bb751a15d99522a846e87fe4c032433 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 30 Sep 2024 11:25:07 -0700 Subject: [PATCH 1/6] Update torchao to 0.4.0 and fix GPU quantization tutorial --- .ci/docker/requirements.txt | 2 +- prototype_source/gpu_quantization_torchao_tutorial.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 2384fb1b00e..14104155b75 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -68,5 +68,5 @@ iopath pygame==2.6.0 pycocotools semilearn==0.3.2 -torchao==0.0.3 +torchao==0.4.0 segment_anything==1.0 diff --git a/prototype_source/gpu_quantization_torchao_tutorial.py b/prototype_source/gpu_quantization_torchao_tutorial.py index 4050a88e56e..8767f4aca6a 100644 --- a/prototype_source/gpu_quantization_torchao_tutorial.py +++ b/prototype_source/gpu_quantization_torchao_tutorial.py @@ -44,7 +44,7 @@ # import torch -from torchao.quantization import change_linear_weights_to_int8_dqtensors +from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight from segment_anything import sam_model_registry from torch.utils.benchmark import Timer @@ -156,9 +156,9 @@ def get_sam_model(only_one_block=False, batchsize=1): # in memory bound situations where the benefit comes from loading less # weight data, rather than doing less computation. The torchao APIs: # -# ``change_linear_weights_to_int8_dqtensors``, -# ``change_linear_weights_to_int8_woqtensors`` or -# ``change_linear_weights_to_int4_woqtensors`` +# ``int8_dynamic_activation_int8_weight()``, +# ``int8_dynamic_activation_int8_semi_sparse_weight`` or +# ``int8_dynamic_activation_int4_weight`` # # can be used to easily apply the desired quantization technique and then # once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is @@ -185,7 +185,7 @@ def get_sam_model(only_one_block=False, batchsize=1): model, image = get_sam_model(only_one_block, batchsize) model = model.to(torch.bfloat16) image = image.to(torch.bfloat16) -change_linear_weights_to_int8_dqtensors(model) +quantize_(model, int8_dynamic_activation_int8_weight()) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") From 913d43e8817c51791fb17b4fd34542b73629ce2f Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 30 Sep 2024 11:34:29 -0700 Subject: [PATCH 2/6] Update --- prototype_source/gpu_quantization_torchao_tutorial.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/prototype_source/gpu_quantization_torchao_tutorial.py b/prototype_source/gpu_quantization_torchao_tutorial.py index 8767f4aca6a..23d1fab0077 100644 --- a/prototype_source/gpu_quantization_torchao_tutorial.py +++ b/prototype_source/gpu_quantization_torchao_tutorial.py @@ -170,7 +170,7 @@ def get_sam_model(only_one_block=False, batchsize=1): # ``apply_weight_only_int8_quant`` instead as drop in replacement for the two # above (no replacement for int4). # -# The difference between the two APIs is that ``change_linear_weights`` API +# The difference between the two APIs is that ``int8_dynamic_activation`` API # alters the weight tensor of the linear module so instead of doing a # normal linear, it does a quantized operation. This is helpful when you # have non-standard linear ops that do more than one thing. The ``apply`` @@ -220,7 +220,7 @@ def get_sam_model(only_one_block=False, batchsize=1): model = model.to(torch.bfloat16) image = image.to(torch.bfloat16) torch._inductor.config.force_fuse_int_mm_with_mul = True -change_linear_weights_to_int8_dqtensors(model) +quantize_(model, int8_dynamic_activation_int8_weight()) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") @@ -251,7 +251,7 @@ def get_sam_model(only_one_block=False, batchsize=1): torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.force_fuse_int_mm_with_mul = True -change_linear_weights_to_int8_dqtensors(model) +quantize_(model, int8_dynamic_activation_int8_weight()) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") @@ -280,7 +280,7 @@ def get_sam_model(only_one_block=False, batchsize=1): model, image = get_sam_model(False, batchsize) model = model.to(torch.bfloat16) image = image.to(torch.bfloat16) - change_linear_weights_to_int8_dqtensors(model) + quantize_(model, int8_dynamic_activation_int8_weight()) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") From f88728336ccc4ac6b9cd970f74e4862896d435fa Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 30 Sep 2024 13:21:32 -0700 Subject: [PATCH 3/6] Try 0.5.0 --- .ci/docker/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 14104155b75..afa55889192 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -68,5 +68,5 @@ iopath pygame==2.6.0 pycocotools semilearn==0.3.2 -torchao==0.4.0 +torchao==0.5.0 segment_anything==1.0 From eded873a8f3be913c0c942e6dde217edf1b6b706 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 30 Sep 2024 21:26:17 -0700 Subject: [PATCH 4/6] some small fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- prototype_source/gpu_quantization_torchao_tutorial.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prototype_source/gpu_quantization_torchao_tutorial.py b/prototype_source/gpu_quantization_torchao_tutorial.py index 23d1fab0077..8be185d3eb6 100644 --- a/prototype_source/gpu_quantization_torchao_tutorial.py +++ b/prototype_source/gpu_quantization_torchao_tutorial.py @@ -157,8 +157,8 @@ def get_sam_model(only_one_block=False, batchsize=1): # weight data, rather than doing less computation. The torchao APIs: # # ``int8_dynamic_activation_int8_weight()``, -# ``int8_dynamic_activation_int8_semi_sparse_weight`` or -# ``int8_dynamic_activation_int4_weight`` +# ``int8_weight_only()`` or +# ``int4_weight_only()`` # # can be used to easily apply the desired quantization technique and then # once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is From 52657d6785d9ccb47d4b56614ff5c5804ecd3fad Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 30 Sep 2024 21:31:10 -0700 Subject: [PATCH 5/6] add unwrap for old torch version Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- prototype_source/gpu_quantization_torchao_tutorial.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/prototype_source/gpu_quantization_torchao_tutorial.py b/prototype_source/gpu_quantization_torchao_tutorial.py index 8be185d3eb6..19e5f73a42c 100644 --- a/prototype_source/gpu_quantization_torchao_tutorial.py +++ b/prototype_source/gpu_quantization_torchao_tutorial.py @@ -45,6 +45,7 @@ import torch from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight +from torchao.utils import unwrap_tensor_subclass, TORCH_VERSION_AT_LEAST_2_5 from segment_anything import sam_model_registry from torch.utils.benchmark import Timer @@ -186,6 +187,9 @@ def get_sam_model(only_one_block=False, batchsize=1): model = model.to(torch.bfloat16) image = image.to(torch.bfloat16) quantize_(model, int8_dynamic_activation_int8_weight()) +if not TORCH_VERSION_AT_LEAST_2_5: + # needed for subclass + compile to work on older versions of pytorch + unwrap_tensor_subclass(model) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") @@ -252,6 +256,7 @@ def get_sam_model(only_one_block=False, batchsize=1): torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.force_fuse_int_mm_with_mul = True quantize_(model, int8_dynamic_activation_int8_weight()) +model = model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") From 86aad1b1bb8bf073de15c305c5cbc11b8f917c94 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 30 Sep 2024 21:35:21 -0700 Subject: [PATCH 6/6] unwrap for all model runs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- prototype_source/gpu_quantization_torchao_tutorial.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/prototype_source/gpu_quantization_torchao_tutorial.py b/prototype_source/gpu_quantization_torchao_tutorial.py index 19e5f73a42c..f901f8abd31 100644 --- a/prototype_source/gpu_quantization_torchao_tutorial.py +++ b/prototype_source/gpu_quantization_torchao_tutorial.py @@ -225,6 +225,9 @@ def get_sam_model(only_one_block=False, batchsize=1): image = image.to(torch.bfloat16) torch._inductor.config.force_fuse_int_mm_with_mul = True quantize_(model, int8_dynamic_activation_int8_weight()) +if not TORCH_VERSION_AT_LEAST_2_5: + # needed for subclass + compile to work on older versions of pytorch + unwrap_tensor_subclass(model) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") @@ -256,7 +259,9 @@ def get_sam_model(only_one_block=False, batchsize=1): torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.force_fuse_int_mm_with_mul = True quantize_(model, int8_dynamic_activation_int8_weight()) -model = +if not TORCH_VERSION_AT_LEAST_2_5: + # needed for subclass + compile to work on older versions of pytorch + unwrap_tensor_subclass(model) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") @@ -286,6 +291,9 @@ def get_sam_model(only_one_block=False, batchsize=1): model = model.to(torch.bfloat16) image = image.to(torch.bfloat16) quantize_(model, int8_dynamic_activation_int8_weight()) + if not TORCH_VERSION_AT_LEAST_2_5: + # needed for subclass + compile to work on older versions of pytorch + unwrap_tensor_subclass(model) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")