@@ -230,8 +230,8 @@ int main(int argc, char ** argv) {
230
230
fprintf (stderr, " Input prefix: '%s'\n " , params.input_prefix .c_str ());
231
231
}
232
232
}
233
- fprintf (stderr, " sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " ,
234
- params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
233
+ fprintf (stderr, " sampling: repeat_last_n = %d, repeat_penalty = % f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n " ,
234
+ params.repeat_last_n , params.repeat_penalty , params. alpha_presence , params. alpha_frequency , params. top_k , params.tfs_z , params. top_p , params.typical_p , params.temp );
235
235
fprintf (stderr, " generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n " , n_ctx, params.n_batch , params.n_predict , params.n_keep );
236
236
fprintf (stderr, " \n\n " );
237
237
@@ -304,23 +304,69 @@ int main(int argc, char ** argv) {
304
304
305
305
if ((int ) embd_inp.size () <= n_consumed && !is_interacting) {
306
306
// out of user input, sample next token
307
- const int32_t top_k = params.top_k ;
308
- const float top_p = params.top_p ;
309
307
const float temp = params.temp ;
308
+ const int32_t top_k = params.top_k <= 0 ? llama_n_vocab (ctx) : params.top_k ;
309
+ const float top_p = params.top_p ;
310
+ const float tfs_z = params.tfs_z ;
311
+ const float typical_p = params.typical_p ;
312
+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n ;
310
313
const float repeat_penalty = params.repeat_penalty ;
314
+ const float alpha_presence = params.alpha_presence ;
315
+ const float alpha_frequency = params.alpha_frequency ;
311
316
312
317
llama_token id = 0 ;
313
318
314
319
{
315
320
auto logits = llama_get_logits (ctx);
321
+ auto n_vocab = llama_n_vocab (ctx);
316
322
317
323
if (params.ignore_eos ) {
318
- logits[llama_token_eos ()] = 0 ;
324
+ logits[llama_token_eos ()] = -INFINITY;
325
+ }
326
+
327
+ std::vector<llama_token_data> candidates;
328
+ candidates.reserve (n_vocab);
329
+ for (size_t i = 0 ; i < n_vocab; i++) {
330
+ candidates.emplace_back (i, logits[i], 0 .0f );
319
331
}
320
332
321
- id = llama_sample_top_p_top_k (ctx,
322
- last_n_tokens.data () + n_ctx - params.repeat_last_n ,
323
- params.repeat_last_n , top_k, top_p, temp, repeat_penalty);
333
+ llama_token_data_array candidates_p = { candidates.data (), candidates.size () };
334
+
335
+ // Apply penalties
336
+ auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), repeat_last_n), n_ctx);
337
+ llama_sample_repetition_penalty (&candidates_p,
338
+ last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
339
+ last_n_repeat, repeat_penalty);
340
+ llama_sample_frequency_and_presence_penalties (&candidates_p,
341
+ last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
342
+ last_n_repeat, alpha_frequency, alpha_presence);
343
+
344
+
345
+ #if 1
346
+ if (temp <= 0 ) {
347
+ // Greedy sampling
348
+ id = llama_sample_token_greedy (ctx, &candidates_p);
349
+ } else {
350
+ // Temperature sampling
351
+ llama_sample_top_k (&candidates_p, top_k);
352
+ llama_sample_tail_free (&candidates_p, tfs_z);
353
+ llama_sample_typical (&candidates_p, typical_p);
354
+ llama_sample_top_p (&candidates_p, top_p);
355
+
356
+ llama_sample_temperature (&candidates_p, temp);
357
+ // printf("`%d`", candidates_p.size);
358
+ id = llama_sample_token (ctx, &candidates_p);
359
+ }
360
+ #else
361
+ const float tau = 5.0f;
362
+ static float mu = 2.0f * tau;
363
+ static int k = 40;
364
+ const float eta = 0.1f;
365
+ const int m = 100;
366
+ const float N = n_vocab;
367
+ id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
368
+ // id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
369
+ #endif
324
370
325
371
last_n_tokens.erase (last_n_tokens.begin ());
326
372
last_n_tokens.push_back (id);
0 commit comments