Skip to content

Commit 88873a6

Browse files
committed
Clean up mixtral-moe
1 parent b262949 commit 88873a6

File tree

5 files changed

+20
-737
lines changed

5 files changed

+20
-737
lines changed

mixtral-moe/generate.py

Lines changed: 7 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -86,73 +86,20 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
8686
def model_forward(model, x, input_pos):
8787
return model(x, input_pos)
8888

89-
def speculative_decode(
90-
model: Transformer,
91-
draft_model: Transformer,
92-
cur_token: torch.Tensor,
93-
input_pos: int,
94-
speculate_k: int,
95-
**sampling_kwargs
96-
) -> torch.Tensor:
97-
# draft model inference sequentially
98-
device = cur_token.device
99-
orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
100-
draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs)
101-
102-
draft_tokens = torch.cat(draft_tokens)
103-
# parallel inference on target model using draft tokens
104-
target_logits = model_forward(
105-
model,
106-
torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
107-
torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device)
108-
)
109-
target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
110-
draft_probs = torch.stack(draft_probs)
111-
# q: target prob, p: draft prob
112-
# q >= p: always accept draft token
113-
# q < p: q/p prob to accept draft token
114-
p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
115-
q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
116-
accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p)
117-
rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero()
118-
119-
if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
120-
accept_length = speculate_k + 1
121-
last_token = multinomial_sample_one_no_sync(target_probs[-1])
122-
# fill last token into draft model
123-
model_forward(
124-
draft_model,
125-
draft_tokens[-1].view(1, -1),
126-
orig_input_pos + speculate_k,
127-
)
128-
return torch.cat([draft_tokens, last_token])
129-
else:
130-
accept_length = rejected_locations[0].item()
131-
p = draft_probs[accept_length]
132-
q = target_probs[accept_length]
133-
new = q - p
134-
new = torch.where(new > 0, new, 0.0)
135-
new = new / new.sum()
136-
next_token = multinomial_sample_one_no_sync(new)
137-
return torch.cat([draft_tokens[:accept_length], next_token])
138-
13989
@torch.no_grad()
14090
def generate(
14191
model: Transformer,
14292
prompt: torch.Tensor,
14393
max_new_tokens: int,
14494
*,
14595
interactive: bool,
146-
draft_model: Transformer,
147-
speculate_k: Optional[int] = 8,
14896
callback = lambda x: x,
14997
**sampling_kwargs
15098
) -> torch.Tensor:
15199
"""
152100
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
153101
"""
154102

155-
is_speculative = draft_model is not None
156103
# create an empty tensor of the expected final shape and fill in the current tokens
157104
T = prompt.size(0)
158105
T_new = T + max_new_tokens
@@ -162,11 +109,8 @@ def generate(
162109
max_seq_length = min(T_new, model.config.block_size)
163110

164111
device, dtype = prompt.device, prompt.dtype
165-
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
166112
with torch.device(device):
167113
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
168-
if is_speculative and draft_model is not model:
169-
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
170114

171115
# create an empty tensor of the expected final shape and fill in the current tokens
172116
empty = torch.empty(T_new, dtype=dtype, device=device)
@@ -175,37 +119,14 @@ def generate(
175119
input_pos = torch.arange(0, T, device=device)
176120

177121
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
178-
if is_speculative:
179-
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
180122
seq[T] = next_token
181123

182124
input_pos = torch.tensor([T], device=device, dtype=torch.int)
183-
accept_counts = [0] * (speculate_k + 1)
184-
185-
if is_speculative:
186-
input_pos = input_pos.item() # for speculative decoding easier to keep on host
187-
while input_pos < T_new - 1:
188-
cur_token = next_token.view(())
189-
190-
next_tokens = speculative_decode(
191-
model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs
192-
)
193125

194-
accept_counts[len(next_tokens) - 1] += 1
195-
num_added = min(T_new - input_pos - 1, len(next_tokens))
196-
seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added]
197-
for i in next_tokens[: num_added,]:
198-
callback(i)
199-
input_pos = input_pos + num_added
200-
next_token = next_tokens[-1]
201-
else:
202-
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
203-
seq[T + 1:] = torch.cat(generated_tokens)
126+
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
127+
seq[T + 1:] = torch.cat(generated_tokens)
204128

205-
generate_stats = {
206-
'accept_counts': accept_counts
207-
}
208-
return seq, generate_stats
129+
return seq
209130

210131
def encode_tokens(tokenizer, string, bos=True, device='cuda'):
211132
tokens = tokenizer.encode(string)
@@ -223,15 +144,6 @@ def _load_model(checkpoint_path, device, precision, use_tp):
223144
simple_quantizer = WeightOnlyBit8QuantHandler(model, torch.int8)
224145
model = simple_quantizer.convert_for_runtime()
225146

226-
if "int4" in str(checkpoint_path):
227-
print("Using int4 quantization!")
228-
path_comps = checkpoint_path.name.split(".")
229-
assert path_comps[-2].startswith("g")
230-
groupsize = int(path_comps[-2][1:])
231-
from quantize import WeightOnlyInt4QuantHandler
232-
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
233-
model = simple_quantizer.convert_for_runtime()
234-
235147
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
236148
model.load_state_dict(checkpoint, assign=True)
237149

@@ -252,12 +164,10 @@ def main(
252164
max_new_tokens: int = 100,
253165
top_k: int = 200,
254166
temperature: float = 0.8,
255-
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
167+
checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"),
256168
compile: bool = True,
257169
compile_prefill: bool = False,
258170
profile: Optional[Path] = None,
259-
draft_checkpoint_path: Optional[Path] = None,
260-
speculate_k: int = 5,
261171
device='cuda',
262172
) -> None:
263173
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
@@ -277,18 +187,12 @@ def main(
277187

278188
print(f"Using device={device}")
279189
precision = torch.bfloat16
280-
is_speculative = draft_checkpoint_path is not None
281190
is_chat = "chat" in str(checkpoint_path)
282191

283192
print("Loading model ...")
284193
t0 = time.time()
285194
model = _load_model(checkpoint_path, device, precision, use_tp)
286195

287-
if is_speculative:
288-
draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
289-
else:
290-
draft_model = None
291-
292196
device_sync(device=device) # MKG
293197
print(f"Time to load model: {time.time() - t0:.02f} seconds")
294198

@@ -299,14 +203,7 @@ def main(
299203
torch.manual_seed(1234)
300204
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
301205
if compile:
302-
if is_speculative and use_tp: # and ("cuda" in device):
303-
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
304-
if model.config.moe:
305-
torch._inductor.config.assert_indirect_indexing = False
306-
307-
if is_speculative:
308-
global model_forward, logits_to_prob
309-
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
206+
torch._inductor.config.assert_indirect_indexing = False
310207

311208
global decode_one_token, prefill
312209
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
@@ -318,7 +215,6 @@ def main(
318215

319216
aggregate_metrics = {
320217
'tokens_per_sec': [],
321-
'accept_counts': [],
322218
}
323219
start = -1 if compile else 0
324220

@@ -355,18 +251,15 @@ def callback(x):
355251
torch.profiler._utils._init_for_cuda_graphs()
356252
prof = torch.profiler.profile()
357253
with prof:
358-
y, metrics = generate(
254+
y = generate(
359255
model,
360256
encoded,
361257
max_new_tokens,
362-
draft_model=draft_model,
363-
speculate_k=speculate_k,
364258
interactive=interactive,
365259
callback=callback,
366260
temperature=temperature,
367261
top_k=top_k,
368262
)
369-
aggregate_metrics['accept_counts'].append(metrics['accept_counts'])
370263
if i == -1:
371264
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
372265
continue
@@ -387,12 +280,6 @@ def callback(x):
387280
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
388281
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
389282
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
390-
print("==========")
391-
if is_speculative:
392-
counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])]
393-
acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated]
394-
print(f"Acceptance probs: {acceptance_probs}")
395-
print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}")
396283

