4
4
#include " common.h"
5
5
#include " sampling.h"
6
6
7
+ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
8
+ #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
9
+
7
10
struct common_speculative {
8
11
struct common_speculative_params params;
9
12
10
- llama_batch batch_dft ;
13
+ llama_batch batch ;
11
14
15
+ struct llama_context * ctx;
12
16
struct common_sampler * smpl;
13
17
14
- llama_tokens prompt_last ;
18
+ llama_tokens prompt ;
15
19
};
16
20
17
- struct common_speculative * common_speculative_init (struct common_speculative_params params) {
21
+ struct common_speculative * common_speculative_init (
22
+ struct common_speculative_params params,
23
+ struct llama_context * ctx_dft) {
18
24
auto * result = new common_speculative {
19
- /* .params = */ params,
20
- /* .batch_dft = */ llama_batch_init (llama_n_batch (params.ctx_dft ), 0 , 1 ),
21
- /* .smpl = */ nullptr ,
25
+ /* .params = */ params,
26
+ /* .batch = */ llama_batch_init (llama_n_batch (ctx_dft), 0 , 1 ),
27
+ /* .ctx = */ ctx_dft,
28
+ /* .smpl = */ nullptr ,
29
+ /* .prompt = */ {},
22
30
};
23
31
24
32
// TODO: optimize or pass from outside?
@@ -36,7 +44,7 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
36
44
COMMON_SAMPLER_TYPE_INFILL,
37
45
};
38
46
39
- result->smpl = common_sampler_init (params. model_dft , sparams);
47
+ result->smpl = common_sampler_init (llama_get_model (ctx_dft) , sparams);
40
48
}
41
49
#else
42
50
{
@@ -49,46 +57,104 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
49
57
COMMON_SAMPLER_TYPE_TOP_K,
50
58
};
51
59
52
- result->smpl = common_sampler_init(params.model_dft , sparams);
60
+ result->smpl = common_sampler_init(llama_get_model(ctx_dft) , sparams);
53
61
}
54
62
#endif
55
63
56
- result->batch_dft = llama_batch_init (llama_n_batch (params.ctx_dft ), 0 , 1 );
57
-
58
64
return result;
59
65
}
60
66
61
67
void common_speculative_free (struct common_speculative * spec) {
62
68
common_sampler_free (spec->smpl );
63
69
64
- llama_batch_free (spec->batch_dft );
70
+ llama_batch_free (spec->batch );
65
71
66
72
delete spec;
67
73
}
68
74
75
+ bool common_speculative_are_compatible (
76
+ const struct llama_context * ctx_tgt,
77
+ const struct llama_context * ctx_dft) {
78
+ const struct llama_model * model_tgt = llama_get_model (ctx_tgt);
79
+ const struct llama_model * model_dft = llama_get_model (ctx_dft);
80
+
81
+ const bool vocab_type_tgt = llama_vocab_type (model_tgt);
82
+ LOG_DBG (" %s: vocab_type tgt: %d\n " , __func__, vocab_type_tgt);
83
+
84
+ const bool vocab_type_dft = llama_vocab_type (model_dft);
85
+ LOG_DBG (" %s: vocab_type dft: %d\n " , __func__, vocab_type_dft);
86
+
87
+ if (vocab_type_tgt != vocab_type_dft) {
88
+ LOG_ERR (" %s: draft model vocab type must match target model to use speculation but "
89
+ " vocab_type_dft = %d while vocab_type_tgt = %d\n " , __func__, vocab_type_dft, vocab_type_tgt);
90
+ return false ;
91
+ }
92
+
93
+ if (llama_add_bos_token (model_tgt) != llama_add_bos_token (model_dft) ||
94
+ llama_add_eos_token (model_tgt) != llama_add_eos_token (model_dft) ||
95
+ llama_token_bos (model_tgt) != llama_token_bos (model_dft) ||
96
+ llama_token_eos (model_tgt) != llama_token_eos (model_dft)
97
+ ) {
98
+ LOG_ERR (" %s: draft model special tokens must match target model to use speculation\n " , __func__);
99
+ return false ;
100
+ }
101
+
102
+ {
103
+ const int n_vocab_tgt = llama_n_vocab (model_tgt);
104
+ const int n_vocab_dft = llama_n_vocab (model_dft);
105
+
106
+ const int vocab_diff = std::abs (n_vocab_tgt - n_vocab_dft);
107
+
108
+ if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
109
+ LOG_ERR (" %s: draft model vocab must closely match target model to use speculation but "
110
+ " target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n " ,
111
+ __func__, n_vocab_tgt, llama_n_vocab (model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
112
+ return false ;
113
+ }
114
+
115
+ for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min (n_vocab_tgt, n_vocab_dft); ++i) {
116
+ const char * token_text_tgt = llama_token_get_text (model_tgt, i);
117
+ const char * token_text_dft = llama_token_get_text (model_dft, i);
118
+ if (std::strcmp (token_text_tgt, token_text_dft) != 0 ) {
119
+ LOG_ERR (" %s: draft model vocab must match target model to use speculation but "
120
+ " token %d content differs - target '%s', draft '%s'\n " , __func__, i,
121
+ common_token_to_piece (ctx_tgt, i).c_str (),
122
+ common_token_to_piece (ctx_dft, i).c_str ());
123
+ return false ;
124
+ }
125
+ }
126
+ }
127
+
128
+ return true ;
129
+ }
130
+
69
131
void common_speculative_add_draft (
70
132
struct common_speculative * spec,
71
133
struct llama_batch & batch_tgt,
72
- const llama_tokens & prompt ,
134
+ const llama_tokens & prompt_tgt ,
73
135
llama_token id_last,
74
136
llama_token n_past_tgt) {
137
+ auto & batch = spec->batch ;
138
+ auto & ctx = spec->ctx ;
139
+ auto & smpl = spec->smpl ;
140
+ auto & prompt = spec->prompt ;
75
141
76
142
int reuse_i = 0 ;
77
143
int reuse_n = 0 ;
78
144
79
- const int n_ctx = llama_n_ctx (spec-> params . ctx_dft ) - spec->params .n_draft ;
145
+ const int n_ctx = llama_n_ctx (ctx ) - spec->params .n_draft ;
80
146
81
- const int i_start = std::max<int >(0 , (int ) prompt .size () - n_ctx);
147
+ const int i_start = std::max<int >(0 , (int ) prompt_tgt .size () - n_ctx);
82
148
83
- for (int i = 0 ; i < (int ) spec-> prompt_last .size (); ++i) {
149
+ for (int i = 0 ; i < (int ) prompt .size (); ++i) {
84
150
int cur = 0 ;
85
- while (i_start + cur < (int ) prompt .size () &&
86
- i + cur < (int ) spec-> prompt_last .size () &&
87
- prompt [i_start + cur] == spec-> prompt_last [i + cur]) {
151
+ while (i_start + cur < (int ) prompt_tgt .size () &&
152
+ i + cur < (int ) prompt .size () &&
153
+ prompt_tgt [i_start + cur] == prompt [i + cur]) {
88
154
cur++;
89
155
}
90
156
91
- if ((cur >= spec->params .n_reuse || prompt .size () <= n_ctx) && cur > reuse_n) {
157
+ if ((cur >= spec->params .n_reuse || prompt_tgt .size () <= n_ctx) && cur > reuse_n) {
92
158
reuse_i = i;
93
159
reuse_n = cur;
94
160
}
@@ -97,59 +163,59 @@ void common_speculative_add_draft(
97
163
LOG_DBG (" %s: reuse_i = %d, reuse_n = %d\n " , __func__, reuse_i, reuse_n);
98
164
99
165
if (reuse_n == 0 ) {
100
- llama_kv_cache_clear (spec-> params . ctx_dft );
166
+ llama_kv_cache_clear (ctx );
101
167
102
- spec-> prompt_last .clear ();
168
+ prompt .clear ();
103
169
} else {
104
- llama_kv_cache_seq_rm (spec-> params . ctx_dft , 0 , 0 , reuse_i);
105
- llama_kv_cache_seq_rm (spec-> params . ctx_dft , 0 , reuse_i + reuse_n, -1 );
106
- llama_kv_cache_seq_add (spec-> params . ctx_dft , 0 , reuse_i, -1 , -reuse_i);
170
+ llama_kv_cache_seq_rm (ctx , 0 , 0 , reuse_i);
171
+ llama_kv_cache_seq_rm (ctx , 0 , reuse_i + reuse_n, -1 );
172
+ llama_kv_cache_seq_add (ctx , 0 , reuse_i, -1 , -reuse_i);
107
173
108
- spec-> prompt_last .erase (spec-> prompt_last .begin (), spec-> prompt_last .begin () + reuse_i);
109
- spec-> prompt_last .erase (spec-> prompt_last .begin () + reuse_n, spec-> prompt_last .end ());
174
+ prompt .erase (prompt .begin (), prompt .begin () + reuse_i);
175
+ prompt .erase (prompt .begin () + reuse_n, prompt .end ());
110
176
}
111
177
112
- common_batch_clear (spec-> batch_dft );
178
+ common_batch_clear (batch );
113
179
114
- for (int i = i_start + reuse_n; i < (int ) prompt .size (); ++i) {
115
- // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt [i]);
116
- common_batch_add (spec-> batch_dft , prompt [i], i - i_start, { 0 }, false );
180
+ for (int i = i_start + reuse_n; i < (int ) prompt_tgt .size (); ++i) {
181
+ // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt [i]);
182
+ common_batch_add (batch, prompt_tgt [i], i - i_start, { 0 }, false );
117
183
118
- spec-> prompt_last .push_back (prompt [i]);
184
+ prompt .push_back (prompt_tgt [i]);
119
185
}
120
186
121
- const llama_pos n_past = prompt .size () - i_start;
187
+ const llama_pos n_past = prompt_tgt .size () - i_start;
122
188
123
189
LOG_DBG (" %s: n_past = %d\n " , __func__, n_past);
124
190
125
- if (spec-> batch_dft .n_tokens > 0 ) {
126
- LOG_DBG (" %s: draft batch: %s\n " , __func__, string_from (spec-> params . ctx_dft , spec-> batch_dft ).c_str ());
191
+ if (batch .n_tokens > 0 ) {
192
+ LOG_DBG (" %s: draft batch: %s\n " , __func__, string_from (ctx, batch ).c_str ());
127
193
128
- llama_decode (spec-> params . ctx_dft , spec-> batch_dft );
194
+ llama_decode (ctx, batch );
129
195
}
130
196
131
- common_batch_clear (spec-> batch_dft );
132
- common_batch_add (spec-> batch_dft , id_last, n_past, { 0 }, true );
197
+ common_batch_clear (batch );
198
+ common_batch_add (batch , id_last, n_past, { 0 }, true );
133
199
134
- spec-> prompt_last .push_back (id_last);
200
+ prompt .push_back (id_last);
135
201
136
- LOG_DBG (" %s: prompt_last: %s\n " , __func__, string_from (spec-> params . ctx_dft , spec-> prompt_last ).c_str ());
202
+ LOG_DBG (" %s: prompt_last: %s\n " , __func__, string_from (ctx, prompt ).c_str ());
137
203
138
- llama_decode (spec-> params . ctx_dft , spec-> batch_dft );
204
+ llama_decode (ctx, batch );
139
205
140
- common_sampler_reset (spec-> smpl );
206
+ common_sampler_reset (smpl);
141
207
142
208
// sample n_draft tokens from the draft model
143
209
for (int i = 0 ; i < spec->params .n_draft ; ++i) {
144
- common_batch_clear (spec-> batch_dft );
210
+ common_batch_clear (batch );
145
211
146
- common_sampler_sample (spec-> smpl , spec-> params . ctx_dft , 0 , true );
212
+ common_sampler_sample (smpl, ctx , 0 , true );
147
213
148
- const auto * cur_p = common_sampler_get_candidates (spec-> smpl );
214
+ const auto * cur_p = common_sampler_get_candidates (smpl);
149
215
150
216
for (int k = 0 ; k < std::min (3 , (int ) cur_p->size ); ++k) {
151
217
LOG_DBG (" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n " ,
152
- k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (spec-> params . ctx_dft , cur_p->data [k].id ).c_str ());
218
+ k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (ctx , cur_p->data [k].id ).c_str ());
153
219
}
154
220
155
221
// add drafted token for each sequence
@@ -160,20 +226,20 @@ void common_speculative_add_draft(
160
226
break ;
161
227
}
162
228
163
- common_sampler_accept (spec-> smpl , id, true );
229
+ common_sampler_accept (smpl, id, true );
164
230
165
231
common_batch_add (batch_tgt, id, n_past_tgt + i, { 0 }, true );
166
232
167
233
if (batch_tgt.n_tokens > spec->params .n_draft ) {
168
234
break ;
169
235
}
170
236
171
- common_batch_add (spec-> batch_dft , id, n_past + i + 1 , { 0 }, true );
237
+ common_batch_add (batch , id, n_past + i + 1 , { 0 }, true );
172
238
173
239
// evaluate the drafted tokens on the draft model
174
- llama_decode (spec-> params . ctx_dft , spec-> batch_dft );
240
+ llama_decode (ctx, batch );
175
241
176
- spec-> prompt_last .push_back (id);
242
+ prompt .push_back (id);
177
243
}
178
244
179
245
// don't waste time on small batches
0 commit comments