17
17
except :
18
18
pass
19
19
20
- from model import Transformer
20
+ from model import Transformer , find_multiple
21
21
22
22
##### Quantization Primitives ######
23
23
@@ -365,29 +365,27 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
365
365
def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = 1 ):
366
366
return k % groupsize == 0 and k % (inner_k_tiles * 16 ) == 0
367
367
368
- def replace_linear_int4 (module , groupsize , inner_k_tiles , padding ):
368
+ def _calc_padded_size_linear_int4 (k , groupsize = 1 , inner_k_tiles = 1 ):
369
+ return find_multiple (k , groupsize , inner_k_tiles * 16 )
370
+
371
+ def replace_linear_int4 (module , groupsize , inner_k_tiles , padding_allowed ):
369
372
for name , child in module .named_children ():
370
373
if isinstance (child , nn .Linear ):
371
- if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ):
374
+ if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ) or padding_allowed :
372
375
setattr (module , name , WeightOnlyInt4Linear (
373
376
child .in_features , child .out_features , bias = False ,
374
- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = False ,
375
- ))
376
- elif padding :
377
- setattr (module , name , WeightOnlyInt4Linear (
378
- child .in_features , child .out_features , bias = False ,
379
- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = True ,
377
+ groupsize = groupsize , inner_k_tiles = inner_k_tiles ,
380
378
))
381
379
else :
382
- replace_linear_int4 (child , groupsize , inner_k_tiles , padding )
380
+ replace_linear_int4 (child , groupsize , inner_k_tiles , padding_allowed )
383
381
384
382
385
383
class WeightOnlyInt4QuantHandler :
386
- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
384
+ def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
387
385
self .mod = mod
388
386
self .groupsize = groupsize
389
387
self .inner_k_tiles = inner_k_tiles
390
- self .padding = padding
388
+ self .padding_allowed = padding_allowed
391
389
assert groupsize in [32 , 64 , 128 , 256 ]
392
390
assert inner_k_tiles in [2 , 4 , 8 ]
393
391
@@ -409,11 +407,9 @@ def create_quantized_state_dict(self, use_cuda = True):
409
407
410
408
weight = mod .weight .data
411
409
if not _check_linear_int4_k (in_features , self .groupsize , self .inner_k_tiles ):
412
- if self .padding :
413
- from model import find_multiple
414
- import torch .nn .functional as F
410
+ if self .padding_allowed :
415
411
print (f"warning: { fqn } is padded to satisfy in_features % 1024 == 0" )
416
- padded_in_features = find_multiple (in_features , 1024 )
412
+ padded_in_features = _calc_padded_size_linear_int4 (in_features , 1024 )
417
413
weight = F .pad (weight , pad = (0 , padded_in_features - in_features ))
418
414
else :
419
415
print (f"warning: { fqn } is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
@@ -428,31 +424,30 @@ def create_quantized_state_dict(self, use_cuda = True):
428
424
return cur_state_dict
429
425
430
426
def convert_for_runtime (self ):
431
- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
427
+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding_allowed )
432
428
return self .mod
433
429
434
430
class WeightOnlyInt4GPTQQuantHandler (GPTQQuantHandler ):
435
- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
436
- from model import find_multiple
431
+ def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
437
432
self .mod = mod
438
433
self .groupsize = groupsize
439
434
self .inner_k_tiles = inner_k_tiles
440
- self .padding = padding
435
+ self .padding_allowed = padding_allowed
441
436
self .get_qparams_func = lambda w : get_group_qparams (w , 4 , groupsize )
442
437
self .quantize_func = lambda w , qparams : \
443
438
group_quantize_tensor_from_qparams (w , qparams [0 ], qparams [1 ], 4 , groupsize )
444
439
self .dequantize_func = lambda q , qparams : \
445
440
group_dequantize_tensor_from_qparams (q , qparams [0 ], qparams [1 ], 4 , groupsize ).float ()
446
441
self .combine_qparams_list_func = lambda qparams_list : \
447
442
[torch .cat (x , dim = 1 ) for x in zip (* qparams_list )]
448
- # skip unless padding =True or its correctly sized
443
+ # skip unless padding_allowed =True or its correctly sized
449
444
self .skip_layer_func = lambda linear_weight : not (
450
- _check_linear_int4_k (linear_weight .shape [- 1 ], groupsize , inner_k_tiles ) or padding
445
+ _check_linear_int4_k (linear_weight .shape [- 1 ], groupsize , inner_k_tiles ) or padding_allowed
451
446
)
452
447
# we need to do the padding here, both for q and the qparams if necessary
453
448
def make_names_and_values_dict_func (q , qparams ):
454
449
k = q .shape [1 ]
455
- new_k = find_multiple (k , 1024 )
450
+ new_k = _calc_padded_size_linear_int4 (k , groupsize , inner_k_tiles )
456
451
# how much we need to pad the weight
457
452
delta_k = new_k - q .shape [1 ]
458
453
final_q = torch .ops .aten ._convert_weight_to_int4pack (F .pad (q , pad = (0 , delta_k )), inner_k_tiles )
@@ -466,7 +461,7 @@ def make_names_and_values_dict_func(q, qparams):
466
461
467
462
468
463
def convert_for_runtime (self ):
469
- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
464
+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding_allowed )
470
465
return self .mod
471
466
472
467
class WeightOnlyInt4Linear (torch .nn .Module ):
@@ -477,17 +472,16 @@ class WeightOnlyInt4Linear(torch.nn.Module):
477
472
478
473
def __init__ (
479
474
self , in_features : int , out_features : int ,
480
- bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , padding : bool = True ,
475
+ bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 ,
481
476
) -> None :
482
477
super ().__init__ ()
483
- self .padding = padding
484
- if padding :
485
- from model import find_multiple
486
- self .origin_in_features = in_features
487
- in_features = find_multiple (in_features , 1024 )
488
478
479
+ # always pad if needed since it becomes a noop at runtime if not needed
480
+ self .origin_in_features = in_features
481
+ in_features = _calc_padded_size_linear_int4 (in_features , groupsize , inner_k_tiles )
489
482
self .in_features = in_features
490
483
self .out_features = out_features
484
+
491
485
assert not bias , "require bias=False"
492
486
self .groupsize = groupsize
493
487
self .inner_k_tiles = inner_k_tiles
@@ -505,9 +499,7 @@ def __init__(
505
499
506
500
def forward (self , input : torch .Tensor ) -> torch .Tensor :
507
501
input = input .to (torch .bfloat16 )
508
- if self .padding :
509
- import torch .nn .functional as F
510
- input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
502
+ input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
511
503
return linear_forward_int4 (
512
504
input ,
513
505
self .weight , self .scales_and_zeros , self .out_features , self .groupsize
0 commit comments