Skip to content

Clean up mixtral-moe #118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 7 additions & 123 deletions mixtral-moe/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,73 +86,20 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
def model_forward(model, x, input_pos):
return model(x, input_pos)

def speculative_decode(
model: Transformer,
draft_model: Transformer,
cur_token: torch.Tensor,
input_pos: int,
speculate_k: int,
**sampling_kwargs
) -> torch.Tensor:
# draft model inference sequentially
device = cur_token.device
orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs)

draft_tokens = torch.cat(draft_tokens)
# parallel inference on target model using draft tokens
target_logits = model_forward(
model,
torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device)
)
target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
draft_probs = torch.stack(draft_probs)
# q: target prob, p: draft prob
# q >= p: always accept draft token
# q < p: q/p prob to accept draft token
p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p)
rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero()

if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
accept_length = speculate_k + 1
last_token = multinomial_sample_one_no_sync(target_probs[-1])
# fill last token into draft model
model_forward(
draft_model,
draft_tokens[-1].view(1, -1),
orig_input_pos + speculate_k,
)
return torch.cat([draft_tokens, last_token])
else:
accept_length = rejected_locations[0].item()
p = draft_probs[accept_length]
q = target_probs[accept_length]
new = q - p
new = torch.where(new > 0, new, 0.0)
new = new / new.sum()
next_token = multinomial_sample_one_no_sync(new)
return torch.cat([draft_tokens[:accept_length], next_token])

@torch.no_grad()
def generate(
model: Transformer,
prompt: torch.Tensor,
max_new_tokens: int,
*,
interactive: bool,
draft_model: Transformer,
speculate_k: Optional[int] = 8,
callback = lambda x: x,
**sampling_kwargs
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""

is_speculative = draft_model is not None
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(0)
T_new = T + max_new_tokens
Expand All @@ -162,11 +109,8 @@ def generate(
max_seq_length = min(T_new, model.config.block_size)

device, dtype = prompt.device, prompt.dtype
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if is_speculative and draft_model is not model:
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
Expand All @@ -175,37 +119,14 @@ def generate(
input_pos = torch.arange(0, T, device=device)

next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
if is_speculative:
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
seq[T] = next_token

input_pos = torch.tensor([T], device=device, dtype=torch.int)
accept_counts = [0] * (speculate_k + 1)

if is_speculative:
input_pos = input_pos.item() # for speculative decoding easier to keep on host
while input_pos < T_new - 1:
cur_token = next_token.view(())

next_tokens = speculative_decode(
model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs
)

accept_counts[len(next_tokens) - 1] += 1
num_added = min(T_new - input_pos - 1, len(next_tokens))
seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added]
for i in next_tokens[: num_added,]:
callback(i)
input_pos = input_pos + num_added
next_token = next_tokens[-1]
else:
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
seq[T + 1:] = torch.cat(generated_tokens)
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
seq[T + 1:] = torch.cat(generated_tokens)

generate_stats = {
'accept_counts': accept_counts
}
return seq, generate_stats
return seq

def encode_tokens(tokenizer, string, bos=True, device='cuda'):
tokens = tokenizer.encode(string)
Expand All @@ -223,15 +144,6 @@ def _load_model(checkpoint_path, device, precision, use_tp):
simple_quantizer = WeightOnlyBit8QuantHandler(model, torch.int8)
model = simple_quantizer.convert_for_runtime()

if "int4" in str(checkpoint_path):
print("Using int4 quantization!")
path_comps = checkpoint_path.name.split(".")
assert path_comps[-2].startswith("g")
groupsize = int(path_comps[-2][1:])
from quantize import WeightOnlyInt4QuantHandler
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
model = simple_quantizer.convert_for_runtime()

checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)

Expand All @@ -252,12 +164,10 @@ def main(
max_new_tokens: int = 100,
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"),
compile: bool = True,
compile_prefill: bool = False,
profile: Optional[Path] = None,
draft_checkpoint_path: Optional[Path] = None,
speculate_k: int = 5,
device='cuda',
) -> None:
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
Expand All @@ -277,18 +187,12 @@ def main(

print(f"Using device={device}")
precision = torch.bfloat16
is_speculative = draft_checkpoint_path is not None
is_chat = "chat" in str(checkpoint_path)

print("Loading model ...")
t0 = time.time()
model = _load_model(checkpoint_path, device, precision, use_tp)

if is_speculative:
draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
else:
draft_model = None

device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")

Expand All @@ -299,14 +203,7 @@ def main(
torch.manual_seed(1234)
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
if compile:
if is_speculative and use_tp: # and ("cuda" in device):
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
if model.config.moe:
torch._inductor.config.assert_indirect_indexing = False

if is_speculative:
global model_forward, logits_to_prob
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
torch._inductor.config.assert_indirect_indexing = False

global decode_one_token, prefill
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
Expand All @@ -318,7 +215,6 @@ def main(

aggregate_metrics = {
'tokens_per_sec': [],
'accept_counts': [],
}
start = -1 if compile else 0

Expand Down Expand Up @@ -355,18 +251,15 @@ def callback(x):
torch.profiler._utils._init_for_cuda_graphs()
prof = torch.profiler.profile()
with prof:
y, metrics = generate(
y = generate(
model,
encoded,
max_new_tokens,
draft_model=draft_model,
speculate_k=speculate_k,
interactive=interactive,
callback=callback,
temperature=temperature,
top_k=top_k,
)
aggregate_metrics['accept_counts'].append(metrics['accept_counts'])
if i == -1:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
continue
Expand All @@ -387,12 +280,6 @@ def callback(x):
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
print("==========")
if is_speculative:
counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])]
acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated]
print(f"Acceptance probs: {acceptance_probs}")
print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}")

print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
Expand All @@ -412,13 +299,10 @@ def callback(x):
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
parser.add_argument('--device', type=str, default="cuda", help='device to use')

args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
args.speculate_k, args.device
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device
)
30 changes: 3 additions & 27 deletions mixtral-moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class ModelArgs:
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
moe: bool = False
num_experts: int = 8
num_activated_experts: int = 2

Expand All @@ -53,13 +52,7 @@ def from_name(cls, name: str):


transformer_configs = {
"Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2, moe=True),
"CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000),
"7B": dict(n_layer=32, n_head=32, dim=4096),
"13B": dict(n_layer=40, n_head=40, dim=5120),
"30B": dict(n_layer=60, n_head=52, dim=6656),
"34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf
"70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
"Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2),
}

class KVCache(nn.Module):
Expand Down Expand Up @@ -129,19 +122,13 @@ class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.attention = Attention(config)
if config.moe:
self.block_sparse_moe = MOEFeedForward(config)
else:
self.feed_forward = FeedForward(config)
self.block_sparse_moe = MOEFeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
if hasattr(self, "block_sparse_moe"):
out = h + self.block_sparse_moe(self.ffn_norm(h))
else:
out = h + self.feed_forward(self.ffn_norm(h))
out = h + self.block_sparse_moe(self.ffn_norm(h))
return out


Expand Down Expand Up @@ -197,17 +184,6 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
return y


class FeedForward(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)

def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))


class ConditionalFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down
Loading