Skip to content

Commit cf0a43d

Browse files
committed
precompute start string len, and keep start string state in slot
1 parent 1fb4fa2 commit cf0a43d

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

examples/server/server.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ struct slot_params {
105105

106106
std::vector<std::string> antiprompt;
107107
std::vector<std::string> start_strings;
108+
size_t start_string_max_len;
108109
std::vector<std::string> response_fields;
109110
bool timings_per_token = false;
110111
bool post_sampling_probs = false;
@@ -247,8 +248,7 @@ struct server_task {
247248
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
248249
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
249250
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
250-
params.start_strings = json_value(data, "start_strings", defaults.start_strings);
251-
251+
252252
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
253253
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
254254
params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
@@ -282,6 +282,14 @@ struct server_task {
282282
params.speculative.n_min = std::max(params.speculative.n_min, 0);
283283
params.speculative.n_max = std::max(params.speculative.n_max, 0);
284284

285+
// start strings
286+
params.start_strings = json_value(data, "start_strings", defaults.start_strings);
287+
params.start_string_max_len = 0;
288+
for (auto start_string: params.start_strings) {
289+
params.start_string_max_len = std::max(params.start_string_max_len, start_string.size());
290+
}
291+
292+
285293
// Use OpenAI API logprobs only if n_probs wasn't provided
286294
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
287295
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
@@ -1295,6 +1303,8 @@ struct server_slot {
12951303

12961304
std::string stopping_word;
12971305

1306+
bool start_string_found = false;
1307+
12981308
// sampling
12991309
json json_schema;
13001310

@@ -1332,6 +1342,7 @@ struct server_slot {
13321342
n_past = 0;
13331343
n_sent_text = 0;
13341344
task_type = SERVER_TASK_TYPE_COMPLETION;
1345+
start_string_found = false;
13351346

13361347
generated_tokens.clear();
13371348
generated_token_probs.clear();
@@ -2197,11 +2208,8 @@ struct server_context {
21972208
const std::string str_test = slot.generated_text.substr(pos);
21982209
bool send_text = true;
21992210

2200-
if (slot.n_sent_text == 0 && slot.has_next_token && !slot.params.start_strings.empty()) {
2201-
size_t max_start_string_size = 0;
2202-
for (auto start_string: slot.params.start_strings) {
2203-
max_start_string_size = std::max(max_start_string_size, start_string.size());
2204-
}
2211+
if (!slot.start_string_found && slot.has_next_token && !slot.params.start_strings.empty()) {
2212+
size_t max_start_string_size = slot.params.start_string_max_len;
22052213
size_t search_len = max_start_string_size + token_str.size();
22062214
size_t search_pos = 0;
22072215
if (slot.generated_text.size() > search_len) {
@@ -2224,6 +2232,7 @@ struct server_context {
22242232
slot.generated_text.erase(
22252233
slot.generated_text.begin(),
22262234
slot.generated_text.begin() + found_pos + found_string.size());
2235+
slot.start_string_found = true;
22272236
} else {
22282237
send_text = false;
22292238
}

0 commit comments

Comments
 (0)