Skip to content

Commit 010662e

Browse files
committed
token healing : refactor argument parsing
Unify `main` and `server` token healing argument handling.
1 parent 8aec522 commit 010662e

File tree

5 files changed

+61
-60
lines changed

5 files changed

+61
-60
lines changed

common/common.cpp

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,21 +1060,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10601060
}
10611061
if (arg == "-th" || arg == "--token-healing") {
10621062
CHECK_ARG
1063-
sparams.token_healing_enabled = true;
1064-
auto & th_type = sparams.token_healing_type;
1065-
auto & th_n_rollback = sparams.token_healing_n_rollback;
10661063
std::string value(argv[i]);
1067-
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
1068-
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
1069-
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
1070-
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
1071-
else if (value[0] == 'r' ) {
1072-
th_type = llama_token_healing_type::ROLLBACK_MULTI;
1073-
th_n_rollback = std::stoi(value.substr(1));
1074-
if (th_n_rollback <= 0) {
1075-
sparams.token_healing_enabled = false;
1076-
}
1077-
} else { invalid_param = true; }
1064+
invalid_param = !llama_token_healing_parse_params(value, sparams.token_healing);
10781065
return true;
10791066
}
10801067
if (arg == "--override-kv") {

common/sampling.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,25 @@ void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const
154154
ctx_sampling->token_healing_prefix = prefix;
155155
}
156156

157+
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params) {
158+
th_params.enabled = true;
159+
th_params.n_rollback = -1;
160+
/**/ if (params == "0" ) { th_params.enabled = false; }
161+
else if (params == "1" ) { th_params.type = llama_token_healing_type::ROLLBACK_LAST; }
162+
else if (params == "d1") { th_params.type = llama_token_healing_type::DYNAMIC_ONCE; }
163+
else if (params == "d" ) { th_params.type = llama_token_healing_type::DYNAMIC_MULTI; }
164+
else if (params[0] == 'r' ) {
165+
th_params.type = llama_token_healing_type::ROLLBACK_MULTI;
166+
th_params.n_rollback = std::stoi(params.substr(1));
167+
if (th_params.n_rollback <= 0) {
168+
return false;
169+
}
170+
} else {
171+
return false;
172+
}
173+
return true;
174+
}
175+
157176
//
158177
// Sampling
159178
//
@@ -551,11 +570,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
551570
cur.clear();
552571

