@@ -86,73 +86,20 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
86
86
def model_forward (model , x , input_pos ):
87
87
return model (x , input_pos )
88
88
89
- def speculative_decode (
90
- model : Transformer ,
91
- draft_model : Transformer ,
92
- cur_token : torch .Tensor ,
93
- input_pos : int ,
94
- speculate_k : int ,
95
- ** sampling_kwargs
96
- ) -> torch .Tensor :
97
- # draft model inference sequentially
98
- device = cur_token .device
99
- orig_input_pos = torch .tensor ([input_pos ], dtype = torch .int64 , device = cur_token .device )
100
- draft_tokens , draft_probs = decode_n_tokens (draft_model , cur_token .view (1 , - 1 ), orig_input_pos .clone (), speculate_k , ** sampling_kwargs )
101
-
102
- draft_tokens = torch .cat (draft_tokens )
103
- # parallel inference on target model using draft tokens
104
- target_logits = model_forward (
105
- model ,
106
- torch .cat ([cur_token .view (1 ), draft_tokens ]).view (1 , - 1 ),
107
- torch .arange (input_pos , input_pos + speculate_k + 1 , device = cur_token .device )
108
- )
109
- target_probs = logits_to_probs (target_logits [0 ], ** sampling_kwargs )
110
- draft_probs = torch .stack (draft_probs )
111
- # q: target prob, p: draft prob
112
- # q >= p: always accept draft token
113
- # q < p: q/p prob to accept draft token
114
- p = draft_probs [torch .arange (0 , speculate_k , device = device ), draft_tokens ]
115
- q = target_probs [torch .arange (0 , speculate_k , device = device ), draft_tokens ]
116
- accept_draft_prob = torch .minimum (torch .ones (()), q [:speculate_k ]/ p )
117
- rejected_locations = (torch .rand_like (accept_draft_prob ) > accept_draft_prob ).nonzero ()
118
-
119
- if rejected_locations .shape [0 ] == 0 : # All draft tokens have been accepted
120
- accept_length = speculate_k + 1
121
- last_token = multinomial_sample_one_no_sync (target_probs [- 1 ])
122
- # fill last token into draft model
123
- model_forward (
124
- draft_model ,
125
- draft_tokens [- 1 ].view (1 , - 1 ),
126
- orig_input_pos + speculate_k ,
127
- )
128
- return torch .cat ([draft_tokens , last_token ])
129
- else :
130
- accept_length = rejected_locations [0 ].item ()
131
- p = draft_probs [accept_length ]
132
- q = target_probs [accept_length ]
133
- new = q - p
134
- new = torch .where (new > 0 , new , 0.0 )
135
- new = new / new .sum ()
136
- next_token = multinomial_sample_one_no_sync (new )
137
- return torch .cat ([draft_tokens [:accept_length ], next_token ])
138
-
139
89
@torch .no_grad ()
140
90
def generate (
141
91
model : Transformer ,
142
92
prompt : torch .Tensor ,
143
93
max_new_tokens : int ,
144
94
* ,
145
95
interactive : bool ,
146
- draft_model : Transformer ,
147
- speculate_k : Optional [int ] = 8 ,
148
96
callback = lambda x : x ,
149
97
** sampling_kwargs
150
98
) -> torch .Tensor :
151
99
"""
152
100
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
153
101
"""
154
102
155
- is_speculative = draft_model is not None
156
103
# create an empty tensor of the expected final shape and fill in the current tokens
157
104
T = prompt .size (0 )
158
105
T_new = T + max_new_tokens
@@ -162,11 +109,8 @@ def generate(
162
109
max_seq_length = min (T_new , model .config .block_size )
163
110
164
111
device , dtype = prompt .device , prompt .dtype
165
- max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
166
112
with torch .device (device ):
167
113
model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
168
- if is_speculative and draft_model is not model :
169
- draft_model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
170
114
171
115
# create an empty tensor of the expected final shape and fill in the current tokens
172
116
empty = torch .empty (T_new , dtype = dtype , device = device )
@@ -175,37 +119,14 @@ def generate(
175
119
input_pos = torch .arange (0 , T , device = device )
176
120
177
121
next_token = prefill (model , prompt .view (1 , - 1 ), input_pos , ** sampling_kwargs )
178
- if is_speculative :
179
- prefill (draft_model , prompt .view (1 , - 1 ), input_pos , ** sampling_kwargs )
180
122
seq [T ] = next_token
181
123
182
124
input_pos = torch .tensor ([T ], device = device , dtype = torch .int )
183
- accept_counts = [0 ] * (speculate_k + 1 )
184
-
185
- if is_speculative :
186
- input_pos = input_pos .item () # for speculative decoding easier to keep on host
187
- while input_pos < T_new - 1 :
188
- cur_token = next_token .view (())
189
-
190
- next_tokens = speculative_decode (
191
- model , draft_model , cur_token , input_pos , speculate_k , ** sampling_kwargs
192
- )
193
125
194
- accept_counts [len (next_tokens ) - 1 ] += 1
195
- num_added = min (T_new - input_pos - 1 , len (next_tokens ))
196
- seq [input_pos + 1 : input_pos + num_added + 1 ] = next_tokens [: num_added ]
197
- for i in next_tokens [: num_added ,]:
198
- callback (i )
199
- input_pos = input_pos + num_added
200
- next_token = next_tokens [- 1 ]
201
- else :
202
- generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
203
- seq [T + 1 :] = torch .cat (generated_tokens )
126
+ generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
127
+ seq [T + 1 :] = torch .cat (generated_tokens )
204
128
205
- generate_stats = {
206
- 'accept_counts' : accept_counts
207
- }
208
- return seq , generate_stats
129
+ return seq
209
130
210
131
def encode_tokens (tokenizer , string , bos = True , device = 'cuda' ):
211
132
tokens = tokenizer .encode (string )
@@ -223,15 +144,6 @@ def _load_model(checkpoint_path, device, precision, use_tp):
223
144
simple_quantizer = WeightOnlyBit8QuantHandler (model , torch .int8 )
224
145
model = simple_quantizer .convert_for_runtime ()
225
146
226
- if "int4" in str (checkpoint_path ):
227
- print ("Using int4 quantization!" )
228
- path_comps = checkpoint_path .name .split ("." )
229
- assert path_comps [- 2 ].startswith ("g" )
230
- groupsize = int (path_comps [- 2 ][1 :])
231
- from quantize import WeightOnlyInt4QuantHandler
232
- simple_quantizer = WeightOnlyInt4QuantHandler (model , groupsize )
233
- model = simple_quantizer .convert_for_runtime ()
234
-
235
147
checkpoint = torch .load (str (checkpoint_path ), mmap = True , weights_only = True )
236
148
model .load_state_dict (checkpoint , assign = True )
237
149
@@ -252,12 +164,10 @@ def main(
252
164
max_new_tokens : int = 100 ,
253
165
top_k : int = 200 ,
254
166
temperature : float = 0.8 ,
255
- checkpoint_path : Path = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf /model.pth" ),
167
+ checkpoint_path : Path = Path ("checkpoints/mistralai/Mixtral-8x7B-v0.1 /model.pth" ),
256
168
compile : bool = True ,
257
169
compile_prefill : bool = False ,
258
170
profile : Optional [Path ] = None ,
259
- draft_checkpoint_path : Optional [Path ] = None ,
260
- speculate_k : int = 5 ,
261
171
device = 'cuda' ,
262
172
) -> None :
263
173
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
@@ -277,18 +187,12 @@ def main(
277
187
278
188
print (f"Using device={ device } " )
279
189
precision = torch .bfloat16
280
- is_speculative = draft_checkpoint_path is not None
281
190
is_chat = "chat" in str (checkpoint_path )
282
191
283
192
print ("Loading model ..." )
284
193
t0 = time .time ()
285
194
model = _load_model (checkpoint_path , device , precision , use_tp )
286
195
287
- if is_speculative :
288
- draft_model = _load_model (draft_checkpoint_path , device , precision , use_tp )
289
- else :
290
- draft_model = None
291
-
292
196
device_sync (device = device ) # MKG
293
197
print (f"Time to load model: { time .time () - t0 :.02f} seconds" )
294
198
@@ -299,14 +203,7 @@ def main(
299
203
torch .manual_seed (1234 )
300
204
model_size = sum ([p .numel () * p .dtype .itemsize for p in itertools .chain (model .parameters (), model .buffers ())])
301
205
if compile :
302
- if is_speculative and use_tp : # and ("cuda" in device):
303
- torch ._inductor .config .triton .cudagraph_trees = False # Bug with cudagraph trees in this case
304
- if model .config .moe :
305
- torch ._inductor .config .assert_indirect_indexing = False
306
-
307
- if is_speculative :
308
- global model_forward , logits_to_prob
309
- model_forward = torch .compile (model_forward , mode = "reduce-overhead" , fullgraph = True )
206
+ torch ._inductor .config .assert_indirect_indexing = False
310
207
311
208
global decode_one_token , prefill
312
209
decode_one_token = torch .compile (decode_one_token , mode = "reduce-overhead" , fullgraph = True )
@@ -318,7 +215,6 @@ def main(
318
215
319
216
aggregate_metrics = {
320
217
'tokens_per_sec' : [],
321
- 'accept_counts' : [],
322
218
}
323
219
start = - 1 if compile else 0
324
220
@@ -355,18 +251,15 @@ def callback(x):
355
251
torch .profiler ._utils ._init_for_cuda_graphs ()
356
252
prof = torch .profiler .profile ()
357
253
with prof :
358
- y , metrics = generate (
254
+ y = generate (
359
255
model ,
360
256
encoded ,
361
257
max_new_tokens ,
362
- draft_model = draft_model ,
363
- speculate_k = speculate_k ,
364
258
interactive = interactive ,
365
259
callback = callback ,
366
260
temperature = temperature ,
367
261
top_k = top_k ,
368
262
)
369
- aggregate_metrics ['accept_counts' ].append (metrics ['accept_counts' ])
370
263
if i == - 1 :
371
264
print (f"Compilation time: { time .perf_counter () - t0 :.2f} seconds" )
372
265
continue
@@ -387,12 +280,6 @@ def callback(x):
387
280
aggregate_metrics ['tokens_per_sec' ].append (tokens_sec )
388
281
print (f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec" )
389
282
print (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
390
- print ("==========" )
391
- if is_speculative :
392
- counts_aggregated = [sum (i ) for i in zip (* aggregate_metrics ['accept_counts' ])]
393
- acceptance_probs = [i / sum (counts_aggregated ) for i in counts_aggregated ]
394
- print (f"Acceptance probs: { acceptance_probs } " )
395
- print (f"Mean Accepted: { sum ([idx * i for idx , i in enumerate (counts_aggregated )])/ sum (counts_aggregated )} " )
396
283
397
284
print (f"Average tokens/sec: { torch .mean (torch .tensor (aggregate_metrics ['tokens_per_sec' ])).item ():.2f} " )
398
285
print (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
@@ -412,13 +299,10 @@ def callback(x):
412
299
parser .add_argument ('--compile' , action = 'store_true' , help = 'Whether to compile the model.' )
413
300
parser .add_argument ('--compile_prefill' , action = 'store_true' , help = 'Whether to compile the prefill (improves prefill perf, but higher compile times)' )
414
301
parser .add_argument ('--profile' , type = Path , default = None , help = 'Profile path.' )
415
- parser .add_argument ('--speculate_k' , type = int , default = 5 , help = 'Speculative execution depth.' )
416
- parser .add_argument ('--draft_checkpoint_path' , type = Path , default = None , help = 'Draft checkpoint path.' )
417
302
parser .add_argument ('--device' , type = str , default = "cuda" , help = 'device to use' )
418
303
419
304
args = parser .parse_args ()
420
305
main (
421
306
args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .top_k ,
422
- args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .draft_checkpoint_path ,
423
- args .speculate_k , args .device
307
+ args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .device
424
308
)
0 commit comments