Skip to content

kv-cache : simplify #13746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
Draft

kv-cache : simplify #13746

wants to merge 12 commits into from

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented May 24, 2025

cont #13706 (comment), #13194

Main goal here is to simplify the abstract interface of struct llama_kv_cache.

Overview

Changes to the internal struct llama_kv_cache abstract interface:

  • Remove llama_kv_cache::commit()
  • Remove llama_kv_cache::restore()
  • Remove llama_kv_cache::sbatch_init()
  • Remove llama_kv_cache::ubatch_next()
  • Remove llama_kv_cache::find_slot()
  • Remove llama_kv_cache_guard
  • Add:
--- llama-memory.h

    class llama_memory_decode_state_i {
    public:
        virtual ~llama_memory_decode_state_i() = default;
    
        // consume the next ubatch from the decode state
        // return nullptr if we are done
        virtual llama_ubatch * next() = 0;
    
        // TODO: this might get reworked in the future when refactoring llama_batch
        virtual std::vector<int64_t> & out_ids() = 0;
    
        virtual llama_memory_status get_status() const = 0;
    };
    
    using llama_memory_decode_state_ptr = std::unique_ptr<llama_memory_decode_state_i>;

--- llama-kv-cache.h

    // split the input batch into a set of ubatches and verify that they can fit into the cache
    // check the llama_memory_decode_state_i::get_status() for the result
    virtual llama_memory_decode_state_ptr llama_kv_cache::init(
            const llama_batch & batch,
            uint32_t n_ubatch,
            bool embd_pooled,
            bool logits_all) = 0;

This new interface changes the logic in llama_decode() to first make sure that we can fit the input batch into the cache and only after that we start to process the ubatches. This check takes correctly into account SWA masking and also makes sure that the cache will not be modified before we start the actual computation.

note: the latter is not yet true for the recurrent cache - see comments in the code

Another important update in this PR is that the find_slot() logic for unified caches is now improved. Before we looked for a slot (i.e. a set of contiguous cells) that is empty in order to place the ubatch in it. We now allow the slot to contain data from the same or other sequence which is masked (either by causality or by SWA):

// keep track of what the minimum sequence positions would be if we accept the ubatch
llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
seq_pos_min[s] = cells.seq_pos_min(s);
}
bool found = true;
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_pos pos = ubatch.pos[i];
const llama_seq_id seq_id = ubatch.seq_id[i][0];
// can we use this cell? either:
// - the cell is empty
// - the cell is occupied only by one sequence:
// - mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos
bool can_use = cells.is_empty(head_cur + i);
if (!can_use && cells.seq_count(head_cur + i) == 1) {
const llama_pos pos_cell = cells.pos_get(head_cur + i);
// causal mask
if (cells.seq_has(head_cur + i, seq_id)) {
can_use = pos_cell >= pos;
}
if (!can_use) {
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
// SWA mask
if (pos_cell == seq_pos_min[seq_id_cell] &&
is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
seq_pos_min[seq_id_cell]++;
can_use = true;
}
}
}
if (!can_use) {
found = false;
head_cur += i + 1;
n_tested += i + 1;
break;
}
}

This change is needed for the next PR, which will optimize the SWA cache to use just n_swa + n_ubatch cells and it also has some other nice properties. For example, we no longer have to explicitly prune tokens on successful batch processing, which simplifies the logic significantly and allows us to re-enable speculative decoding for SWA models (will be done also in the next PR).

The worst-graph reserve logic is also refactored and simplified significantly.

There are also some changes to llama-batch, but these are mainly to patch things up so that we are able to push the KV cache refactor first. So no need to review the llama-batch in deep details - the code there will be reworked soon.


With this refactor, I think the struct llama_kv_cache interface is getting close to finalized. I still don't like the llama_kv_cache::set_full() mechanism and will try to find a way to avoid it. I am also hesitating if the llama_kv_cache::update(llama_context) call is really necessary - it could probably be absorbed in the llama_kv_cache::init() call, but then the logic there might get too overloaded, so not sure.

TODO

  • Adapt the recurrent cache to the new interface
  • Test optimization workflow

Next PRs

@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch from d23f887 to 8323e23 Compare May 24, 2025 14:06
Base automatically changed from gg/kv-cache-simplify-part2 to master May 25, 2025 13:34
@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch from c1434b8 to 1eec34a Compare May 25, 2025 13:42
@ggerganov ggerganov marked this pull request as ready for review May 25, 2025 14:50
@ggerganov ggerganov requested a review from ngxson as a code owner May 25, 2025 14:50
@ggerganov
Copy link
Member Author