553572
// Constrain tokens based on the remaining token healing prefix (if any)
554-
const auto & th_type = params.token_healing_type;
555573
const auto & th_prefix = ctx_sampling->token_healing_prefix;
556-
if (params.token_healing_enabled && !th_prefix.empty()) {
557-
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
558-
th_type == llama_token_healing_type::DYNAMIC_MULTI;
574+
if (params.token_healing.enabled && !th_prefix.empty()) {
575+
const bool is_multi_step = params.token_healing.type == llama_token_healing_type::ROLLBACK_MULTI ||
576+
params.token_healing.type == llama_token_healing_type::DYNAMIC_MULTI;
559577
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);
560578

561579
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
@@ -634,7 +652,7 @@ void llama_sampling_accept(
634652
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
635653
}
636654

637-
if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
655+
if (ctx_sampling->params.token_healing.enabled && apply_grammar) {
638656
std::string & th_prefix = ctx_sampling->token_healing_prefix;
639657
if (!th_prefix.empty()) {
640658
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);

common/sampling.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ enum class llama_token_healing_type : uint8_t {
2626
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
2727
};
2828

29+
struct llama_token_healing_params {
30+
bool enabled = false;
31+
llama_token_healing_type type = llama_token_healing_type::DYNAMIC_MULTI;
32+
int n_rollback = -1; // number of tokens to roll back
33+
};
34+
2935
// sampling parameters
3036
typedef struct llama_sampling_params {
3137
int32_t n_prev = 64; // number of previous tokens to remember
@@ -70,9 +76,7 @@ typedef struct llama_sampling_params {
7076
std::vector<llama_token> penalty_prompt_tokens;
7177
bool use_penalty_prompt_tokens = false;
7278

73-
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
74-
bool token_healing_enabled = false;
75-
int token_healing_n_rollback = -1; // number of tokens to roll back
79+
llama_token_healing_params token_healing;
7680
} llama_sampling_params;
7781

7882
// general sampler context
@@ -190,3 +194,6 @@ llama_token_healing_output llama_token_healing_rollback(
190194
int max_to_remove = -1);
191195

192196
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
197+
198+
// Helper for parsing token healing params from a string.
199+
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params);

examples/main/main.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,14 +276,14 @@ int main(int argc, char ** argv) {
276276
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
277277
}
278278

279-
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) {
280-
sparams.token_healing_enabled = false;
279+
if (sparams.token_healing.enabled && (params.conversation || !params.input_suffix.empty())) {
280+
sparams.token_healing.enabled = false;
281281
LOG("token healing: disabled due to custom suffix/conversation mode");
282282
}
283283
llama_token_healing_output token_healing_out{};
284-
if (!params.interactive_first && sparams.token_healing_enabled) {
285-
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
286-
sparams.token_healing_n_rollback);
284+
if (!params.interactive_first && sparams.token_healing.enabled) {
285+
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp,
286+
sparams.token_healing.n_rollback);
287287
}
288288

289289
// Should not run without any tokens
@@ -911,13 +911,13 @@ int main(int argc, char ** argv) {
911911
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
912912
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
913913

914-
if (sparams.token_healing_enabled) {
914+
if (sparams.token_healing.enabled) {
915915
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
916916
const int n_new_tokens = embd_inp.size() - original_size;
917-
const int max_to_remove = sparams.token_healing_n_rollback < 0
917+
const int max_to_remove = sparams.token_healing.n_rollback < 0
918918
? n_new_tokens
919-
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
920-
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, max_to_remove);
919+
: std::min(sparams.token_healing.n_rollback, n_new_tokens);
920+
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp, max_to_remove);
921921
n_bytes_to_skip = token_healing_out.prefix.size();
922922
}
923923

examples/server/server.cpp

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,31 +1087,20 @@ struct server_context {
10871087

10881088
{
10891089
const auto & token_healing_str = data.find("token_healing");
1090-
auto & th_enabled = slot.sparams.token_healing_enabled;
1091-
th_enabled = default_sparams.token_healing_enabled;
10921090
if (token_healing_str != data.end() && token_healing_str->is_string()) {
10931091
const auto value = token_healing_str->get<std::string>();
1094-
auto & th_type = slot.sparams.token_healing_type;
1095-
auto & th_n_rollback = slot.sparams.token_healing_n_rollback;
1096-
th_enabled = true;
1097-
/**/ if (value == "0" ) { th_enabled = false; }
1098-
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
1099-
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
1100-
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
1101-
else if (value[0] == 'r' ) {
1102-
th_type = llama_token_healing_type::ROLLBACK_MULTI;
1103-
th_n_rollback = std::stoi(value.substr(1));
1104-
if (th_n_rollback <= 0) {
1105-
th_enabled = false;
1106-
}
1107-
} else { th_enabled = false; }
1108-
1092+
if (!llama_token_healing_parse_params(value, slot.sparams.token_healing)) {
1093+
send_error(task, "\"token_healing\" parse error", ERROR_TYPE_INVALID_REQUEST);
1094+
return false;
1095+
}
11091096
LOG_VERBOSE("token healing", {
11101097
{"id_slot", slot.id},
1111-
{"enabled", th_enabled},
1112-
{"type", th_type},
1113-
{"n_rollback", th_n_rollback}
1098+
{"enabled", slot.sparams.token_healing.enabled},
1099+
{"type", slot.sparams.token_healing.type},
1100+
{"n_rollback", slot.sparams.token_healing.n_rollback}
11141101
});
1102+
} else {
1103+
slot.sparams.token_healing = default_sparams.token_healing;
11151104
}
11161105
}
11171106

@@ -1395,7 +1384,7 @@ struct server_context {
13951384
{"min_keep", slot.sparams.min_keep},
13961385
{"grammar", slot.sparams.grammar},
13971386
{"samplers", samplers_sequence},
1398-
{"token_healing_enabled", slot.sparams.token_healing_enabled}
1387+
{"token_healing_enabled", slot.sparams.token_healing.enabled}
13991388
};
14001389
}
14011390

@@ -2085,10 +2074,10 @@ struct server_context {
20852074
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
20862075
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
20872076

2088-
if (slot.sparams.token_healing_enabled) {
2077+
if (slot.sparams.token_healing.enabled) {
20892078
// For FIM roll back only the prefix part (i.e. cursor location)
2090-
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
2091-
prefix_tokens, slot.sparams.token_healing_n_rollback);
2079+
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type,
2080+
prefix_tokens, slot.sparams.token_healing.n_rollback);
20922081
}
20932082

20942083
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
@@ -2107,9 +2096,9 @@ struct server_context {
21072096
} else {
21082097
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
21092098

2110-
if (slot.sparams.token_healing_enabled) {
2111-
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
2112-
prompt_tokens, slot.sparams.token_healing_n_rollback);
2099+
if (slot.sparams.token_healing.enabled) {
2100+
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type,
2101+
prompt_tokens, slot.sparams.token_healing.n_rollback);
21132102
}
21142103
}
21152104

@@ -2125,7 +2114,7 @@ struct server_context {
21252114
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
21262115
});
21272116

2128-
if (slot.sparams.token_healing_enabled) {
2117+
if (slot.sparams.token_healing.enabled) {
21292118
slot.n_th_prefix = token_healing_out.prefix.size();
21302119
LOG_VERBOSE("token healing prompt", {
21312120
{"id_slot", slot.id},
@@ -2200,7 +2189,7 @@ struct server_context {
22002189
}
22012190

22022191
llama_sampling_reset(slot.ctx_sampling);
2203-
if (slot.sparams.token_healing_enabled) {
2192+
if (slot.sparams.token_healing.enabled) {
22042193
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
22052194
}
22062195

0 commit comments

Comments
 (0)