Skip to content

Commit 60d8af4

Browse files
committed
tests : replace macros with functions
ggml-ci
1 parent 78118c4 commit 60d8af4

File tree

1 file changed

+104
-108
lines changed

1 file changed

+104
-108
lines changed

tests/test-sampling.cpp

Lines changed: 104 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -18,169 +18,165 @@ static void dump(const llama_token_data_array * cur_p) {
1818

1919
#define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
2020

21-
#define APPLY(__cnstr, __cur_p) do { \
22-
auto * cnstr = (__cnstr); \
23-
llama_sampler_apply(cnstr, (__cur_p)); \
24-
llama_sampler_free(cnstr); \
25-
} while(0)
26-
27-
#define CUR_P_FROM_PROBS() \
28-
const size_t n_vocab = probs.size(); \
29-
std::vector<llama_token_data> cur; \
30-
cur.reserve(n_vocab); \
31-
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { \
32-
const float logit = logf(probs[token_id]); \
33-
cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); \
34-
} \
35-
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }
36-
37-
static void test_temp(const std::vector<float> & probs, const std::vector<float> & expected_probs, float temp) {
38-
CUR_P_FROM_PROBS();
39-
40-
DUMP(&cur_p);
41-
APPLY(llama_sampler_init_temp(temp), &cur_p);
42-
APPLY(llama_sampler_init_dist(0), &cur_p);
43-
DUMP(&cur_p);
44-
45-
GGML_ASSERT(cur_p.size == expected_probs.size());
46-
for (size_t i = 0; i < cur_p.size; i++) {
47-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
21+
struct sampler_tester {
22+
sampler_tester(size_t n_vocab) {
23+
cur.reserve(n_vocab);
24+
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
25+
const float logit = logf(token_id);
26+
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
27+
}
28+
29+
cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
4830
}
49-
}
5031

51-
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
52-
CUR_P_FROM_PROBS();
32+
sampler_tester(const std::vector<float> & probs, const std::vector<float> & probs_expected) : probs_expected(probs_expected) {
33+
cur.reserve(probs.size());
34+
for (llama_token token_id = 0; token_id < (llama_token)probs.size(); token_id++) {
35+
const float logit = logf(probs[token_id]);
36+
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
37+
}
38+
39+
cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
40+
}
5341

54-
DUMP(&cur_p);
55-
APPLY(llama_sampler_init_top_k(k), &cur_p);
56-
APPLY(llama_sampler_init_dist (0), &cur_p);
57-
DUMP(&cur_p);
42+
void apply(llama_sampler * sampler) {
43+
llama_sampler_apply(sampler, &cur_p);
44+
llama_sampler_free(sampler);
45+
}
5846

59-
GGML_ASSERT(cur_p.size == expected_probs.size());
60-
for (size_t i = 0; i < cur_p.size; i++) {
61-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
47+
void check() {
48+
GGML_ASSERT(cur_p.size == probs_expected.size());
49+
for (size_t i = 0; i < cur_p.size; i++) {
50+
GGML_ASSERT(fabs(cur_p.data[i].p - probs_expected[i]) < 1e-5);
51+
}
6252
}
53+
54+
llama_token_data_array cur_p;
55+
56+
private:
57+
const std::vector<float> probs_expected;
58+
59+
std::vector<llama_token_data> cur;
60+
};
61+
62+
static void test_temp(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp) {
63+
sampler_tester tester(probs, probs_expected);
64+
65+
DUMP(&tester.cur_p);
66+
tester.apply(llama_sampler_init_temp(temp));
67+
tester.apply(llama_sampler_init_dist(0));
68+
DUMP(&tester.cur_p);
69+
70+
tester.check();
6371
}
6472

65-
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
66-
CUR_P_FROM_PROBS();
73+
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
74+
sampler_tester tester(probs, probs_expected);
6775

68-
DUMP(&cur_p);
69-
APPLY(llama_sampler_init_top_p(p, 1), &cur_p);
70-
APPLY(llama_sampler_init_dist (0), &cur_p);
71-
DUMP(&cur_p);
72-
DUMP(&cur_p);
76+
DUMP(&tester.cur_p);
77+
tester.apply(llama_sampler_init_top_k(k));
78+
tester.apply(llama_sampler_init_dist (0));
79+
DUMP(&tester.cur_p);
7380

74-
GGML_ASSERT(cur_p.size == expected_probs.size());
75-
for (size_t i = 0; i < cur_p.size; i++) {
76-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
77-
}
81+
tester.check();
7882
}
7983

80-
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
81-
CUR_P_FROM_PROBS();
84+
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
85+
sampler_tester tester(probs, probs_expected);
8286

83-
DUMP(&cur_p);
84-
APPLY(llama_sampler_init_tail_free(z, 1), &cur_p);
85-
DUMP(&cur_p);
87+
DUMP(&tester.cur_p);
88+
tester.apply(llama_sampler_init_top_p(p, 1));
89+
tester.apply(llama_sampler_init_dist (0));
90+
DUMP(&tester.cur_p);
8691

87-
GGML_ASSERT(cur_p.size == expected_probs.size());
88-
for (size_t i = 0; i < cur_p.size; i++) {
89-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
90-
}
92+
tester.check();
9193
}
9294