This PR should not cause any performance changes and the numerical results should be mostly the same (with some small exceptions due to the new logic in find_slot()).

Would appreciate some testing and reports for regressions. Thanks.

@ggerganov ggerganov requested a review from slaren May 25, 2025 14:52
@ngxson
Copy link
Collaborator

ngxson commented May 25, 2025

I re-run the ppl test from #13194 (comment)

master at aa50ba4

OK:   Final estimate: PPL = 7.8002 +/- 0.17654   ggml-org/gemma-3-4b-it-GGUF:Q4_K_M
OK:   Final estimate: PPL = 37.6848 +/- 1.03389   bartowski/gemma-2-9b-it-GGUF:Q4_K_M
OK:   Final estimate: PPL = 5.9658 +/- 0.11216   lmstudio-community/Phi-3.1-mini-128k-instruct-GGUF:Q4_K_M
OK:   Final estimate: PPL = 5.2653 +/- 0.09581   bartowski/CohereForAI_c4ai-command-a-03-2025-GGUF:IQ1_M
OK:   Final estimate: PPL = 7.3320 +/- 0.16048   unsloth/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S

This PR:

OK:   Final estimate: PPL = 7.8003 +/- 0.17654   ggml-org/gemma-3-4b-it-GGUF:Q4_K_M
OK:   Final estimate: PPL = 37.6620 +/- 1.03339   bartowski/gemma-2-9b-it-GGUF:Q4_K_M
OK:   Final estimate: PPL = 5.9658 +/- 0.11216   lmstudio-community/Phi-3.1-mini-128k-instruct-GGUF:Q4_K_M
OK:   Final estimate: PPL = 5.2642 +/- 0.09577   bartowski/CohereForAI_c4ai-command-a-03-2025-GGUF:IQ1_M
OK:   Final estimate: PPL = 7.3302 +/- 0.16037   unsloth/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S

Some results changed very slightly, so I'm not sure if this is expect

@ggerganov
Copy link
Member Author

Yes, I think this difference is expected for SWA models (note Phi currently is disabled SWA, so no difference). It's caused by the different order in which we place the data in memory, due to the find_slot() updates. The results become identical with --swa-full - can you confirm?

@ngxson
Copy link
Collaborator

ngxson commented May 25, 2025

Yes that's right, I added --swa-full and now it become identical to master version:

OK:   Final estimate: PPL = 7.8002 +/- 0.17654   ggml-org/gemma-3-4b-it-GGUF:Q4_K_M
OK:   Final estimate: PPL = 37.7017 +/- 1.03468   bartowski/gemma-2-9b-it-GGUF:Q4_K_M
OK:   Final estimate: PPL = 5.9658 +/- 0.11216   lmstudio-community/Phi-3.1-mini-128k-instruct-GGUF:Q4_K_M
OK:   Final estimate: PPL = 5.2654 +/- 0.09581   bartowski/CohereForAI_c4ai-command-a-03-2025-GGUF:IQ1_M
OK:   Final estimate: PPL = 7.3320 +/- 0.16048   unsloth/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S

Edit: except for gemma-2-9b-it-GGUF

@rhvall

This comment was marked as resolved.

@ngxson
Copy link
Collaborator

ngxson commented May 26, 2025

I re-run the test and the ppl stays the same as my last comment.

Btw, just thinking, is it possible (and it is useful) to add a ppl test mode that uses the KV remove API?

@ggerganov
Copy link
Member Author

I re-run the test and the ppl stays the same as my last comment.

The bartowski/gemma-2-9b-it-GGUF:Q4_K_M model produces the same PPL on master and on this PR with this command:

./bin/llama-perplexity -hf bartowski/gemma-2-9b-it-GGUF:Q4_K_M -f ./wikitext-2-raw/wiki.test.raw -c 16384 -fa --chunks 2 --swa-full

Maybe your reference value on master is outdated?

Btw, just thinking, is it possible (and it is useful) to add a ppl test mode that uses the KV remove API?

Can you clarify?

@ngxson
Copy link
Collaborator

ngxson commented May 26, 2025

I can't run the ppl rn, but if you get correct result, then I think yes could be a problem on my side.

Btw, just thinking, is it possible (and it is useful) to add a ppl test mode that uses the KV remove API?

Can you clarify?

