Skip to content

Commit a0ddf1f

Browse files
ggerganovdrollings
authored andcommitted
llama : improve infill support and special token detection (ggml-org#9798)
* llama : improve infill support ggml-ci * llama : add more FIM token strings ggml-ci * server : update prompt on slot restore (ggml-org#9800) * gguf : deprecate old FIM token KVs
1 parent bb832ca commit a0ddf1f

File tree

12 files changed

+613
-439
lines changed

12 files changed

+613
-439
lines changed

common/arg.cpp

Lines changed: 111 additions & 137 deletions
Large diffs are not rendered by default.

common/common.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <algorithm>
1414
#include <cinttypes>
15+
#include <climits>
1516
#include <cmath>
1617
#include <codecvt>
1718
#include <cstdarg>
@@ -23,10 +24,10 @@
2324
#include <regex>
2425
#include <sstream>
2526
#include <string>
27+
#include <thread>
2628
#include <unordered_map>
2729
#include <unordered_set>
2830
#include <vector>
29-
#include <thread>
3031

3132
#if defined(__APPLE__) && defined(__MACH__)
3233
#include <sys/types.h>
@@ -400,6 +401,21 @@ std::string common_params_get_system_info(const common_params & params) {
400401
// String utils
401402
//
402403

404+
std::string string_format(const char * fmt, ...) {
405+
va_list ap;
406+
va_list ap2;
407+
va_start(ap, fmt);
408+
va_copy(ap2, ap);
409+
int size = vsnprintf(NULL, 0, fmt, ap);
410+
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
411+
std::vector<char> buf(size + 1);
412+
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
413+
GGML_ASSERT(size2 == size);
414+
va_end(ap2);
415+
va_end(ap);
416+
return std::string(buf.data(), size);
417+
}
418+
403419
std::vector<std::string> string_split(std::string input, char separator) {
404420
std::vector<std::string> parts;
405421
size_t separator_pos = input.find(separator);

common/common.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,15 +352,28 @@ void common_init();
352352

353353
std::string common_params_get_system_info(const common_params & params);
354354

355-
bool parse_cpu_range(const std::string& range, bool(&boolmask)[GGML_MAX_N_THREADS]);
356-
bool parse_cpu_mask(const std::string& mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
357-
void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model = nullptr);
355+
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
356+
bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
357+
void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
358358
bool set_process_priority(enum ggml_sched_priority prio);
359359

360360
//
361361
// String utils
362362
//
363363

364+
#ifdef __GNUC__
365+
#ifdef __MINGW32__
366+
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
367+
#else
368+
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
369+
#endif
370+
#else
371+
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
372+
#endif
373+
374+
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
375+
std::string string_format(const char * fmt, ...);
376+
364377
std::vector<std::string> string_split(std::string input, char separator);
365378

366379
std::string string_strip(const std::string & str);

examples/infill/infill.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,11 @@ int main(int argc, char ** argv) {
205205
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
206206
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);
207207

208-
GGML_ASSERT(llama_token_prefix(model) >= 0);
209-
GGML_ASSERT(llama_token_suffix(model) >= 0);
208+
GGML_ASSERT(llama_token_fim_pre(model) >= 0);
209+
GGML_ASSERT(llama_token_fim_suf(model) >= 0);
210210

211-
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
212-
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
211+
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
212+
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));
213213

214214
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
215215
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
@@ -218,7 +218,7 @@ int main(int argc, char ** argv) {
218218
}
219219
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
220220

221-
const llama_token middle_token = llama_token_middle(model);
221+
const llama_token middle_token = llama_token_fim_mid(model);
222222
if (middle_token >= 0) {
223223
embd_inp.push_back(middle_token);
224224
}
@@ -508,8 +508,8 @@ int main(int argc, char ** argv) {
508508
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
509509
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);
510510

511-
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
512-
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
511+
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
512+
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));
513513

514514
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
515515
embd_end = params.spm_infill ? inp_pfx : inp_sfx;

