Skip to content

Commit f115cba

Browse files
committed
sampling : drop cfg + simplify more
ggml-ci
1 parent 6e49744 commit f115cba

File tree

16 files changed

+90
-282
lines changed

16 files changed

+90
-282
lines changed

common/common.cpp

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
310310
string_process_escapes(params.prompt);
311311
string_process_escapes(params.input_prefix);
312312
string_process_escapes(params.input_suffix);
313-
string_process_escapes(sparams.cfg_negative_prompt);
314313
for (auto & antiprompt : params.antiprompt) {
315314
string_process_escapes(antiprompt);
316315
}
@@ -321,8 +320,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
321320
params.kv_overrides.back().key[0] = 0;
322321
}
323322

324-
if (params.sparams.seed == LLAMA_DEFAULT_SEED) {
325-
params.sparams.seed = time(NULL);
323+
if (sparams.seed == LLAMA_DEFAULT_SEED) {
324+
sparams.seed = time(NULL);
326325
}
327326

328327
return true;
@@ -665,30 +664,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
665664
sparams.mirostat_tau = std::stof(argv[i]);
666665
return true;
667666
}
668-
if (arg == "--cfg-negative-prompt") {
669-
CHECK_ARG
670-
sparams.cfg_negative_prompt = argv[i];
671-
return true;
672-
}
673-
if (arg == "--cfg-negative-prompt-file") {
674-
CHECK_ARG
675-
std::ifstream file(argv[i]);
676-
if (!file) {
677-
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
678-
invalid_param = true;
679-
return true;
680-
}
681-
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(sparams.cfg_negative_prompt));
682-
if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') {
683-
sparams.cfg_negative_prompt.pop_back();
684-
}
685-
return true;
686-
}
687-
if (arg == "--cfg-scale") {
688-
CHECK_ARG
689-
sparams.cfg_scale = std::stof(argv[i]);
690-
return true;
691-
}
692667
if (arg == "-b" || arg == "--batch-size") {
693668
CHECK_ARG
694669
params.n_batch = std::stoi(argv[i]);
@@ -1577,11 +1552,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
15771552
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
15781553
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
15791554
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
1580-
options.push_back({ "main", " --cfg-negative-prompt PROMPT",
1581-
"negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() });
1582-
options.push_back({ "main", " --cfg-negative-prompt-file FNAME",
1583-
"negative prompt file to use for guidance" });
1584-
options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
15851555
options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
15861556
"set custom jinja chat template (default: template taken from model's metadata)\n"
15871557
"if suffix/prefix are specified, template will be disabled\n"
@@ -3258,8 +3228,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
32583228

32593229
fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str());
32603230
fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch);
3261-
yaml_dump_string_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str());
3262-
fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale);
32633231
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
32643232
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
32653233
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);

common/sampling.cpp

Lines changed: 6 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa
2929
lp.mirostat = params.mirostat;
3030
lp.mirostat_tau = params.mirostat_tau;
3131
lp.mirostat_eta = params.mirostat_eta;
32-
lp.cfg_scale = params.cfg_scale;
3332
lp.penalize_nl = params.penalize_nl;
3433
lp.ignore_eos = params.ignore_eos;
3534

@@ -51,9 +50,6 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
5150

5251
void llama_sampling_reset(llama_sampling_context * ctx) {
5352
llama_sampling_reset(ctx->smpl);
54-
55-
ctx->cur.clear();
56-
ctx->org.clear();
5753
}
5854

5955
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
@@ -219,61 +215,11 @@ static void sampler_queue(
219215
}
220216
}
221217

222-
llama_token_data_array llama_sampling_prepare(
218+
void llama_sampling_prepare(
223219
struct llama_sampling_context * ctx_sampling,
224220
struct llama_context * ctx_main,
225-
struct llama_context * ctx_cfg,
226221
int idx) {
227-
const gpt_sampling_params & params = ctx_sampling->params;
228-
229-
auto & cur = ctx_sampling->cur;
230-
231-
// Get a pointer to the logits
232-
float * logits = llama_get_logits_ith(ctx_main, idx);
233-
234-
// apply params.logit_bias map
235-
for (const auto & logit_bias : params.logit_bias) {
236-
logits[logit_bias.token] += logit_bias.bias;
237-
}
238-
239-
if (params.ignore_eos) {
240-
logits[llama_token_eos(llama_get_model(ctx_main))] = -INFINITY;
241-
}
242-
243-
llama_sampling * smpl = ctx_sampling->smpl;
244-
245-
if (ctx_cfg) {
246-
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
247-
llama_sampling_cfg(smpl, logits, logits_guidance);
248-
}
249-
250-
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
251-
252-
cur.resize(n_vocab);
253-
254-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
255-
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
256-
}
257-
258-
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
259-
260-
// apply penalties
261-
{
262-
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
263-
264-
llama_sampling_penalties(smpl, &cur_p);
265-
266-
if (!params.penalize_nl) {
267-
for (size_t idx = 0; idx < cur_p.size; idx++) {
268-
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
269-
cur_p.data[idx].logit = nl_logit;
270-
break;
271-
}
272-
}
273-
}
274-
}
275-
276-
return cur_p;
222+
llama_sampling_set_logits(ctx_sampling->smpl, llama_get_logits_ith(ctx_main, idx));
277223
}
278224