93-
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
94-
CUR_P_FROM_PROBS();
95+
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & probs_expected, float z) {
96+
sampler_tester tester(probs, probs_expected);
9597

96-
DUMP(&cur_p);
97-
APPLY(llama_sampler_init_min_p(p, 1), &cur_p);
98-
APPLY(llama_sampler_init_dist (0), &cur_p);
99-
DUMP(&cur_p);
98+
DUMP(&tester.cur_p);
99+
tester.apply(llama_sampler_init_tail_free(z, 1));
100+
DUMP(&tester.cur_p);
100101

101-
GGML_ASSERT(cur_p.size == expected_probs.size());
102-
for (size_t i = 0; i < cur_p.size; i++) {
103-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
104-
}
102+
tester.check();
105103
}
106104

107-
static void test_xtc(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p, float t) {
108-
CUR_P_FROM_PROBS();
105+
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
106+
sampler_tester tester(probs, probs_expected);
109107

110-
DUMP(&cur_p);
111-
APPLY(llama_sampler_init_xtc(p, t, 0, 0), &cur_p);
112-
DUMP(&cur_p);
108+
DUMP(&tester.cur_p);
109+
tester.apply(llama_sampler_init_min_p(p, 1));
110+
tester.apply(llama_sampler_init_dist (0));
111+
DUMP(&tester.cur_p);
113112

114-
GGML_ASSERT(cur_p.size == expected_probs.size());
115-
for (size_t i = 0; i < cur_p.size; i++) {
116-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
117-
}
113+
tester.check();
118114
}
119115

120-
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
121-
CUR_P_FROM_PROBS();
116+
static void test_xtc(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p, float t) {
117+
sampler_tester tester(probs, probs_expected);
122118

123-
DUMP(&cur_p);
124-
APPLY(llama_sampler_init_typical(p, 1), &cur_p);
125-
DUMP(&cur_p);
119+
DUMP(&tester.cur_p);
120+
tester.apply(llama_sampler_init_xtc(p, t, 0, 0));
121+
DUMP(&tester.cur_p);
126122

127-
GGML_ASSERT(cur_p.size == expected_probs.size());
128-
for (size_t i = 0; i < cur_p.size; i++) {
129-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
130-
}
123+
tester.check();
124+
}
125+
126+
static void test_typical(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
127+
sampler_tester tester(probs, probs_expected);
128+
129+
DUMP(&tester.cur_p);
130+
tester.apply(llama_sampler_init_typical(p, 1));
131+
DUMP(&tester.cur_p);
132+
133+
tester.check();
131134
}
132135

133136
static void test_penalties(
134137
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
135-
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
138+
const std::vector<float> & probs_expected, float repeat_penalty, float alpha_frequency, float alpha_presence
136139
) {
137-
GGML_ASSERT(probs.size() == expected_probs.size());
140+
GGML_ASSERT(probs.size() == probs_expected.size());
138141

139-
CUR_P_FROM_PROBS();
142+
sampler_tester tester(probs, probs_expected);
140143

144+
const size_t n_vocab = probs.size();
141145
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
142146

143147
for (size_t i = 0; i < last_tokens.size(); i++) {
144148
llama_sampler_accept(sampler, last_tokens[i]);
145149
}
146150

147-
DUMP(&cur_p);
148-
APPLY(sampler, &cur_p);
149-
APPLY(llama_sampler_init_dist(0), &cur_p);
150-
DUMP(&cur_p);
151+
DUMP(&tester.cur_p);
152+
tester.apply(sampler);
153+
tester.apply(llama_sampler_init_dist(0));
154+
DUMP(&tester.cur_p);
151155

152-
GGML_ASSERT(cur_p.size == expected_probs.size());
153-
for (size_t i = 0; i < cur_p.size; i++) {
154-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
155-
}
156+
tester.check();
156157
}
157158

158159
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
159160
) {
160-
std::vector<llama_token_data> cur;
161-
cur.reserve(n_vocab);
162-
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
163-
const float logit = logf(token_id);
164-
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
165-
}
166-
167-
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
161+
sampler_tester tester(n_vocab);
168162

169163
llama_token min_token_id = 0;
170164
const llama_token max_token_id = n_vocab-1;
171165

172166
for (auto s : samplers_sequence) {
173167
switch (s){
174-
case 'k': APPLY(llama_sampler_init_top_k(top_k), &cur_p); break;
168+
case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
175169
case 'f': GGML_ABORT("tail_free test not implemented");
176170
case 'y': GGML_ABORT("typical test not implemented");
177-
case 'p': APPLY(llama_sampler_init_top_p(top_p, 1), &cur_p); break;
178-
case 'm': APPLY(llama_sampler_init_min_p(min_p, 1), &cur_p); break;
171+
case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
172+
case 'm': tester.apply(llama_sampler_init_min_p(min_p, 1)); break;
179173
case 't': GGML_ABORT("temperature test not implemented");
180174
default : GGML_ABORT("Unknown sampler");
181175
}
182176

183-
APPLY(llama_sampler_init_dist(0), &cur_p);
177+
tester.apply(llama_sampler_init_dist(0));
178+
179+
auto & cur_p = tester.cur_p;
184180

185181
const int size = cur_p.size;
186182

0 commit comments

Comments
 (0)