Skip to content

Commit 9630a50

Browse files
committed
sampling : update naming
ggml-ci
1 parent f115cba commit 9630a50

File tree

17 files changed

+54
-54
lines changed

17 files changed

+54
-54
lines changed

common/sampling.cpp

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,38 @@
22

33
#include "common.h"
44

5-
struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_params & params, const struct llama_model * model) {
5+
struct llama_sampling_context * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) {
66
struct llama_sampling_context * result = new llama_sampling_context();
77

88
result->params = params;
99

1010
{
11-
auto lp = llama_sampling_default_params();
12-
13-
lp.seed = params.seed;
14-
lp.n_prev = params.n_prev;
15-
lp.n_probs = params.n_probs;
16-
lp.min_keep = params.min_keep;
17-
lp.top_k = params.top_k;
18-
lp.top_p = params.top_p;
19-
lp.min_p = params.min_p;
20-
lp.tfs_z = params.tfs_z;
21-
lp.typical_p = params.typical_p;
22-
lp.temp = params.temp;
23-
lp.dynatemp_range = params.dynatemp_range;
24-
lp.dynatemp_exponent = params.dynatemp_exponent;
25-
lp.penalty_last_n = params.penalty_last_n;
26-
lp.penalty_repeat = params.penalty_repeat;
27-
lp.penalty_freq = params.penalty_freq;
28-
lp.penalty_present = params.penalty_present;
29-
lp.mirostat = params.mirostat;
30-
lp.mirostat_tau = params.mirostat_tau;
31-
lp.mirostat_eta = params.mirostat_eta;
32-
lp.penalize_nl = params.penalize_nl;
33-
lp.ignore_eos = params.ignore_eos;
34-
35-
result->smpl = llama_sampling_init(model, lp);
36-
37-
llama_sampling_set_rng_seed (result->smpl, params.seed);
11+
auto lparams = llama_sampling_default_params();
12+
13+
lparams.seed = params.seed;
14+
lparams.n_prev = params.n_prev;
15+
lparams.n_probs = params.n_probs;
16+
lparams.min_keep = params.min_keep;
17+
lparams.top_k = params.top_k;
18+
lparams.top_p = params.top_p;
19+
lparams.min_p = params.min_p;
20+
lparams.tfs_z = params.tfs_z;
21+
lparams.typical_p = params.typical_p;
22+
lparams.temp = params.temp;
23+
lparams.dynatemp_range = params.dynatemp_range;
24+
lparams.dynatemp_exponent = params.dynatemp_exponent;
25+
lparams.penalty_last_n = params.penalty_last_n;
26+
lparams.penalty_repeat = params.penalty_repeat;
27+
lparams.penalty_freq = params.penalty_freq;
28+
lparams.penalty_present = params.penalty_present;
29+
lparams.mirostat = params.mirostat;
30+
lparams.mirostat_tau = params.mirostat_tau;
31+
lparams.mirostat_eta = params.mirostat_eta;
32+
lparams.penalize_nl = params.penalize_nl;
33+
lparams.ignore_eos = params.ignore_eos;
34+
35+
result->smpl = llama_sampling_init(model, lparams);
36+
3837
llama_sampling_set_grammar (result->smpl, params.grammar.c_str(), "root");
3938
llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data());
4039
}
@@ -248,7 +247,7 @@ static llama_token llama_sampling_sample(
248247
} else {
249248
sampler_queue(ctx_sampling, cur_p);
250249

251-
id = llama_sampling_sample(smpl, cur_p);
250+
id = llama_sampling_sample_dist(smpl, cur_p);
252251

253252
//{
254253
// const int n_top = 10;

common/sampling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ struct llama_sampling_context {
6363
};
6464

6565
// Create a new sampling context instance.
66-
struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_params & params, const struct llama_model * model);
66+
struct llama_sampling_context * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params);
6767