Currently, AFAIU the ppl test simply evaluate text chunk by chunk, but only going forward. For example, if I have 3 chunks: 1-2-3, then they will be evaluated in the order of 1-2-3

But what we also what to test is for example:

  • Evaluate chunk 1, 2
  • Remove chunk 2 from memory
  • Evaluate chunk 2, 3

So I expect the ppl to be the same as just doing 1-2-3

@slaren
Copy link
Member

slaren commented May 26, 2025

How does this recover from a failed call to graph_compute? What is the replacement for commit/restore?

@ggerganov
Copy link
Member Author

How does this recover from a failed call to graph_compute? What is the replacement for commit/restore?

There are some tricky scenarios in which we could have overwritten some of the data in the cache by the time the error occurs (i.e. processed the first few ubatches, but not all of them yet). Before (i.e. on master), we allowed to place ubatches only in empty slots, so we could simply mark the cells back to empty and recover in such cases. But with the new logic, this is no longer guaranteed because we allow to place ubatches in masked slots. This new logic is quite beneficial because it will enable smaller caches for SWA (i.e. n_swa + n_ubatch vs n_swa + n_batch) and also we don't have to explicitly prune SWA-masked tokens on successful batch, which allows to seamlessly do short rollbacks. The latter is needed for speculative decoding (#13747) and for cases where the last generated chat response can contain a few extra newlines, which are then discarded by the Web UI. In the latter case, if we pruned all tokens strictly by the SWA window (as it is currently on master), then this would cause full reprocessing of the context, while with the new logic, we can still rollback and have all necessary cache data available to reuse.

I think that on compute error, the KV cache should be assumed in an undefined state and the application should take necessary steps to recover (i.e. by clearing it and reprocessing the context that is currently needed). Later on, this reprocessing will become seamless, when we start storing the necessary tokens/embeddings information and add the logic for auto-reprocessing whatever is currently missing from the cache.

@slaren
Copy link
Member

slaren commented May 26, 2025

I am mostly concerned about the abort callback functionality. Errors in the backend are likely to be unrecoverable, but I am not sure if the abort functionality makes sense if it leaves the cache in a bad state.

@ggerganov
Copy link
Member Author

I admit that I had completely forgotten about the abort callback. Let me see if we can do something about this.

@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch from 0b73da5 to 2252eef Compare May 27, 2025 13:11
@ggerganov ggerganov marked this pull request as draft May 27, 2025 13:32
@ggerganov
Copy link
Member Author

Drafting for now as I want to do some more testing and think about the abort mechanism.


std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) {
ubatches.push_back(sbatch.split_simple(n_ubatch));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm reading this right that it's not possible to use split_equal with the unified cache after this refactor. If so, I think this will cause problems with the eventual hybrid cache where split_equal is required for recurrent child caches (cc @compilade since you pointed that out to me earlier).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible to split_equal with the unified cache - it is just not using it here because there is no reason for it. The decision how to split the batch into ubatches is implemented per cache type. If the hybrid cache requires split_equal then it's init() method should use that. This code here will never be called for the hybrid cache.

}
std::vector<llama_ubatch> ubatches;

