@@ -120,7 +120,6 @@ int main(int argc, char ** argv) {
120
120
}
121
121
}
122
122
123
-
124
123
// Tokenize the prompt
125
124
std::vector<llama_token> inp;
126
125
inp = common_tokenize (ctx_tgt, params.prompt , true , true );
@@ -139,18 +138,6 @@ int main(int argc, char ** argv) {
139
138
LOG (" %s" , common_token_to_piece (ctx_tgt, id).c_str ());
140
139
}
141
140
142
- const int n_input = inp.size ();
143
-
144
- const auto t_enc_start = ggml_time_us ();
145
-
146
- // eval the prompt
147
- llama_decode (ctx_tgt, llama_batch_get_one (inp.data (), n_input - 1 ));
148
-
149
- // note: keep the last token separate!
150
- llama_token id_last = inp.back ();
151
-
152
- int n_past = inp.size () - 1 ;
153
-
154
141
// how many tokens to draft each time
155
142
int n_draft = params.n_draft ;
156
143
@@ -161,9 +148,25 @@ int main(int argc, char ** argv) {
161
148
// used to determine end of generation
162
149
bool has_eos = false ;
163
150
151
+ // ================================================
152
+ // everything until here is standard initialization
153
+ // the relevant stuff for speculative decoding starts here
154
+
155
+ const int n_input = inp.size ();
156
+
157
+ const auto t_enc_start = ggml_time_us ();
158
+
164
159
// target model sampling context
165
160
struct common_sampler * smpl = common_sampler_init (model_tgt, params.sparams );
166
161
162
+ // eval the prompt
163
+ llama_decode (ctx_tgt, llama_batch_get_one (inp.data (), n_input - 1 ));
164
+
165
+ // note: keep the last token separate!
166
+ llama_token id_last = inp.back ();
167
+
168
+ int n_past = inp.size () - 1 ;
169
+
167
170
// init the speculator
168
171
struct common_speculative_params params_spec;
169
172
params_spec.n_draft = n_draft;
@@ -174,6 +177,13 @@ int main(int argc, char ** argv) {
174
177
struct common_speculative * spec = common_speculative_init (params_spec);
175
178
176
179
// feed the prompt to the speculator
180
+ //
181
+ // this has to be kept synchronized with the target context
182
+ //
183
+ // TODO: simplify this by moving the context management logic in the common_speculative instance
184
+ // for example, the common_speculative_add_draft can pass the entire context (or part of it) and the
185
+ // speculator will automatically compute any new tokens that are not present in its context
186
+ //
177
187
common_speculative_set_prompt (spec, inp.data (), n_input - 1 );
178
188
179
189
llama_batch batch_tgt = llama_batch_init (llama_n_batch (ctx_tgt), 0 , 1 );
@@ -188,23 +198,41 @@ int main(int argc, char ** argv) {
188
198
common_batch_add (batch_tgt, id_last, n_past, { 0 }, true );
189
199
190
200
// optionally, append draft tokens to the target batch
201
+ //
202
+ // this is the most important part of the speculation. the more probable tokens that are provided here
203
+ // the better the performance will be. in theory, this computation can be performed asynchronously and even
204
+ // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
205
+ // from a cache or lookup tables.
206
+ //
191
207
common_speculative_add_draft (spec, batch_tgt, id_last, n_past);
192
208
193
- // evaluate the target model on the drafted tokens
209
+ // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
194
210
{
195
211
// LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
196
212
197
213
llama_decode (ctx_tgt, batch_tgt);
198
214
}
199
215
200
- // process the full target batch and return the accepted token based on the target sampler
216
+ // sample from the full target batch and return the accepted tokens based on the target sampler
217
+ //
218
+ // for each token to be accepted, the sampler would have to sample that same token
219
+ // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
220
+ // available logits from the batch and sample the next token until we run out of logits or the sampler
221
+ // disagrees with the draft
222
+ //
201
223
const auto ids = common_speculative_sample (spec, smpl, ctx_tgt);
202
224
225
+ GGML_ASSERT (ids.size () > 0 ); // there will always be at least one accepted token
226
+
203
227
n_past += ids.size ();
204
228
n_drafted += batch_tgt.n_tokens - 1 ;
205
229
n_accept += ids.size () - 1 ;
206
230
207
231
// process the accepted tokens and update contexts
232
+ //
233
+ // this is the standard token post-processing that we normally do
234
+ // in this case, we do it for a group of accepted tokens at once
235
+ //
208
236
{
209
237
llama_token id;
210
238
std::string token_str;
@@ -232,7 +260,7 @@ int main(int argc, char ** argv) {
232
260
break ;
233
261
}
234
262
235
- LOG_DBG (" the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens \n " , id, token_str.c_str ());
263
+ LOG_DBG (" accepted %d draft tokens, the last target token is: (%d, '%s')\n " , ( int ) ids. size () - 1 , id, token_str.c_str ());
236
264
237
265
{
238
266
LOG_DBG (" clear kv cache from any extra tokens, n_past = %d\n " , n_past);
@@ -241,6 +269,7 @@ int main(int argc, char ** argv) {
241
269
llama_kv_cache_seq_rm (ctx_dft, 0 , n_past, -1 );
242
270
}
243
271
272
+ // remember the last accepted token for the next iteration
244
273
id_last = id;
245
274
}
246
275
}
0 commit comments