examples/server/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ Takes a prefix and a suffix and returns the predicted completion as stream.
526526
- `input_prefix`: Set the prefix of the code to infill.
527527
- `input_suffix`: Set the suffix of the code to infill.
528528

529-
It also accepts all the options of `/completion` except `stream` and `prompt`.
529+
It also accepts all the options of `/completion`.
530530

531531
### **GET** `/props`: Get server global properties.
532532

examples/server/server.cpp

Lines changed: 87 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -753,12 +753,7 @@ struct server_context {
753753
metrics.init();
754754
}
755755

756-
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
757-
// TODO: currently, we tokenize using special tokens by default
758-
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
759-
// but it's better compared to completely ignoring ChatML and other chat templates
760-
const bool TMP_FORCE_SPECIAL = true;
761-
756+
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
762757
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
763758
// or the first element of the json_prompt array is a string.
764759
std::vector<llama_token> prompt_tokens;
@@ -771,10 +766,10 @@ struct server_context {
771766

772767
std::vector<llama_token> p;
773768
if (first) {
774-
p = common_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
769+
p = common_tokenize(ctx, s, add_special, parse_special);
775770
first = false;
776771
} else {
777-
p = common_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
772+
p = common_tokenize(ctx, s, false, parse_special);
778773
}
779774

780775
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
@@ -788,7 +783,7 @@ struct server_context {
788783
}
789784
} else {
790785
auto s = json_prompt.template get<std::string>();
791-
prompt_tokens = common_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
786+
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
792787
}
793788

794789
return prompt_tokens;
@@ -1215,7 +1210,7 @@ struct server_context {
12151210
slot.params.n_predict, n_ctx_train);
12161211
}
12171212

1218-
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: '%s'\n", slot.n_decoded, slot.n_remaining, token_str.c_str());
1213+
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
12191214

12201215
return slot.has_next_token; // continue
12211216
}
@@ -1483,9 +1478,8 @@ struct server_context {
14831478
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
14841479
data["index"] = 0;
14851480
create_task(data, false, nullptr);
1486-
}
1487-
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
1488-
else if (prompt.is_array()) {
1481+
} else if (prompt.is_array()) {
1482+
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
14891483
std::vector<json> prompts = prompt;
14901484
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
14911485
// prompts[0] is the question
@@ -1510,9 +1504,8 @@ struct server_context {
15101504
}
15111505
}
15121506
}
1513-
}
1514-
// invalid case
1515-
else {
1507+
} else {
1508+
// invalid case
15161509
throw std::runtime_error(error_msg);
15171510
}
15181511

@@ -1785,6 +1778,9 @@ struct server_context {
17851778
}
17861779
slot->cache_tokens.resize(token_count);
17871780

1781+
// TODO: maybe detokenize the slot->cache_tokens instead?
1782+
slot->prompt = string_format("[restored %d tokens from file]", (int) token_count);
1783+
17881784
const int64_t t_end = ggml_time_us();
17891785
const double t_restore_ms = (t_end - t_start) / 1000.0;
17901786

