diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index 2f82d1ec..ffe71131 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -86,56 +86,6 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc def model_forward(model, x, input_pos): return model(x, input_pos) -def speculative_decode( - model: Transformer, - draft_model: Transformer, - cur_token: torch.Tensor, - input_pos: int, - speculate_k: int, - **sampling_kwargs -) -> torch.Tensor: - # draft model inference sequentially - device = cur_token.device - orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device) - draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs) - - draft_tokens = torch.cat(draft_tokens) - # parallel inference on target model using draft tokens - target_logits = model_forward( - model, - torch.cat([cur_token.view(1), draft_tokens]).view(1, -1), - torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device) - ) - target_probs = logits_to_probs(target_logits[0], **sampling_kwargs) - draft_probs = torch.stack(draft_probs) - # q: target prob, p: draft prob - # q >= p: always accept draft token - # q < p: q/p prob to accept draft token - p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens] - q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens] - accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p) - rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() - - if rejected_locations.shape[0] == 0: # All draft tokens have been accepted - accept_length = speculate_k + 1 - last_token = multinomial_sample_one_no_sync(target_probs[-1]) - # fill last token into draft model - model_forward( - draft_model, - draft_tokens[-1].view(1, -1), - orig_input_pos + speculate_k, - ) - return torch.cat([draft_tokens, last_token]) - else: - accept_length = rejected_locations[0].item() - p = draft_probs[accept_length] - q = target_probs[accept_length] - new = q - p - new = torch.where(new > 0, new, 0.0) - new = new / new.sum() - next_token = multinomial_sample_one_no_sync(new) - return torch.cat([draft_tokens[:accept_length], next_token]) - @torch.no_grad() def generate( model: Transformer, @@ -143,8 +93,6 @@ def generate( max_new_tokens: int, *, interactive: bool, - draft_model: Transformer, - speculate_k: Optional[int] = 8, callback = lambda x: x, **sampling_kwargs ) -> torch.Tensor: @@ -152,7 +100,6 @@ def generate( Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. """ - is_speculative = draft_model is not None # create an empty tensor of the expected final shape and fill in the current tokens T = prompt.size(0) T_new = T + max_new_tokens @@ -162,11 +109,8 @@ def generate( max_seq_length = min(T_new, model.config.block_size) device, dtype = prompt.device, prompt.dtype - max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length with torch.device(device): model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - if is_speculative and draft_model is not model: - draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) # create an empty tensor of the expected final shape and fill in the current tokens empty = torch.empty(T_new, dtype=dtype, device=device) @@ -175,37 +119,14 @@ def generate( input_pos = torch.arange(0, T, device=device) next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) - if is_speculative: - prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs) seq[T] = next_token input_pos = torch.tensor([T], device=device, dtype=torch.int) - accept_counts = [0] * (speculate_k + 1) - - if is_speculative: - input_pos = input_pos.item() # for speculative decoding easier to keep on host - while input_pos < T_new - 1: - cur_token = next_token.view(()) - - next_tokens = speculative_decode( - model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs - ) - accept_counts[len(next_tokens) - 1] += 1 - num_added = min(T_new - input_pos - 1, len(next_tokens)) - seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added] - for i in next_tokens[: num_added,]: - callback(i) - input_pos = input_pos + num_added - next_token = next_tokens[-1] - else: - generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) - seq[T + 1:] = torch.cat(generated_tokens) + generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) + seq[T + 1:] = torch.cat(generated_tokens) - generate_stats = { - 'accept_counts': accept_counts - } - return seq, generate_stats + return seq def encode_tokens(tokenizer, string, bos=True, device='cuda'): tokens = tokenizer.encode(string) @@ -223,15 +144,6 @@ def _load_model(checkpoint_path, device, precision, use_tp): simple_quantizer = WeightOnlyBit8QuantHandler(model, torch.int8) model = simple_quantizer.convert_for_runtime() - if "int4" in str(checkpoint_path): - print("Using int4 quantization!") - path_comps = checkpoint_path.name.split(".") - assert path_comps[-2].startswith("g") - groupsize = int(path_comps[-2][1:]) - from quantize import WeightOnlyInt4QuantHandler - simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) - model = simple_quantizer.convert_for_runtime() - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) @@ -252,12 +164,10 @@ def main( max_new_tokens: int = 100, top_k: int = 200, temperature: float = 0.8, - checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), + checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"), compile: bool = True, compile_prefill: bool = False, profile: Optional[Path] = None, - draft_checkpoint_path: Optional[Path] = None, - speculate_k: int = 5, device='cuda', ) -> None: """Generates text samples based on a pre-trained Transformer model and tokenizer. @@ -277,18 +187,12 @@ def main( print(f"Using device={device}") precision = torch.bfloat16 - is_speculative = draft_checkpoint_path is not None is_chat = "chat" in str(checkpoint_path) print("Loading model ...") t0 = time.time() model = _load_model(checkpoint_path, device, precision, use_tp) - if is_speculative: - draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) - else: - draft_model = None - device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") @@ -299,14 +203,7 @@ def main( torch.manual_seed(1234) model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) if compile: - if is_speculative and use_tp: # and ("cuda" in device): - torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case - if model.config.moe: - torch._inductor.config.assert_indirect_indexing = False - - if is_speculative: - global model_forward, logits_to_prob - model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) + torch._inductor.config.assert_indirect_indexing = False global decode_one_token, prefill decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) @@ -318,7 +215,6 @@ def main( aggregate_metrics = { 'tokens_per_sec': [], - 'accept_counts': [], } start = -1 if compile else 0 @@ -355,18 +251,15 @@ def callback(x): torch.profiler._utils._init_for_cuda_graphs() prof = torch.profiler.profile() with prof: - y, metrics = generate( + y = generate( model, encoded, max_new_tokens, - draft_model=draft_model, - speculate_k=speculate_k, interactive=interactive, callback=callback, temperature=temperature, top_k=top_k, ) - aggregate_metrics['accept_counts'].append(metrics['accept_counts']) if i == -1: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") continue @@ -387,12 +280,6 @@ def callback(x): aggregate_metrics['tokens_per_sec'].append(tokens_sec) print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") - print("==========") - if is_speculative: - counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])] - acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated] - print(f"Acceptance probs: {acceptance_probs}") - print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}") print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") @@ -412,13 +299,10 @@ def callback(x): parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') parser.add_argument('--profile', type=Path, default=None, help='Profile path.') - parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.') - parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.') parser.add_argument('--device', type=str, default="cuda", help='device to use') args = parser.parse_args() main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, - args.speculate_k, args.device + args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device ) diff --git a/mixtral-moe/model.py b/mixtral-moe/model.py index 1aae4521..85c740f1 100644 --- a/mixtral-moe/model.py +++ b/mixtral-moe/model.py @@ -29,7 +29,6 @@ class ModelArgs: head_dim: int = 64 rope_base: float = 10000 norm_eps: float = 1e-5 - moe: bool = False num_experts: int = 8 num_activated_experts: int = 2 @@ -53,13 +52,7 @@ def from_name(cls, name: str): transformer_configs = { - "Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2, moe=True), - "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000), - "7B": dict(n_layer=32, n_head=32, dim=4096), - "13B": dict(n_layer=40, n_head=40, dim=5120), - "30B": dict(n_layer=60, n_head=52, dim=6656), - "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf - "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), + "Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), } class KVCache(nn.Module): @@ -129,19 +122,13 @@ class TransformerBlock(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.attention = Attention(config) - if config.moe: - self.block_sparse_moe = MOEFeedForward(config) - else: - self.feed_forward = FeedForward(config) + self.block_sparse_moe = MOEFeedForward(config) self.ffn_norm = RMSNorm(config.dim, config.norm_eps) self.attention_norm = RMSNorm(config.dim, config.norm_eps) def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) - if hasattr(self, "block_sparse_moe"): - out = h + self.block_sparse_moe(self.ffn_norm(h)) - else: - out = h + self.feed_forward(self.ffn_norm(h)) + out = h + self.block_sparse_moe(self.ffn_norm(h)) return out @@ -197,17 +184,6 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona return y -class FeedForward(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) - - def forward(self, x: Tensor) -> Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - class ConditionalFeedForward(nn.Module): def __init__(self, config): super().__init__() diff --git a/mixtral-moe/quantize.py b/mixtral-moe/quantize.py index 17b1ae5a..f78acc14 100644 --- a/mixtral-moe/quantize.py +++ b/mixtral-moe/quantize.py @@ -9,13 +9,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from sentencepiece import SentencePieceProcessor - -try: - from GPTQ import GenericGPTQRunner, InputRecorder - from eval import get_task_dict, evaluate -except: - pass from model import Transformer, ConditionalFeedForward @@ -55,246 +48,6 @@ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): return quant, scales, zero_points -def get_group_qparams(w, n_bit=4, groupsize=128): - # needed for GPTQ with padding - if groupsize > w.shape[-1]: - groupsize = w.shape[-1] - assert groupsize > 1 - assert w.shape[-1] % groupsize == 0 - assert w.dim() == 2 - - to_quant = w.reshape(-1, groupsize) - assert torch.isnan(to_quant).sum() == 0 - - max_val = to_quant.amax(dim=1, keepdim=True) - min_val = to_quant.amin(dim=1, keepdim=True) - max_int = 2**n_bit - 1 - scales = (max_val - min_val).clamp(min=1e-6) / max_int - zeros = min_val + scales * (2 ** (n_bit - 1)) - return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( - torch.bfloat16 - ).reshape(w.shape[0], -1) - - -def pack_scales_and_zeros(scales, zeros): - assert scales.shape == zeros.shape - assert scales.dtype == torch.bfloat16 - assert zeros.dtype == torch.bfloat16 - return ( - torch.cat( - [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), - ], - 2, - ) - .transpose(0, 1) - .contiguous() - ) - - -def unpack_scales_and_zeros(scales_and_zeros): - assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 - assert scales_and_zeros.dtype == torch.float - return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) - - -def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): - assert groupsize > 1 - # needed for GPTQ single column quantize - if groupsize > w.shape[-1] and scales.shape[-1] == 1: - groupsize = w.shape[-1] - - assert w.shape[-1] % groupsize == 0 - assert w.dim() == 2 - - to_quant = w.reshape(-1, groupsize) - assert torch.isnan(to_quant).sum() == 0 - - scales = scales.reshape(-1, 1) - zeros = zeros.reshape(-1, 1) - min_val = zeros - scales * (2 ** (n_bit - 1)) - max_int = 2**n_bit - 1 - min_int = 0 - w_int32 = ( - to_quant.sub(min_val) - .div(scales) - .round() - .clamp_(min_int, max_int) - .to(torch.int32) - .reshape_as(w) - ) - - return w_int32 - - -def group_quantize_tensor(w, n_bit=4, groupsize=128): - scales, zeros = get_group_qparams(w, n_bit, groupsize) - w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) - scales_and_zeros = pack_scales_and_zeros(scales, zeros) - return w_int32, scales_and_zeros - - -def group_dequantize_tensor_from_qparams( - w_int32, scales, zeros, n_bit=4, groupsize=128 -): - assert groupsize > 1 - # needed for GPTQ single column dequantize - if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: - groupsize = w_int32.shape[-1] - assert w_int32.shape[-1] % groupsize == 0 - assert w_int32.dim() == 2 - - w_int32_grouped = w_int32.reshape(-1, groupsize) - scales = scales.reshape(-1, 1) - zeros = zeros.reshape(-1, 1) - - w_dq = ( - w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) - ) - return w_dq - - -def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): - scales, zeros = unpack_scales_and_zeros(scales_and_zeros) - return group_dequantize_tensor_from_qparams( - w_int32, scales, zeros, n_bit, groupsize - ) - -class QuantHandler: - def __init__(self, mod): - self.mod = mod - - def create_quantized_state_dict(self) -> "StateDict": - pass - - def convert_for_runtime(self) -> "nn.Module": - pass - -class GPTQQuantHandler(QuantHandler): - """ - This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. - Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement - __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. - - The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and - create_quantized_state_dict. Here is a description of each function. - - get_qparams_func: - A function that calculates the quantization qparams for an input tensor. - Args: - weight: A 2d weight tensor with non-integer dtype. - Returns: - qparams: it can have any format but will need to be handled by the other defined functions below. - - quantize_func: - A function that applies quantization to an input tensor. It should be noted - that this function needs to be able to handle quantizing the entire weight tensor, a single group, - or a single column. - Args: - weight: A 2d weight tensor with non-integer dtype. - qparams: the output from get_qparams_func - Returns: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - - - dequantize_func: - A function that dequantizes an input quantized weight tensor. It should be noted - that this function needs to be able to handle dequantizing the entire weight tensor, a single group, - or a single column. - Args: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - qparams: the output from get_qparams_func - Returns: - weight: A 2d weight tensor with non-integer dtype. - - combine_qparams_list_func: - A function that combines several qparams into one qparam. - Args: - qparams_list: a list of qparams objects, each obtained by calling get_qparams_func - on a single group from a weight tensor - Returns: - qparams: an object of the same format as the qparams above. - - skip_layer_func: - A function that determines which linear layers should be skipped during GPTQ - Args: - weight: A 2d weight tensor with non-integer dtype. - Returns: - skip: boolean indicating whether layer should be skipped - - make_names_and_values_dict_func: - A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they - should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. - Args: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - qparams: the output from get_qparams_func - Returns: - names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the - corresponding quantized weights and qparams. - """ - def __init__(self): - assert self.mod is not None - assert self.get_qparams_func is not None - assert self.quantize_func is not None - assert self.dequantize_func is not None - assert self.combine_qparams_list_func is not None - assert self.make_names_and_values_dict_func is not None - - @staticmethod - def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput": - input_recorder = InputRecorder( - model, - tokenizer, - calibration_seq_length, - pad_calibration_inputs, - ) - task_dict = get_task_dict(calibration_tasks) - print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) - evaluate( - input_recorder, - task_dict, - limit=calibration_limit, - ) - inputs = input_recorder.get_recorded_inputs() - print(f"Obtained {len(inputs[0].values)} calibration samples") - return inputs - - @torch.no_grad() - def create_quantized_state_dict( - self, - tokenizer, - blocksize, - percdamp, - groupsize, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs, - ) -> "StateDict": - inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) - print("Tracing model for GPTQ") - GPTQ_runner = GenericGPTQRunner( - self.mod, - inputs, - blocksize, - percdamp, - groupsize, - ).configure_quantization_mode( - self.get_qparams_func, - self.quantize_func, - self.dequantize_func, - self.combine_qparams_list_func, - self.make_names_and_values_dict_func, - self.skip_layer_func - ) - - print("Applying GPTQ to weights") - GPTQ_runner.run() - return GPTQ_runner.get_quantized_state_dict() - - def convert_for_runtime(self) -> "nn.Module": - pass ##### Weight-only int8 per-channel quantized code ###### @@ -388,189 +141,10 @@ def forward(self, x, expert_indices): expert_outs = torch.einsum('tao, taoi -> tai', (x1 * x3), w2_weights) return expert_outs -##### weight only int4 per channel groupwise quantized code ###### - -def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): - weight_int32, scales_and_zeros = group_quantize_tensor( - weight_bf16, n_bit=4, groupsize=groupsize - ) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) - return weight_int4pack, scales_and_zeros - - -def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): - origin_x_size = x.size() - x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) - new_shape = origin_x_size[:-1] + (out_features,) - c = c.reshape(new_shape) - return c - - -def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1): - return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 - -def replace_linear_int4(module, groupsize, inner_k_tiles, padding): - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): - setattr(module, name, WeightOnlyInt4Linear( - child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False, - )) - elif padding: - setattr(module, name, WeightOnlyInt4Linear( - child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True, - )) - else: - replace_linear_int4(child, groupsize, inner_k_tiles, padding) - - -class WeightOnlyInt4QuantHandler: - def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): - self.mod = mod - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.padding = padding - assert groupsize in [32, 64, 128, 256] - assert inner_k_tiles in [2, 4, 8] - - @torch.no_grad() - def create_quantized_state_dict(self, use_cuda = True): - if use_cuda: - device="cuda" - else: - device="cpu" - - cur_state_dict = self.mod.state_dict() - for fqn, mod in self.mod.named_modules(): - if isinstance(mod, torch.nn.Linear): - assert not mod.bias - out_features = mod.out_features - in_features = mod.in_features - assert out_features % 8 == 0, "require out_features % 8 == 0" - print(f"linear: {fqn}, in={in_features}, out={out_features}") - - weight = mod.weight.data - if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): - if self.padding: - from model import find_multiple - import torch.nn.functional as F - print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") - padded_in_features = find_multiple(in_features, 1024) - weight = F.pad(weight, pad=(0, padded_in_features - in_features)) - else: - print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + - "and that groupsize and inner_k_tiles*16 evenly divide into it") - continue - weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros( - weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles - ) - cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu') - cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu') - - return cur_state_dict - - def convert_for_runtime(self): - replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) - return self.mod - -class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): - def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): - from model import find_multiple - self.mod = mod - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.padding = padding - self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize) - self.quantize_func = lambda w, qparams: \ - group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize) - self.dequantize_func = lambda q, qparams: \ - group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float() - self.combine_qparams_list_func = lambda qparams_list: \ - [torch.cat(x, dim=1) for x in zip(*qparams_list)] - # skip unless padding=True or its correctly sized - self.skip_layer_func = lambda linear_weight: not ( - _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding - ) - # we need to do the padding here, both for q and the qparams if necessary - def make_names_and_values_dict_func(q, qparams): - k = q.shape[1] - new_k = find_multiple(k, 1024) - # how much we need to pad the weight - delta_k = new_k - q.shape[1] - final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) - scales_and_zeros = pack_scales_and_zeros(*qparams) - # how many new groups we need for padded weight - delta_groups = new_k // groupsize - scales_and_zeros.shape[0] - final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) - return {"weight": final_q, "scales_and_zeros": final_s_and_z} - self.make_names_and_values_dict_func = make_names_and_values_dict_func - super().__init__() - - - def convert_for_runtime(self): - replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) - return self.mod - -class WeightOnlyInt4Linear(torch.nn.Module): - __constants__ = ['in_features', 'out_features'] - in_features: int - out_features: int - weight: torch.Tensor - - def __init__( - self, in_features: int, out_features: int, - bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True, - ) -> None: - super().__init__() - self.padding = padding - if padding: - from model import find_multiple - self.origin_in_features = in_features - in_features = find_multiple(in_features, 1024) - - self.in_features = in_features - self.out_features = out_features - assert not bias, "require bias=False" - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - - assert out_features % 8 == 0, "require out_features % 8 == 0" - assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" - self.register_buffer( - "weight", - torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) - ) - self.register_buffer( - "scales_and_zeros", - torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = input.to(torch.bfloat16) - if self.padding: - import torch.nn.functional as F - input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) - return linear_forward_int4( - input, - self.weight, self.scales_and_zeros, self.out_features, self.groupsize - ) - def quantize( - checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), + checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"), mode: str = 'int8', - # following arguments only available when setting int4 quantization. - groupsize: int = 128, - # following arguments only used for GPTQ - calibration_tasks: list = ["hellaswag"], - calibration_limit: int = 1000, - calibration_seq_length: int = 100, - pad_calibration_inputs: bool = False, - percdamp: float = .01, - blocksize: int = 128, label: str = '', ) -> None: assert checkpoint_path.is_file(), checkpoint_path @@ -597,39 +171,8 @@ def quantize( base_name = checkpoint_path.name new_base_name = base_name.replace('.pth', f'{label}int8.pth') - elif mode == 'int4': - print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization") - quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) - quantized_state_dict = quant_handler.create_quantized_state_dict() - - dir_name = checkpoint_path.parent - base_name = checkpoint_path.name - new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth") - - elif mode == 'int4-gptq': - print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...") - quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize) - - tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), tokenizer_path - tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) - - quantized_state_dict = quant_handler.create_quantized_state_dict( - tokenizer, - blocksize, - percdamp, - groupsize, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs - ) - - dir_name = checkpoint_path.parent - base_name = checkpoint_path.name - new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth") else: - raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]") + raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8,]") quantize_path = dir_name / new_base_name print(f"Writing quantized weights to {quantize_path}") @@ -643,14 +186,7 @@ def quantize( parser = argparse.ArgumentParser(description='Quantize a model.') parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') - parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.') - parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['hellaswag'], help='tasks to do gptq calibration on, if doing gptq') - parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration') - parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration') - parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower') - parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening') - parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq') parser.add_argument('--label', type=str, default='_', help='label to add to output filename') args = parser.parse_args() - quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label) + quantize(args.checkpoint_path, args.mode, args.label) diff --git a/mixtral-moe/scripts/convert_hf_checkpoint.py b/mixtral-moe/scripts/convert_hf_checkpoint.py index d789d689..686c4673 100644 --- a/mixtral-moe/scripts/convert_hf_checkpoint.py +++ b/mixtral-moe/scripts/convert_hf_checkpoint.py @@ -20,86 +20,14 @@ @torch.inference_mode() -def convert_hf_checkpoint_llama( - *, - checkpoint_dir: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"), - model_name: Optional[str] = None, -) -> None: - config = ModelArgs.from_name(model_name) - print(f"Model config {config.__dict__}") - - # Load the json file containing weight mapping - model_map_json = checkpoint_dir / "pytorch_model.bin.index.json" - - assert model_map_json.is_file() - - with open(model_map_json) as json_map: - bin_index = json.load(json_map) - - weight_map = { - "model.embed_tokens.weight": "tok_embeddings.weight", - "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", - "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", - "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", - "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, - 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', - "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", - "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", - "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", - "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", - "model.norm.weight": "norm.weight", - "lm_head.weight": "output.weight", - } - bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} - - def permute(w, n_head): - dim = config.dim - return ( - w.view(n_head, 2, config.head_dim // 2, dim) - .transpose(1, 2) - .reshape(config.head_dim * n_head, dim) - ) - - merged_result = {} - for file in sorted(bin_files): - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) - merged_result.update(state_dict) - final_result = {} - for key, value in merged_result.items(): - if "layers" in key: - abstract_key = re.sub(r'(\d+)', '{}', key) - layer_num = re.search(r'\d+', key).group(0) - new_key = weight_map[abstract_key] - if new_key is None: - continue - new_key = new_key.format(layer_num) - else: - new_key = weight_map[key] - - final_result[new_key] = value - - for key in tuple(final_result.keys()): - if "wq" in key: - q = final_result[key] - k = final_result[key.replace("wq", "wk")] - v = final_result[key.replace("wq", "wv")] - q = permute(q, config.n_head) - k = permute(k, config.n_local_heads) - final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) - del final_result[key] - del final_result[key.replace("wq", "wk")] - del final_result[key.replace("wq", "wv")] - print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") - torch.save(final_result, checkpoint_dir / "model.pth") - - -@torch.inference_mode() -def convert_hf_checkpoint_mixtral( +def convert_hf_checkpoint( *, checkpoint_dir: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1"), model_name: Optional[str] = None, ) -> None: + if model_name is None: + model_name = checkpoint_dir.name + config = ModelArgs.from_name(model_name) print(f"Model config {config.__dict__}") @@ -157,20 +85,6 @@ def convert_hf_checkpoint_mixtral( torch.save(final_result, checkpoint_dir / "model.pth") -def convert_hf_checkpoint( - *, - checkpoint_dir: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"), - model_name: Optional[str] = None, -) -> None: - if model_name is None: - model_name = checkpoint_dir.name - - if model_name == "Mixtral-8x7B-v0.1": - return convert_hf_checkpoint_mixtral(checkpoint_dir=checkpoint_dir, model_name=model_name) - else: - return convert_hf_checkpoint_llama(checkpoint_dir=checkpoint_dir, model_name=model_name) - - if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') diff --git a/mixtral-moe/tp.py b/mixtral-moe/tp.py index 94304a6e..406f9df5 100644 --- a/mixtral-moe/tp.py +++ b/mixtral-moe/tp.py @@ -11,8 +11,7 @@ from torch import nn from torch.distributed import _functional_collectives as funcol -from model import Attention, FeedForward, MOEFeedForward, Transformer -from quantize import WeightOnlyInt4Linear +from model import Attention, MOEFeedForward, Transformer def _get_rank() -> int: @@ -81,20 +80,11 @@ def shard_qkv(qkv, dim, weight_splits): # attention assert len(weight_splits) == 3 - if isinstance(linear, WeightOnlyInt4Linear): - sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits]) - linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits) - else: - sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits) + sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits) if hasattr(linear, "scales") and style == "colwise": linear.scales = shard_qkv(linear.scales, 0, weight_splits) else: sharded_weight = shard(linear.weight, shard_dim) - if isinstance(linear, WeightOnlyInt4Linear): - linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim) - if style == "rowwise": - assert linear.scales_and_zeros.shape[0] * 32 == sharded_weight.shape[1] * sharded_weight.shape[2] * sharded_weight.shape[3] - assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8 if hasattr(linear, "scales") and style == "colwise": linear.scales = shard(linear.scales, 0) @@ -106,20 +96,6 @@ def shard_qkv(qkv, dim, weight_splits): # assert linear.weight.shape == (linear.out_features, linear.in_features) -def _apply_tp_ffn(mlp: FeedForward) -> None: - assert hasattr(mlp, "w1") - assert hasattr(mlp, "w3") - assert hasattr(mlp, "w2") - - _apply_tp_linear(mlp.w1, "colwise") - _apply_tp_linear(mlp.w3, "colwise") - _apply_tp_linear(mlp.w2, "rowwise") - - world_size = _get_world_size() - mlp.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( - output, "sum", list(range(world_size)))) - - def _apply_tp_moe_ffn(mlp: MOEFeedForward) -> None: mlp.cond_ffn.w1 = nn.Parameter(shard(mlp.cond_ffn.w1, 1), requires_grad=False) mlp.cond_ffn.w3 = nn.Parameter(shard(mlp.cond_ffn.w3, 1), requires_grad=False) @@ -167,8 +143,5 @@ def apply_tp(model: Transformer) -> None: _apply_tp_Transformer(model) for block in model.layers: # Apply to MLP - if model.config.moe: - _apply_tp_moe_ffn(block.block_sparse_moe) - else: - _apply_tp_ffn(block.feed_forward) + _apply_tp_moe_ffn(block.block_sparse_moe) _apply_tp_attn(block.attention)