From e57d8c532b76df7c79c6ea7bd76aef6d169262bc Mon Sep 17 00:00:00 2001 From: matteo serva Date: Wed, 30 Apr 2025 11:56:17 +0200 Subject: [PATCH 01/16] support --start-string --- common/arg.cpp | 8 ++++++++ common/common.h | 1 + tools/server/server.cpp | 43 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index 5080aa2fcbffd..c417a1185970c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2848,6 +2848,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex else { std::invalid_argument("invalid value"); } } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK")); + add_opt(common_arg( + {"--start-string"}, "STRING", + "Start outputting tokens only when the start string has been reached", + [](common_params & params, const std::string & value) { + params.start_strings.resize(1); + params.start_strings[0] = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_START_STRING")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( diff --git a/common/common.h b/common/common.h index cfe1b72786795..adb72c4d0e5b4 100644 --- a/common/common.h +++ b/common/common.h @@ -366,6 +366,7 @@ struct common_params { bool use_jinja = false; // NOLINT bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + std::vector start_strings; std::vector api_keys; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index c580ec123299c..b74832f161499 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -104,6 +104,7 @@ struct slot_params { std::vector lora; std::vector antiprompt; + std::vector start_strings; std::vector response_fields; bool timings_per_token = false; bool post_sampling_probs = false; @@ -161,6 +162,7 @@ struct slot_params { {"mirostat", sampling.mirostat}, {"mirostat_tau", sampling.mirostat_tau}, {"mirostat_eta", sampling.mirostat_eta}, + {"start", start_strings}, {"stop", antiprompt}, {"max_tokens", n_predict}, // User configured n_predict {"n_keep", n_keep}, @@ -229,6 +231,7 @@ struct server_task { slot_params defaults; defaults.sampling = params_base.sampling; defaults.speculative = params_base.speculative; + defaults.start_strings = params_base.start_strings; // enabling this will output extra debug information in the HTTP responses from the server params.verbose = params_base.verbosity > 9; @@ -244,6 +247,7 @@ struct server_task { //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); params.response_fields = json_value(data, "response_fields", std::vector()); + params.start_strings = json_value(data, "start_strings", defaults.start_strings); params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); @@ -1998,6 +2002,7 @@ struct server_context { SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); slot.params.sampling = params_base.sampling; + slot.params.start_strings = params_base.start_strings; slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); @@ -2192,6 +2197,42 @@ struct server_context { const std::string str_test = slot.generated_text.substr(pos); bool send_text = true; + if(slot.n_sent_text == 0 && slot.has_next_token && !slot.params.start_strings.empty()) + { + size_t max_start_string_size = 0; + for(auto start_string: slot.params.start_strings) + { + max_start_string_size = std::max(max_start_string_size, start_string.size()); + } + size_t search_len = max_start_string_size + token_str.size(); + size_t search_pos = 0; + if(slot.generated_text.size() > search_len) + { + search_pos = slot.generated_text.size() - search_len; + } + + auto found_pos = slot.generated_text.npos; + bool found = false; + std::string found_string; + for(auto start_string: slot.params.start_strings) + { + found_pos = slot.generated_text.find(start_string,search_pos); + if(found_pos != slot.generated_text.npos) { + found = true; + found_string = start_string; + break; + } + } + + if(found && slot.generated_text.size() > (found_pos + found_string.size()) ) { + slot.generated_text.erase( + slot.generated_text.begin(), + slot.generated_text.begin() + found_pos + found_string.size()); + } else { + send_text = false; + } + } + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); if (stop_pos != std::string::npos) { slot.generated_text.erase( @@ -2200,7 +2241,7 @@ struct server_context { pos = std::min(slot.n_sent_text, slot.generated_text.size()); } else if (slot.has_next_token) { stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); - send_text = stop_pos == std::string::npos; + send_text = send_text && stop_pos == std::string::npos; } // check if there is any token to predict From 5c0c03658925f00253ac11c10626890ddae2cb28 Mon Sep 17 00:00:00 2001 From: matteo serva Date: Wed, 30 Apr 2025 17:41:45 +0200 Subject: [PATCH 02/16] can set start-string multiple times, doc --- common/arg.cpp | 5 ++--- tools/server/README.md | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index c417a1185970c..ac9d9f20262f1 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2850,10 +2850,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK")); add_opt(common_arg( {"--start-string"}, "STRING", - "Start outputting tokens only when the start string has been reached", + "Start outputting tokens only when at least one start string has been reached. Can be set multiple times.", [](common_params & params, const std::string & value) { - params.start_strings.resize(1); - params.start_strings[0] = value; + params.start_strings.push_back(value); } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_START_STRING")); add_opt(common_arg( diff --git a/tools/server/README.md b/tools/server/README.md index 0ec786ea76f7a..e344b248ccd99 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -160,6 +160,7 @@ The project is under active development, and we are [looking for feedback and co | `--props` | enable changing global properties via POST /props (default: disabled)
(env: LLAMA_ARG_ENDPOINT_PROPS) | | `--no-slots` | disables slots monitoring endpoint
(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) | | `--slot-save-path PATH` | path to save slot kv cache (default: disabled) | +| `--start-string STRING` | The response is not sent to client until one start string is reached. Can be set multiple times | | `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
list of built-in templates:
chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, exaone3, gemma, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, monarch, openchat, orion, phi3, rwkv-world, vicuna, vicuna-orca, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | From 513c419a31cdd26968e0214026100abcb9011e5a Mon Sep 17 00:00:00 2001 From: matteo serva Date: Wed, 30 Apr 2025 18:08:05 +0200 Subject: [PATCH 03/16] added doc for client parameter --- tools/server/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/server/README.md b/tools/server/README.md index e344b248ccd99..68fddae254b6d 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -391,6 +391,9 @@ By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to re `stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`. +`start_strings`: Specify a JSON array of starting strings. +The output of the model is discarded until the first start string is reached, the matching string is not included in the completion. Default: `[]` + `stop`: Specify a JSON array of stopping strings. These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. Default: `[]` From a4b5247e7e3f2ddd01b354de6cb3bf267ba3849e Mon Sep 17 00:00:00 2001 From: matteo serva Date: Wed, 30 Apr 2025 18:37:32 +0200 Subject: [PATCH 04/16] remove whitespaces --- tools/server/server.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index b74832f161499..afcc889fbc4b6 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2199,7 +2199,7 @@ struct server_context { if(slot.n_sent_text == 0 && slot.has_next_token && !slot.params.start_strings.empty()) { - size_t max_start_string_size = 0; + size_t max_start_string_size = 0; for(auto start_string: slot.params.start_strings) { max_start_string_size = std::max(max_start_string_size, start_string.size()); @@ -2218,12 +2218,12 @@ struct server_context { { found_pos = slot.generated_text.find(start_string,search_pos); if(found_pos != slot.generated_text.npos) { - found = true; + found = true; found_string = start_string; break; } } - + if(found && slot.generated_text.size() > (found_pos + found_string.size()) ) { slot.generated_text.erase( slot.generated_text.begin(), From e4f48641d3bbc9977308bdb0629090140250c8cc Mon Sep 17 00:00:00 2001 From: matteo serva Date: Wed, 30 Apr 2025 18:58:33 +0200 Subject: [PATCH 05/16] use correct coding style --- tools/server/server.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index afcc889fbc4b6..914af94af1797 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2197,25 +2197,21 @@ struct server_context { const std::string str_test = slot.generated_text.substr(pos); bool send_text = true; - if(slot.n_sent_text == 0 && slot.has_next_token && !slot.params.start_strings.empty()) - { + if(slot.n_sent_text == 0 && slot.has_next_token && !slot.params.start_strings.empty()) { size_t max_start_string_size = 0; - for(auto start_string: slot.params.start_strings) - { + for(auto start_string: slot.params.start_strings) { max_start_string_size = std::max(max_start_string_size, start_string.size()); } size_t search_len = max_start_string_size + token_str.size(); size_t search_pos = 0; - if(slot.generated_text.size() > search_len) - { + if(slot.generated_text.size() > search_len) { search_pos = slot.generated_text.size() - search_len; } auto found_pos = slot.generated_text.npos; bool found = false; std::string found_string; - for(auto start_string: slot.params.start_strings) - { + for(auto start_string: slot.params.start_strings) { found_pos = slot.generated_text.find(start_string,search_pos); if(found_pos != slot.generated_text.npos) { found = true; From a7349d177e9692f60b017294af328f1009778e8b Mon Sep 17 00:00:00 2001 From: matteo serva Date: Wed, 30 Apr 2025 21:04:26 +0200 Subject: [PATCH 06/16] Added tests for start string feature --- .../server/tests/unit/test_chat_completion.py | 82 +++++++++++++++++++ tools/server/tests/utils.py | 3 + 2 files changed, 85 insertions(+) diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 491cb3a5df636..9feda210dc59b 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -309,3 +309,85 @@ def test_logprobs_stream(): assert token.top_logprobs is not None assert len(token.top_logprobs) > 0 assert aggregated_text == output_text + + +def test_startstring_serverconfig(): + global server + server.jinja = False + server.start_string=" 9 " + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 32, + "messages": [ + {"role": "user", "content": "List the numbers from 1 to 100"}, + ], + "grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\"", + }) + assert res.status_code == 200, res.body + choice = res.body["choices"][0] + content = choice["message"]["content"] + print(content) + assert content.startswith("10 ") + +def test_startstring_clientconfig(): + global server + server.jinja = False + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 32, + "messages": [ + {"role": "user", "content": "List the numbers from 1 to 100"}, + ], + "grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\"", + "start_strings": ["10"] + }) + assert res.status_code == 200, res.body + choice = res.body["choices"][0] + content = choice["message"]["content"] + assert content.startswith(" 11") + + +def test_startstring_clientconfig_stream(): + global server + server.jinja = False + server.start() + max_tokens=64 + system_prompt="" + user_prompt="" + res = server.make_stream_request("POST", "/chat/completions", data={ + "max_tokens": max_tokens, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + "grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\" .+", + "start_strings": ["10"], + "stream": True, + }) + + content = "" + last_cmpl_id = None + for data in res: + choice = data["choices"][0] + if choice["finish_reason"] not in ["stop", "length"]: + delta = choice["delta"]["content"] + content += delta + assert content.startswith(" 11") + + +def test_startstring_clientconfig_multiple(): + global server + server.jinja = False + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 32, + "messages": [ + {"role": "user", "content": "List the numbers from 1 to 100"}, + ], + "grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\"", + "start_strings": ["10","9"] + }) + assert res.status_code == 200, res.body + choice = res.body["choices"][0] + content = choice["message"]["content"] + assert content.startswith(" 10") diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 4dc2062a8e5b9..d2e93bf859b46 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -88,6 +88,7 @@ class ServerProcess: chat_template: str | None = None chat_template_file: str | None = None server_path: str | None = None + start_string: str | None = None # session variables process: subprocess.Popen | None = None @@ -194,6 +195,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(["--chat-template", self.chat_template]) if self.chat_template_file: server_args.extend(["--chat-template-file", self.chat_template_file]) + if self.start_string: + server_args.extend(["--start_string", self.start_string]) args = [str(arg) for arg in [server_path, *server_args]] print(f"tests: starting server with: {' '.join(args)}") From b843667809388b2bc12a35faff6a02fe9d1cbf65 Mon Sep 17 00:00:00 2001 From: matteo serva Date: Wed, 30 Apr 2025 21:05:50 +0200 Subject: [PATCH 07/16] fixed formatting --- tools/server/server.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 914af94af1797..e4a0e6320bb20 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2197,30 +2197,30 @@ struct server_context { const std::string str_test = slot.generated_text.substr(pos); bool send_text = true; - if(slot.n_sent_text == 0 && slot.has_next_token && !slot.params.start_strings.empty()) { + if (slot.n_sent_text == 0 && slot.has_next_token && !slot.params.start_strings.empty()) { size_t max_start_string_size = 0; - for(auto start_string: slot.params.start_strings) { + for (auto start_string: slot.params.start_strings) { max_start_string_size = std::max(max_start_string_size, start_string.size()); } size_t search_len = max_start_string_size + token_str.size(); size_t search_pos = 0; - if(slot.generated_text.size() > search_len) { + if (slot.generated_text.size() > search_len) { search_pos = slot.generated_text.size() - search_len; } auto found_pos = slot.generated_text.npos; bool found = false; std::string found_string; - for(auto start_string: slot.params.start_strings) { + for (auto start_string: slot.params.start_strings) { found_pos = slot.generated_text.find(start_string,search_pos); - if(found_pos != slot.generated_text.npos) { + if (found_pos != slot.generated_text.npos) { found = true; found_string = start_string; break; } } - if(found && slot.generated_text.size() > (found_pos + found_string.size()) ) { + if (found && slot.generated_text.size() > (found_pos + found_string.size()) ) { slot.generated_text.erase( slot.generated_text.begin(), slot.generated_text.begin() + found_pos + found_string.size()); From 792604bcbcc8d6511e0ce9aa38d4eb3ed966cef5 Mon Sep 17 00:00:00 2001 From: matteo serva Date: Fri, 2 May 2025 08:19:13 +0200 Subject: [PATCH 08/16] precompute start string len, and keep start string state in slot --- tools/server/server.cpp | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index e4a0e6320bb20..34adb94296db6 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -105,6 +105,7 @@ struct slot_params { std::vector antiprompt; std::vector start_strings; + size_t start_string_max_len; std::vector response_fields; bool timings_per_token = false; bool post_sampling_probs = false; @@ -247,8 +248,7 @@ struct server_task { //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); params.response_fields = json_value(data, "response_fields", std::vector()); - params.start_strings = json_value(data, "start_strings", defaults.start_strings); - + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); @@ -282,6 +282,14 @@ struct server_task { params.speculative.n_min = std::max(params.speculative.n_min, 0); params.speculative.n_max = std::max(params.speculative.n_max, 0); + // start strings + params.start_strings = json_value(data, "start_strings", defaults.start_strings); + params.start_string_max_len = 0; + for (auto start_string: params.start_strings) { + params.start_string_max_len = std::max(params.start_string_max_len, start_string.size()); + } + + // Use OpenAI API logprobs only if n_probs wasn't provided if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); @@ -1295,6 +1303,8 @@ struct server_slot { std::string stopping_word; + bool start_string_found = false; + // sampling json json_schema; @@ -1332,6 +1342,7 @@ struct server_slot { n_past = 0; n_sent_text = 0; task_type = SERVER_TASK_TYPE_COMPLETION; + start_string_found = false; generated_tokens.clear(); generated_token_probs.clear(); @@ -2197,11 +2208,8 @@ struct server_context { const std::string str_test = slot.generated_text.substr(pos); bool send_text = true; - if (slot.n_sent_text == 0 && slot.has_next_token && !slot.params.start_strings.empty()) { - size_t max_start_string_size = 0; - for (auto start_string: slot.params.start_strings) { - max_start_string_size = std::max(max_start_string_size, start_string.size()); - } + if (!slot.start_string_found && slot.has_next_token && !slot.params.start_strings.empty()) { + size_t max_start_string_size = slot.params.start_string_max_len; size_t search_len = max_start_string_size + token_str.size(); size_t search_pos = 0; if (slot.generated_text.size() > search_len) { @@ -2224,6 +2232,7 @@ struct server_context { slot.generated_text.erase( slot.generated_text.begin(), slot.generated_text.begin() + found_pos + found_string.size()); + slot.start_string_found = true; } else { send_text = false; } From 124a92d9446caa81a2545c1aa03b52b37aace77b Mon Sep 17 00:00:00 2001 From: matteo serva Date: Fri, 2 May 2025 08:59:25 +0200 Subject: [PATCH 09/16] refactor the substring search function --- tools/server/server.cpp | 17 +++++------------ tools/server/utils.hpp | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 34adb94296db6..b2eb1a8d90946 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2216,19 +2216,12 @@ struct server_context { search_pos = slot.generated_text.size() - search_len; } - auto found_pos = slot.generated_text.npos; - bool found = false; - std::string found_string; - for (auto start_string: slot.params.start_strings) { - found_pos = slot.generated_text.find(start_string,search_pos); - if (found_pos != slot.generated_text.npos) { - found = true; - found_string = start_string; - break; - } - } + std::pair search_result = find_first_substring(slot.generated_text,slot.params.start_strings, search_pos); + bool found = search_result.first != std::string::npos; - if (found && slot.generated_text.size() > (found_pos + found_string.size()) ) { + if (found) { + auto found_pos = search_result.first; + std::string found_string = slot.params.start_strings[search_result.second]; slot.generated_text.erase( slot.generated_text.begin(), slot.generated_text.begin() + found_pos + found_string.size()); diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index b497959fd8689..2c21213ad644a 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -489,6 +489,25 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin return std::string::npos; } +/* returns a pair containing the position and index of first element found. + returns npos and -1 if not found */ +static std::pair find_first_substring(const std::string &haystack, const std::vector & needles, size_t search_pos = 0) +{ + size_t found_pos = std::string::npos; + int found_idx = -1; + + for (unsigned int i = 0; i < needles.size(); ++i) { + const std::string & start_string = needles[i]; + found_pos = haystack.find(start_string,search_pos); + if (found_pos != std::string::npos) { + found_idx = i; + break; + } + } + + return std::pair(found_pos,found_idx); +} + // TODO: reuse llama_detokenize template static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { From 5f99f8ae048bda03765946a7b4a9a845418c642a Mon Sep 17 00:00:00 2001 From: matteo serva Date: Fri, 2 May 2025 09:01:02 +0200 Subject: [PATCH 10/16] fix comments --- tools/server/server.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index b2eb1a8d90946..50c85fc0ba663 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2201,13 +2201,14 @@ struct server_context { // check if there is incomplete UTF-8 character at the end bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); - // search stop word and delete it + if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); bool send_text = true; + // Handle the start strings if (!slot.start_string_found && slot.has_next_token && !slot.params.start_strings.empty()) { size_t max_start_string_size = slot.params.start_string_max_len; size_t search_len = max_start_string_size + token_str.size(); @@ -2231,6 +2232,7 @@ struct server_context { } } + // search stop word and delete it size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); if (stop_pos != std::string::npos) { slot.generated_text.erase( From 521868feacb4f273f978df5cdafb7a8fd0a281c9 Mon Sep 17 00:00:00 2001 From: matteo serva Date: Fri, 2 May 2025 09:02:57 +0200 Subject: [PATCH 11/16] cleaning --- tools/server/server.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 50c85fc0ba663..6f7678b4afc26 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -248,7 +248,7 @@ struct server_task { //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); params.response_fields = json_value(data, "response_fields", std::vector()); - + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); @@ -288,7 +288,7 @@ struct server_task { for (auto start_string: params.start_strings) { params.start_string_max_len = std::max(params.start_string_max_len, start_string.size()); } - + // Use OpenAI API logprobs only if n_probs wasn't provided if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ From 8a12c963fe801f70551591d5eda53b1e77014533 Mon Sep 17 00:00:00 2001 From: matteo serva Date: Fri, 2 May 2025 09:38:30 +0200 Subject: [PATCH 12/16] fix substring pos calculation --- tools/server/utils.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 2c21213ad644a..69b8c5b2eb1c7 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -498,10 +498,10 @@ static std::pair find_first_substring(const std::string &haystack, for (unsigned int i = 0; i < needles.size(); ++i) { const std::string & start_string = needles[i]; - found_pos = haystack.find(start_string,search_pos); - if (found_pos != std::string::npos) { + auto needle_pos = haystack.find(start_string,search_pos); + if (needle_pos != std::string::npos && (found_pos == std::string::npos || needle_pos < found_pos) ) { + found_pos = needle_pos; found_idx = i; - break; } } From 89d0c7ae5b56cc8263d07094293351ad029f918e Mon Sep 17 00:00:00 2001 From: matteo serva Date: Sat, 3 May 2025 17:19:06 +0200 Subject: [PATCH 13/16] initial refactoring of token processing --- tools/server/server.cpp | 54 ++++++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 6f7678b4afc26..c99770910abc6 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2196,11 +2196,39 @@ struct server_context { if (slot.params.return_tokens) { slot.generated_tokens.push_back(result.tok); } - slot.has_next_token = true; + + + // SECTION: compute conditions on generated tokens so far + slot.has_next_token = true; // check if there is incomplete UTF-8 character at the end bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); + bool token_budget_exhausted = slot.n_decoded > 0 && !slot.has_budget(params_base); + bool start_string_missing = !slot.params.start_strings.empty() && !slot.start_string_found; + bool full_stop_reached = false; + bool partial_stop_reached = false; + + // search start strings + if (start_string_missing && !incomplete && slot.has_next_token) { + size_t max_start_string_size = slot.params.start_string_max_len; + size_t search_len = max_start_string_size + token_str.size(); + size_t search_pos = 0; + if (slot.generated_text.size() > search_len) { + search_pos = slot.generated_text.size() - search_len; + } + std::pair search_result = find_first_substring(slot.generated_text,slot.params.start_strings, search_pos); + bool start_string_found = search_result.first != std::string::npos; + if (start_string_found) { + auto found_pos = search_result.first; + std::string found_string = slot.params.start_strings[search_result.second]; + slot.generated_text.erase( + slot.generated_text.begin(), + slot.generated_text.begin() + found_pos + found_string.size()); + slot.start_string_found = true; + start_string_missing = false; + } + } if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); @@ -2209,27 +2237,9 @@ struct server_context { bool send_text = true; // Handle the start strings - if (!slot.start_string_found && slot.has_next_token && !slot.params.start_strings.empty()) { - size_t max_start_string_size = slot.params.start_string_max_len; - size_t search_len = max_start_string_size + token_str.size(); - size_t search_pos = 0; - if (slot.generated_text.size() > search_len) { - search_pos = slot.generated_text.size() - search_len; - } - - std::pair search_result = find_first_substring(slot.generated_text,slot.params.start_strings, search_pos); - bool found = search_result.first != std::string::npos; - - if (found) { - auto found_pos = search_result.first; - std::string found_string = slot.params.start_strings[search_result.second]; - slot.generated_text.erase( - slot.generated_text.begin(), - slot.generated_text.begin() + found_pos + found_string.size()); - slot.start_string_found = true; - } else { - send_text = false; - } + if (start_string_missing) + { + send_text = false; } // search stop word and delete it From 0c65d40cb2144890eb636ab8edb80ee6622f31b4 Mon Sep 17 00:00:00 2001 From: matteo serva Date: Sat, 3 May 2025 17:23:45 +0200 Subject: [PATCH 14/16] refactoring the find_first_substring function --- tools/server/server.cpp | 4 ++-- tools/server/utils.hpp | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index c99770910abc6..8ab2186d8c61c 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2217,11 +2217,11 @@ struct server_context { search_pos = slot.generated_text.size() - search_len; } - std::pair search_result = find_first_substring(slot.generated_text,slot.params.start_strings, search_pos); + std::pair search_result = find_first_substring(slot.generated_text,slot.params.start_strings, search_pos); bool start_string_found = search_result.first != std::string::npos; if (start_string_found) { auto found_pos = search_result.first; - std::string found_string = slot.params.start_strings[search_result.second]; + std::string found_string = search_result.second; slot.generated_text.erase( slot.generated_text.begin(), slot.generated_text.begin() + found_pos + found_string.size()); diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 69b8c5b2eb1c7..208c492ecff88 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -491,21 +491,21 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin /* returns a pair containing the position and index of first element found. returns npos and -1 if not found */ -static std::pair find_first_substring(const std::string &haystack, const std::vector & needles, size_t search_pos = 0) +static std::pair find_first_substring(const std::string &haystack, const std::vector & needles, size_t search_pos = 0) { size_t found_pos = std::string::npos; - int found_idx = -1; + std::string found_str = ""; for (unsigned int i = 0; i < needles.size(); ++i) { const std::string & start_string = needles[i]; auto needle_pos = haystack.find(start_string,search_pos); if (needle_pos != std::string::npos && (found_pos == std::string::npos || needle_pos < found_pos) ) { found_pos = needle_pos; - found_idx = i; + found_str = start_string; } } - return std::pair(found_pos,found_idx); + return std::pair(found_pos,found_str); } // TODO: reuse llama_detokenize From 3bead57b9e462f8c0a45f631dd11053d341a2b93 Mon Sep 17 00:00:00 2001 From: matteo serva Date: Sat, 3 May 2025 18:15:50 +0200 Subject: [PATCH 15/16] refactoring the process_token function --- tools/server/server.cpp | 51 ++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 8ab2186d8c61c..a6d3a4c4e7885 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2208,7 +2208,7 @@ struct server_context { bool full_stop_reached = false; bool partial_stop_reached = false; - // search start strings + // search the start strings if (start_string_missing && !incomplete && slot.has_next_token) { size_t max_start_string_size = slot.params.start_string_max_len; size_t search_len = max_start_string_size + token_str.size(); @@ -2230,17 +2230,11 @@ struct server_context { } } + // search the stop strings if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); - bool send_text = true; - - // Handle the start strings - if (start_string_missing) - { - send_text = false; - } // search stop word and delete it size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); @@ -2249,25 +2243,36 @@ struct server_context { slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); pos = std::min(slot.n_sent_text, slot.generated_text.size()); + full_stop_reached = true; } else if (slot.has_next_token) { stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); - send_text = send_text && stop_pos == std::string::npos; + partial_stop_reached = (stop_pos != std::string::npos); } + } - // check if there is any token to predict - if (send_text) { - // no send the stop word in the response - result.text_to_send = slot.generated_text.substr(pos, std::string::npos); - slot.n_sent_text += result.text_to_send.size(); - // add the token to slot queue and cache - } else { - result.text_to_send = ""; - } + if(full_stop_reached) + { + slot.stop = STOP_TYPE_WORD; + slot.has_next_token = false; + SLT_DBG(slot, "stopped by word, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); + } - slot.add_token(result); - if (slot.params.stream) { - send_partial_response(slot, result); - } + if(partial_stop_reached || start_string_missing) + { + result.text_to_send = ""; + } + else + { + size_t valid_generated_len = validate_utf8(slot.generated_text); + size_t available_data = valid_generated_len - slot.n_sent_text; + result.text_to_send = slot.generated_text.substr(slot.n_sent_text, available_data); + slot.n_sent_text += result.text_to_send.size(); + } + + slot.add_token(result); + + if (slot.params.stream && !result.text_to_send.empty()) { + send_partial_response(slot, result); } if (incomplete) { @@ -2275,7 +2280,7 @@ struct server_context { } // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { + if (slot.has_next_token && token_budget_exhausted) { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; From a943218d01225fb2ef80b5d1030127476d60ea55 Mon Sep 17 00:00:00 2001 From: matteo serva Date: Sat, 3 May 2025 18:19:37 +0200 Subject: [PATCH 16/16] remove empty whitespace --- tools/server/server.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a6d3a4c4e7885..c38877d2c0ac2 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2196,7 +2196,6 @@ struct server_context { if (slot.params.return_tokens) { slot.generated_tokens.push_back(result.tok); } - // SECTION: compute conditions on generated tokens so far @@ -2250,6 +2249,7 @@ struct server_context { } } + // @ngxson all the other stop reasons should be in this function if(full_stop_reached) { slot.stop = STOP_TYPE_WORD; @@ -2257,6 +2257,7 @@ struct server_context { SLT_DBG(slot, "stopped by word, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); } + // hold the output if we are not ready if(partial_stop_reached || start_string_missing) { result.text_to_send = ""; @@ -2269,8 +2270,10 @@ struct server_context { slot.n_sent_text += result.text_to_send.size(); } + // @ngxson: add the token and its probabilities even if not valid utf8 data slot.add_token(result); + // @ngxson: we also avoid outputting the final token if it's entirely a stop word if (slot.params.stream && !result.text_to_send.empty()) { send_partial_response(slot, result); }