279225
static llama_token llama_sampling_sample(
@@ -325,41 +271,14 @@ static llama_token llama_sampling_sample(
325271
llama_token llama_sampling_sample(
326272
struct llama_sampling_context * ctx_sampling,
327273
struct llama_context * ctx_main,
328-
struct llama_context * ctx_cfg,
329274
int idx) {
330-
llama_token_data_array cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx);
331-
332-
if (ctx_sampling->params.grammar.empty()) {
333-
return llama_sampling_sample(ctx_sampling, &cur_p);
334-
}
275+
llama_sampling_prepare(ctx_sampling, ctx_main, idx);
335276

336-
// TODO: this logic is confusing, try to figure out a better way to handle this
277+
auto * cur_p = llama_sampling_get_candidates(ctx_sampling->smpl);
337278

338-
// store the original candidates
339-
ctx_sampling->org = ctx_sampling->cur;
340-
llama_token_data_array org_p = { ctx_sampling->org.data(), ctx_sampling->org.size(), false };
279+
llama_sampling_grammar(ctx_sampling->smpl, cur_p);
341280

342-
llama_token id = llama_sampling_sample(ctx_sampling, &cur_p);
343-
344-
// Create an array with a single token data element for the sampled id
345-
llama_token_data single_token_data = { id, 1.0f, 0.0f };
346-
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
347-
348-
// Apply grammar constraints to the single token
349-
llama_sampling_grammar(ctx_sampling->smpl, &single_token_data_array);
350-
351-
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
352-
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
353-
354-
if (!is_valid) {
355-
llama_sampling_grammar(ctx_sampling->smpl, &org_p);
356-
357-
id = llama_sampling_sample(ctx_sampling, &org_p);
358-
359-
ctx_sampling->cur = std::move(ctx_sampling->org);
360-
}
361-
362-
return id;
281+
return llama_sampling_sample(ctx_sampling, cur_p);
363282
}
364283

365284
void llama_sampling_accept(

common/sampling.h

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,6 @@ typedef struct gpt_sampling_params {
5050

5151
std::string grammar; // optional BNF-like grammar to constrain sampling
5252

53-
// Classifier-Free Guidance
54-
// https://arxiv.org/abs/2306.17806
55-
std::string cfg_negative_prompt; // string to help guidance
56-
float cfg_scale = 1.f; // how strong is guidance
57-
5853
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
5954
} gpt_sampling_params;
6055

@@ -65,9 +60,6 @@ struct llama_sampling_context {
6560
gpt_sampling_params params;
6661

6762
llama_sampling * smpl;
68-
69-
std::vector<llama_token_data> cur;
70-
std::vector<llama_token_data> org;
7163
};
7264

7365
// Create a new sampling context instance.
@@ -101,11 +93,10 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
10193
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);
10294

10395
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
104-
llama_token_data_array llama_sampling_prepare(
96+
void llama_sampling_prepare(
10597
struct llama_sampling_context * ctx_sampling,
10698
struct llama_context * ctx_main,
107-
struct llama_context * ctx_cfg,
108-
int idx = 0);
99+
int idx);
109100

110101
// this is a common sampling function used across the examples for convenience
111102
// it can serve as a starting point for implementing your own sampling function
@@ -117,7 +108,6 @@ llama_token_data_array llama_sampling_prepare(
117108
// - ctx_sampling: sampling-specific context
118109
//
119110
// optional:
120-
// - ctx_cfg: context to use for classifier-free guidance
121111
// - idx: sample from llama_get_logits_ith(ctx, idx)
122112
//
123113
// returns:
@@ -131,7 +121,6 @@ llama_token_data_array llama_sampling_prepare(
131121
llama_token llama_sampling_sample(
132122
struct llama_sampling_context * ctx_sampling,
133123
struct llama_context * ctx_main,
134-
struct llama_context * ctx_cfg,
135124
int idx = -1);
136125

137126
void llama_sampling_accept(

examples/infill/infill.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ int main(int argc, char ** argv) {
417417
embd.clear();
418418

419419
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
420-
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr);
420+
const llama_token id = llama_sampling_sample(ctx_sampling, ctx);
421421

422422
llama_sampling_accept(ctx_sampling, id, true);
423423

examples/llava/llava-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
4343
static const char * sample(struct llama_sampling_context * ctx_sampling,
4444
struct llama_context * ctx_llama,
4545
int * n_past) {
46-
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
46+
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama);
4747
llama_sampling_accept(ctx_sampling, id, true);
4848
static std::string ret;
4949
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {

examples/llava/minicpmv-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
166166
static const char * sample(struct llama_sampling_context * ctx_sampling,
167167
struct llama_context * ctx_llama,
168168
int * n_past) {
169-
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
169+
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama);
170170
llama_sampling_accept(ctx_sampling, id, true);
171171
static std::string ret;
172172
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {

examples/lookahead/lookahead.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ int main(int argc, char ** argv) {
158158

159159
// sample first token
160160
{
161-
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
161+
id = llama_sampling_sample(ctx_sampling, ctx, 0);
162162

163163
llama_sampling_accept(ctx_sampling, id, true);
164164

@@ -283,7 +283,7 @@ int main(int argc, char ** argv) {
283283
}
284284

285285
// sample the next token
286-
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
286+
id = llama_sampling_sample(ctx_sampling, ctx, i_batch);
287287

288288
llama_sampling_accept(ctx_sampling, id, true);
289289

@@ -360,7 +360,7 @@ int main(int argc, char ** argv) {
360360
if (v == 0) {
361361
// sample from the last level
362362
for (int i = 0; i < W; i++) {
363-
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
363+
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
364364
}
365365
} else {
366366
for (int i = 0; i < W; i++) {

examples/lookup/lookup.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ int main(int argc, char ** argv){
128128
int i_dft = 0;
129129
while (true) {
130130
// sample from the target model
131-
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
131+
llama_token id = llama_sampling_sample(ctx_sampling, ctx, i_dft);
132132

133133
llama_sampling_accept(ctx_sampling, id, true);
134134

0 commit comments

Comments
 (0)