Skip to content

Commit 861ad6f

Browse files
committed
cont
ggml-ci
1 parent 81471a7 commit 861ad6f

File tree

9 files changed

+66
-45
lines changed

9 files changed

+66
-45
lines changed

common/common.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -584,12 +584,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
584584
if (arg == "--samplers") {
585585
CHECK_ARG
586586
const auto sampler_names = string_split(argv[i], ';');
587-
sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true);
587+
sparams.samplers = llama_sampling_types_from_names(sampler_names, true);
588588
return true;
589589
}
590590
if (arg == "--sampling-seq") {
591591
CHECK_ARG
592-
sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]);
592+
sparams.samplers = llama_sampling_types_from_chars(argv[i]);
593593
return true;
594594
}
595595
if (arg == "--top-p") {
@@ -1438,9 +1438,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
14381438

14391439
std::string sampler_type_chars;
14401440
std::string sampler_type_names;
1441-
for (const auto sampler_type : sparams.samplers_sequence) {
1442-
sampler_type_chars += static_cast<char>(sampler_type);
1443-
sampler_type_names += llama_sampling_type_to_str(sampler_type) + ";";
1441+
for (const auto & sampler : sparams.samplers) {
1442+
sampler_type_chars += llama_sampling_type_to_chr(sampler);
1443+
sampler_type_names += llama_sampling_type_to_str(sampler) + ";";
14441444
}
14451445
sampler_type_names.pop_back();
14461446

common/sampling.cpp

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_model * m
3232
lparams.penalize_nl = params.penalize_nl;
3333
lparams.ignore_eos = params.ignore_eos;
3434

35+
lparams.n_samplers = params.samplers.size();
36+
3537
result->smpl = llama_sampling_init(model, lparams);
3638

3739
llama_sampling_set_grammar (result->smpl, params.grammar.c_str(), "root");
@@ -101,7 +103,7 @@ std::string llama_sampling_print(const gpt_sampling_params & params) {
101103
std::string llama_sampling_order_print(const gpt_sampling_params & params) {
102104
std::string result = "CFG -> Penalties ";
103105
if (params.mirostat == 0) {
104-
for (auto sampler_type : params.samplers_sequence) {
106+
for (auto sampler_type : params.samplers) {
105107
const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
106108
if (!sampler_type_name.empty()) {
107109
result += "-> " + sampler_type_name + " ";
@@ -114,6 +116,18 @@ std::string llama_sampling_order_print(const gpt_sampling_params & params) {
114116
return result;
115117
}
116118

119+
char llama_sampling_type_to_chr(llama_sampler_type sampler_type) {
120+
switch (sampler_type) {
121+
case LLAMA_SAMPLER_TYPE_TOP_K: return 'k';
122+
case LLAMA_SAMPLER_TYPE_TFS_Z: return 'f';
123+
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return 'y';
124+
case LLAMA_SAMPLER_TYPE_TOP_P: return 'p';
125+
case LLAMA_SAMPLER_TYPE_MIN_P: return 'm';
126+
case LLAMA_SAMPLER_TYPE_TEMPERATURE: return 't';
127+
default : return '?';
128+
}
129+
}
130+
117131
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
118132
switch (sampler_type) {
119133
case LLAMA_SAMPLER_TYPE_TOP_K: return "top_k";
@@ -128,26 +142,26 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
128142

129143
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
130144
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
131-
{"top_k", LLAMA_SAMPLER_TYPE_TOP_K},
132-
{"top_p", LLAMA_SAMPLER_TYPE_TOP_P},
133-
{"typical_p", LLAMA_SAMPLER_TYPE_TYPICAL_P},
134-
{"min_p", LLAMA_SAMPLER_TYPE_MIN_P},
135-
{"tfs_z", LLAMA_SAMPLER_TYPE_TFS_Z},
136-
{"temperature", LLAMA_SAMPLER_TYPE_TEMPERATURE}
145+
{ "top_k", LLAMA_SAMPLER_TYPE_TOP_K },
146+
{ "top_p", LLAMA_SAMPLER_TYPE_TOP_P },
147+
{ "typical_p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
148+
{ "min_p", LLAMA_SAMPLER_TYPE_MIN_P },
149+
{ "tfs_z", LLAMA_SAMPLER_TYPE_TFS_Z },
150+
{ "temperature", LLAMA_SAMPLER_TYPE_TEMPERATURE },
137151
};
138152

139153
// since samplers names are written multiple ways
140154
// make it ready for both system names and input names
141155
std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
142-
{"top-k", LLAMA_SAMPLER_TYPE_TOP_K},
143-
{"top-p", LLAMA_SAMPLER_TYPE_TOP_P},
144-
{"nucleus", LLAMA_SAMPLER_TYPE_TOP_P},
145-
{"typical-p", LLAMA_SAMPLER_TYPE_TYPICAL_P},
146-
{"typical", LLAMA_SAMPLER_TYPE_TYPICAL_P},
147-
{"min-p", LLAMA_SAMPLER_TYPE_MIN_P},
148-
{"tfs-z", LLAMA_SAMPLER_TYPE_TFS_Z},
149-
{"tfs", LLAMA_SAMPLER_TYPE_TFS_Z},
150-
{"temp", LLAMA_SAMPLER_TYPE_TEMPERATURE}
156+
{ "top-k", LLAMA_SAMPLER_TYPE_TOP_K },
157+
{ "top-p", LLAMA_SAMPLER_TYPE_TOP_P },
158+
{ "nucleus", LLAMA_SAMPLER_TYPE_TOP_P },
159+
{ "typical-p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
160+
{ "typical", LLAMA_SAMPLER_TYPE_TYPICAL_P },
161+
{ "min-p", LLAMA_SAMPLER_TYPE_MIN_P },
162+
{ "tfs-z", LLAMA_SAMPLER_TYPE_TFS_Z },
163+
{ "tfs", LLAMA_SAMPLER_TYPE_TFS_Z },
164+
{ "temp", LLAMA_SAMPLER_TYPE_TEMPERATURE },
151165
};
152166

153167
std::vector<llama_sampler_type> sampler_types;
@@ -172,12 +186,12 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
172186

173187
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
174188
std::unordered_map<char, llama_sampler_type> sampler_name_map {
175-
{'k', LLAMA_SAMPLER_TYPE_TOP_K},
176-
{'p', LLAMA_SAMPLER_TYPE_TOP_P},
177-
{'y', LLAMA_SAMPLER_TYPE_TYPICAL_P},
178-
{'m', LLAMA_SAMPLER_TYPE_MIN_P},
179-
{'f', LLAMA_SAMPLER_TYPE_TFS_Z},
180-
{'t', LLAMA_SAMPLER_TYPE_TEMPERATURE}
189+
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_K), LLAMA_SAMPLER_TYPE_TOP_K },
190+
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TFS_Z), LLAMA_SAMPLER_TYPE_TFS_Z },
191+
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TYPICAL_P), LLAMA_SAMPLER_TYPE_TYPICAL_P },
192+
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_P), LLAMA_SAMPLER_TYPE_TOP_P },
193+
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_MIN_P), LLAMA_SAMPLER_TYPE_MIN_P },
194+
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TEMPERATURE), LLAMA_SAMPLER_TYPE_TEMPERATURE }
181195
};
182196

