44
44
#
45
45
46
46
import torch
47
- from torchao .quantization import change_linear_weights_to_int8_dqtensors
47
+ from torchao .quantization .quant_api import quantize_ , int8_dynamic_activation_int8_weight
48
+ from torchao .utils import unwrap_tensor_subclass , TORCH_VERSION_AT_LEAST_2_5
48
49
from segment_anything import sam_model_registry
49
50
from torch .utils .benchmark import Timer
50
51
@@ -156,9 +157,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
156
157
# in memory bound situations where the benefit comes from loading less
157
158
# weight data, rather than doing less computation. The torchao APIs:
158
159
#
159
- # ``change_linear_weights_to_int8_dqtensors ``,
160
- # ``change_linear_weights_to_int8_woqtensors `` or
161
- # ``change_linear_weights_to_int4_woqtensors ``
160
+ # ``int8_dynamic_activation_int8_weight() ``,
161
+ # ``int8_weight_only() `` or
162
+ # ``int4_weight_only() ``
162
163
#
163
164
# can be used to easily apply the desired quantization technique and then
164
165
# once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is
@@ -170,7 +171,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
170
171
# ``apply_weight_only_int8_quant`` instead as drop in replacement for the two
171
172
# above (no replacement for int4).
172
173
#
173
- # The difference between the two APIs is that ``change_linear_weights `` API
174
+ # The difference between the two APIs is that ``int8_dynamic_activation `` API
174
175
# alters the weight tensor of the linear module so instead of doing a
175
176
# normal linear, it does a quantized operation. This is helpful when you
176
177
# have non-standard linear ops that do more than one thing. The ``apply``
@@ -185,7 +186,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
185
186
model , image = get_sam_model (only_one_block , batchsize )
186
187
model = model .to (torch .bfloat16 )
187
188
image = image .to (torch .bfloat16 )
188
- change_linear_weights_to_int8_dqtensors (model )
189
+ quantize_ (model , int8_dynamic_activation_int8_weight ())
190
+ if not TORCH_VERSION_AT_LEAST_2_5 :
191
+ # needed for subclass + compile to work on older versions of pytorch
192
+ unwrap_tensor_subclass (model )
189
193
model_c = torch .compile (model , mode = 'max-autotune' )
190
194
quant_res = benchmark (model_c , image )
191
195
print (f"bf16 compiled runtime of the quantized block is { quant_res ['time' ]:0.2f} ms and peak memory { quant_res ['memory' ]: 0.2f} GB" )
@@ -220,7 +224,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
220
224
model = model .to (torch .bfloat16 )
221
225
image = image .to (torch .bfloat16 )
222
226
torch ._inductor .config .force_fuse_int_mm_with_mul = True
223
- change_linear_weights_to_int8_dqtensors (model )
227
+ quantize_ (model , int8_dynamic_activation_int8_weight ())
228
+ if not TORCH_VERSION_AT_LEAST_2_5 :
229
+ # needed for subclass + compile to work on older versions of pytorch
230
+ unwrap_tensor_subclass (model )
224
231
model_c = torch .compile (model , mode = 'max-autotune' )
225
232
quant_res = benchmark (model_c , image )
226
233
print (f"bf16 compiled runtime of the fused quantized block is { quant_res ['time' ]:0.2f} ms and peak memory { quant_res ['memory' ]: 0.2f} GB" )
@@ -251,7 +258,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
251
258
torch ._inductor .config .coordinate_descent_tuning = True
252
259
torch ._inductor .config .coordinate_descent_check_all_directions = True
253
260
torch ._inductor .config .force_fuse_int_mm_with_mul = True
254
- change_linear_weights_to_int8_dqtensors (model )
261
+ quantize_ (model , int8_dynamic_activation_int8_weight ())
262
+ if not TORCH_VERSION_AT_LEAST_2_5 :
263
+ # needed for subclass + compile to work on older versions of pytorch
264
+ unwrap_tensor_subclass (model )
255
265
model_c = torch .compile (model , mode = 'max-autotune' )
256
266
quant_res = benchmark (model_c , image )
257
267
print (f"bf16 compiled runtime of the final quantized block is { quant_res ['time' ]:0.2f} ms and peak memory { quant_res ['memory' ]: 0.2f} GB" )
@@ -280,7 +290,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
280
290
model , image = get_sam_model (False , batchsize )
281
291
model = model .to (torch .bfloat16 )
282
292
image = image .to (torch .bfloat16 )
283
- change_linear_weights_to_int8_dqtensors (model )
293
+ quantize_ (model , int8_dynamic_activation_int8_weight ())
294
+ if not TORCH_VERSION_AT_LEAST_2_5 :
295
+ # needed for subclass + compile to work on older versions of pytorch
296
+ unwrap_tensor_subclass (model )
284
297
model_c = torch .compile (model , mode = 'max-autotune' )
285
298
quant_res = benchmark (model_c , image )
286
299
print (f"bf16 compiled runtime of the quantized full model is { quant_res ['time' ]:0.2f} ms and peak memory { quant_res ['memory' ]: 0.2f} GB" )
0 commit comments