Skip to content

kv-cache : avoid modifying recurrent cells when setting inputs #13834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gg/kv-cache-simplify-part3
Choose a base branch
from

Conversation

compilade
Copy link
Collaborator

@compilade compilade commented May 27, 2025

NOTE: this targets #13746, not master.

@ggerganov As discussed in #9126 (comment), this ports some of the changes from #9126 to #13746 for recurrent caches.

It mostly works, but there is still something wrong somewhere indicated by non-consecutive
token positions when running mamba-130m-hf with llama-parallel:

$ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
...
find_slot: non-consecutive token position 0 after 335 for sequence 1 with 1 new tokens
...
find_slot: non-consecutive token position 0 after 346 for sequence 3 with 1 new tokens
...

This was not a problem in #9126, so it might (or not) relate to other recent changes in how the kv cache is handled.

I'll attempt to figure out what is wrong by updating #9126 to the latest master and see if the problem also appears (EDIT: it does cause the problem to appear).

Is there any recent change in how kv_cell.pos is handled which comes to mind and could cause this? (pos is not reset properly between re-uses of clients in parallel, at least for recurrent models)

I'm suspecting #13598 might be related.

EDIT: it seems like with -pps, the problem isn't there, and this was the default behavior before #13598.


Make sure to read the contributing guidelines before submitting a PR

- kv-cache : remove inp_s_mask

It was replaced with equivalent and simpler functionality
with rs_z (the first zeroed state) and the already-existing inp_s_copy.
@ggerganov
Copy link
Member

ggerganov commented May 27, 2025

The find_slot: non-consecutive token position 0 after 346 for sequence 3 with 1 new tokens warning was what I considered to be broken. I didn't investigate the source of the warning and assumed that it is caused by the cells being modified during compute. I assumed that, because if I disable the "preparation phase", the warning does not appear.

To clarify, the "preparation phase" is where we simply insert the ubatches in the cells to make sure they fit and then we revert back to the initial state as if no changes were ever made. We then start to insert the ubatches and compute them one-by-one. So I was confused why the warning disappears when the "preparation phase" is skipped and the only explanation I had was because of the updates during the compute.

