Skip to content

Commit 5bf70c1

Browse files
committed
fixing over padding and GPTQ padding bug
Summary: don't always need to pad to 1024, only that groupsize, inner_k_tiles*16 can divide into the inner_dim Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent eb1789b commit 5bf70c1

File tree

2 files changed

+30
-35
lines changed

2 files changed

+30
-35
lines changed

model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
from dataclasses import dataclass
7-
from typing import Optional
7+
from typing import Optional, Tuple
88

99
import torch
1010
import torch.nn as nn
1111
from torch import Tensor
1212
from torch.nn import functional as F
13+
from math import gcd
14+
from functools import reduce
1315

1416

15-
def find_multiple(n: int, k: int) -> int:
17+
def find_multiple(n: int, *args: Tuple[int]) -> int:
18+
k = reduce(lambda x,y: x*y//gcd(x,y), args+(1,))
1619
if n % k == 0:
1720
return n
1821
return n + k - (n % k)

quantize.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
except:
1818
pass
1919

20-
from model import Transformer
20+
from model import Transformer, find_multiple
2121

2222
##### Quantization Primitives ######
2323

@@ -365,29 +365,27 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
365365
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
366366
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
367367

368-
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
368+
def _calc_padded_size_linear_int4(k, groupsize = 1, inner_k_tiles = 1):
369+
return find_multiple(k, groupsize, inner_k_tiles*16)
370+
371+
def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed):
369372
for name, child in module.named_children():
370373
if isinstance(child, nn.Linear):
371-
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
374+
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
372375
setattr(module, name, WeightOnlyInt4Linear(
373376
child.in_features, child.out_features, bias=False,
374-
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False,
375-
))
376-
elif padding:
377-
setattr(module, name, WeightOnlyInt4Linear(
378-
child.in_features, child.out_features, bias=False,
379-
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True,
377+
groupsize=groupsize, inner_k_tiles=inner_k_tiles,
380378
))
381379
else:
382-
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
380+
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed)
383381

384382

385383
class WeightOnlyInt4QuantHandler:
386-
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
384+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
387385
self.mod = mod
388386
self.groupsize = groupsize
389387
self.inner_k_tiles = inner_k_tiles
390-
self.padding = padding
388+
self.padding_allowed = padding_allowed
391389
assert groupsize in [32, 64, 128, 256]
392390
assert inner_k_tiles in [2, 4, 8]
393391

@@ -409,11 +407,9 @@ def create_quantized_state_dict(self, use_cuda = True):
409407

410408
weight = mod.weight.data
411409
if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
412-
if self.padding:
413-
from model import find_multiple
414-
import torch.nn.functional as F
410+
if self.padding_allowed:
415411
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
416-
padded_in_features = find_multiple(in_features, 1024)
412+
padded_in_features = _calc_padded_size_linear_int4(in_features, 1024)
417413
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
418414
else:
419415
print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
@@ -428,31 +424,30 @@ def create_quantized_state_dict(self, use_cuda = True):
428424
return cur_state_dict
429425

430426
def convert_for_runtime(self):
431-
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
427+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed)
432428
return self.mod
433429

434430
class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
435-
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
436-
from model import find_multiple
431+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
437432
self.mod = mod
438433
self.groupsize = groupsize
439434
self.inner_k_tiles = inner_k_tiles
440-
self.padding = padding
435+
self.padding_allowed = padding_allowed
441436
self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize)
442437
self.quantize_func = lambda w, qparams: \
443438
group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize)
444439
self.dequantize_func = lambda q, qparams: \
445440
group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float()
446441
self.combine_qparams_list_func = lambda qparams_list: \
447442
[torch.cat(x, dim=1) for x in zip(*qparams_list)]
448-
# skip unless padding=True or its correctly sized
443+
# skip unless padding_allowed=True or its correctly sized
449444
self.skip_layer_func = lambda linear_weight: not (
450-
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding
445+
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding_allowed
451446
)
452447
# we need to do the padding here, both for q and the qparams if necessary
453448
def make_names_and_values_dict_func(q, qparams):
454449
k = q.shape[1]
455-
new_k = find_multiple(k, 1024)
450+
new_k = _calc_padded_size_linear_int4(k, groupsize, inner_k_tiles)
456451
# how much we need to pad the weight
457452
delta_k = new_k - q.shape[1]
458453
final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
@@ -466,7 +461,7 @@ def make_names_and_values_dict_func(q, qparams):
466461

467462

468463
def convert_for_runtime(self):
469-
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
464+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed)
470465
return self.mod
471466

472467
class WeightOnlyInt4Linear(torch.nn.Module):
@@ -477,17 +472,16 @@ class WeightOnlyInt4Linear(torch.nn.Module):
477472

478473
def __init__(
479474
self, in_features: int, out_features: int,
480-
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True,
475+
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8,
481476
) -> None:
482477
super().__init__()
483-
self.padding = padding
484-
if padding:
485-
from model import find_multiple
486-
self.origin_in_features = in_features
487-
in_features = find_multiple(in_features, 1024)
488478

479+
# always pad if needed since it becomes a noop at runtime if not needed
480+
self.origin_in_features = in_features
481+
in_features = _calc_padded_size_linear_int4(in_features, groupsize, inner_k_tiles)
489482
self.in_features = in_features
490483
self.out_features = out_features
484+
491485
assert not bias, "require bias=False"
492486
self.groupsize = groupsize
493487
self.inner_k_tiles = inner_k_tiles
@@ -505,9 +499,7 @@ def __init__(
505499

506500
def forward(self, input: torch.Tensor) -> torch.Tensor:
507501
input = input.to(torch.bfloat16)
508-
if self.padding:
509-
import torch.nn.functional as F
510-
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
502+
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
511503
return linear_forward_int4(
512504
input,
513505
self.weight, self.scales_and_zeros, self.out_features, self.groupsize

0 commit comments

Comments
 (0)