Skip to content

Cache based tokenization for the server input prompts #12067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "common.h"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "llama-vocab.h"
#include "log.h"
#include "sampling.h"
#include "speculative.h"
Expand Down Expand Up @@ -3841,7 +3842,65 @@ int main(int argc, char ** argv) {
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());

std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
std::vector<llama_tokens> tokenized_prompts; // start of new tokenization code based on caches; it may need optimizations and bug fixes
if (prompt.is_string()) { // attempt tokenization based on the slot token caches first, only for prompts consisting of a single string
llama_tokens cache_based_tokenization;
std::string prompt_string = prompt.get<std::string>();
size_t max_prompt_match_in_chars = 0;

SRV_DBG("Attempting slot cache based tokenization of the prompt, total prompt length %lu characters.\n", prompt_string.size());
for (size_t slot_index = 0; slot_index < ctx_server.slots.size(); slot_index++) {
size_t prompt_index = 0;
size_t cache_index = 0;
llama_tokens partially_tokenized_prompt;
llama_tokens cache_tokens = ctx_server.slots[slot_index].cache_tokens; // accessing the caches like this might be unsafe

if (cache_tokens.size() > 0) {
SRV_DBG("Slot %ld has %lu cached tokens, attempting prompt tokenization based on them.\n", slot_index, cache_tokens.size());
for (cache_index = 0; cache_index < cache_tokens.size() && prompt_index < prompt_string.size(); cache_index++) {
llama_token token = cache_tokens[cache_index];
const std::string token_string = common_token_to_piece(ctx_server.vocab, token, true);
size_t token_size = token_string.size();

if (prompt_index + token_size <= prompt_string.size() && prompt_string.compare(prompt_index, token_size, token_string) == 0) {
prompt_index += token_size;
partially_tokenized_prompt.push_back(token);
} else if (cache_index == 0) { // the first token from the cache doesn't have to be in the prompt, as it might be a BOS token, so just add it. This might cause issues.
partially_tokenized_prompt.push_back(token);
} else {
break;
}
}

if (prompt_index > max_prompt_match_in_chars) { // the tokenization based on this slot matches more characters than the previous best match
max_prompt_match_in_chars = prompt_index;
cache_based_tokenization = partially_tokenized_prompt;
}
}
}

if (max_prompt_match_in_chars > 0) { // if some of the prompt was tokenized based on the slot caches
std::string remaining_string = prompt_string.substr(max_prompt_match_in_chars);
std::vector<llama_token> remaining_prompt_tokens = common_tokenize(ctx_server.vocab, remaining_string, true, true); // tokenize the rest of the prompt normally

SRV_DBG("The slot caches based tokenization has produced %lu tokens and the regular tokenization an additional %lu tokens for a total of %lu.\n",
cache_based_tokenization.size(), remaining_prompt_tokens.size(), cache_based_tokenization.size() + remaining_prompt_tokens.size());

// concatenate the additional tokens to the cached tokens, but skip the additinal BOS, as we don't need one in the middle of the tokens. This might cause issues.
if (remaining_prompt_tokens.size() > 1) {
cache_based_tokenization.insert(cache_based_tokenization.end(), remaining_prompt_tokens.begin() + 1, remaining_prompt_tokens.end());
}

tokenized_prompts.push_back(cache_based_tokenization);
} else {
SRV_DBG("Partial tokenization of the %lu character long prompt based on slot caches was not possible.\n", prompt_string.size());
}
}

if (tokenized_prompts.empty()) { // if the slot token cache based tokenization was not possible, tokenize the prompt normally
tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
} // end of new tokenization code based on caches

tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type);
Expand Down