From 3dbdbcaadcee9e5d0fbe13a0d70a12824ddb72de Mon Sep 17 00:00:00 2001 From: Vladimir Nicolici Date: Tue, 25 Feb 2025 13:21:29 +0200 Subject: [PATCH] Hybrid tokenization of chat prompts based on the slot caches. --- examples/server/server.cpp | 61 +++++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2306dc26fe431..abf3b389ce7d0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -4,6 +4,7 @@ #include "common.h" #include "json-schema-to-grammar.h" #include "llama.h" +#include "llama-vocab.h" #include "log.h" #include "sampling.h" #include "speculative.h" @@ -3841,7 +3842,65 @@ int main(int argc, char ** argv) { // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + std::vector tokenized_prompts; // start of new tokenization code based on caches; it may need optimizations and bug fixes + if (prompt.is_string()) { // attempt tokenization based on the slot token caches first, only for prompts consisting of a single string + llama_tokens cache_based_tokenization; + std::string prompt_string = prompt.get(); + size_t max_prompt_match_in_chars = 0; + + SRV_DBG("Attempting slot cache based tokenization of the prompt, total prompt length %lu characters.\n", prompt_string.size()); + for (size_t slot_index = 0; slot_index < ctx_server.slots.size(); slot_index++) { + size_t prompt_index = 0; + size_t cache_index = 0; + llama_tokens partially_tokenized_prompt; + llama_tokens cache_tokens = ctx_server.slots[slot_index].cache_tokens; // accessing the caches like this might be unsafe + + if (cache_tokens.size() > 0) { + SRV_DBG("Slot %ld has %lu cached tokens, attempting prompt tokenization based on them.\n", slot_index, cache_tokens.size()); + for (cache_index = 0; cache_index < cache_tokens.size() && prompt_index < prompt_string.size(); cache_index++) { + llama_token token = cache_tokens[cache_index]; + const std::string token_string = common_token_to_piece(ctx_server.vocab, token, true); + size_t token_size = token_string.size(); + + if (prompt_index + token_size <= prompt_string.size() && prompt_string.compare(prompt_index, token_size, token_string) == 0) { + prompt_index += token_size; + partially_tokenized_prompt.push_back(token); + } else if (cache_index == 0) { // the first token from the cache doesn't have to be in the prompt, as it might be a BOS token, so just add it. This might cause issues. + partially_tokenized_prompt.push_back(token); + } else { + break; + } + } + + if (prompt_index > max_prompt_match_in_chars) { // the tokenization based on this slot matches more characters than the previous best match + max_prompt_match_in_chars = prompt_index; + cache_based_tokenization = partially_tokenized_prompt; + } + } + } + + if (max_prompt_match_in_chars > 0) { // if some of the prompt was tokenized based on the slot caches + std::string remaining_string = prompt_string.substr(max_prompt_match_in_chars); + std::vector remaining_prompt_tokens = common_tokenize(ctx_server.vocab, remaining_string, true, true); // tokenize the rest of the prompt normally + + SRV_DBG("The slot caches based tokenization has produced %lu tokens and the regular tokenization an additional %lu tokens for a total of %lu.\n", + cache_based_tokenization.size(), remaining_prompt_tokens.size(), cache_based_tokenization.size() + remaining_prompt_tokens.size()); + + // concatenate the additional tokens to the cached tokens, but skip the additinal BOS, as we don't need one in the middle of the tokens. This might cause issues. + if (remaining_prompt_tokens.size() > 1) { + cache_based_tokenization.insert(cache_based_tokenization.end(), remaining_prompt_tokens.begin() + 1, remaining_prompt_tokens.end()); + } + + tokenized_prompts.push_back(cache_based_tokenization); + } else { + SRV_DBG("Partial tokenization of the %lu character long prompt based on slot caches was not possible.\n", prompt_string.size()); + } + } + + if (tokenized_prompts.empty()) { // if the slot token cache based tokenization was not possible, tokenize the prompt normally + tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + } // end of new tokenization code based on caches + tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(type);