Skip to content

Commit 81471a7

Browse files
committed
llama : refactor samplers (wip)
ggml-ci
1 parent 9630a50 commit 81471a7

File tree

4 files changed

+79
-72
lines changed

4 files changed

+79
-72
lines changed

common/sampling.cpp

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -116,38 +116,38 @@ std::string llama_sampling_order_print(const gpt_sampling_params & params) {
116116

117117
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
118118
switch (sampler_type) {
119-
case llama_sampler_type::TOP_K: return "top_k";
120-
case llama_sampler_type::TFS_Z: return "tfs_z";
121-
case llama_sampler_type::TYPICAL_P: return "typical_p";
122-
case llama_sampler_type::TOP_P: return "top_p";
123-
case llama_sampler_type::MIN_P: return "min_p";
124-
case llama_sampler_type::TEMPERATURE: return "temperature";
119+
case LLAMA_SAMPLER_TYPE_TOP_K: return "top_k";
120+
case LLAMA_SAMPLER_TYPE_TFS_Z: return "tfs_z";
121+
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return "typical_p";
122+
case LLAMA_SAMPLER_TYPE_TOP_P: return "top_p";
123+
case LLAMA_SAMPLER_TYPE_MIN_P: return "min_p";
124+
case LLAMA_SAMPLER_TYPE_TEMPERATURE: return "temperature";
125125
default : return "";
126126
}
127127
}
128128

129129
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
130130
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
131-
{"top_k", llama_sampler_type::TOP_K},
132-
{"top_p", llama_sampler_type::TOP_P},
133-
{"typical_p", llama_sampler_type::TYPICAL_P},
134-
{"min_p", llama_sampler_type::MIN_P},
135-
{"tfs_z", llama_sampler_type::TFS_Z},
136-
{"temperature", llama_sampler_type::TEMPERATURE}
131+
{"top_k", LLAMA_SAMPLER_TYPE_TOP_K},
132+
{"top_p", LLAMA_SAMPLER_TYPE_TOP_P},
133+
{"typical_p", LLAMA_SAMPLER_TYPE_TYPICAL_P},
134+
{"min_p", LLAMA_SAMPLER_TYPE_MIN_P},
135+
{"tfs_z", LLAMA_SAMPLER_TYPE_TFS_Z},
136+
{"temperature", LLAMA_SAMPLER_TYPE_TEMPERATURE}
137137
};
138138

139139
// since samplers names are written multiple ways
140140
// make it ready for both system names and input names
141141
std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
142-
{"top-k", llama_sampler_type::TOP_K},
143-
{"top-p", llama_sampler_type::TOP_P},
144-
{"nucleus", llama_sampler_type::TOP_P},
145-
{"typical-p", llama_sampler_type::TYPICAL_P},
146-
{"typical", llama_sampler_type::TYPICAL_P},
147-
{"min-p", llama_sampler_type::MIN_P},
148-
{"tfs-z", llama_sampler_type::TFS_Z},
149-
{"tfs", llama_sampler_type::TFS_Z},
150-
{"temp", llama_sampler_type::TEMPERATURE}
142+
{"top-k", LLAMA_SAMPLER_TYPE_TOP_K},
143+
{"top-p", LLAMA_SAMPLER_TYPE_TOP_P},
144+
{"nucleus", LLAMA_SAMPLER_TYPE_TOP_P},
145+
{"typical-p", LLAMA_SAMPLER_TYPE_TYPICAL_P},
146+
{"typical", LLAMA_SAMPLER_TYPE_TYPICAL_P},
147+
{"min-p", LLAMA_SAMPLER_TYPE_MIN_P},
148+
{"tfs-z", LLAMA_SAMPLER_TYPE_TFS_Z},
149+
{"tfs", LLAMA_SAMPLER_TYPE_TFS_Z},
150+
{"temp", LLAMA_SAMPLER_TYPE_TEMPERATURE}
151151
};
152152

153153
std::vector<llama_sampler_type> sampler_types;
@@ -172,12 +172,12 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
172172

