Skip to content

Commit 55e4778

Browse files
authored
llama : default sampling changes + greedy update (#9897)
* llama : deprecate softmax sampler + fix dist sampler ggml-ci * tests : replace macros with functions ggml-ci * sampling : change temperature sampler logic For t <= 0.0f, keep the max logit intact and set the rest to -inf * cont : no need for special "greedy" logic top-k == 1 is the same * tests : init prob correctly * llama : handle temp <= 0.0 in the temp_ext sampler too ggml-ci * cont : avoid extra loop in temperature sampler for sub-zero temp ggml-ci
1 parent bc21975 commit 55e4778

File tree

7 files changed

+201
-217
lines changed

7 files changed

+201
-217
lines changed

common/sampling.cpp

Lines changed: 37 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -171,60 +171,46 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
171171
params.penalize_nl,
172172
params.ignore_eos));
173173

174-
if (params.temp > 0.0f) {
175-
if (params.mirostat == 0) {
176-
for (const auto & cnstr : params.samplers) {
177-
switch (cnstr) {
178-
case COMMON_SAMPLER_TYPE_TOP_K:
179-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
180-
break;
181-
case COMMON_SAMPLER_TYPE_TOP_P:
182-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
183-
break;
184-
case COMMON_SAMPLER_TYPE_MIN_P:
185-
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
186-
break;
187-
case COMMON_SAMPLER_TYPE_XTC:
188-
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
189-
break;
190-
case COMMON_SAMPLER_TYPE_TFS_Z:
191-
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
192-
break;
193-
case COMMON_SAMPLER_TYPE_TYPICAL_P:
194-
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
195-
break;
196-
case COMMON_SAMPLER_TYPE_TEMPERATURE:
197-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
198-
break;
199-
case COMMON_SAMPLER_TYPE_INFILL:
200-
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
201-
break;
202-
default:
203-
GGML_ASSERT(false && "unknown sampler type");
204-
}
174+
if (params.mirostat == 0) {
175+
for (const auto & cnstr : params.samplers) {
176+
switch (cnstr) {
177+
case COMMON_SAMPLER_TYPE_TOP_K:
178+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
179+
break;
180+
case COMMON_SAMPLER_TYPE_TOP_P:
181+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
182+
break;
183+
case COMMON_SAMPLER_TYPE_MIN_P:
184+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
185+
break;
186+
case COMMON_SAMPLER_TYPE_XTC:
187+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
188+
break;
189+
case COMMON_SAMPLER_TYPE_TFS_Z:
190+
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
191+
break;
192+
case COMMON_SAMPLER_TYPE_TYPICAL_P:
193+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
194+
break;
195+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
196+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
197+
break;
198+
case COMMON_SAMPLER_TYPE_INFILL:
199+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
200+
break;
201+
default:
202+
GGML_ASSERT(false && "unknown sampler type");
205203
}
206-
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
207-
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
208-
} else if (params.mirostat == 1) {
209-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
210-
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
211-
} else if (params.mirostat == 2) {
212-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
213-
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
214-
} else {
215-
GGML_ASSERT(false && "unknown mirostat version");
216204
}
205+
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
206+
} else if (params.mirostat == 1) {
207+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
208+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
209+
} else if (params.mirostat == 2) {
210+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
211+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
217212
} else {
218-
if (params.n_probs > 0) {
219-
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
220-
// ref: https://github.com/ggerganov/llama.cpp/pull/9605
221-
//
222-
// the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
223-
// it is much faster, since we avoid sorting all tokens and should give a good approximation
224-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
225-
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
226-
}
227-
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
213+
GGML_ASSERT(false && "unknown mirostat version");
228214
}
229215

230216
return result;

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ actor LlamaContext {
4646
let sparams = llama_sampler_chain_default_params()
4747
self.sampling = llama_sampler_chain_init(sparams)
4848
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
49-
llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
5049
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
5150
}
5251

examples/save-load-state/save-load-state.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ int main(int argc, char ** argv) {
4242

4343
llama_sampler * smpl = llama_sampler_chain_init(sparams);
4444

45-
llama_sampler_chain_add(smpl, llama_sampler_init_softmax());
4645
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed));
4746

4847
// tokenize prompt
@@ -107,7 +106,6 @@ int main(int argc, char ** argv) {
107106

108107
llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
109108

110-
llama_sampler_chain_add(smpl2, llama_sampler_init_softmax());
111109
llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed));
112110

113111
printf("\nsecond run: %s", params.prompt.c_str());
@@ -171,7 +169,6 @@ int main(int argc, char ** argv) {
171169

172170
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
173171

174-
llama_sampler_chain_add(smpl3, llama_sampler_init_softmax());
175172
llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed));
176173

