Skip to content

Commit 52657d6

Browse files
committed
add unwrap for old torch version
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent eded873 commit 52657d6

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

prototype_source/gpu_quantization_torchao_tutorial.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
import torch
4747
from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight
48+
from torchao.utils import unwrap_tensor_subclass, TORCH_VERSION_AT_LEAST_2_5
4849
from segment_anything import sam_model_registry
4950
from torch.utils.benchmark import Timer
5051

@@ -186,6 +187,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
186187
model = model.to(torch.bfloat16)
187188
image = image.to(torch.bfloat16)
188189
quantize_(model, int8_dynamic_activation_int8_weight())
190+
if not TORCH_VERSION_AT_LEAST_2_5:
191+
# needed for subclass + compile to work on older versions of pytorch
192+
unwrap_tensor_subclass(model)
189193
model_c = torch.compile(model, mode='max-autotune')
190194
quant_res = benchmark(model_c, image)
191195
print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
@@ -252,6 +256,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
252256
torch._inductor.config.coordinate_descent_check_all_directions = True
253257
torch._inductor.config.force_fuse_int_mm_with_mul = True
254258
quantize_(model, int8_dynamic_activation_int8_weight())
259+
model =
255260
model_c = torch.compile(model, mode='max-autotune')
256261
quant_res = benchmark(model_c, image)
257262
print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")

0 commit comments

Comments
 (0)