diff --git a/common/arg.cpp b/common/arg.cpp index 5080aa2fcbffd..ac9d9f20262f1 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2848,6 +2848,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex else { std::invalid_argument("invalid value"); } } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK")); + add_opt(common_arg( + {"--start-string"}, "STRING", + "Start outputting tokens only when at least one start string has been reached. Can be set multiple times.", + [](common_params & params, const std::string & value) { + params.start_strings.push_back(value); + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_START_STRING")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( diff --git a/common/common.h b/common/common.h index cfe1b72786795..adb72c4d0e5b4 100644 --- a/common/common.h +++ b/common/common.h @@ -366,6 +366,7 @@ struct common_params { bool use_jinja = false; // NOLINT bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + std::vector start_strings; std::vector api_keys; diff --git a/tools/server/README.md b/tools/server/README.md index 0ec786ea76f7a..68fddae254b6d 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -160,6 +160,7 @@ The project is under active development, and we are [looking for feedback and co | `--props` | enable changing global properties via POST /props (default: disabled)
(env: LLAMA_ARG_ENDPOINT_PROPS) | | `--no-slots` | disables slots monitoring endpoint
(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) | | `--slot-save-path PATH` | path to save slot kv cache (default: disabled) | +| `--start-string STRING` | The response is not sent to client until one start string is reached. Can be set multiple times | | `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
list of built-in templates:
chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, exaone3, gemma, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, monarch, openchat, orion, phi3, rwkv-world, vicuna, vicuna-orca, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | @@ -390,6 +391,9 @@ By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to re `stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`. +`start_strings`: Specify a JSON array of starting strings. +The output of the model is discarded until the first start string is reached, the matching string is not included in the completion. Default: `[]` + `stop`: Specify a JSON array of stopping strings. These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. Default: `[]` diff --git a/tools/server/server.cpp b/tools/server/server.cpp index c580ec123299c..c38877d2c0ac2 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -104,6 +104,8 @@ struct slot_params { std::vector lora; std::vector antiprompt; + std::vector start_strings; + size_t start_string_max_len; std::vector response_fields; bool timings_per_token = false; bool post_sampling_probs = false; @@ -161,6 +163,7 @@ struct slot_params { {"mirostat", sampling.mirostat}, {"mirostat_tau", sampling.mirostat_tau}, {"mirostat_eta", sampling.mirostat_eta}, + {"start", start_strings}, {"stop", antiprompt}, {"max_tokens", n_predict}, // User configured n_predict {"n_keep", n_keep}, @@ -229,6 +232,7 @@ struct server_task { slot_params defaults; defaults.sampling = params_base.sampling; defaults.speculative = params_base.speculative; + defaults.start_strings = params_base.start_strings; // enabling this will output extra debug information in the HTTP responses from the server params.verbose = params_base.verbosity > 9; @@ -278,6 +282,14 @@ struct server_task { params.speculative.n_min = std::max(params.speculative.n_min, 0); params.speculative.n_max = std::max(params.speculative.n_max, 0); + // start strings + params.start_strings = json_value(data, "start_strings", defaults.start_strings); + params.start_string_max_len = 0; + for (auto start_string: params.start_strings) { + params.start_string_max_len = std::max(params.start_string_max_len, start_string.size()); + } + + // Use OpenAI API logprobs only if n_probs wasn't provided if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); @@ -1291,6 +1303,8 @@ struct server_slot { std::string stopping_word; + bool start_string_found = false; + // sampling json json_schema; @@ -1328,6 +1342,7 @@ struct server_slot { n_past = 0; n_sent_text = 0; task_type = SERVER_TASK_TYPE_COMPLETION; + start_string_found = false; generated_tokens.clear(); generated_token_probs.clear(); @@ -1998,6 +2013,7 @@ struct server_context { SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); slot.params.sampling = params_base.sampling; + slot.params.start_strings = params_base.start_strings; slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); @@ -2180,43 +2196,86 @@ struct server_context { if (slot.params.return_tokens) { slot.generated_tokens.push_back(result.tok); } - slot.has_next_token = true; + // SECTION: compute conditions on generated tokens so far + + slot.has_next_token = true; // check if there is incomplete UTF-8 character at the end bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); + bool token_budget_exhausted = slot.n_decoded > 0 && !slot.has_budget(params_base); + bool start_string_missing = !slot.params.start_strings.empty() && !slot.start_string_found; + bool full_stop_reached = false; + bool partial_stop_reached = false; + + // search the start strings + if (start_string_missing && !incomplete && slot.has_next_token) { + size_t max_start_string_size = slot.params.start_string_max_len; + size_t search_len = max_start_string_size + token_str.size(); + size_t search_pos = 0; + if (slot.generated_text.size() > search_len) { + search_pos = slot.generated_text.size() - search_len; + } + + std::pair search_result = find_first_substring(slot.generated_text,slot.params.start_strings, search_pos); + bool start_string_found = search_result.first != std::string::npos; + if (start_string_found) { + auto found_pos = search_result.first; + std::string found_string = search_result.second; + slot.generated_text.erase( + slot.generated_text.begin(), + slot.generated_text.begin() + found_pos + found_string.size()); + slot.start_string_found = true; + start_string_missing = false; + } + } - // search stop word and delete it + // search the stop strings if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); - bool send_text = true; + // search stop word and delete it size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); if (stop_pos != std::string::npos) { slot.generated_text.erase( slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); pos = std::min(slot.n_sent_text, slot.generated_text.size()); + full_stop_reached = true; } else if (slot.has_next_token) { stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); - send_text = stop_pos == std::string::npos; + partial_stop_reached = (stop_pos != std::string::npos); } + } - // check if there is any token to predict - if (send_text) { - // no send the stop word in the response - result.text_to_send = slot.generated_text.substr(pos, std::string::npos); - slot.n_sent_text += result.text_to_send.size(); - // add the token to slot queue and cache - } else { - result.text_to_send = ""; - } + // @ngxson all the other stop reasons should be in this function + if(full_stop_reached) + { + slot.stop = STOP_TYPE_WORD; + slot.has_next_token = false; + SLT_DBG(slot, "stopped by word, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); + } - slot.add_token(result); - if (slot.params.stream) { - send_partial_response(slot, result); - } + // hold the output if we are not ready + if(partial_stop_reached || start_string_missing) + { + result.text_to_send = ""; + } + else + { + size_t valid_generated_len = validate_utf8(slot.generated_text); + size_t available_data = valid_generated_len - slot.n_sent_text; + result.text_to_send = slot.generated_text.substr(slot.n_sent_text, available_data); + slot.n_sent_text += result.text_to_send.size(); + } + + // @ngxson: add the token and its probabilities even if not valid utf8 data + slot.add_token(result); + + // @ngxson: we also avoid outputting the final token if it's entirely a stop word + if (slot.params.stream && !result.text_to_send.empty()) { + send_partial_response(slot, result); } if (incomplete) { @@ -2224,7 +2283,7 @@ struct server_context { } // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { + if (slot.has_next_token && token_budget_exhausted) { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 491cb3a5df636..9feda210dc59b 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -309,3 +309,85 @@ def test_logprobs_stream(): assert token.top_logprobs is not None assert len(token.top_logprobs) > 0 assert aggregated_text == output_text + + +def test_startstring_serverconfig(): + global server + server.jinja = False + server.start_string=" 9 " + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 32, + "messages": [ + {"role": "user", "content": "List the numbers from 1 to 100"}, + ], + "grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\"", + }) + assert res.status_code == 200, res.body + choice = res.body["choices"][0] + content = choice["message"]["content"] + print(content) + assert content.startswith("10 ") + +def test_startstring_clientconfig(): + global server + server.jinja = False + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 32, + "messages": [ + {"role": "user", "content": "List the numbers from 1 to 100"}, + ], + "grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\"", + "start_strings": ["10"] + }) + assert res.status_code == 200, res.body + choice = res.body["choices"][0] + content = choice["message"]["content"] + assert content.startswith(" 11") + + +def test_startstring_clientconfig_stream(): + global server + server.jinja = False + server.start() + max_tokens=64 + system_prompt="" + user_prompt="" + res = server.make_stream_request("POST", "/chat/completions", data={ + "max_tokens": max_tokens, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + "grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\" .+", + "start_strings": ["10"], + "stream": True, + }) + + content = "" + last_cmpl_id = None + for data in res: + choice = data["choices"][0] + if choice["finish_reason"] not in ["stop", "length"]: + delta = choice["delta"]["content"] + content += delta + assert content.startswith(" 11") + + +def test_startstring_clientconfig_multiple(): + global server + server.jinja = False + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 32, + "messages": [ + {"role": "user", "content": "List the numbers from 1 to 100"}, + ], + "grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\"", + "start_strings": ["10","9"] + }) + assert res.status_code == 200, res.body + choice = res.body["choices"][0] + content = choice["message"]["content"] + assert content.startswith(" 10") diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 4dc2062a8e5b9..d2e93bf859b46 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -88,6 +88,7 @@ class ServerProcess: chat_template: str | None = None chat_template_file: str | None = None server_path: str | None = None + start_string: str | None = None # session variables process: subprocess.Popen | None = None @@ -194,6 +195,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(["--chat-template", self.chat_template]) if self.chat_template_file: server_args.extend(["--chat-template-file", self.chat_template_file]) + if self.start_string: + server_args.extend(["--start_string", self.start_string]) args = [str(arg) for arg in [server_path, *server_args]] print(f"tests: starting server with: {' '.join(args)}") diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index b497959fd8689..208c492ecff88 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -489,6 +489,25 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin return std::string::npos; } +/* returns a pair containing the position and index of first element found. + returns npos and -1 if not found */ +static std::pair find_first_substring(const std::string &haystack, const std::vector & needles, size_t search_pos = 0) +{ + size_t found_pos = std::string::npos; + std::string found_str = ""; + + for (unsigned int i = 0; i < needles.size(); ++i) { + const std::string & start_string = needles[i]; + auto needle_pos = haystack.find(start_string,search_pos); + if (needle_pos != std::string::npos && (found_pos == std::string::npos || needle_pos < found_pos) ) { + found_pos = needle_pos; + found_str = start_string; + } + } + + return std::pair(found_pos,found_str); +} + // TODO: reuse llama_detokenize template static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {