Skip to content

Commit 11c2e46

Browse files
committed
sampling : improve mirostat implementation
ggml-ci
1 parent 39fe5a3 commit 11c2e46

File tree

7 files changed

+95
-88
lines changed

7 files changed

+95
-88
lines changed

common/sampling.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ struct gpt_sampler {
121121
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
122122
}
123123

124-
cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false };
124+
cur_p = { cur.data(), cur.size(), -1, false };
125125
}
126126
};
127127

@@ -202,17 +202,17 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
202202
GGML_ASSERT(false && "unknown sampler type");
203203
}
204204
}
205+
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
206+
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
205207
} else if (params.mirostat == 1) {
206208
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
207-
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
209+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta));
208210
} else if (params.mirostat == 2) {
209211
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
210-
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
212+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
211213
} else {
212214
GGML_ASSERT(false && "unknown mirostat version");
213215
}
214-
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
215-
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
216216
} else {
217217
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
218218
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
@@ -246,8 +246,8 @@ struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
246246
};
247247
}
248248

249-
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar) {
250-
if (apply_grammar) {
249+
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) {
250+
if (accept_grammar) {
251251
llama_sampler_accept(gsmpl->grmr, token);
252252
}
253253

@@ -293,9 +293,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
293293

294294
llama_sampler_apply(chain, &cur_p);
295295

296-
const llama_token id = cur_p.data[cur_p.selected].id;
296+
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
297297

298-
GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - check your sampling configuration");
298+
const llama_token id = cur_p.data[cur_p.selected].id;
299299

300300
if (grammar_first) {
301301
return id;
@@ -304,7 +304,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
304304
// check if it the sampled token fits the grammar
305305
{
306306
llama_token_data single_token_data = { id, 1.0f, 0.0f };
307-
llama_token_data_array single_token_data_array = { &single_token_data, 1, LLAMA_TOKEN_NULL, false };
307+
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
308308

309309
llama_sampler_apply(grmr, &single_token_data_array);
310310

@@ -324,7 +324,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
324324

325325
llama_sampler_apply(chain, &cur_p);
326326

327-
GGML_ASSERT(cur_p.data[cur_p.selected].id != LLAMA_TOKEN_NULL && "null token in the sampling history - check your sampling configuration");
327+
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
328328

329329
return cur_p.data[cur_p.selected].id;
330330
}

common/sampling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl);
7070

7171
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl);
7272

73-
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar);
73+
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
7474
void gpt_sampler_reset (struct gpt_sampler * gsmpl);
7575

7676
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);

include/llama.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,7 @@ extern "C" {
10661066
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
10671067
LLAMA_API struct llama_sampler * llama_sampler_init_mirostat(
10681068
const struct llama_model * model,
1069+
uint32_t seed,
10691070
float tau,
10701071
float eta);
10711072

@@ -1075,6 +1076,7 @@ extern "C" {
10751076
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
10761077
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
10771078
LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2(
1079+
uint32_t seed,
10781080
float tau,
10791081
float eta);
10801082

src/llama-sampling.cpp

Lines changed: 68 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,17 @@
1111
#include <random>
1212
#include <unordered_map>
1313

14+
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float> & probs) {
15+
probs.resize(cur_p->size);
16+
for (size_t i = 0; i < cur_p->size; ++i) {
17+
probs[i] = cur_p->data[i].p;
18+
}
19+
20+
std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
21+
22+
return dist(rng);
23+
}
24+
1425
static void llama_log_softmax(float * array, size_t size) {
1526
float max_l = *std::max_element(array, array + size);
1627
float sum = 0.f;
@@ -456,22 +467,16 @@ struct llama_sampler_context_dist {
456467
const uint32_t seed;
457468

458469
std::mt19937 rng;
470+
471+
std::vector<float> probs; // work array
459472
};
460473

461474
static struct llama_sampler_i llama_sampler_dist_i = {
462475
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "dist"; },
463476
/* .accept = */ nullptr,
464477
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
465478
auto * ctx = (llama_sampler_context_dist *) smpl->ctx;
466-
std::vector<float> probs;
467-
probs.reserve(cur_p->size);
468-
for (size_t i = 0; i < cur_p->size; ++i) {
469-
probs.push_back(cur_p->data[i].p);
470-
}
471-
472-
std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
473-
474-
cur_p->selected = dist(ctx->rng);
479+
cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
475480
},
476481
/* .reset = */ nullptr,
477482
/* .clone = */ [](const struct llama_sampler * smpl) {
@@ -489,6 +494,7 @@ struct llama_sampler * llama_sampler_init_dist_impl(uint32_t seed) {
489494
/* .ctx = */ new llama_sampler_context_dist {
490495
/* .seed = */ seed,
491496
/* .rng = */ std::mt19937(seed),
497+
/* .probs = */ {},
492498
},
493499
};
494500
}
@@ -761,35 +767,23 @@ struct llama_sampler * llama_sampler_init_temp_ext_impl(float temp, float delta,
761767
struct llama_sampler_context_mirostat {
762768
const struct llama_vocab * vocab;
763769

770+
const uint32_t seed;
771+
764772
const float tau;
765773
const float eta;
766774

767775
const int32_t m;
768776

769777
float mu;
770778

771-
std::vector<llama_token_data> cur;
779+
std::mt19937 rng;
780+
781+
std::vector<float> probs;
772782
};
773783

774784
static struct llama_sampler_i llama_sampler_mirostat_i = {
775785
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat"; },
776-
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
777-
auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx;
778-
779-
int32_t idx = -1;
780-
for (size_t i = 0; i < ctx->cur.size(); ++i) {
781-
if (ctx->cur[i].id == token) {
782-
idx = i;
783-
break;
784-
}
785-
}
786-
787-
float observed_surprise = -log2f(ctx->cur[idx].p);
788-
float e = observed_surprise - ctx->tau;
789-
790-
// Update mu using the learning rate and error
791-
ctx->mu = ctx->mu - ctx->eta * e;
792-
},
786+
/* .accept = */ nullptr,
793787
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
794788
auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx;
795789

@@ -812,70 +806,66 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
812806
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat);
813807

814808
llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
809+
llama_sampler_softmax_impl(cur_p);
815810

816-
// remember the order to be able to compute the distance later when accepting the token
817-
ctx->cur.resize(cur_p->size);
818-
for (size_t i = 0; i < cur_p->size; ++i) {
819-
ctx->cur[i] = cur_p->data[i];
820-
}
811+
const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
812+
813+
cur_p->selected = idx;
814+
815+
float observed_surprise = -log2f(cur_p->data[idx].p);
816+
float e = observed_surprise - ctx->tau;
817+
818+
// Update mu using the learning rate and error
819+
ctx->mu = ctx->mu - ctx->eta * e;
821820
},
822821
/* .reset = */ [](struct llama_sampler * smpl) {
823822
auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx;
824823
ctx->mu = 2.0f*ctx->tau;
824+
ctx->rng = std::mt19937(ctx->seed);
825825
},
826826
/* .clone = */ [](const struct llama_sampler * smpl) {
827827
const auto * ctx = (const llama_sampler_context_mirostat *) smpl->ctx;
828-
return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->tau, ctx->eta, ctx->m);
828+
return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
829829
},
830830
/* .free = */ [](struct llama_sampler * smpl) {
831831
delete (llama_sampler_context_mirostat *) smpl->ctx;
832832
},
833833
};
834834

835-
struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, float tau, float eta, int32_t m) {
835+
struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, uint32_t seed, float tau, float eta, int32_t m) {
836836
return new llama_sampler {
837837
/* .iface = */ &llama_sampler_mirostat_i,
838838
/* .ctx = */ new llama_sampler_context_mirostat {
839839
/* .vocab = */ &vocab,
840+
/* .seed = */ seed,
840841
/* .tau = */ tau,
841842
/* .eta = */ eta,
842843
/* .m = */ m,
843844
/* .mu = */ 2.0f*tau,
844-
/* .cur = */ {},
845+
/* .rng = */ std::mt19937(seed),
846+
/* .probs = */ {},
845847
},
846848
};
847849
}
848850

849851
// mirostat v2
850852

851853
struct llama_sampler_context_mirostat_v2 {
854+
const uint32_t seed;
855+
852856
const float tau;
853857
const float eta;
854858

855859
float mu;
856860

857-
std::vector<llama_token_data> cur;
861+
std::mt19937 rng;
862+
863+
std::vector<float> probs;
858864
};
859865

860866
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
861867
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat-v2"; },
862-
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
863-
auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx;
864-
865-
int32_t idx = -1;
866-
for (size_t i = 0; i < ctx->cur.size(); ++i) {
867-
if (ctx->cur[i].id == token) {
868-
idx = i;
869-
break;
870-
}
871-
}
872-
873-
float observed_surprise = -log2f(ctx->cur[idx].p);
874-
float e = observed_surprise - ctx->tau;
875-
876-
// Update mu using the learning rate and error
877-
ctx->mu = ctx->mu - ctx->eta * e;
878-
},
868+
/* .accept = */ nullptr,
879869
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
880870
auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx;
881871