6868
void llama_sampling_free(struct llama_sampling_context * ctx);
6969

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ int main(int argc, char ** argv) {
187187
llama_sampling_top_p(smpl, &candidates_p);
188188
llama_sampling_temp (smpl, &candidates_p);
189189

190-
const llama_token new_token_id = llama_sampling_sample(smpl, &candidates_p);
190+
const llama_token new_token_id = llama_sampling_sample_dist(smpl, &candidates_p);
191191

192192
//const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
193193

examples/infill/infill.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ int main(int argc, char ** argv) {
345345

346346
std::vector<llama_token> embd;
347347

348-
ctx_sampling = llama_sampling_init(sparams, model);
348+
ctx_sampling = llama_sampling_init(model, sparams);
349349

350350
while (n_remain != 0 || params.interactive) {
351351
// predict

examples/llava/llava-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
191191

192192
LOG_TEE("\n");
193193

194-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, ctx_llava->model);
194+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(ctx_llava->model, params->sparams);
195195
if (!ctx_sampling) {
196196
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
197197
exit(1);

examples/llava/minicpmv-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla
238238

239239
LOG_TEE("\n");
240240

241-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, ctx_llava->model);
241+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(ctx_llava->model, params->sparams);
242242
return ctx_sampling;
243243
}
244244

examples/lookahead/lookahead.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ int main(int argc, char ** argv) {
117117
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
118118

119119
// target model sampling context
120-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, model);
120+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(model, params.sparams);
121121

122122
// verification n-grams
123123
std::vector<ngram_data> ngrams_cur(G);

examples/lookup/lookup.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ int main(int argc, char ** argv){
104104

105105
bool has_eos = false;
106106

107-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, model);
107+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(model, params.sparams);
108108

109109
std::vector<llama_token> draft;
110110

examples/main/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ int main(int argc, char ** argv) {
494494
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
495495
}
496496

497-
ctx_sampling = llama_sampling_init(sparams, model);
497+
ctx_sampling = llama_sampling_init(model, sparams);
498498
if (!ctx_sampling) {
499499
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
500500
exit(1);

examples/parallel/parallel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
161161
for (size_t i = 0; i < clients.size(); ++i) {
162162
auto & client = clients[i];
163163
client.id = i;
164-
client.ctx_sampling = llama_sampling_init(params.sparams, model);
164+
client.ctx_sampling = llama_sampling_init(model, params.sparams);
165165
}
166166

167167
std::vector<llama_token> tokens_system;

examples/save-load-state/save-load-state.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ int main(int argc, char ** argv) {
7878
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
7979
}
8080
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
81-
auto next_token = llama_sampling_sample(smpl, &candidates_p);
81+
auto next_token = llama_sampling_sample_dist(smpl, &candidates_p);
8282
auto next_token_str = llama_token_to_piece(ctx, next_token);
8383

8484
printf("%s", next_token_str.c_str());
@@ -139,7 +139,7 @@ int main(int argc, char ** argv) {
139139
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
140140
}
141141
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
142-
auto next_token = llama_sampling_sample(smpl2, &candidates_p);
142+
auto next_token = llama_sampling_sample_dist(smpl2, &candidates_p);
143143
auto next_token_str = llama_token_to_piece(ctx2, next_token);
144144

145145
printf("%s", next_token_str.c_str());
@@ -232,7 +232,7 @@ int main(int argc, char ** argv) {
232232
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
233233
}
234234
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
235-
auto next_token = llama_sampling_sample(smpl3, &candidates_p);
235+
auto next_token = llama_sampling_sample_dist(smpl3, &candidates_p);
236236
auto next_token_str = llama_token_to_piece(ctx3, next_token);
237237

238238
printf("%s", next_token_str.c_str());

examples/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,7 @@ struct server_context {
10581058
llama_sampling_free(slot.ctx_sampling);
10591059
}
10601060

1061-
slot.ctx_sampling = llama_sampling_init(slot.sparams, model);
1061+
slot.ctx_sampling = llama_sampling_init(model, slot.sparams);
10621062
if (slot.ctx_sampling == nullptr) {
10631063
// for now, the only error that may happen here is invalid grammar
10641064
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);

examples/speculative/speculative.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ int main(int argc, char ** argv) {
176176
bool has_eos = false;
177177

178178
// target model sampling context (reuse the llama_context's sampling instance)
179-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, model_tgt);
179+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(model_tgt, params.sparams);
180180

181181
// draft sequence data
182182
std::vector<seq_draft> drafts(n_seq_dft);
@@ -187,7 +187,7 @@ int main(int argc, char ** argv) {
187187

188188
for (int s = 0; s < n_seq_dft; ++s) {
189189
// allocate llama_sampling for each draft sequence
190-
drafts[s].ctx_sampling = llama_sampling_init(params.sparams, model_dft);
190+
drafts[s].ctx_sampling = llama_sampling_init(model_dft, params.sparams);
191191
}
192192

193193
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
@@ -334,7 +334,7 @@ int main(int argc, char ** argv) {
334334
// all drafted tokens were rejected
335335
// sample from the target model
336336
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
337-
token_id = llama_sampling_sample(ctx_sampling->smpl, &dist_tgt);
337+
token_id = llama_sampling_sample_dist(ctx_sampling->smpl, &dist_tgt);
338338
llama_sampling_accept(ctx_sampling, token_id, true);
339339
token_str = llama_token_to_piece(ctx_tgt, token_id);
340340
}

include/llama.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,8 +1086,8 @@ extern "C" {
10861086
struct llama_sampling * smpl,
10871087
llama_token_data_array * candidates);
10881088

1089-
/// @details Randomly selects a token from the candidates based on their probabilities
1090-
LLAMA_API llama_token llama_sampling_sample(
1089+
/// @details Randomly selects a token from the candidates based on their probability distribution.
1090+
LLAMA_API llama_token llama_sampling_sample_dist(
10911091
struct llama_sampling * smpl,
10921092
llama_token_data_array * candidates);
10931093

src/llama-sampling.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ llama_token llama_sampling_sample_mirostat_impl(struct llama_token_data_array *
532532

533533
// Sample the next word X using top-k sampling
534534
llama_sampling_top_k_impl(candidates, int(k), 1);
535-
llama_token X = llama_sampling_sample_impl(candidates, rng);
535+
llama_token X = llama_sampling_sample_dist_impl(candidates, rng);
536536

537537
// Compute error as the difference between observed surprise and target surprise value
538538
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
@@ -563,7 +563,7 @@ llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array
563563
llama_sampling_softmax_impl(candidates);
564564

565565
// Sample the next word X from the remaining words
566-
llama_token X = llama_sampling_sample_impl(candidates, rng);
566+
llama_token X = llama_sampling_sample_dist_impl(candidates, rng);
567567

568568
// Compute error as the difference between observed surprise and target surprise value
569569
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
@@ -589,18 +589,19 @@ llama_token llama_sampling_sample_greedy_impl(llama_token_data_array * candidate
589589
return result;
590590
}
591591

592-
llama_token llama_sampling_sample_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) {
592+
llama_token llama_sampling_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) {
593593
llama_sampling_softmax_impl(candidates);
594594

595595
std::vector<float> probs;
596596
probs.reserve(candidates->size);
597+
597598
for (size_t i = 0; i < candidates->size; ++i) {
598599
probs.push_back(candidates->data[i].p);
599600
}
600601

601602
std::discrete_distribution<> dist(probs.begin(), probs.end());
602-
int idx = dist(rng);
603603

604+
const int idx = dist(rng);
604605
llama_token result = candidates->data[idx].id;
605606

606607
return result;

src/llama-sampling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ llama_token llama_sampling_sample_mirostat_impl (struct llama_token_data_array
9797
llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu);
9898

9999
llama_token llama_sampling_sample_greedy_impl (struct llama_token_data_array * candidates);
100-
llama_token llama_sampling_sample_impl (struct llama_token_data_array * candidates, std::mt19937 & rng);
100+
llama_token llama_sampling_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng);
101101

102102
void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar);
103103

src/llama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20260,10 +20260,10 @@ llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_tok
2026020260
return res;
2026120261
}
2026220262

20263-
llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) {
20263+
llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2026420264
time_meas tm(smpl->t_sample_us);
2026520265

20266-
auto res = llama_sampling_sample_impl(candidates, smpl->rng);
20266+
auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng);
2026720267

2026820268
smpl->n_sample++;
2026920269

0 commit comments

Comments
 (0)