Skip to content

llama : add infill sampler #9896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
COMMON_SAMPLER_TYPE_XTC = 7,

COMMON_SAMPLER_TYPE_INFILL = 8,
};

// dimensionality reduction methods, used by cvector-generator
Expand Down Expand Up @@ -136,7 +136,7 @@ struct common_sampler_params {
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_MIN_P,
COMMON_SAMPLER_TYPE_XTC,
COMMON_SAMPLER_TYPE_TEMPERATURE
COMMON_SAMPLER_TYPE_TEMPERATURE,
};

std::string grammar; // optional BNF-like grammar to constrain sampling
Expand Down
9 changes: 8 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
Expand Down Expand Up @@ -376,6 +379,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x';
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
default : return '?';
}
}
Expand All @@ -389,6 +393,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
default : return "";
}
}
Expand All @@ -402,6 +407,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
};

// since samplers names are written multiple ways
Expand Down Expand Up @@ -448,7 +454,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
};

std::vector<common_sampler_type> samplers;
Expand Down
34 changes: 17 additions & 17 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,30 +569,30 @@ int main(int argc, char ** argv) {
if (!params.ctx_shift){
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
break;
} else {
if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
}
}

if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
}

const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;
const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;

LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);

llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);

n_past -= n_discard;
n_past -= n_discard;

LOG_DBG("after swap: n_past = %d\n", n_past);
LOG_DBG("after swap: n_past = %d\n", n_past);

LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());

LOG_DBG("clear session path\n");
path_session.clear();
}
LOG_DBG("clear session path\n");
path_session.clear();
}
} else {
// context extension via Self-Extend
Expand Down
28 changes: 28 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,12 @@ extern "C" {
int32_t lstrip,
bool special);

// check if token0 is contained as a prefix in token1
LLAMA_API bool llama_token_is_prefix(
const struct llama_model * model,
llama_token token0,
llama_token token1);

/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
/// @param text The char pointer must be large enough to hold the resulting text.
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
Expand Down Expand Up @@ -1148,6 +1154,28 @@ extern "C" {
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);

// this sampler is meant to be used for fill-in-the-middle infilling
// it's supposed to be used after top_k + top_p sampling
//
// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
// 2. combine probs of tokens that have the same prefix
//
// example:
//
// - before:
// "hel": 0.5
// "hell": 0.2
// "hello": 0.1
// "dummy": 0.1
//
// - after:
// "hel": 0.8
// "dummy": 0.1
//
// 3. discard non-EOG tokens with low prob
// 4. if no tokens are left -> pick EOT
//
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);

// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
Expand Down
201 changes: 201 additions & 0 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1739,6 +1739,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
};
}

// infill

//#define GGML_DEBUG_SAMPLER_INFILL

struct llama_sampler_infill {
const struct llama_vocab * vocab;
};

static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
return "infill";
}

static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_infill *) smpl->ctx;

llama_sampler_softmax_impl(cur_p);

#if defined(GGML_DEBUG_SAMPLER_INFILL)
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
#else
#define LOG_DBG_CUR(...)
#endif

for (size_t i = 0; i < cur_p->size; ++i) {
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}

float p_txt_sum = 0.0f;
float p_eog_sum = 0.0f;

for (size_t i = 0; i < cur_p->size; ++i) {
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
p_eog_sum += cur_p->data[i].p;
} else {
p_txt_sum += cur_p->data[i].p;
}
}

const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);

LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);

if (3*p_eog_sum*cur_p->size > p_txt_sum) {
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);

// keep just the EOG tokens
const auto size_org = cur_p->size;

cur_p->size = 0;

float p_sum = 0.0f;

for (size_t i = 0; i < size_org; ++i) {
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
p_sum += cur_p->data[i].p;

cur_p->data[cur_p->size++] = cur_p->data[i];
}
}

// normalize probs
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= p_sum;
}
Comment on lines +1802 to +1805
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few samplers do this, but I don't see the point because every sampler that needs the probabilities calls softmax first anyway and recomputes the probabilities.

During the refactor I came to the conclusion that the we only really store logits. Every time probabilities are needed a softmax is done to get them, llama_token_data::p is only used as temporary storage for the result of the softmax, and could be removed entirely.

Copy link
Member Author

