Skip to content

Support start strings, the opposite of stop tokens. #13214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2848,6 +2848,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
else { std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
add_opt(common_arg(
{"--start-string"}, "STRING",
"Start outputting tokens only when at least one start string has been reached. Can be set multiple times.",
[](common_params & params, const std::string & value) {
params.start_strings.push_back(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_START_STRING"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ struct common_params {
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
std::vector<std::string> start_strings;

std::vector<std::string> api_keys;

Expand Down
4 changes: 4 additions & 0 deletions tools/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ The project is under active development, and we are [looking for feedback and co
| `--props` | enable changing global properties via POST /props (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_PROPS) |
| `--no-slots` | disables slots monitoring endpoint<br/>(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) |
| `--slot-save-path PATH` | path to save slot kv cache (default: disabled) |
| `--start-string STRING` | The response is not sent to client until one start string is reached. Can be set multiple times |
| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>list of built-in templates:<br/>chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, exaone3, gemma, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, monarch, openchat, orion, phi3, rwkv-world, vicuna, vicuna-orca, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) |
| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)<br/> |
| `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
Expand Down Expand Up @@ -390,6 +391,9 @@ By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to re

`stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`.

`start_strings`: Specify a JSON array of starting strings.
The output of the model is discarded until the first start string is reached, the matching string is not included in the completion. Default: `[]`

`stop`: Specify a JSON array of stopping strings.
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. Default: `[]`

Expand Down
95 changes: 77 additions & 18 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ struct slot_params {
std::vector<common_adapter_lora_info> lora;

std::vector<std::string> antiprompt;
std::vector<std::string> start_strings;
size_t start_string_max_len;
std::vector<std::string> response_fields;
bool timings_per_token = false;
bool post_sampling_probs = false;
Expand Down Expand Up @@ -161,6 +163,7 @@ struct slot_params {
{"mirostat", sampling.mirostat},
{"mirostat_tau", sampling.mirostat_tau},
{"mirostat_eta", sampling.mirostat_eta},
{"start", start_strings},
{"stop", antiprompt},
{"max_tokens", n_predict}, // User configured n_predict
{"n_keep", n_keep},
Expand Down Expand Up @@ -229,6 +232,7 @@ struct server_task {
slot_params defaults;
defaults.sampling = params_base.sampling;
defaults.speculative = params_base.speculative;
defaults.start_strings = params_base.start_strings;

// enabling this will output extra debug information in the HTTP responses from the server
params.verbose = params_base.verbosity > 9;
Expand Down Expand Up @@ -278,6 +282,14 @@ struct server_task {
params.speculative.n_min = std::max(params.speculative.n_min, 0);
params.speculative.n_max = std::max(params.speculative.n_max, 0);

// start strings
params.start_strings = json_value(data, "start_strings", defaults.start_strings);
params.start_string_max_len = 0;
for (auto start_string: params.start_strings) {
params.start_string_max_len = std::max(params.start_string_max_len, start_string.size());
}


// Use OpenAI API logprobs only if n_probs wasn't provided
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
Expand Down Expand Up @@ -1291,6 +1303,8 @@ struct server_slot {

std::string stopping_word;

bool start_string_found = false;

// sampling
json json_schema;

Expand Down Expand Up @@ -1328,6 +1342,7 @@ struct server_slot {
n_past = 0;
n_sent_text = 0;
task_type = SERVER_TASK_TYPE_COMPLETION;
start_string_found = false;

generated_tokens.clear();
generated_token_probs.clear();
Expand Down Expand Up @@ -1998,6 +2013,7 @@ struct server_context {
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);

slot.params.sampling = params_base.sampling;
slot.params.start_strings = params_base.start_strings;

slot.callback_on_release = [this](int) {
queue_tasks.pop_deferred_task();
Expand Down Expand Up @@ -2180,51 +2196,94 @@ struct server_context {
if (slot.params.return_tokens) {
slot.generated_tokens.push_back(result.tok);
}
slot.has_next_token = true;

// SECTION: compute conditions on generated tokens so far

slot.has_next_token = true;
// check if there is incomplete UTF-8 character at the end
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
bool token_budget_exhausted = slot.n_decoded > 0 && !slot.has_budget(params_base);
bool start_string_missing = !slot.params.start_strings.empty() && !slot.start_string_found;
bool full_stop_reached = false;
bool partial_stop_reached = false;

// search the start strings
if (start_string_missing && !incomplete && slot.has_next_token) {
size_t max_start_string_size = slot.params.start_string_max_len;
size_t search_len = max_start_string_size + token_str.size();
size_t search_pos = 0;
if (slot.generated_text.size() > search_len) {
search_pos = slot.generated_text.size() - search_len;
}

std::pair<size_t, std::string> search_result = find_first_substring(slot.generated_text,slot.params.start_strings, search_pos);
bool start_string_found = search_result.first != std::string::npos;
if (start_string_found) {
auto found_pos = search_result.first;
std::string found_string = search_result.second;
slot.generated_text.erase(
slot.generated_text.begin(),
slot.generated_text.begin() + found_pos + found_string.size());
slot.start_string_found = true;
start_string_missing = false;
}
}

// search stop word and delete it
// search the stop strings
if (!incomplete) {
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());

const std::string str_test = slot.generated_text.substr(pos);
bool send_text = true;

// search stop word and delete it
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
if (stop_pos != std::string::npos) {
slot.generated_text.erase(
slot.generated_text.begin() + pos + stop_pos,
slot.generated_text.end());
pos = std::min(slot.n_sent_text, slot.generated_text.size());
full_stop_reached = true;
} else if (slot.has_next_token) {
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
send_text = stop_pos == std::string::npos;
partial_stop_reached = (stop_pos != std::string::npos);
}
}

// check if there is any token to predict
if (send_text) {
// no send the stop word in the response
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
slot.n_sent_text += result.text_to_send.size();
// add the token to slot queue and cache
} else {
result.text_to_send = "";
}
// @ngxson all the other stop reasons should be in this function
if(full_stop_reached)
{
slot.stop = STOP_TYPE_WORD;
slot.has_next_token = false;
SLT_DBG(slot, "stopped by word, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
}

slot.add_token(result);
if (slot.params.stream) {
send_partial_response(slot, result);
}
// hold the output if we are not ready
if(partial_stop_reached || start_string_missing)
{
result.text_to_send = "";
}
else
{
size_t valid_generated_len = validate_utf8(slot.generated_text);
size_t available_data = valid_generated_len - slot.n_sent_text;
result.text_to_send = slot.generated_text.substr(slot.n_sent_text, available_data);
slot.n_sent_text += result.text_to_send.size();
}

// @ngxson: add the token and its probabilities even if not valid utf8 data
slot.add_token(result);

// @ngxson: we also avoid outputting the final token if it's entirely a stop word
if (slot.params.stream && !result.text_to_send.empty()) {
send_partial_response(slot, result);
}

if (incomplete) {
slot.has_next_token = true;
}

// check the limits
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
if (slot.has_next_token && token_budget_exhausted) {
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;

Expand Down
82 changes: 82 additions & 0 deletions tools/server/tests/unit/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,85 @@ def test_logprobs_stream():
assert token.top_logprobs is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text


def test_startstring_serverconfig():
global server
server.jinja = False
server.start_string=" 9 "
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 32,
"messages": [
{"role": "user", "content": "List the numbers from 1 to 100"},
],
"grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\"",
})
assert res.status_code == 200, res.body
choice = res.body["choices"][0]
content = choice["message"]["content"]
print(content)
assert content.startswith("10 ")

def test_startstring_clientconfig():
global server
server.jinja = False
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 32,
"messages": [
{"role": "user", "content": "List the numbers from 1 to 100"},
],
"grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\"",
"start_strings": ["10"]
})
assert res.status_code == 200, res.body
choice = res.body["choices"][0]
content = choice["message"]["content"]
assert content.startswith(" 11")


def test_startstring_clientconfig_stream():
global server
server.jinja = False
server.start()
max_tokens=64
system_prompt=""
user_prompt=""
res = server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\" .+",
"start_strings": ["10"],
"stream": True,
})

content = ""
last_cmpl_id = None
for data in res:
choice = data["choices"][0]
if choice["finish_reason"] not in ["stop", "length"]:
delta = choice["delta"]["content"]
content += delta
assert content.startswith(" 11")


def test_startstring_clientconfig_multiple():
global server
server.jinja = False
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 32,
"messages": [
{"role": "user", "content": "List the numbers from 1 to 100"},
],
"grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\"",
"start_strings": ["10","9"]
})
assert res.status_code == 200, res.body
choice = res.body["choices"][0]
content = choice["message"]["content"]
assert content.startswith(" 10")
3 changes: 3 additions & 0 deletions tools/server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class ServerProcess:
chat_template: str | None = None
chat_template_file: str | None = None
server_path: str | None = None
start_string: str | None = None

# session variables
process: subprocess.Popen | None = None
Expand Down Expand Up @@ -194,6 +195,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
server_args.extend(["--chat-template", self.chat_template])
if self.chat_template_file:
server_args.extend(["--chat-template-file", self.chat_template_file])
if self.start_string:
server_args.extend(["--start_string", self.start_string])

args = [str(arg) for arg in [server_path, *server_args]]
print(f"tests: starting server with: {' '.join(args)}")
Expand Down
19 changes: 19 additions & 0 deletions tools/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,25 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
return std::string::npos;
}

/* returns a pair containing the position and index of first element found.
returns npos and -1 if not found */
static std::pair<size_t, std::string> find_first_substring(const std::string &haystack, const std::vector<std::string> & needles, size_t search_pos = 0)
{
size_t found_pos = std::string::npos;
std::string found_str = "";

for (unsigned int i = 0; i < needles.size(); ++i) {
const std::string & start_string = needles[i];
auto needle_pos = haystack.find(start_string,search_pos);
if (needle_pos != std::string::npos && (found_pos == std::string::npos || needle_pos < found_pos) ) {
found_pos = needle_pos;
found_str = start_string;
}
}

return std::pair<size_t, std::string>(found_pos,found_str);
}

// TODO: reuse llama_detokenize
template <class Iter>
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
Expand Down
Loading