Skip to content

Commit f64dea0

Browse files
committed
added implementation of DRY sampler
1 parent 784e11d commit f64dea0

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

common/sampling.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,13 +267,18 @@ static llama_token_data_array llama_sampling_prepare_impl(
267267

268268
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
269269

270+
// repetition penalties
270271
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
271272
const float penalty_repeat = params.penalty_repeat;
272273
const float penalty_freq = params.penalty_freq;
273274
const float penalty_present = params.penalty_present;
274-
275275
const bool penalize_nl = params.penalize_nl;
276276

277+
// DRY sampler parameters
278+
const float dry_multiplier = params.dry_multiplier;
279+
const float dry_base = params.dry_base;
280+
const int dry_allowed_length = params.dry_allowed_length;
281+
277282
auto & prev = ctx_sampling->prev;
278283
auto & cur = ctx_sampling->cur;
279284

@@ -309,10 +314,20 @@ static llama_token_data_array llama_sampling_prepare_impl(
309314
if (penalty_tokens_used_size) {
310315
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
311316

317+
// repetition penalties
312318
llama_sample_repetition_penalties(ctx_main, &cur_p,
313319
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
314320
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
315321

322+
// DRY penalties (multiplier > 0 means enabled)
323+
if(dry_multiplier > 0.0f) {
324+
llama_sample_dry(ctx_main, &cur_p,
325+
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
326+
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
327+
params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size());
328+
}
329+
330+
316331
if (!penalize_nl) {
317332
for (size_t idx = 0; idx < cur_p.size; idx++) {
318333
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {

common/sampling.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ typedef struct llama_sampling_params {
4141
float mirostat_eta = 0.10f; // learning rate
4242
bool penalize_nl = false; // consider newlines as a repeatable token
4343
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
44+
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
45+
float dry_base = 1.75f;
46+
int dry_allowed_length = 2;
4447

4548
std::vector<llama_sampler_type> samplers_sequence = {
4649
llama_sampler_type::TOP_K,
@@ -61,6 +64,7 @@ typedef struct llama_sampling_params {
6164
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
6265

6366
std::vector<llama_token> penalty_prompt_tokens;
67+
std::vector<llama_token> dry_sequence_breakers; // sequence breakers for the DRY sampler
6468
bool use_penalty_prompt_tokens = false;
6569
} llama_sampling_params;
6670

llama.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,18 @@ extern "C" {
924924
float p,
925925
size_t min_keep);
926926

927+
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
928+
LLAMA_API void llama_sample_dry(
929+
struct llama_context * ctx,
930+
llama_token_data_array * candidates,
931+
const llama_token * last_tokens,
932+
int last_tokens_size,
933+
float dry_base,
934+
float dry_multiplier,
935+
int dry_allowed_length,
936+
const llama_token * seq_breakers,
937+
int seq_breakers_size);
938+
927939
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
928940
LLAMA_API void llama_sample_tail_free(
929941
struct llama_context * ctx,

0 commit comments

Comments
 (0)