while (sbatch.n_tokens > 0) {
Copy link
Contributor

@gabe-l-hart gabe-l-hart May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In trying to update my hybrid implementation to match this PR, I think the logic around creating ubatches is going to make things tricky. I've attempted to implement it in the abstract where the parent cache doesn't know anything about the child caches and never downcasts them unless asked to by an owner that does know about the child types. The challenge crops up here with the splitting logic since recurrent requires split_equal and with this change, unified requires split_simple, so the ubatches themselves are not guaranteed to be identical across child caches.

Reading what you've got here for iswa, it looks like you're essentially unrolling the logic in llama_kv_cache_unified::init and performing it pairwise across the two child caches based on the explicit knowledge that they're both llama_kv_cache_unified types and therefore expose prepare. I could do something similar in llama_kv_cache_hybrid::init, but not without a lot of conditional logic and dynamic casting to call the appropriate flavor of batch splitting and prepare. The two possible solutions I can imagine that would avoid this would be:

  1. Add an argument to the abstract definition of init to allow the caller to specify the split type
  2. Make prepare virtual in llama_kv_cache and update the implementation in llama_kv_cache_recurrent to also return a vector of heads (though I'm not clear what that would mean for the recurrent cache).

I think my personal preference would be (1) which I may try to do on my branch to see how it works. The other alternative would be to scrap the idea of keeping llama_kv_cache_hybrid abstract and instead explicitly have it own two child caches, one unified and one recurrent. I'd love to avoid this to enable arbitrary future hybrid styles like mixes of swa, unified, recurrent, etc all within one model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Add an argument to the abstract definition of init to allow the caller to specify the split type
  2. Make prepare virtual in llama_kv_cache and update the implementation in llama_kv_cache_recurrent to also return a vector of heads (though I'm not clear what that would mean for the recurrent cache).

I've considered these options and they don't work for the same reason. The produced decoding state requires to keep cache-specific information, like the head positions and potentially other data for other types of caches. Abstracting this at the llama_kv_cache level, will simply move the same logic that we currently have to the llama_memory_decode_state.

Additionally, I think the caller should not know about sbatches. With the current design, we init() with a generic batch and receive a set of ubatches. The entire logic for how to produce these ubatches is contained inside the KV cache implementation.

The other alternative would be to scrap the idea of keeping llama_kv_cache_hybrid abstract and instead explicitly have it own two child caches, one unified and one recurrent.

This should work.

@ggerganov ggerganov mentioned this pull request May 28, 2025
1 task
@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch 2 times, most recently from be635a7 to 7dc61c2 Compare May 28, 2025 10:54
@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch from 7dc61c2 to a3ebf0a Compare May 28, 2025 11:47
@ggerganov
Copy link
Member Author

I am mostly concerned about the abort callback functionality. Errors in the backend are likely to be unrecoverable, but I am not sure if the abort functionality makes sense if it leaves the cache in a bad state.

@slaren With the current proposal, we have the following invariant that should always be true:

For each sequence id seq_id, all positions in the range[llama_kv_self_seq_pos_min(ctx, seq_id), llama_kv_self_seq_pos_max(ctx, seq_id)] are guaranteed to be present in the memory of the context.

This remains true even after aborting the processing of a batch. This way after the abort, the user code can query the context about the min/max pos for each sequence and decide which tokens from the input batch weren't processed and take the respective action.

To achieve that, with a3ebf0a we now call llama_kv_self_seq_rm() when an abort occurs. We call this for all sequences and positions participating in the ubatch that was aborted. This should effectively put the KV cache in a state from which new computations can continue, as long as they respect the pos_min and pos_max values that can be retrieved from the llama_kv_self_seq_ API.

The same logic is also applied when a compute error occurs, although there is no guarantee that some other state would be in a healthy state after such errors.

Let me know if this sounds good.

Copy link
Member

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The solution seems good.

I still think that n and head should not be part of the KV state. It leads to very confusing code, for example the new process function does not receive or do any KV allocation information, so how does it even work? well, it doesn't really work by itself, it depends on the state of the object that must be setup before calling this function. Making functions as close as possible to pure functions that do not depend on any external state greatly improves the code readability.

@ggerganov
Copy link
Member Author

I still think that n and head should not be part of the KV state.

I tried to address this concern in the latest commit. Still need to do the same update for the recurrent cache. Let me know if this seems better now.

@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch from 1942427 to a592c13 Compare May 28, 2025 17:34
@slaren
Copy link
Member

slaren commented May 28, 2025

It's definitely clearer now, but the fundamental problem is still the same, n_kv and head are still part of the state of llama_kv_cache_unified. I think where it would make more sense to have this state is in llama_memory_state, but I understand that would require changing a lot of the graph building functions and it is not strictly necessary.

@ggerganov
Copy link
Member Author

Yes, I see it now - it might not be too difficult to do. Will try to do that.

// TODO: improve to accept cells that are masked by the SWA
if (!cells.is_empty(head + i)) {
const llama_pos pos = ubatch.pos[i];
const llama_seq_id seq_id = ubatch.seq_id[i][0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In local testing trying to use split_equal with the unified portion of the hybrid cache, this line is causing problems. Specifically, I think it conflicts with the logic in llama_sbatch::add_seq_to_ubatch (here) where the ubatch.seq_id is only populated with a non-null value at position ubatch.n_seqs. I'll keep digging to see if there's a simple solution, but wanted to flag this in case it's an easy fix on your end.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably not the right solution, but this "fixed" the issue on my Granite 4 branch:

Suggested change
const llama_seq_id seq_id = ubatch.seq_id[i][0];
const llama_seq_id seq_id = ubatch.seq_id[i] == nullptr ? ubatch.seq_id[0][0] : ubatch.seq_id[i][0];

Copy link
Collaborator

@compilade compilade May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intended way to traverse ubatch.seq_id since the splits were introduced in #8526 is by using ubatch.n_seqs, not ubatch.n_tokens. In simple splits, ubatch.n_seqs is equal to ubatch.n_tokens. Fixing this loop (and also the one in apply_ubatch) should make it work properly with equal splits too.

ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits

See also the comments explaining the sizes of the arrays in ubatch:

struct llama_ubatch {
bool equal_seqs;
// TODO: whole_seqs for embeddings?
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
uint32_t n_seq_tokens; // tokens per sequence
uint32_t n_seqs;
llama_token * token; // [n_tokens]
float * embd; // [n_embd, n_tokens]
llama_pos * pos; // [n_tokens]
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
int8_t * output; // [n_tokens]
};

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correction, using ubatch.n_seqs for traversal only applies to ubatch.n_seq_id and ubatch.seq_id (in case anyone here relies on the notifications and missed the edit in my previous comment).

Copy link
Contributor

@gabe-l-hart gabe-l-hart May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's really helpful. Since this loop is indexing into both pos and seq_id which in this case have different lengths, I'm not quite following the relationship that should be used to extract seq_id for the given pos element. I think if ubatch.n_seqs > 1 here, that would automatically disqualify all of the other logic around reusing full cells?

Copy link
Collaborator

@compilade compilade May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite following the relationship that should be used to extract seq_id for the given pos element.

Usually, this is relatively simple since the number of tokens per sequence is known in ubatch.n_seq_tokens (for simple splits, this is always 1). In fact, here it could probably be possible to use

Suggested change
const llama_seq_id seq_id = ubatch.seq_id[i][0];
const llama_seq_id seq_id = ubatch.seq_id[i / ubatch.n_seq_tokens][0];

although there is another approach without divisions, but with nested loops and which would change the indexing for ubatch.pos[i] to ubatch.pos[s * ubatch.n_seq_tokens + j] where s is in [0, ubatch.n_seqs) and j in [0, ubatch.n_seq_tokens).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm that with this fix, the model safely produces output on my Granite 4 branch (no comment on cache correctness though!)

Comment on lines +594 to +604
const llama_seq_id seq_id = ubatch.seq_id[i][0];

// can we use this cell? either:
// - the cell is empty
// - the cell is occupied only by one sequence:
// - mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos
bool can_use = cells.is_empty(head_cur + i);

if (!can_use && cells.seq_count(head_cur + i) == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would automatically disqualify all of the other logic around reusing full cells?

Assuming this is correct, I think this would be the correct approach?

Suggested change
const llama_seq_id seq_id = ubatch.seq_id[i][0];
// can we use this cell? either:
// - the cell is empty
// - the cell is occupied only by one sequence:
// - mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos
bool can_use = cells.is_empty(head_cur + i);
if (!can_use && cells.seq_count(head_cur + i) == 1) {
// can we use this cell? either:
// - the cell is empty
// - the cell is occupied only by one sequence:
// - mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos
bool can_use = cells.is_empty(head_cur + i);
if (!can_use && cells.seq_count(head_cur + i) == 1 && ubatch.n_seqs == 1) {
const llama_seq_id seq_id = ubatch.seq_id[0][0];

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That diff is gross, but it just adds an extra conditional to the outer check that checks whether ubatch.n_seqs == 1 and then always uses ubatch.seq_id[0][0].

Copy link
Collaborator

@compilade compilade May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gabe-l-hart It should not be necessary to limit this branch to when ubatch.n_seqs to 1. This almost never happens for simple splits anyway, except when n_ubatch is 1.

See #13746 (comment).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Race condition! Thanks thanks

@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch from 37cec43 to 825efad Compare May 29, 2025 12:50
@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch from 825efad to eed741e Compare May 29, 2025 13:22
@ggerganov
Copy link
Member Author

@slaren In eed741e I think I managed to extract the head and n_kv state from the KV cache object completely. It is now contained in the llama_memory_state objects.

We no longer pass the memory object when building the compute graphs. Instead, we prepare a memory state for each ubatch and we pass this state to the graph building context. The memory state carries the necessary information about the current head and n_kv.

I was also able to elegantly replace the llama_kv_cache::set_full() concept with llama_memory_state_ptr llama_kv_cache::init_full(); which makes more sense semantically and I plan to apply the same idea to replace the sched_defrag() method in a follow-up PR in order to extract the defrag_info state from the KV cache in a similar way.

Sorry for the large diff again. Let me know if you have any follow-up comments or suggestions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants