Skip to content

Commit 9c78250

Browse files
Sample interface, new samplers.
New samplers: - locally typical sampling - tail free sampling - frequency and presence penalty - mirostat Ignore EOS fix: -inf should be used.
1 parent c50b628 commit 9c78250

File tree

7 files changed

+450
-131
lines changed

7 files changed

+450
-131
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
7575
# Compile flags
7676
#
7777

78-
set(CMAKE_CXX_STANDARD 11)
78+
set(CMAKE_CXX_STANDARD 20)
7979
set(CMAKE_CXX_STANDARD_REQUIRED true)
8080
set(CMAKE_C_STANDARD 11)
8181
set(CMAKE_C_STANDARD_REQUIRED true)

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ endif
3535

3636
# keep standard at C11 and C++11
3737
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
38-
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
38+
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++20 -fPIC
3939
LDFLAGS =
4040

4141
# warnings

examples/common.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
108108
break;
109109
}
110110
params.temp = std::stof(argv[i]);
111+
} else if (arg == "--tfs") {
112+
if (++i >= argc) {
113+
invalid_param = true;
114+
break;
115+
}
116+
params.tfs_z = std::stof(argv[i]);
117+
} else if (arg == "--typical") {
118+
if (++i >= argc) {
119+
invalid_param = true;
120+
break;
121+
}
122+
params.typical_p = std::stof(argv[i]);
111123
} else if (arg == "--repeat_last_n") {
112124
if (++i >= argc) {
113125
invalid_param = true;
@@ -120,6 +132,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
120132
break;
121133
}
122134
params.repeat_penalty = std::stof(argv[i]);
135+
} else if (arg == "--alpha_frequency") {
136+
if (++i >= argc) {
137+
invalid_param = true;
138+
break;
139+
}
140+
params.alpha_frequency = std::stof(argv[i]);
141+
} else if (arg == "--alpha_presence") {
142+
if (++i >= argc) {
143+
invalid_param = true;
144+
break;
145+
}
146+
params.alpha_presence = std::stof(argv[i]);
123147
} else if (arg == "-b" || arg == "--batch_size") {
124148
if (++i >= argc) {
125149
invalid_param = true;
@@ -237,6 +261,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
237261
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
238262
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
239263
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
264+
fprintf(stderr, " --tfs N tail free sampling (default: %.1f)\n", (double)params.tfs_z);
265+
fprintf(stderr, " --typical N locally typical sampling (default: %.1f)\n", (double)params.typical_p);
266+
fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %d)\n", params.alpha_presence);
267+
fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f)\n", (double)params.alpha_frequency);
240268
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
241269
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
242270
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);

examples/main/main.cpp

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ int main(int argc, char ** argv) {
230230
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
231231
}
232232
}
233-
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
234-
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
233+
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n",
234+
params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp);
235235
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
236236
fprintf(stderr, "\n\n");
237237

@@ -304,23 +304,69 @@ int main(int argc, char ** argv) {
304304

305305
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
306306
// out of user input, sample next token
307-
const int32_t top_k = params.top_k;
308-
const float top_p = params.top_p;
309307
const float temp = params.temp;
308+
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
309+
const float top_p = params.top_p;
310+
const float tfs_z = params.tfs_z;
311+
const float typical_p = params.typical_p;
312+
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
310313
const float repeat_penalty = params.repeat_penalty;
314+
const float alpha_presence = params.alpha_presence;
315+
const float alpha_frequency = params.alpha_frequency;
311316

312317
llama_token id = 0;
313318

314319
{
315320
auto logits = llama_get_logits(ctx);
321+
auto n_vocab = llama_n_vocab(ctx);
316322

317323
if (params.ignore_eos) {
318-
logits[llama_token_eos()] = 0;
324+
logits[llama_token_eos()] = -INFINITY;
325+
}
326+
327+
std::vector<llama_token_data> candidates;
328+
candidates.reserve(n_vocab);
329+
for (size_t i = 0; i < n_vocab; i++) {
330+
candidates.emplace_back(i, logits[i], 0.0f);
319331
}
320332

321-
id = llama_sample_top_p_top_k(ctx,
322-
last_n_tokens.data() + n_ctx - params.repeat_last_n,
323-
params.repeat_last_n, top_k, top_p, temp, repeat_penalty);
333+
llama_token_data_array candidates_p = { candidates.data(), candidates.size() };
334+
335+
// Apply penalties
336+
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
337+
llama_sample_repetition_penalty(&candidates_p,
338+
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
339+
last_n_repeat, repeat_penalty);
340+
llama_sample_frequency_and_presence_penalties(&candidates_p,
341+
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
342+
last_n_repeat, alpha_frequency, alpha_presence);
343+
344+
345+
#if 1
346+
if (temp <= 0) {
347+
// Greedy sampling
348+
id = llama_sample_token_greedy(ctx, &candidates_p);
349+
} else {
350+
// Temperature sampling
351+
llama_sample_top_k(&candidates_p, top_k);
352+
llama_sample_tail_free(&candidates_p, tfs_z);
353+
llama_sample_typical(&candidates_p, typical_p);
354+
llama_sample_top_p(&candidates_p, top_p);
355+
356+
llama_sample_temperature(&candidates_p, temp);
357+
// printf("`%d`", candidates_p.size);
358+
id = llama_sample_token(ctx, &candidates_p);
359+
}
360+
#else
361+
const float tau = 5.0f;
362+
static float mu = 2.0f * tau;
363+
static int k = 40;
364+
const float eta = 0.1f;
365+
const int m = 100;
366+
const float N = n_vocab;
367+
id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
368+
// id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
369+
#endif
324370

325371
last_n_tokens.erase(last_n_tokens.begin());
326372
last_n_tokens.push_back(id);

0 commit comments

Comments
 (0)