173173
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
174174
std::unordered_map<char, llama_sampler_type> sampler_name_map {
175-
{'k', llama_sampler_type::TOP_K},
176-
{'p', llama_sampler_type::TOP_P},
177-
{'y', llama_sampler_type::TYPICAL_P},
178-
{'m', llama_sampler_type::MIN_P},
179-
{'f', llama_sampler_type::TFS_Z},
180-
{'t', llama_sampler_type::TEMPERATURE}
175+
{'k', LLAMA_SAMPLER_TYPE_TOP_K},
176+
{'p', LLAMA_SAMPLER_TYPE_TOP_P},
177+
{'y', LLAMA_SAMPLER_TYPE_TYPICAL_P},
178+
{'m', LLAMA_SAMPLER_TYPE_MIN_P},
179+
{'f', LLAMA_SAMPLER_TYPE_TFS_Z},
180+
{'t', LLAMA_SAMPLER_TYPE_TEMPERATURE}
181181
};
182182

183183
std::vector<llama_sampler_type> sampler_types;
@@ -203,12 +203,12 @@ static void sampler_queue(
203203

204204
for (auto sampler_type : samplers_sequence) {
205205
switch (sampler_type) {
206-
case llama_sampler_type::TOP_K: llama_sampling_top_k (smpl, cur_p); break;
207-
case llama_sampler_type::TFS_Z: llama_sampling_tail_free(smpl, cur_p); break;
208-
case llama_sampler_type::TYPICAL_P: llama_sampling_typical (smpl, cur_p); break;
209-
case llama_sampler_type::TOP_P: llama_sampling_top_p (smpl, cur_p); break;
210-
case llama_sampler_type::MIN_P: llama_sampling_min_p (smpl, cur_p); break;
211-
case llama_sampler_type::TEMPERATURE: llama_sampling_temp (smpl, cur_p); break;
206+
case LLAMA_SAMPLER_TYPE_TOP_K: llama_sampling_top_k (smpl, cur_p); break;
207+
case LLAMA_SAMPLER_TYPE_TFS_Z: llama_sampling_tail_free(smpl, cur_p); break;
208+
case LLAMA_SAMPLER_TYPE_TYPICAL_P: llama_sampling_typical (smpl, cur_p); break;
209+
case LLAMA_SAMPLER_TYPE_TOP_P: llama_sampling_top_p (smpl, cur_p); break;
210+
case LLAMA_SAMPLER_TYPE_MIN_P: llama_sampling_min_p (smpl, cur_p); break;
211+
case LLAMA_SAMPLER_TYPE_TEMPERATURE: llama_sampling_temp (smpl, cur_p); break;
212212
default : break;
213213
}
214214
}

common/sampling.h

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,47 +5,38 @@
55
#include <string>
66
#include <vector>
77

8-
// sampler types
9-
enum class llama_sampler_type : char {
10-
TOP_K = 'k',
11-
TOP_P = 'p',
12-
MIN_P = 'm',
13-
TFS_Z = 'f',
14-
TYPICAL_P = 'y',
15-
TEMPERATURE = 't'
16-
};
17-
188
// sampling parameters
199
typedef struct gpt_sampling_params {
20-
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
21-
int32_t n_prev = 64; // number of previous tokens to remember
22-
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
23-
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
24-
int32_t top_k = 40; // <= 0 to use vocab size
25-
float top_p = 0.95f; // 1.0 = disabled
26-
float min_p = 0.05f; // 0.0 = disabled
27-
float tfs_z = 1.00f; // 1.0 = disabled
28-
float typical_p = 1.00f; // 1.0 = disabled
29-
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
30-
float dynatemp_range = 0.00f; // 0.0 = disabled
31-
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
32-
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
33-
float penalty_repeat = 1.00f; // 1.0 = disabled
34-
float penalty_freq = 0.00f; // 0.0 = disabled
35-
float penalty_present = 0.00f; // 0.0 = disabled
36-
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
37-
float mirostat_tau = 5.00f; // target entropy
38-
float mirostat_eta = 0.10f; // learning rate
39-
bool penalize_nl = false; // consider newlines as a repeatable token
40-
bool ignore_eos = false;
10+
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
11+
12+
int32_t n_prev = 64; // number of previous tokens to remember
13+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
14+
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
15+
int32_t top_k = 40; // <= 0 to use vocab size
16+
float top_p = 0.95f; // 1.0 = disabled
17+
float min_p = 0.05f; // 0.0 = disabled
18+
float tfs_z = 1.00f; // 1.0 = disabled
19+
float typical_p = 1.00f; // 1.0 = disabled
20+
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
21+
float dynatemp_range = 0.00f; // 0.0 = disabled
22+
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
23+
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
24+
float penalty_repeat = 1.00f; // 1.0 = disabled
25+
float penalty_freq = 0.00f; // 0.0 = disabled
26+
float penalty_present = 0.00f; // 0.0 = disabled
27+
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
28+
float mirostat_tau = 5.00f; // target entropy
29+
float mirostat_eta = 0.10f; // learning rate
30+
bool penalize_nl = false; // consider newlines as a repeatable token
31+
bool ignore_eos = false;
4132

4233
std::vector<llama_sampler_type> samplers_sequence = {
43-
llama_sampler_type::TOP_K,
44-
llama_sampler_type::TFS_Z,
45-
llama_sampler_type::TYPICAL_P,
46-
llama_sampler_type::TOP_P,
47-
llama_sampler_type::MIN_P,
48-
llama_sampler_type::TEMPERATURE
34+
LLAMA_SAMPLER_TYPE_TOP_K,
35+
LLAMA_SAMPLER_TYPE_TFS_Z,
36+
LLAMA_SAMPLER_TYPE_TYPICAL_P,
37+
LLAMA_SAMPLER_TYPE_TOP_P,
38+
LLAMA_SAMPLER_TYPE_MIN_P,
39+
LLAMA_SAMPLER_TYPE_TEMPERATURE
4940
};
5041

5142
std::string grammar; // optional BNF-like grammar to constrain sampling

include/llama.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
4747
#define LLAMA_STATE_SEQ_VERSION 2
4848

49+
#define LLAMA_MAX_SAMPLERS 16
50+
4951
#ifdef __cplusplus
5052
extern "C" {
5153
#endif
@@ -203,6 +205,16 @@ extern "C" {
203205
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
204206
};
205207

208+
enum llama_sampler_type {
209+
LLAMA_SAMPLER_TYPE_NONE = 0,
210+
LLAMA_SAMPLER_TYPE_TOP_K = 1,
211+
LLAMA_SAMPLER_TYPE_TOP_P = 2,
212+
LLAMA_SAMPLER_TYPE_MIN_P = 3,
213+
LLAMA_SAMPLER_TYPE_TFS_Z = 4,
214+
LLAMA_SAMPLER_TYPE_TYPICAL_P = 5,
215+
LLAMA_SAMPLER_TYPE_TEMPERATURE = 6,
216+
};
217+
206218
typedef struct llama_token_data {
207219
llama_token id; // token id
208220
float logit; // log-odds of the token
@@ -387,7 +399,10 @@ extern "C" {
387399
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
388400
float mirostat_tau; // target entropy
389401
float mirostat_eta; // learning rate
390-
float cfg_scale; // classifier-free guidance scale
402+
403+
// samples
404+
int32_t n_samplers;
405+
enum llama_sampler_type samplers[LLAMA_MAX_SAMPLERS];
391406

392407
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
393408
bool penalize_nl; // consider newlines as a repeatable token

src/llama.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17429,7 +17429,8 @@ struct llama_sampling_params llama_sampling_default_params() {
1742917429
/*.mirostat =*/ 0,
1743017430
/*.mirostat_tau =*/ 5.00f,
1743117431
/*.mirostat_eta =*/ 0.10f,
17432-
/*.cfg_scale =*/ 1.00f,
17432+
/*.n_samplers =*/ 3,
17433+
/*.samplers =*/ { LLAMA_SAMPLER_TYPE_TOP_K, LLAMA_SAMPLER_TYPE_TOP_P, LLAMA_SAMPLER_TYPE_TEMPERATURE },
1743317434
/*.penalize_nl =*/ false,
1743417435
/*.ignore_eos =*/ false,
1743517436
};

0 commit comments

Comments
 (0)