Comment on lines +2311 to +2316
for (const auto & ubatch : ubatches) {
if (!find_slot(ubatch)) {
success = false;
break;
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the "preparation phase".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. When I leave this out, the non-consecutive token position problem is still there (the only difference is that the warnings are not duplicated), so I don't think the preparation phase is the source of the problem.

When using -pps the warnings are gone, though. Maybe this was an existing problem which was surfaced by not making -pps the default in parallel.cpp. I'll try to find the source of the kv_cell.pos discrepancy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I confirm that the warnings are still there even without preparation - both on this branch and on the target branch. It's possible that I hallucinated that the warnings disappear without preparation.

When using -pps the warnings are gone, though. Maybe this was an existing problem which was surfaced by not making -pps the default in parallel.cpp. I'll try to find the source of the kv_cell.pos discrepancy.

You are likely right. Will take a look as well tomorrow.

Copy link
Collaborator Author

@compilade compilade May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's another problem which is detectable by running llama-parallel with -pps and deterministic settings and comparing the output at different -ub. -ub 1 is always fine, but -ub 2 seems to produce weird output in this branch (like swapping answers). Default -ub of 512 also manifests this problem.

For example (click to expand)

Towards the end the output is weird.

$ ./bin/llama-parallel -m /path/to/mamba-130M-hf-F16.gguf -np 5 -ns 8 --temp 0 --repeat-penalty 1.1 -ub 2 -pps
...
llama_context: constructing llama_context
llama_context: n_seq_max     = 6
llama_context: n_ctx         = 4096
llama_context: n_ctx_per_seq = 682
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 2
llama_context: causal_attn   = 1
llama_context: flash_attn    = 0
llama_context: freq_base     = 10000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_per_seq (682) < n_ctx_train (1048576) -- the full capacity of the model will not be utilized
llama_context:        CPU  output buffer size =     1.15 MiB
llama_kv_cache_recurrent: kv_size = 6, n_seq_max = 6, type_k = 'f32', type_v = 'f32', n_layer = 24
llama_kv_cache_recurrent:        CPU KV buffer size =    16.03 MiB
llama_kv_cache_recurrent: KV self size  =   16.03 MiB, K (f32):    2.53 MiB, V (f32):   13.50 MiB
llama_context:        CPU compute buffer size =     1.34 MiB
llama_context: graph nodes  = 1278
llama_context: graph splits = 1
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
No new questions so proceed with build-in defaults.


main: Simulating parallel requests from clients:
main: n_parallel = 5, n_sequences = 8, cont_batching = 1, system tokens = 260

main: Evaluating the system prompt ...

Processing requests ...

main: clearing the KV cache
Client   0, seq    0, started decoding ...
Client   1, seq    1, started decoding ...
Client   2, seq    2, started decoding ...
Client   3, seq    3, started decoding ...
Client   4, seq    4, started decoding ...
Client   0, seq   0/  8, prompt   15 t, response   26 t, time  4.30 s, speed  9.54 t/s, cache miss 0  

Input:    What is the meaning of life?
Response: Life is a series of events, and the best way to understand them are by looking at what happens in each one.

Client   2, seq   2/  8, prompt   15 t, response   26 t, time  4.30 s, speed  9.54 t/s, cache miss 0  

Input:    What is the meaning of life?
Response: Life is a series of events, and the best way to understand them are by looking at what happens in each one.

Client   0, seq    5, started decoding ...
Client   2, seq    6, started decoding ...
Client   2, seq   6/  8, prompt   21 t, response   11 t, time  2.01 s, speed 15.93 t/s, cache miss 0  

Input:    If you could have any superpower, what would it be?
Response: I would like to be a superpower.

Client   2, seq    7, started decoding ...
Client   3, seq   3/  8, prompt   26 t, response   39 t, time  6.84 s, speed  9.51 t/s, cache miss 0  

Input:    Are you familiar with the Special Theory of Relativity and can you explain it to me?
Response: I recommend the steak. It is a very good steak, and it's easy to cook with your hands on the stove or in a skillet if you have one handy at all times!

Client   2, seq   7/  8, prompt   28 t, response   10 t, time  1.39 s, speed 27.28 t/s, cache miss 0  

Input:    I want to learn how to play the piano. What would be the best way to do it?
Response: I would suggest you play the piano.

Client   4, seq   4/  8, prompt   16 t, response   49 t, time  7.70 s, speed  8.44 t/s, cache miss 0  

Input:    Recommend some interesting books to read.
Response: I recommend the book "The Golden Duck" by Richard Feynman. It is a fascinating and entertaining book that is written in a very readable style, and it is also a great way to learn about physics at school or college level!

Client   0, seq   5/  8, prompt   26 t, response   65 t, time  5.27 s, speed 17.28 t/s, cache miss 0  

Input:    Are you familiar with the Special Theory of Relativity and can you explain it to me?
Response: I am a physicist and I have been working on this theory for over 20 years. It is a very simple theory that describes the behavior of particles in a vacuum, but it has many interesting properties such as the existence or nonexistence theorems etc., which are not known to us yet because we do know about them.

Client   1, seq   1/  8, prompt   18 t, response  128 t, time 10.40 s, speed 14.04 t/s, cache miss 0  

Input:    What is the best way to cook a steak?
Response: The meaning of life is the pursuit and enjoyment that comes from living. It is a state in which one lives, and it is an important part to be alive at all times; it is a few years ago I was asked by my friend who works for me how I would like to work with him on this project. He said he would love if we could do some research into the effects of alcohol consumption among young people in the UK, and that he would be happy when we had a chance interview him about it.
I have been working at The Institute of Alcohol Studies for over 20 years now (and I am still working there), so

main: clearing the KV cache

run parameters as of 2025-05-27 23:50:32

main: n_parallel = 5, n_sequences = 8, cont_batching = 1, system tokens = 260
External prompt file: used built-in defaults
Model and path used:  /path/to/mamba-130M-hf-F16.gguf

Total prompt tokens:    165, speed: 11.41 t/s
Total gen tokens:       354, speed: 24.48 t/s
Total speed (AVG):           speed: 35.88 t/s
Cache misses:             0

llama_perf_context_print:        load time =     120.10 ms
llama_perf_context_print: prompt eval time =   13517.87 ms /   743 tokens (   18.19 ms per token,    54.96 tokens per second)
llama_perf_context_print:        eval time =     823.75 ms /    36 runs   (   22.88 ms per token,    43.70 tokens per second)
llama_perf_context_print:       total time =   14465.21 ms /   779 tokens

Notice how the last question about the steak is somehow answered as if the meaning of life was asked.

This does not happen with -ub 1 or with -np 1.


The problem does not exist in #13746, but does exist in #9126 since at least 35d06fa, (but not on the corresponding commit on master) which means the problem is most likely related to how I changed how the recurrent states are copied (which is weird because I think it did work at some point).

That your branch doesn't manifest the same problem narrows this to my new changes. Still not sure what exactly is the root cause; hopefully I'll find it soon.

This might or might not be related to the non-consecutive kv_cell.pos problem. It's likely a different problem.

(EDIT: it's definitely a different problem, because it still happens after bdbfb4e even though the non-consecutive kv_cell.pos problem has been fixed (the tail cell ids were not swapped correctly))

(EDIT2: Ok, I think I know what is happening, the first zero-ed cell sometimes is used as a source for non-zeroed states and this messes up some things. The fix will need to prevent that situation (EDIT3: somehow that's not sufficient; that was not the root cause))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, one change in #13746 that could be relevant to this is that we now allocate separate ubatch data buffers for each llama_ubatch:

https://github.com/ggml-org/llama.cpp/pull/13746/files#diff-e86f4c320ddf096b16dccbc876be6a50f6d6fc4e690b7ebba8a526cd8caa8f14R52-R63

On master, the ubatches were simply views of a single buffer, the contents of which were updated on every split. But this was not compatible with the peparation logic, so now we allocate separate buffers.

The problem was apparently caused by how the tail cells were swapped.

* graph : simplify logic for recurrent state copies
@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch 3 times, most recently from 825efad to eed741e Compare May 29, 2025 13:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants