@@ -225,6 +225,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
225
225
image = image .to (torch .bfloat16 )
226
226
torch ._inductor .config .force_fuse_int_mm_with_mul = True
227
227
quantize_ (model , int8_dynamic_activation_int8_weight ())
228
+ if not TORCH_VERSION_AT_LEAST_2_5 :
229
+ # needed for subclass + compile to work on older versions of pytorch
230
+ unwrap_tensor_subclass (model )
228
231
model_c = torch .compile (model , mode = 'max-autotune' )
229
232
quant_res = benchmark (model_c , image )
230
233
print (f"bf16 compiled runtime of the fused quantized block is { quant_res ['time' ]:0.2f} ms and peak memory { quant_res ['memory' ]: 0.2f} GB" )
@@ -256,7 +259,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
256
259
torch ._inductor .config .coordinate_descent_check_all_directions = True
257
260
torch ._inductor .config .force_fuse_int_mm_with_mul = True
258
261
quantize_ (model , int8_dynamic_activation_int8_weight ())
259
- model =
262
+ if not TORCH_VERSION_AT_LEAST_2_5 :
263
+ # needed for subclass + compile to work on older versions of pytorch
264
+ unwrap_tensor_subclass (model )
260
265
model_c = torch .compile (model , mode = 'max-autotune' )
261
266
quant_res = benchmark (model_c , image )
262
267
print (f"bf16 compiled runtime of the final quantized block is { quant_res ['time' ]:0.2f} ms and peak memory { quant_res ['memory' ]: 0.2f} GB" )
@@ -286,6 +291,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
286
291
model = model .to (torch .bfloat16 )
287
292
image = image .to (torch .bfloat16 )
288
293
quantize_ (model , int8_dynamic_activation_int8_weight ())
294
+ if not TORCH_VERSION_AT_LEAST_2_5 :
295
+ # needed for subclass + compile to work on older versions of pytorch
296
+ unwrap_tensor_subclass (model )
289
297
model_c = torch .compile (model , mode = 'max-autotune' )
290
298
quant_res = benchmark (model_c , image )
291
299
print (f"bf16 compiled runtime of the quantized full model is { quant_res ['time' ]:0.2f} ms and peak memory { quant_res ['memory' ]: 0.2f} GB" )
0 commit comments