183197
std::vector<llama_sampler_type> sampler_types;
@@ -199,10 +213,10 @@ static void sampler_queue(
199213

200214
const gpt_sampling_params & params = ctx_sampling->params;
201215

202-
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
216+
const std::vector<llama_sampler_type> & samplers = params.samplers;
203217

204-
for (auto sampler_type : samplers_sequence) {
205-
switch (sampler_type) {
218+
for (const auto & sampler : samplers) {
219+
switch (sampler) {
206220
case LLAMA_SAMPLER_TYPE_TOP_K: llama_sampling_top_k (smpl, cur_p); break;
207221
case LLAMA_SAMPLER_TYPE_TFS_Z: llama_sampling_tail_free(smpl, cur_p); break;
208222
case LLAMA_SAMPLER_TYPE_TYPICAL_P: llama_sampling_typical (smpl, cur_p); break;

common/sampling.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ typedef struct gpt_sampling_params {
3030
bool penalize_nl = false; // consider newlines as a repeatable token
3131
bool ignore_eos = false;
3232

33-
std::vector<llama_sampler_type> samplers_sequence = {
33+
std::vector<llama_sampler_type> samplers = {
3434
LLAMA_SAMPLER_TYPE_TOP_K,
3535
LLAMA_SAMPLER_TYPE_TFS_Z,
3636
LLAMA_SAMPLER_TYPE_TYPICAL_P,
@@ -78,6 +78,7 @@ std::string llama_sampling_print(const gpt_sampling_params & params);
7878
// Print sampling order into a string
7979
std::string llama_sampling_order_print(const gpt_sampling_params & params);
8080

81+
char llama_sampling_type_to_chr(llama_sampler_type sampler_type);
8182
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
8283

8384
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);

examples/batched.swift/Sources/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ while n_cur <= n_len {
155155
llama_sampling_top_p(smpl, &candidates_p)
156156
llama_sampling_temp (smpl, &candidates_p)
157157

158-
let new_token_id = llama_sampling_sample(smpl, &candidates_p)
158+
let new_token_id = llama_sampling_sample_dist(smpl, &candidates_p)
159159

160160
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
161161

examples/server/server.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,17 +1039,17 @@ struct server_context {
10391039
}
10401040

10411041
{
1042-
const auto & samplers_sequence = data.find("samplers");
1043-
if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
1042+
const auto & samplers = data.find("samplers");
1043+
if (samplers != data.end() && samplers->is_array()) {
10441044
std::vector<std::string> sampler_names;
1045-
for (const auto & sampler_name : *samplers_sequence) {
1045+
for (const auto & sampler_name : *samplers) {
10461046
if (sampler_name.is_string()) {
10471047
sampler_names.emplace_back(sampler_name);
10481048
}
10491049
}
1050-
slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false);
1050+
slot.sparams.samplers = llama_sampling_types_from_names(sampler_names, false);
10511051
} else {
1052-
slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
1052+
slot.sparams.samplers = default_sparams.samplers;
10531053
}
10541054
}
10551055

@@ -1265,10 +1265,10 @@ struct server_context {
12651265
}
12661266

12671267
json get_formated_generation(const server_slot & slot) const {
1268-
std::vector<std::string> samplers_sequence;
1269-
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
1270-
for (const auto & sampler_type : slot.sparams.samplers_sequence) {
1271-
samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
1268+
std::vector<std::string> samplers;
1269+
samplers.reserve(slot.sparams.samplers.size());
1270+
for (const auto & sampler : slot.sparams.samplers) {
1271+
samplers.emplace_back(llama_sampling_type_to_str(sampler));
12721272
}
12731273

12741274
return json {
@@ -1302,7 +1302,7 @@ struct server_context {
13021302
{"n_probs", slot.sparams.n_probs},
13031303
{"min_keep", slot.sparams.min_keep},
13041304
{"grammar", slot.sparams.grammar},
1305-
{"samplers", samplers_sequence},
1305+
{"samplers", samplers},
13061306
};
13071307
}
13081308

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ extern "C" {
400400
float mirostat_tau; // target entropy
401401
float mirostat_eta; // learning rate
402402

403-
// samples
403+
// samplers
404404
int32_t n_samplers;
405405
enum llama_sampler_type samplers[LLAMA_MAX_SAMPLERS];
406406

src/llama-sampling.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & voca
4040

4141
result->prev = ring_buffer<llama_token>(params.n_prev);
4242

43+
for (int i = 0; i < params.n_samplers; ++i) {
44+
result->samplers.push_back(params.samplers[i]);
45+
}
46+
4347
llama_sampling_set_rng_seed_impl(*result, params.seed);
4448

4549
return result;

src/llama-sampling.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ struct llama_sampling {
2727

2828
const struct llama_vocab & vocab;
2929

30-
struct llama_grammar * grammar = nullptr;
30+
std::vector<llama_sampler_type> samplers;
3131

3232
ring_buffer<llama_token> prev;
3333

34+
struct llama_grammar * grammar = nullptr;
35+
3436
// mirostat sampler state
3537
float mirostat_mu;
3638

src/llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17430,7 +17430,7 @@ struct llama_sampling_params llama_sampling_default_params() {
1743017430
/*.mirostat_tau =*/ 5.00f,
1743117431
/*.mirostat_eta =*/ 0.10f,
1743217432
/*.n_samplers =*/ 3,
17433-
/*.samplers =*/ { LLAMA_SAMPLER_TYPE_TOP_K, LLAMA_SAMPLER_TYPE_TOP_P, LLAMA_SAMPLER_TYPE_TEMPERATURE },
17433+
/*.samplers =*/ { LLAMA_SAMPLER_TYPE_TEMPERATURE, LLAMA_SAMPLER_TYPE_TOP_K, LLAMA_SAMPLER_TYPE_TOP_P, },
1743417434
/*.penalize_nl =*/ false,
1743517435
/*.ignore_eos =*/ false,
1743617436
};

0 commit comments

Comments
 (0)