diff --git a/README.md b/README.md index 58ad9635..bab7f35d 100644 --- a/README.md +++ b/README.md @@ -22,10 +22,10 @@ Please check the rest of this page about benchmark of LLaMA family models. ### Mixtral 8x7B We also supported [Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) which is a high-quality sparse mixture of experts (MoE) model, the average token generation rates are: -| | 1 GPU | 2 GPU | 4 GPU | 8 GPU | +| | 1 GPU | 2 GPU | 4 GPU | 8 GPU | |------------------|---------|-----------|--------|------------| -|baseline(bfloat16)| OOM | 78.75 | 118.23 | 203.69 | -| int8 | 56.04 | 99.91 | 149.53 | 218.48 | +|baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 | +| int8 | 97.92 | 155.03 | 216.87 | 279.35 | Note that the benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens). diff --git a/mixtral-moe/README.md b/mixtral-moe/README.md index 9f7d3597..cf5e9d9b 100644 --- a/mixtral-moe/README.md +++ b/mixtral-moe/README.md @@ -12,11 +12,10 @@ python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO ## Benchmarks Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens). -| | 1 GPU | 2 GPU | 4 GPU | 8 GPU | +| | 1 GPU | 2 GPU | 4 GPU | 8 GPU | |------------------|---------|-----------|--------|------------| -|baseline(bfloat16)| OOM | 78.75 | 118.23 | 203.69 | -| int8 | 56.04 | 99.91 | 149.53 | 218.48 | - +|baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 | +| int8 | 97.92 | 155.03 | 216.87 | 279.35 | ## Generate Text diff --git a/mixtral-moe/model.py b/mixtral-moe/model.py index 85c740f1..9249ac9d 100644 --- a/mixtral-moe/model.py +++ b/mixtral-moe/model.py @@ -188,16 +188,16 @@ class ConditionalFeedForward(nn.Module): def __init__(self, config): super().__init__() self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) - self.w2 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) + self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: - w1_weights = self.w1[expert_indices].transpose(-1, -2) # [T, A, D, D] - w3_weights = self.w3[expert_indices].transpose(-1, -2) # [T, A, D, D] + w1_weights = self.w1[expert_indices] # [T, A, D, D] + w3_weights = self.w3[expert_indices] # [T, A, D, D] w2_weights = self.w2[expert_indices] # [T, A, D, D] - x1 = F.silu(torch.einsum('ti,taio -> tao', x, w1_weights)) - x3 = torch.einsum('ti, taio -> tao', x, w3_weights) - expert_outs = torch.einsum('tao, taoi -> tai', (x1 * x3), w2_weights) + x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) + x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) + expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) return expert_outs diff --git a/mixtral-moe/quantize.py b/mixtral-moe/quantize.py index f78acc14..6312863c 100644 --- a/mixtral-moe/quantize.py +++ b/mixtral-moe/quantize.py @@ -75,11 +75,11 @@ def create_quantized_state_dict(self): cur_state_dict[f"{fqn}.weight"] = int8_weight cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) elif isinstance(mod, ConditionalFeedForward): - num_experts, intermediate_size, dim = mod.w1.shape for weight_idx in range(0, 3): weight_name = f"w{weight_idx + 1}" scales_name = f"scales{weight_idx + 1}" weight = getattr(mod, weight_name) + num_experts, intermediate_size, dim = weight.shape bit8_weight_list = [] scales_list = [] @@ -125,20 +125,20 @@ def __init__(self, num_experts, intermediate_size, dim, target_dtype): self.target_dtype = target_dtype self.register_buffer("w1", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype)) - self.register_buffer("w2", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype)) + self.register_buffer("w2", torch.empty(num_experts, dim, intermediate_size, dtype=target_dtype)) self.register_buffer("w3", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype)) self.register_buffer("scales1", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16)) - self.register_buffer("scales2", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16)) + self.register_buffer("scales2", torch.empty(num_experts, dim, dtype=torch.bfloat16)) self.register_buffer("scales3", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16)) def forward(self, x, expert_indices): - w1_weights = (self.w1.to(x.dtype)[expert_indices] * self.scales1[expert_indices].to(x.dtype).unsqueeze(-1)).transpose(-1, -2) # [T, A, D, D] - w3_weights = (self.w3.to(x.dtype)[expert_indices] * self.scales3[expert_indices].to(x.dtype).unsqueeze(-1)).transpose(-1, -2) # [T, A, D, D] - w2_weights = (self.w2.to(x.dtype)[expert_indices] * self.scales2[expert_indices].to(x.dtype).unsqueeze(-1)) # [T, A, D, D] - x1 = F.silu(torch.einsum('ti,taio -> tao', x, w1_weights)) - x3 = torch.einsum('ti, taio -> tao', x, w3_weights) - expert_outs = torch.einsum('tao, taoi -> tai', (x1 * x3), w2_weights) + w1_weights = self.w1.to(x.dtype)[expert_indices] # [T, A, D, D] + w3_weights = self.w3.to(x.dtype)[expert_indices] # [T, A, D, D] + w2_weights = self.w2.to(x.dtype)[expert_indices] + x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights) * self.scales1[expert_indices].to(x.dtype)) + x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) * self.scales3[expert_indices].to(x.dtype) + expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) * self.scales2[expert_indices].to(x.dtype) # [T, A, D, D] return expert_outs diff --git a/mixtral-moe/scripts/convert_hf_checkpoint.py b/mixtral-moe/scripts/convert_hf_checkpoint.py index 686c4673..e659931d 100644 --- a/mixtral-moe/scripts/convert_hf_checkpoint.py +++ b/mixtral-moe/scripts/convert_hf_checkpoint.py @@ -76,9 +76,11 @@ def convert_hf_checkpoint( del final_result[key] del final_result[key.replace("wq", "wk")] del final_result[key.replace("wq", "wv")] - if "w1" in key or "w2" in key or "w3" in key: + elif "w1" in key or "w3" in key: final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous() - if "gate" in key: + elif "w2" in key: + final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous() + elif "gate" in key: final_result[key] = final_result[key].contiguous() print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") diff --git a/mixtral-moe/tp.py b/mixtral-moe/tp.py index 406f9df5..75336b58 100644 --- a/mixtral-moe/tp.py +++ b/mixtral-moe/tp.py @@ -99,12 +99,12 @@ def shard_qkv(qkv, dim, weight_splits): def _apply_tp_moe_ffn(mlp: MOEFeedForward) -> None: mlp.cond_ffn.w1 = nn.Parameter(shard(mlp.cond_ffn.w1, 1), requires_grad=False) mlp.cond_ffn.w3 = nn.Parameter(shard(mlp.cond_ffn.w3, 1), requires_grad=False) - mlp.cond_ffn.w2 = nn.Parameter(shard(mlp.cond_ffn.w2, 1), requires_grad=False) + mlp.cond_ffn.w2 = nn.Parameter(shard(mlp.cond_ffn.w2, 2), requires_grad=False) if hasattr(mlp.cond_ffn, "scales1"): mlp.cond_ffn.scales1 = nn.Parameter(shard(mlp.cond_ffn.scales1, 1), requires_grad=False) mlp.cond_ffn.scales3 = nn.Parameter(shard(mlp.cond_ffn.scales3, 1), requires_grad=False) - mlp.cond_ffn.scales2 = nn.Parameter(shard(mlp.cond_ffn.scales2, 1), requires_grad=False) + mlp.cond_ffn.scales2 = nn.Parameter(mlp.cond_ffn.scales2, requires_grad=False) world_size = _get_world_size() mlp.cond_ffn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(