Description
What happened?
Sometimes the part of the initial prompt that should be considered for the penalties is ignored. Only the newly generated tokens are used for calculating penalty. For now I can assume it has something to do with the prompt caching (explained below).
Let's add the following debug code to the llama_sample_repetition_penalties_impl
right after the token_count
map is filled in:
printf("------\n");
for (const auto & entry : token_count) {
printf("[%d] = %d\n", entry.first, entry.second);
}
It will show the tokens that will be used for penalty calculation.
After starting the server and running this:
curl -s --data '{"prompt": "Note that the file, line, and message properties are", "n_predict": 4, "repeat_penalty": 1.1, "cache_prompt": true}' http://127.0.0.1:8080/completion > /dev/null
the server log shows:
------
[0] = 64
------
[2016] = 1
[0] = 63
------
[1562] = 1
[2016] = 1
[0] = 62
------
[1278] = 1
[1562] = 1
[2016] = 1
[0] = 61
So it ignores the initial prompt and only uses the new tokens.
However, if I run the exact same query the second time, I get this:
------
[1584] = 1
[7192] = 1
[3110] = 1
[5117] = 1
[1321] = 1
[3323] = 1
[1044] = 2
[1278] = 1
[1455] = 1
[12791] = 1
[1] = 1
[0] = 52
------
[1584] = 1
[7192] = 1
[3110] = 1
[5117] = 1
[1321] = 1
[3323] = 1
[1044] = 2
[1278] = 1
[1455] = 1
[12791] = 1
[2016] = 1
[1] = 1
[0] = 51
------
[1536] = 1
[1] = 1
[2016] = 1
[1455] = 1
[12791] = 1
[1278] = 1
[3323] = 1
[1321] = 1
[5117] = 1
[3110] = 1
[0] = 50
[1044] = 2
[7192] = 1
[1584] = 1
------
[1536] = 1
[1] = 1
[2016] = 1
[1455] = 1
[12791] = 1
[1278] = 2
[3323] = 1
[1321] = 1
[5117] = 1
[3110] = 1
[0] = 49
[1044] = 2
[7192] = 1
[1584] = 1
Now it has all the initial tokens + one new token each step.
The bug has something to do with the prompt caching, because it does not happen when the cached prompt is used. But it happens in all other cases:
- the first inference
- the prompt is changed
cache_prompt = false
I tested it with CUDA/no-CUDA builds and two different models - the results are the same.
Name and Version
./llama-server --version
version: 3565 (6e02327)
built with cc (Ubuntu 13.2.0-23ubuntu4) 13.2.0 for x86_64-linux-gnu
What operating system are you seeing the problem on?
Linux
Relevant log output
No response