Skip to content

Commit 86aad1b

Browse files
committed
unwrap for all model runs
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 52657d6 commit 86aad1b

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

prototype_source/gpu_quantization_torchao_tutorial.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
225225
image = image.to(torch.bfloat16)
226226
torch._inductor.config.force_fuse_int_mm_with_mul = True
227227
quantize_(model, int8_dynamic_activation_int8_weight())
228+
if not TORCH_VERSION_AT_LEAST_2_5:
229+
# needed for subclass + compile to work on older versions of pytorch
230+
unwrap_tensor_subclass(model)
228231
model_c = torch.compile(model, mode='max-autotune')
229232
quant_res = benchmark(model_c, image)
230233
print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
@@ -256,7 +259,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
256259
torch._inductor.config.coordinate_descent_check_all_directions = True
257260
torch._inductor.config.force_fuse_int_mm_with_mul = True
258261
quantize_(model, int8_dynamic_activation_int8_weight())
259-
model =
262+
if not TORCH_VERSION_AT_LEAST_2_5:
263+
# needed for subclass + compile to work on older versions of pytorch
264+
unwrap_tensor_subclass(model)
260265
model_c = torch.compile(model, mode='max-autotune')
261266
quant_res = benchmark(model_c, image)
262267
print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
@@ -286,6 +291,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
286291
model = model.to(torch.bfloat16)
287292
image = image.to(torch.bfloat16)
288293
quantize_(model, int8_dynamic_activation_int8_weight())
294+
if not TORCH_VERSION_AT_LEAST_2_5:
295+
# needed for subclass + compile to work on older versions of pytorch
296+
unwrap_tensor_subclass(model)
289297
model_c = torch.compile(model, mode='max-autotune')
290298
quant_res = benchmark(model_c, image)
291299
print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")

0 commit comments

Comments
 (0)