397284
print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
398285
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
@@ -412,13 +299,10 @@ def callback(x):
412299
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
413300
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
414301
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
415-
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
416-
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
417302
parser.add_argument('--device', type=str, default="cuda", help='device to use')
418303

419304
args = parser.parse_args()
420305
main(
421306
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
422-
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
423-
args.speculate_k, args.device
307+
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device
424308
)

mixtral-moe/model.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class ModelArgs:
2929
head_dim: int = 64
3030
rope_base: float = 10000
3131
norm_eps: float = 1e-5
32-
moe: bool = False
3332
num_experts: int = 8
3433
num_activated_experts: int = 2
3534

@@ -53,13 +52,7 @@ def from_name(cls, name: str):
5352

5453

5554
transformer_configs = {
56-
"Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2, moe=True),
57-
"CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000),
58-
"7B": dict(n_layer=32, n_head=32, dim=4096),
59-
"13B": dict(n_layer=40, n_head=40, dim=5120),
60-
"30B": dict(n_layer=60, n_head=52, dim=6656),
61-
"34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf
62-
"70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
55+
"Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2),
6356
}
6457

6558
class KVCache(nn.Module):
@@ -129,19 +122,13 @@ class TransformerBlock(nn.Module):
129122
def __init__(self, config: ModelArgs) -> None:
130123
super().__init__()
131124
self.attention = Attention(config)
132-
if config.moe:
133-
self.block_sparse_moe = MOEFeedForward(config)
134-
else:
135-
self.feed_forward = FeedForward(config)
125+
self.block_sparse_moe = MOEFeedForward(config)
136126
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
137127
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
138128

139129
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
140130
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
141-
if hasattr(self, "block_sparse_moe"):
142-
out = h + self.block_sparse_moe(self.ffn_norm(h))
143-
else:
144-
out = h + self.feed_forward(self.ffn_norm(h))
131+
out = h + self.block_sparse_moe(self.ffn_norm(h))
145132
return out
146133

147134

@@ -197,17 +184,6 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
197184
return y
198185

199186

200-
class FeedForward(nn.Module):
201-
def __init__(self, config: ModelArgs) -> None:
202-
super().__init__()
203-
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
204-
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
205-
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
206-
207-
def forward(self, x: Tensor) -> Tensor:
208-
return self.w2(F.silu(self.w1(x)) * self.w3(x))
209-
210-
211187
class ConditionalFeedForward(nn.Module):
212188
def __init__(self, config):
213189
super().__init__()

0 commit comments

Comments
 (0)