@@ -105,6 +105,7 @@ struct slot_params {
105
105
106
106
std::vector<std::string> antiprompt;
107
107
std::vector<std::string> start_strings;
108
+ size_t start_string_max_len;
108
109
std::vector<std::string> response_fields;
109
110
bool timings_per_token = false ;
110
111
bool post_sampling_probs = false ;
@@ -247,8 +248,7 @@ struct server_task {
247
248
// params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
248
249
params.t_max_predict_ms = json_value (data, " t_max_predict_ms" , defaults.t_max_predict_ms );
249
250
params.response_fields = json_value (data, " response_fields" , std::vector<std::string>());
250
- params.start_strings = json_value (data, " start_strings" , defaults.start_strings );
251
-
251
+
252
252
params.sampling .top_k = json_value (data, " top_k" , defaults.sampling .top_k );
253
253
params.sampling .top_p = json_value (data, " top_p" , defaults.sampling .top_p );
254
254
params.sampling .min_p = json_value (data, " min_p" , defaults.sampling .min_p );
@@ -282,6 +282,14 @@ struct server_task {
282
282
params.speculative .n_min = std::max (params.speculative .n_min , 0 );
283
283
params.speculative .n_max = std::max (params.speculative .n_max , 0 );
284
284
285
+ // start strings
286
+ params.start_strings = json_value (data, " start_strings" , defaults.start_strings );
287
+ params.start_string_max_len = 0 ;
288
+ for (auto start_string: params.start_strings ) {
289
+ params.start_string_max_len = std::max (params.start_string_max_len , start_string.size ());
290
+ }
291
+
292
+
285
293
// Use OpenAI API logprobs only if n_probs wasn't provided
286
294
if (data.contains (" logprobs" ) && params.sampling .n_probs == defaults.sampling .n_probs ){
287
295
params.sampling .n_probs = json_value (data, " logprobs" , defaults.sampling .n_probs );
@@ -1295,6 +1303,8 @@ struct server_slot {
1295
1303
1296
1304
std::string stopping_word;
1297
1305
1306
+ bool start_string_found = false ;
1307
+
1298
1308
// sampling
1299
1309
json json_schema;
1300
1310
@@ -1332,6 +1342,7 @@ struct server_slot {
1332
1342
n_past = 0 ;
1333
1343
n_sent_text = 0 ;
1334
1344
task_type = SERVER_TASK_TYPE_COMPLETION;
1345
+ start_string_found = false ;
1335
1346
1336
1347
generated_tokens.clear ();
1337
1348
generated_token_probs.clear ();
@@ -2197,11 +2208,8 @@ struct server_context {
2197
2208
const std::string str_test = slot.generated_text .substr (pos);
2198
2209
bool send_text = true ;
2199
2210
2200
- if (slot.n_sent_text == 0 && slot.has_next_token && !slot.params .start_strings .empty ()) {
2201
- size_t max_start_string_size = 0 ;
2202
- for (auto start_string: slot.params .start_strings ) {
2203
- max_start_string_size = std::max (max_start_string_size, start_string.size ());
2204
- }
2211
+ if (!slot.start_string_found && slot.has_next_token && !slot.params .start_strings .empty ()) {
2212
+ size_t max_start_string_size = slot.params .start_string_max_len ;
2205
2213
size_t search_len = max_start_string_size + token_str.size ();
2206
2214
size_t search_pos = 0 ;
2207
2215
if (slot.generated_text .size () > search_len) {
@@ -2224,6 +2232,7 @@ struct server_context {
2224
2232
slot.generated_text .erase (
2225
2233
slot.generated_text .begin (),
2226
2234
slot.generated_text .begin () + found_pos + found_string.size ());
2235
+ slot.start_string_found = true ;
2227
2236
} else {
2228
2237
send_text = false ;
2229
2238
}
0 commit comments