Skip to content

Commit 44c0943

Browse files
ggerganovdsx1986
authored andcommitted
speculative : fix handling of some input params (ggml-org#9963)
* speculative : fix batch sizes at initialization ggml-ci * speculative : handle params.n_predict == -1 * speculative : limit batch size to llama_n_batch
1 parent ec2a378 commit 44c0943

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

examples/speculative/speculative.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ int main(int argc, char ** argv) {
3939
return 1;
4040
}
4141

42+
if (params.n_predict < -1) {
43+
LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
44+
return 1;
45+
}
46+
4247
common_init();
4348

4449
if (params.model_draft.empty()) {
@@ -190,8 +195,8 @@ int main(int argc, char ** argv) {
190195
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
191196
}
192197

193-
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
194-
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
198+
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
199+
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
195200

196201
const auto t_dec_start = ggml_time_us();
197202

@@ -441,7 +446,7 @@ int main(int argc, char ** argv) {
441446
++n_past_dft;
442447
}
443448

444-
if (n_predict > params.n_predict || has_eos) {
449+
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
445450
break;
446451
}
447452

0 commit comments

Comments
 (0)