|
45 | 45 |
|
46 | 46 | import torch
|
47 | 47 | from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight
|
| 48 | +from torchao.utils import unwrap_tensor_subclass, TORCH_VERSION_AT_LEAST_2_5 |
48 | 49 | from segment_anything import sam_model_registry
|
49 | 50 | from torch.utils.benchmark import Timer
|
50 | 51 |
|
@@ -186,6 +187,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
|
186 | 187 | model = model.to(torch.bfloat16)
|
187 | 188 | image = image.to(torch.bfloat16)
|
188 | 189 | quantize_(model, int8_dynamic_activation_int8_weight())
|
| 190 | +if not TORCH_VERSION_AT_LEAST_2_5: |
| 191 | + # needed for subclass + compile to work on older versions of pytorch |
| 192 | + unwrap_tensor_subclass(model) |
189 | 193 | model_c = torch.compile(model, mode='max-autotune')
|
190 | 194 | quant_res = benchmark(model_c, image)
|
191 | 195 | print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
|
@@ -252,6 +256,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
|
252 | 256 | torch._inductor.config.coordinate_descent_check_all_directions = True
|
253 | 257 | torch._inductor.config.force_fuse_int_mm_with_mul = True
|
254 | 258 | quantize_(model, int8_dynamic_activation_int8_weight())
|
| 259 | +model = |
255 | 260 | model_c = torch.compile(model, mode='max-autotune')
|
256 | 261 | quant_res = benchmark(model_c, image)
|
257 | 262 | print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
|
|
0 commit comments