@@ -267,13 +267,18 @@ static llama_token_data_array llama_sampling_prepare_impl(
267
267
268
268
const int n_vocab = llama_n_vocab (llama_get_model (ctx_main));
269
269
270
+ // repetition penalties
270
271
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n ;
271
272
const float penalty_repeat = params.penalty_repeat ;
272
273
const float penalty_freq = params.penalty_freq ;
273
274
const float penalty_present = params.penalty_present ;
274
-
275
275
const bool penalize_nl = params.penalize_nl ;
276
276
277
+ // DRY sampler parameters
278
+ const float dry_multiplier = params.dry_multiplier ;
279
+ const float dry_base = params.dry_base ;
280
+ const int dry_allowed_length = params.dry_allowed_length ;
281
+
277
282
auto & prev = ctx_sampling->prev ;
278
283
auto & cur = ctx_sampling->cur ;
279
284
@@ -309,10 +314,20 @@ static llama_token_data_array llama_sampling_prepare_impl(
309
314
if (penalty_tokens_used_size) {
310
315
const float nl_logit = logits[llama_token_nl (llama_get_model (ctx_main))];
311
316
317
+ // repetition penalties
312
318
llama_sample_repetition_penalties (ctx_main, &cur_p,
313
319
penalty_tokens.data () + penalty_tokens.size () - penalty_tokens_used_size,
314
320
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
315
321
322
+ // DRY penalties (multiplier > 0 means enabled)
323
+ if (dry_multiplier > 0 .0f ) {
324
+ llama_sample_dry (ctx_main, &cur_p,
325
+ penalty_tokens.data () + penalty_tokens.size () - penalty_tokens_used_size,
326
+ penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
327
+ params.dry_sequence_breakers .data (), params.dry_sequence_breakers .size ());
328
+ }
329
+
330
+
316
331
if (!penalize_nl) {
317
332
for (size_t idx = 0 ; idx < cur_p.size ; idx++) {
318
333
if (cur_p.data [idx].id == llama_token_nl (llama_get_model (ctx_main))) {
0 commit comments