Skip to content

Commit f91ce94

Browse files
committed
llama : add infill sampler
ggml-ci
1 parent dcdd535 commit f91ce94

File tree

9 files changed

+294
-22
lines changed

9 files changed

+294
-22
lines changed

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ enum common_sampler_type {
9090
COMMON_SAMPLER_TYPE_TFS_Z = 4,
9191
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
9292
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
93+
COMMON_SAMPLER_TYPE_INFILL = 7,
9394
};
9495

9596
// dimensionality reduction methods, used by cvector-generator
@@ -130,7 +131,7 @@ struct common_sampler_params {
130131
COMMON_SAMPLER_TYPE_TYPICAL_P,
131132
COMMON_SAMPLER_TYPE_TOP_P,
132133
COMMON_SAMPLER_TYPE_MIN_P,
133-
COMMON_SAMPLER_TYPE_TEMPERATURE
134+
COMMON_SAMPLER_TYPE_TEMPERATURE,
134135
};
135136

136137
std::string grammar; // optional BNF-like grammar to constrain sampling

common/sampling.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
193193
case COMMON_SAMPLER_TYPE_TEMPERATURE:
194194
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
195195
break;
196+
case COMMON_SAMPLER_TYPE_INFILL:
197+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
198+
break;
196199
default:
197200
GGML_ASSERT(false && "unknown sampler type");
198201
}
@@ -372,6 +375,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
372375
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
373376
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
374377
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
378+
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
375379
default : return '?';
376380
}
377381
}
@@ -384,6 +388,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
384388
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
385389
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
386390
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
391+
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
387392
default : return "";
388393
}
389394
}
@@ -396,6 +401,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
396401
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
397402
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
398403
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
404+
{ "infill", COMMON_SAMPLER_TYPE_INFILL }
399405
};
400406

401407
// since samplers names are written multiple ways
@@ -441,7 +447,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
441447
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
442448
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
443449
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
444-
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }
450+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
451+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
445452
};
446453

447454
std::vector<common_sampler_type> samplers;

examples/main/main.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -569,30 +569,30 @@ int main(int argc, char ** argv) {
569569
if (!params.ctx_shift){
570570
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
571571
break;
572-
} else {
573-
if (params.n_predict == -2) {
574-
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
575-
break;
576-
}
572+
}
573+
574+
if (params.n_predict == -2) {
575+
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
576+
break;
577+
}
577578

578-
const int n_left = n_past - params.n_keep;
579-
const int n_discard = n_left/2;
579+
const int n_left = n_past - params.n_keep;
580+
const int n_discard = n_left/2;
580581

581-
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
582-
n_past, n_left, n_ctx, params.n_keep, n_discard);
582+
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
583+
n_past, n_left, n_ctx, params.n_keep, n_discard);
583584

584-
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
585-
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
585+
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
586+
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
586587

587-
n_past -= n_discard;
588+
n_past -= n_discard;
588589

589-
LOG_DBG("after swap: n_past = %d\n", n_past);
590+
LOG_DBG("after swap: n_past = %d\n", n_past);
590591

591-
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
592+
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
592593

593-
LOG_DBG("clear session path\n");
594-
path_session.clear();
595-
}
594+
LOG_DBG("clear session path\n");
595+
path_session.clear();
596596
}
597597
} else {
598598
// context extension via Self-Extend

include/llama.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,12 @@ extern "C" {
953953
int32_t lstrip,
954954
bool special);
955955

956+
// check if token0 is contained as a prefix in token1
957+
LLAMA_API bool llama_token_is_prefix(
958+
const struct llama_model * model,
959+
llama_token token0,
960+
llama_token token1);
961+
956962
/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
957963
/// @param text The char pointer must be large enough to hold the resulting text.
958964
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
@@ -1145,6 +1151,28 @@ extern "C" {
11451151
int32_t n_logit_bias,
11461152
const llama_logit_bias * logit_bias);
11471153

1154+
// this sampler is meant to be used for fill-in-the-middle infilling
1155+
// it's supposed to be used after top_k + top_p sampling
1156+
//
1157+
// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
1158+
// 2. combine probs of tokens that have the same prefix
1159+
//
1160+
// example:
1161+
//
1162+
// - before:
1163+
// "hel": 0.5
1164+
// "hell": 0.2
1165+
// "hello": 0.1
1166+
// "dummy": 0.1
1167+
//
1168+
// - after:
1169+
// "hel": 0.8
1170+
// "dummy": 0.1
1171+
//
1172+
// 3. discard non-EOG tokens with low prob
1173+
// 4. if no tokens are left -> pick EOT
1174+
//
1175+
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
11481176

11491177
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
11501178
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);

