Skip to content

Mixtral MoE improvements: transposed w2 to have reduction dim be innermost dim #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ Please check the rest of this page about benchmark of LLaMA family models.
### Mixtral 8x7B
We also supported [Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) which is a high-quality sparse mixture of experts (MoE) model, the average token generation rates are:

| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
|------------------|---------|-----------|--------|------------|
|baseline(bfloat16)| OOM | 78.75 | 118.23 | 203.69 |
| int8 | 56.04 | 99.91 | 149.53 | 218.48 |
|baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 |
| int8 | 97.92 | 155.03 | 216.87 | 279.35 |

Note that the benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).

Expand Down
7 changes: 3 additions & 4 deletions mixtral-moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
## Benchmarks
Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).

| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
|------------------|---------|-----------|--------|------------|
|baseline(bfloat16)| OOM | 78.75 | 118.23 | 203.69 |
| int8 | 56.04 | 99.91 | 149.53 | 218.48 |

|baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 |
| int8 | 97.92 | 155.03 | 216.87 | 279.35 |


## Generate Text
Expand Down
12 changes: 6 additions & 6 deletions mixtral-moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,16 @@ class ConditionalFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size))
self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))

def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
w1_weights = self.w1[expert_indices].transpose(-1, -2) # [T, A, D, D]
w3_weights = self.w3[expert_indices].transpose(-1, -2) # [T, A, D, D]
w1_weights = self.w1[expert_indices] # [T, A, D, D]
w3_weights = self.w3[expert_indices] # [T, A, D, D]
w2_weights = self.w2[expert_indices] # [T, A, D, D]
x1 = F.silu(torch.einsum('ti,taio -> tao', x, w1_weights))
x3 = torch.einsum('ti, taio -> tao', x, w3_weights)
expert_outs = torch.einsum('tao, taoi -> tai', (x1 * x3), w2_weights)
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
return expert_outs


Expand Down
18 changes: 9 additions & 9 deletions mixtral-moe/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def create_quantized_state_dict(self):
cur_state_dict[f"{fqn}.weight"] = int8_weight
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
elif isinstance(mod, ConditionalFeedForward):
num_experts, intermediate_size, dim = mod.w1.shape
for weight_idx in range(0, 3):
weight_name = f"w{weight_idx + 1}"
scales_name = f"scales{weight_idx + 1}"
weight = getattr(mod, weight_name)
num_experts, intermediate_size, dim = weight.shape

bit8_weight_list = []
scales_list = []
Expand Down Expand Up @@ -125,20 +125,20 @@ def __init__(self, num_experts, intermediate_size, dim, target_dtype):
self.target_dtype = target_dtype

self.register_buffer("w1", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
self.register_buffer("w2", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
self.register_buffer("w2", torch.empty(num_experts, dim, intermediate_size, dtype=target_dtype))
self.register_buffer("w3", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))

self.register_buffer("scales1", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
self.register_buffer("scales2", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
self.register_buffer("scales2", torch.empty(num_experts, dim, dtype=torch.bfloat16))
self.register_buffer("scales3", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))

def forward(self, x, expert_indices):
w1_weights = (self.w1.to(x.dtype)[expert_indices] * self.scales1[expert_indices].to(x.dtype).unsqueeze(-1)).transpose(-1, -2) # [T, A, D, D]
w3_weights = (self.w3.to(x.dtype)[expert_indices] * self.scales3[expert_indices].to(x.dtype).unsqueeze(-1)).transpose(-1, -2) # [T, A, D, D]
w2_weights = (self.w2.to(x.dtype)[expert_indices] * self.scales2[expert_indices].to(x.dtype).unsqueeze(-1)) # [T, A, D, D]
x1 = F.silu(torch.einsum('ti,taio -> tao', x, w1_weights))
x3 = torch.einsum('ti, taio -> tao', x, w3_weights)
expert_outs = torch.einsum('tao, taoi -> tai', (x1 * x3), w2_weights)
w1_weights = self.w1.to(x.dtype)[expert_indices] # [T, A, D, D]
w3_weights = self.w3.to(x.dtype)[expert_indices] # [T, A, D, D]
w2_weights = self.w2.to(x.dtype)[expert_indices]
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights) * self.scales1[expert_indices].to(x.dtype))
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) * self.scales3[expert_indices].to(x.dtype)
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) * self.scales2[expert_indices].to(x.dtype) # [T, A, D, D]
return expert_outs


Expand Down
6 changes: 4 additions & 2 deletions mixtral-moe/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ def convert_hf_checkpoint(
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
if "w1" in key or "w2" in key or "w3" in key:
elif "w1" in key or "w3" in key:
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous()
if "gate" in key:
elif "w2" in key:
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous()
elif "gate" in key:
final_result[key] = final_result[key].contiguous()

print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
Expand Down
4 changes: 2 additions & 2 deletions mixtral-moe/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def shard_qkv(qkv, dim, weight_splits):
def _apply_tp_moe_ffn(mlp: MOEFeedForward) -> None:
mlp.cond_ffn.w1 = nn.Parameter(shard(mlp.cond_ffn.w1, 1), requires_grad=False)
mlp.cond_ffn.w3 = nn.Parameter(shard(mlp.cond_ffn.w3, 1), requires_grad=False)
mlp.cond_ffn.w2 = nn.Parameter(shard(mlp.cond_ffn.w2, 1), requires_grad=False)
mlp.cond_ffn.w2 = nn.Parameter(shard(mlp.cond_ffn.w2, 2), requires_grad=False)

if hasattr(mlp.cond_ffn, "scales1"):
mlp.cond_ffn.scales1 = nn.Parameter(shard(mlp.cond_ffn.scales1, 1), requires_grad=False)
mlp.cond_ffn.scales3 = nn.Parameter(shard(mlp.cond_ffn.scales3, 1), requires_grad=False)
mlp.cond_ffn.scales2 = nn.Parameter(shard(mlp.cond_ffn.scales2, 1), requires_grad=False)
mlp.cond_ffn.scales2 = nn.Parameter(mlp.cond_ffn.scales2, requires_grad=False)

world_size = _get_world_size()
mlp.cond_ffn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(
Expand Down