@ggerganov ggerganov Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think currently there is a scenario that uses the ps - call dist sampler without explicit softmax before that. We don't do it in any of the examples, but it's technically possible?

Anyway, I agree that the p should be removed completely.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that should be considered a bug in the dist sampler then, because there is no way to know if the probabilities are valid without calling softmax. So any sampler that needs them, must call softmax.

Copy link
Member Author

@ggerganov ggerganov Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed it is a bug. There are a few places where we do:

llama_sampler * smpl = llama_sampler_chain_init(sparams);                                                              
                                                                                                                             
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k));                                      
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep));             
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp));                                                                                                                                                             
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sparams.seed));

This would render the temperature sampler useless as it modifies only the logits. I think we should remove the explicit softmax calls in places like common/sampling.cpp:

llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); // remove this
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));

And update the dist sampler to do softmax at the start. Sounds good?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. We should probably remove llama_sampler_init_softmax entirely since it is useless to applications.


return;
}

size_t n_combined = 0; GGML_UNUSED(n_combined);

// combine tokens with common prefix
for (size_t i = 0; i < cur_p->size; ++i) {
for (size_t j = 0; j < cur_p->size; ++j) {
if (cur_p->data[i].logit == -INFINITY) {
break;
}

if (i == j || cur_p->data[j].logit == -INFINITY) {
continue;
}

if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
if (cur_p->data[i].p > cur_p->data[j].p) {
cur_p->data[i].p += cur_p->data[j].p;
cur_p->data[j].logit = -INFINITY;
cur_p->data[j].p = 0.0f;
} else {
cur_p->data[j].p += cur_p->data[i].p;
cur_p->data[i].logit = -INFINITY;
cur_p->data[i].p = 0.0f;
}

n_combined++;
}
}
}

size_t n_non_eog = 0;

size_t size_org = cur_p->size;

float p_sum = 0.0f;
float thold = 0.2f;

cur_p->size = 0;

LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);

for (size_t i = 0; i < size_org; ++i) {
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);

if (cur_p->data[i].p < thold && !is_eog) {
continue;
}

if (!is_eog) {
++n_non_eog;
}

p_sum += cur_p->data[i].p;

// keep this token
cur_p->data[cur_p->size++] = cur_p->data[i];
}

LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);

// if no non-EOG tokens are left -> reduce cur_p to single EOT token
if (n_non_eog == 0) {
cur_p->size = 1;
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
cur_p->data[0].logit = 1.0f;

return;
}

// normalize probs
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= p_sum;

LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}

size_org = cur_p->size;
p_sum = 0.0f;
thold = 1.0/(n_non_eog + 1);

cur_p->size = 0;

LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);

for (size_t i = 0; i < size_org; ++i) {
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);

if (cur_p->data[i].p < thold && !is_eog) {
continue;
}

p_sum += cur_p->data[i].p;

cur_p->data[cur_p->size++] = cur_p->data[i];
}

// normalize probs
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= p_sum;

LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}

#undef LOG_DBG_CUR
}

static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
return llama_sampler_init_infill_impl(*ctx->vocab);
}

static void llama_sampler_infill_free(struct llama_sampler * smpl) {
delete (llama_sampler_infill *) smpl->ctx;
}

static struct llama_sampler_i llama_sampler_infill_i = {
/* .name = */ llama_sampler_infill_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_infill_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_infill_clone,
/* .free = */ llama_sampler_infill_free,
};

struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab) {
return new llama_sampler {
/* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill {
/* .vocab = */ &vocab,
},
};
}

// utils

uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
Expand Down
5 changes: 3 additions & 2 deletions src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

#include "llama-grammar.h"

#include <unordered_map>

struct llama_vocab;
struct llama_grammar;

Expand All @@ -27,3 +25,6 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab & vocab,
const char * grammar_str,
const char * grammar_root);

struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab);
17 changes: 17 additions & 0 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,23 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
return 0;
}

bool llama_token_is_prefix_impl(
const struct llama_vocab & vocab,
llama_token token0,
llama_token token1) {
char text_buf_0[128];
char text_buf_1[128];

const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);

if (len0 <= 0 || len1 <= 0) {
return false;
}

return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
}

int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
Expand Down
Loading
Loading