src/llama-sampling.cpp

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
16441644
};
16451645
}
16461646

1647+
// infill
1648+
1649+
//#define GGML_DEBUG_SAMPLER_INFILL
1650+
1651+
struct llama_sampler_infill {
1652+
const struct llama_vocab * vocab;
1653+
};
1654+
1655+
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
1656+
return "infill";
1657+
}
1658+
1659+
static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1660+
auto * ctx = (llama_sampler_infill *) smpl->ctx;
1661+
1662+
llama_sampler_softmax_impl(cur_p);
1663+
1664+
#if defined(GGML_DEBUG_SAMPLER_INFILL)
1665+
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
1666+
#else
1667+
#define LOG_DBG_CUR(...)
1668+
#endif
1669+
1670+
for (size_t i = 0; i < cur_p->size; ++i) {
1671+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
1672+
}
1673+
1674+
float p_txt_sum = 0.0f;
1675+
float p_eog_sum = 0.0f;
1676+
1677+
for (size_t i = 0; i < cur_p->size; ++i) {
1678+
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
1679+
p_eog_sum += cur_p->data[i].p;
1680+
} else {
1681+
p_txt_sum += cur_p->data[i].p;
1682+
}
1683+
}
1684+
1685+
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
1686+
1687+
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
1688+
1689+
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
1690+
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
1691+
1692+
// keep just the EOG tokens
1693+
const auto size_org = cur_p->size;
1694+
1695+
cur_p->size = 0;
1696+
1697+
float p_sum = 0.0f;
1698+
1699+
for (size_t i = 0; i < size_org; ++i) {
1700+
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
1701+
p_sum += cur_p->data[i].p;
1702+
1703+
cur_p->data[cur_p->size++] = cur_p->data[i];
1704+
}
1705+
}
1706+
1707+
// normalize probs
1708+
for (size_t i = 0; i < cur_p->size; ++i) {
1709+
cur_p->data[i].p /= p_sum;
1710+
}
1711+
1712+
return;
1713+
}
1714+
1715+
size_t n_combined = 0; GGML_UNUSED(n_combined);
1716+
1717+
// combine tokens with common prefix
1718+
for (size_t i = 0; i < cur_p->size; ++i) {
1719+
for (size_t j = 0; j < cur_p->size; ++j) {
1720+
if (cur_p->data[i].logit == -INFINITY) {
1721+
break;
1722+
}
1723+
1724+
if (i == j || cur_p->data[j].logit == -INFINITY) {
1725+
continue;
1726+
}
1727+
1728+
if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
1729+
if (cur_p->data[i].p > cur_p->data[j].p) {
1730+
cur_p->data[i].p += cur_p->data[j].p;
1731+
cur_p->data[j].logit = -INFINITY;
1732+
cur_p->data[j].p = 0.0f;
1733+
} else {
1734+
cur_p->data[j].p += cur_p->data[i].p;
1735+
cur_p->data[i].logit = -INFINITY;
1736+
cur_p->data[i].p = 0.0f;
1737+
}
1738+
1739+
n_combined++;
1740+
}
1741+
}
1742+
}
1743+
1744+
size_t n_non_eog = 0;
1745+
1746+
size_t size_org = cur_p->size;
1747+
1748+
float p_sum = 0.0f;
1749+
float thold = 0.2f;
1750+
1751+
cur_p->size = 0;
1752+
1753+
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
1754+
1755+
for (size_t i = 0; i < size_org; ++i) {
1756+
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
1757+
1758+
if (cur_p->data[i].p < thold && !is_eog) {
1759+
continue;
1760+
}
1761+
1762+
if (!is_eog) {
1763+
++n_non_eog;
1764+
}
1765+
1766+
p_sum += cur_p->data[i].p;
1767+
1768+
// keep this token
1769+
cur_p->data[cur_p->size++] = cur_p->data[i];
1770+
}
1771+
1772+
LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
1773+
1774+
// if no non-EOG tokens are left -> reduce cur_p to single EOT token
1775+
if (n_non_eog == 0) {
1776+
cur_p->size = 1;
1777+
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
1778+
cur_p->data[0].logit = 1.0f;
1779+
1780+
return;
1781+
}
1782+
1783+
// normalize probs
1784+
for (size_t i = 0; i < cur_p->size; ++i) {
1785+
cur_p->data[i].p /= p_sum;
1786+
1787+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
1788+
}
1789+
1790+
size_org = cur_p->size;
1791+
p_sum = 0.0f;
1792+
thold = 1.0/(n_non_eog + 1);
1793+
1794+
cur_p->size = 0;
1795+
1796+
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
1797+
1798+
for (size_t i = 0; i < size_org; ++i) {
1799+
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
1800+
1801+
if (cur_p->data[i].p < thold && !is_eog) {
1802+
continue;
1803+
}
1804+
1805+
p_sum += cur_p->data[i].p;
1806+
1807+
cur_p->data[cur_p->size++] = cur_p->data[i];
1808+
}
1809+
1810+
// normalize probs
1811+
for (size_t i = 0; i < cur_p->size; ++i) {
1812+
cur_p->data[i].p /= p_sum;
1813+
1814+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
1815+
}
1816+
1817+
#undef LOG_DBG_CUR
1818+
}
1819+
1820+
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
1821+
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
1822+
return llama_sampler_init_infill_impl(*ctx->vocab);
1823+
}
1824+
1825+
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
1826+
delete (llama_sampler_infill *) smpl->ctx;
1827+
}
1828+
1829+
static struct llama_sampler_i llama_sampler_infill_i = {
1830+
/* .name = */ llama_sampler_infill_name,
1831+
/* .accept = */ nullptr,
1832+
/* .apply = */ llama_sampler_infill_apply,
1833+
/* .reset = */ nullptr,
1834+
/* .clone = */ llama_sampler_infill_clone,
1835+
/* .free = */ llama_sampler_infill_free,
1836+
};
1837+
1838+
struct llama_sampler * llama_sampler_init_infill_impl(
1839+
const struct llama_vocab & vocab) {
1840+
return new llama_sampler {
1841+
/* .iface = */ &llama_sampler_infill_i,
1842+
/* .ctx = */ new llama_sampler_infill {
1843+
/* .vocab = */ &vocab,
1844+
},
1845+
};
1846+
}
1847+
16471848
// utils
16481849

16491850
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {

src/llama-sampling.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
#include "llama-grammar.h"
66

7-
#include <unordered_map>
8-
97
struct llama_vocab;
108
struct llama_grammar;
119

@@ -27,3 +25,6 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
2725
const struct llama_vocab & vocab,
2826
const char * grammar_str,
2927
const char * grammar_root);
28+
29+
struct llama_sampler * llama_sampler_init_infill_impl(
30+
const struct llama_vocab & vocab);

src/llama-vocab.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,6 +1858,23 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
18581858
return 0;
18591859
}
18601860

1861+
bool llama_token_is_prefix_impl(
1862+
const struct llama_vocab & vocab,
1863+
llama_token token0,
1864+
llama_token token1) {
1865+
char text_buf_0[128];
1866+
char text_buf_1[128];
1867+
1868+
const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
1869+
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
1870+
1871+
if (len0 <= 0 || len1 <= 0) {
1872+
return false;
1873+
}
1874+
1875+
return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
1876+
}
1877+
18611878
int32_t llama_detokenize_impl(
18621879
const struct llama_vocab & vocab,
18631880
const llama_token * tokens,

0 commit comments

Comments
 (0)