@@ -1971,70 +1967,69 @@ struct server_context {
19711967
slot.t_start_process_prompt = ggml_time_us();
19721968
slot.t_start_generation = 0;
19731969

1974-
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
1975-
const bool add_bos = llama_add_bos_token(model);
1976-
bool suff_rm_leading_spc = true;
1977-
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
1978-
params.input_suffix.erase(0, 1);
1979-
suff_rm_leading_spc = false;
1980-
}
1981-
1982-
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
1983-
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
1984-
1985-
const int space_token = 29871; // TODO: this should not be hardcoded
1986-
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
1987-
suffix_tokens.erase(suffix_tokens.begin());
1988-
}
1989-
1990-
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
1991-
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
1992-
1993-
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
1994-
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
1995-
if (add_bos) {
1996-
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
1997-
}
1998-
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
1999-
2000-
const llama_token middle_token = llama_token_middle(model);
2001-
if (middle_token >= 0) {
2002-
embd_inp.push_back(middle_token);
2003-
}
2004-
2005-
prompt_tokens = embd_inp;
2006-
} else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2007-
// require slot.prompt to be array of 2 strings
2008-
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
2009-
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
2010-
slot.release();
2011-
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
2012-
continue;
2013-
}
2014-
2015-
// prompt: [BOS]query[EOS][SEP]doc[EOS]
2016-
prompt_tokens.clear();
2017-
prompt_tokens.push_back(llama_token_bos(model));
2018-
{
2019-
const auto part = tokenize(slot.prompt[0], false);
2020-
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2021-
}
2022-
prompt_tokens.push_back(llama_token_eos(model));
2023-
prompt_tokens.push_back(llama_token_sep(model));
2024-
{
2025-
const auto part = tokenize(slot.prompt[1], false);
2026-
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2027-
}
2028-
prompt_tokens.push_back(llama_token_eos(model));
2029-
} else {
2030-
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
1970+
switch (slot.cmpl_type) {
1971+
case SERVER_TASK_CMPL_TYPE_NORMAL:
1972+
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
1973+
{
1974+
prompt_tokens = tokenize(slot.prompt, system_prompt.empty(), true); // add BOS if there isn't system prompt
1975+
} break;
1976+
case SERVER_TASK_CMPL_TYPE_RERANK:
1977+
{
1978+
// require slot.prompt to be array of 2 strings
1979+
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
1980+
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
1981+
slot.release();
1982+
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
1983+
continue;
1984+
}
1985+
1986+
// prompt: [BOS]query[EOS][SEP]doc[EOS]
1987+
prompt_tokens.clear();
1988+
prompt_tokens.push_back(llama_token_bos(model));
1989+
{
1990+
const auto part = tokenize(slot.prompt[0], false, false);
1991+
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
1992+
}
1993+
prompt_tokens.push_back(llama_token_eos(model));
1994+
prompt_tokens.push_back(llama_token_sep(model));
1995+
{
1996+
const auto part = tokenize(slot.prompt[1], false, false);
1997+
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
1998+
}
1999+
prompt_tokens.push_back(llama_token_eos(model));
2000+
} break;
2001+
case SERVER_TASK_CMPL_TYPE_INFILL:
2002+
{
2003+
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
2004+
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
2005+
2006+
prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
2007+
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
2008+
2009+
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
2010+
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
2011+
2012+
if (llama_add_bos_token(model)) {
2013+
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
2014+
}
2015+
2016+
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
2017+
embd_inp.push_back(llama_token_fim_mid(model));
2018+
2019+
prompt_tokens = std::move(embd_inp);
2020+
} break;
20312021
}
20322022

20332023
slot.n_past = 0;
20342024
slot.n_prompt_tokens = prompt_tokens.size();
20352025

20362026
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
20372027

2028+
// print prompt tokens:
2029+
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
2030+
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2031+
}
2032+
20382033
// empty prompt passed -> release the slot and send empty response
20392034
if (prompt_tokens.empty()) {
20402035
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
@@ -2924,7 +2919,23 @@ int main(int argc, char ** argv) {
29242919
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
29252920
};
29262921

2927-
const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2922+
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2923+
std::string err;
2924+
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
2925+
err += "prefix token is missing. ";
2926+
}
2927+
if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) {
2928+
err += "suffix token is missing. ";
2929+
}
2930+
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
2931+
err += "middle token is missing. ";
2932+
}
2933+
2934+
if (!err.empty()) {
2935+
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
2936+
return;
2937+
}
2938+
29282939
json data = json::parse(req.body);
29292940
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
29302941
};
@@ -3010,7 +3021,8 @@ int main(int argc, char ** argv) {
30103021
if (body.count("content") != 0) {
30113022
const bool add_special = json_value(body, "add_special", false);
30123023
const bool with_pieces = json_value(body, "with_pieces", false);
3013-
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special);
3024+
3025+
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true);
30143026

30153027
if (with_pieces) {
30163028
for (const auto& token : tokens) {

0 commit comments

Comments
 (0)