@@ -893,33 +883,40 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
893883
// Normalize the probabilities of the remaining words
894884
llama_sampler_softmax_impl(cur_p);
895885

896-
// remember the order to be able to compute the distance later when accepting the token
897-
ctx->cur.resize(cur_p->size);
898-
for (size_t i = 0; i < cur_p->size; ++i) {
899-
ctx->cur[i] = cur_p->data[i];
900-
}
886+
const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
887+
888+
cur_p->selected = idx;
889+
890+
float observed_surprise = -log2f(cur_p->data[idx].p);
891+
float e = observed_surprise - ctx->tau;
892+
893+
// Update mu using the learning rate and error
894+
ctx->mu = ctx->mu - ctx->eta * e;
901895
},
902896
/* .reset = */ [](struct llama_sampler * smpl) {
903897
auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx;
904898
ctx->mu = 2.0f*ctx->tau;
899+
ctx->rng = std::mt19937(ctx->seed);
905900
},
906901
/* .clone = */ [](const struct llama_sampler * smpl) {
907902
const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx;
908-
return llama_sampler_init_mirostat_v2_impl(ctx->tau, ctx->eta);
903+
return llama_sampler_init_mirostat_v2_impl(ctx->seed, ctx->tau, ctx->eta);
909904
},
910905
/* .free = */ [](struct llama_sampler * smpl) {
911906
delete (llama_sampler_context_mirostat_v2 *) smpl->ctx;
912907
},
913908
};
914909

915-
struct llama_sampler * llama_sampler_init_mirostat_v2_impl(float tau, float eta) {
910+
struct llama_sampler * llama_sampler_init_mirostat_v2_impl(uint32_t seed, float tau, float eta) {
916911
return new llama_sampler {
917912
/* .iface = */ &llama_sampler_mirostat_v2_i,
918913
/* .ctx = */ new llama_sampler_context_mirostat_v2 {
919-
/* .tau = */ tau,
920-
/* .eta = */ eta,
921-
/* .mu = */ 2.0f*tau,
922-
/* .cur = */ {},
914+
/* .seed = */ seed,
915+
/* .tau = */ tau,
916+
/* .eta = */ eta,
917+
/* .mu = */ 2.0f*tau,
918+
/* .rng = */ std::mt19937(seed),
919+
/* .probs = */ {},
923920
},
924921
};
925922
}
@@ -1154,9 +1151,15 @@ struct llama_sampler * llama_sampler_init_logit_bias_impl(
11541151

11551152
static struct llama_sampler_i llama_sampler_chain_i = {
11561153
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
1157-
/* .accept = */ [](struct llama_sampler * smpl, llama_token /*token*/) {
1154+
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
11581155
auto * chain = (llama_sampler_chain *) smpl->ctx;
11591156

1157+
time_meas tm(chain->t_sample_us, chain->params.no_timing);
1158+
1159+
for (auto * smpl : chain->samplers) {
1160+
llama_sampler_accept_impl(*smpl, token);
1161+
}
1162+
11601163
chain->n_sample++;
11611164
},
11621165
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {

src/llama-sampling.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,13 @@ struct llama_sampler * llama_sampler_init_temp_ext_impl (float t, float delta
5858

5959
struct llama_sampler * llama_sampler_init_mirostat_impl(
6060
const struct llama_vocab & vocab,
61+
uint32_t seed,
6162
float tau,
6263
float eta,
6364
int32_t m);
6465

6566
struct llama_sampler * llama_sampler_init_mirostat_v2_impl(
67+
uint32_t seed,
6668
float tau,
6769
float eta);
6870

src/llama.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20646,12 +20646,12 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
2064620646
return llama_sampler_init_temp_ext_impl(temp, delta, exponent);
2064720647
}
2064820648

20649-
struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, float tau, float eta) {
20650-
return llama_sampler_init_mirostat_impl(model->vocab, tau, eta, 100);
20649+
struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta) {
20650+
return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, 100);
2065120651
}
2065220652

20653-
struct llama_sampler * llama_sampler_init_mirostat_v2(float tau, float eta) {
20654-
return llama_sampler_init_mirostat_v2_impl(tau, eta);
20653+
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
20654+
return llama_sampler_init_mirostat_v2_impl(seed, tau, eta);
2065520655
}
2065620656

2065720657
struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {

0 commit comments

Comments
 (0)