File tree 1 file changed +8
-3
lines changed
1 file changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -39,6 +39,11 @@ int main(int argc, char ** argv) {
39
39
return 1 ;
40
40
}
41
41
42
+ if (params.n_predict < -1 ) {
43
+ LOG_ERR (" %s: --n-predict must be >= -1\n " , __func__);
44
+ return 1 ;
45
+ }
46
+
42
47
common_init ();
43
48
44
49
if (params.model_draft .empty ()) {
@@ -190,8 +195,8 @@ int main(int argc, char ** argv) {
190
195
drafts[s].smpl = common_sampler_init (model_dft, params.sparams );
191
196
}
192
197
193
- llama_batch batch_dft = llama_batch_init (params. n_ctx , 0 , 1 );
194
- llama_batch batch_tgt = llama_batch_init (params. n_ctx , 0 , n_seq_dft);
198
+ llama_batch batch_dft = llama_batch_init (llama_n_batch (ctx_dft) , 0 , 1 );
199
+ llama_batch batch_tgt = llama_batch_init (llama_n_batch (ctx_tgt) , 0 , n_seq_dft);
195
200
196
201
const auto t_dec_start = ggml_time_us ();
197
202
@@ -441,7 +446,7 @@ int main(int argc, char ** argv) {
441
446
++n_past_dft;
442
447
}
443
448
444
- if (n_predict > params.n_predict || has_eos) {
449
+ if ((params. n_predict >= 0 && n_predict > params.n_predict ) || has_eos) {
445
450
break ;
446
451
}
447
452
You can’t perform that action at this time.
0 commit comments