Skip to content

kv-cache : add SWA support #13194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.n_keep = value;
}
));
add_opt(common_arg(
{"--swa-full"},
string_format("use full-size SWA cache (default: %s)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)", params.swa_full ? "true" : "false"),
[](common_params & params) {
params.swa_full = true;
}
));
add_opt(common_arg(
{"--no-context-shift"},
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
Expand Down
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;

if (params.reranking) {
cparams.embeddings = true;
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ struct common_params {
bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool use_mmap = true; // use mmap for faster loads
Expand Down
28 changes: 20 additions & 8 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,11 @@ extern "C" {

// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings
bool op_offload; // whether to offload host tensor operations to device
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // use flash attention [EXPERIMENTAL]
bool no_perf; // measure performance timings
bool op_offload; // offload host tensor operations to device
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
};

// model quantization parameters
Expand Down Expand Up @@ -730,10 +731,18 @@ extern "C" {
llama_pos p1,
int d);

// Returns the smallest position present in the KV cache for the specified sequence
// This is typically non-zero only for SWA caches
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
struct llama_context * ctx,
llama_seq_id seq_id);

// Returns the largest position present in the KV cache for the specified sequence
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id);
llama_seq_id seq_id);

// Defragment the KV cache
// This will be applied:
Expand Down Expand Up @@ -943,9 +952,12 @@ extern "C" {
// Requires KV cache.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// < 0 - error. the KV cache state is restored to the state before this call
// Upon non-zero return values, the KV cache state is restored to the state before this call
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// 2 - aborted
// -1 - invalid input batch
// < -1 - error
LLAMA_API int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch);
Expand Down
36 changes: 30 additions & 6 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ llama_context::llama_context(
}

cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);

cparams.op_offload = params.op_offload;

const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
Expand Down Expand Up @@ -176,8 +177,9 @@ llama_context::llama_context(
// init the memory module
if (!hparams.vocab_only) {
llama_memory_params params_mem = {
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
};

memory.reset(model.create_memory(params_mem, cparams));
Expand Down Expand Up @@ -947,8 +949,6 @@ int llama_context::decode(llama_batch & inp_batch) {

// find KV slot
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);

return 1;
}

Expand Down Expand Up @@ -2093,6 +2093,7 @@ llama_context_params llama_context_default_params() {
/*.flash_attn =*/ false,
/*.no_perf =*/ true,
/*.op_offload =*/ true,
/*.swa_full =*/ true,
};

return result;
Expand Down Expand Up @@ -2467,6 +2468,15 @@ void llama_kv_self_seq_div(
kv->seq_div(seq_id, p0, p1, d);
}

llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self();
if (!kv) {
return -1;
}

return kv->seq_pos_min(seq_id);
}

// deprecated
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_self_seq_pos_max(ctx, seq_id);
Expand All @@ -2475,7 +2485,7 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self();
if (!kv) {
return 0;
return -1;
}

return kv->seq_pos_max(seq_id);
Expand Down Expand Up @@ -2637,7 +2647,21 @@ int32_t llama_encode(
int32_t llama_decode(
llama_context * ctx,
llama_batch batch) {
const int ret = ctx->decode(batch);
int ret = ctx->decode(batch);

// defrag and try again
// TODO: distinguish return code when we are sure that even after defrag there is no space available
if (ret == 1) {
llama_kv_self_defrag(ctx);
ret = ctx->decode(batch);

if (ret == 1) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);

return ret;
}
}

if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
Expand Down
Loading
Loading