177174
printf("\nsingle seq run: %s", params.prompt.c_str());

examples/speculative/speculative.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,6 @@ int main(int argc, char ** argv) {
185185
// target model sampling context (reuse the llama_context's sampling instance)
186186
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
187187

188-
struct llama_sampler * softmax = llama_sampler_init_softmax();
189-
190188
// draft sequence data
191189
std::vector<seq_draft> drafts(n_seq_dft);
192190

@@ -629,7 +627,6 @@ int main(int argc, char ** argv) {
629627
common_sampler_free(drafts[s].smpl);
630628
}
631629

632-
llama_sampler_free(softmax);
633630
llama_batch_free(batch_dft);
634631

635632
llama_free(ctx_tgt);

include/llama.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ extern "C" {
217217

218218
typedef struct llama_token_data_array {
219219
// TODO: consider SoA
220+
// NOTE: this pointer can be modified by the samplers
220221
llama_token_data * data;
221222
size_t size;
222223
int64_t selected; // this is the index in the data array (i.e. not the token id)
@@ -1069,12 +1070,13 @@ extern "C" {
10691070

10701071
// available samplers:
10711072

1072-
LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void);
1073-
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
1073+
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
1074+
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
10741075

10751076
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
10761077
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
1077-
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
1078+
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
1079+
"will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
10781080

10791081
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
10801082
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
@@ -1090,6 +1092,8 @@ extern "C" {
10901092

10911093
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
10921094
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
1095+
1096+
/// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
10931097
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
10941098

10951099
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.

src/llama-sampling.cpp

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,30 @@ static void llama_log_softmax(float * array, size_t size) {
6363
}
6464
*/
6565

66+
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
67+
if (temp <= 0.0f) {
68+
// find the token with the highest logit and set the rest to -inf
69+
size_t max_i = 0;
70+
float max_l = cur_p->data[0].logit;
71+
72+
for (size_t i = 1; i < cur_p->size; ++i) {
73+
if (cur_p->data[i ].logit > max_l) {
74+
cur_p->data[max_i].logit = -INFINITY;
75+
max_i = i;
76+
max_l = cur_p->data[i].logit;
77+
} else {
78+
cur_p->data[i].logit = -INFINITY;
79+
}
80+
}
81+
82+
return;
83+
}
84+
85+
for (size_t i = 0; i < cur_p->size; ++i) {
86+
cur_p->data[i].logit /= temp;
87+
}
88+
}
89+
6690
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
6791
GGML_ASSERT(cur_p->size > 0);
6892

@@ -427,6 +451,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
427451

428452
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
429453
auto * ctx = (llama_sampler_dist *) smpl->ctx;
454+
455+
llama_sampler_softmax_impl(cur_p);
456+
430457
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
431458
}
432459

@@ -912,9 +939,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
912939

913940
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
914941
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
915-
for (size_t i = 0; i < cur_p->size; ++i) {
916-
cur_p->data[i].logit /= ctx->temp;
917-
}
942+
943+
llama_sampler_temp_impl(cur_p, ctx->temp);
918944
}
919945

920946
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
@@ -961,6 +987,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
961987
if (ctx->delta > 0) {
962988
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
963989
const float max_temp = ctx->temp + ctx->delta;
990+
964991
float exponent_val = ctx->exponent;
965992

966993
// no need to do anything if there is only one (or zero) candidates
@@ -998,9 +1025,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
9981025
#endif
9991026

10001027
// Apply the dynamically calculated temperature scaling
1001-
for (size_t i = 0; i < cur_p->size; ++i) {
1002-
cur_p->data[i].logit /= dyn_temp;
1003-
}
1028+
llama_sampler_temp_impl(cur_p, dyn_temp);
10041029

10051030
// Re-compute softmax probabilities after scaling logits with dynamic temperature
10061031
const double max_l_double = cur_p->data[0].logit;
@@ -1024,9 +1049,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
10241049
}
10251050
#endif
10261051
} else {
1027-
for (size_t i = 0; i < cur_p->size; ++i) {
1028-
cur_p->data[i].logit /= ctx->temp;
1029-
}
1052+
llama_sampler_temp_impl(cur_p, ctx->temp);
10301053
}
10311054
}
10321055

0 commit comments

Comments
 (0)