Skip to content

Commit d3372f4

Browse files
ggerganovinfil00p
authored andcommitted
kv-cache : add SWA support (ggml-org#13194)
* kv-cache : prepare for SWA ggml-ci * kv-cache : initial iSWA implementation ggml-ci * kv-cache : rework error recovery logic ggml-ci * models : fix Phi-3 SWA parameters ggml-ci * model : adjust Granite to rope factor changes ggml-ci * server : check if context can do shifts ggml-ci * iswa : for now, always enable shifts (experiment) ggml-ci * kv-cache : simplify SWA logic ggml-ci * kv-cache : apply defrag when we fail to find slots for the batch ggml-ci * llama : update docs about llama_decode ggml-ci * kv-cache : update warning logs when no space for the batch is available ggml-ci * llama : add llama_kv_self_seq_pos_min() * kv-cache : keep track of partial SWA computes and print warnings * server : disallow use cases involving partial SWA context ggml-ci * llama : add param to control SWA cache size ggml-ci * minor : clean-up ggml-ci
1 parent cfc18f6 commit d3372f4

15 files changed

+1414
-638
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,6 +1445,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14451445
params.n_keep = value;
14461446
}
14471447
));
1448+
add_opt(common_arg(
1449+
{"--swa-full"},
1450+
string_format("use full-size SWA cache (default: %s)\n"
1451+
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)", params.swa_full ? "true" : "false"),
1452+
[](common_params & params) {
1453+
params.swa_full = true;
1454+
}
1455+
));
14481456
add_opt(common_arg(
14491457
{"--no-context-shift"},
14501458
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11361136
cparams.flash_attn = params.flash_attn;
11371137
cparams.no_perf = params.no_perf;
11381138
cparams.op_offload = !params.no_op_offload;
1139+
cparams.swa_full = params.swa_full;
11391140

11401141
if (params.reranking) {
11411142
cparams.embeddings = true;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ struct common_params {
323323
bool flash_attn = false; // flash attention
324324
bool no_perf = false; // disable performance metrics
325325
bool ctx_shift = true; // context shift on inifinite text generation
326+
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
326327

327328
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
328329
bool use_mmap = true; // use mmap for faster loads

include/llama.h

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,11 @@ extern "C" {
361361

362362
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
363363
bool embeddings; // if true, extract embeddings (together with logits)
364-
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
365-
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
366-
bool no_perf; // whether to measure performance timings
367-
bool op_offload; // whether to offload host tensor operations to device
364+
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
365+
bool flash_attn; // use flash attention [EXPERIMENTAL]
366+
bool no_perf; // measure performance timings
367+
bool op_offload; // offload host tensor operations to device
368+
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
368369
};
369370

370371
// model quantization parameters
@@ -730,10 +731,18 @@ extern "C" {
730731
llama_pos p1,
731732
int d);
732733

734+
// Returns the smallest position present in the KV cache for the specified sequence
735+
// This is typically non-zero only for SWA caches
736+
// Return -1 if the sequence is empty
737+
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
738+
struct llama_context * ctx,
739+
llama_seq_id seq_id);
740+
733741
// Returns the largest position present in the KV cache for the specified sequence
742+
// Return -1 if the sequence is empty
734743
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
735744
struct llama_context * ctx,
736-
llama_seq_id seq_id);
745+
llama_seq_id seq_id);
737746

738747
// Defragment the KV cache
739748
// This will be applied:
@@ -943,9 +952,12 @@ extern "C" {
943952
// Requires KV cache.
944953
// For encode-decoder contexts, processes the batch using the decoder.
945954
// Positive return values does not mean a fatal error, but rather a warning.
946-
// 0 - success
947-
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
948-
// < 0 - error. the KV cache state is restored to the state before this call
955+
// Upon non-zero return values, the KV cache state is restored to the state before this call
956+
// 0 - success
957+
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
958+
// 2 - aborted
959+
// -1 - invalid input batch
960+
// < -1 - error
949961
LLAMA_API int32_t llama_decode(
950962
struct llama_context * ctx,
951963
struct llama_batch batch);

src/llama-context.cpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ llama_context::llama_context(
9393
}
9494

9595
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
96+
9697
cparams.op_offload = params.op_offload;
9798

9899
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@@ -176,8 +177,9 @@ llama_context::llama_context(
176177
// init the memory module
177178
if (!hparams.vocab_only) {
178179
llama_memory_params params_mem = {
179-
/*.type_k =*/ params.type_k,
180-
/*.type_v =*/ params.type_v,
180+
/*.type_k =*/ params.type_k,
181+
/*.type_v =*/ params.type_v,
182+
/*.swa_full =*/ params.swa_full,
181183
};
182184

183185
memory.reset(model.create_memory(params_mem, cparams));
@@ -947,8 +949,6 @@ int llama_context::decode(llama_batch & inp_batch) {
947949

948950
// find KV slot
949951
if (!kv_self->find_slot(ubatch)) {
950-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
951-
952952
return 1;
953953
}
954954

@@ -2093,6 +2093,7 @@ llama_context_params llama_context_default_params() {
20932093
/*.flash_attn =*/ false,
20942094
/*.no_perf =*/ true,
20952095
/*.op_offload =*/ true,
2096+
/*.swa_full =*/ true,
20962097
};
20972098

20982099
return result;
@@ -2467,6 +2468,15 @@ void llama_kv_self_seq_div(
24672468
kv->seq_div(seq_id, p0, p1, d);
24682469
}
24692470

2471+
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2472+
const auto * kv = ctx->get_kv_self();
2473+
if (!kv) {
2474+
return -1;
2475+
}
2476+
2477+
return kv->seq_pos_min(seq_id);
2478+
}
2479+
24702480
// deprecated
24712481
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
24722482
return llama_kv_self_seq_pos_max(ctx, seq_id);
@@ -2475,7 +2485,7 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
24752485
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
24762486
const auto * kv = ctx->get_kv_self();
24772487
if (!kv) {
2478-
return 0;
2488+
return -1;
24792489
}
24802490

24812491
return kv->seq_pos_max(seq_id);
@@ -2637,7 +2647,21 @@ int32_t llama_encode(
26372647
int32_t llama_decode(
26382648
llama_context * ctx,
26392649
llama_batch batch) {
2640-
const int ret = ctx->decode(batch);
2650+
int ret = ctx->decode(batch);
2651+
2652+
// defrag and try again
2653+
// TODO: distinguish return code when we are sure that even after defrag there is no space available
2654+
if (ret == 1) {
2655+
llama_kv_self_defrag(ctx);
2656+
ret = ctx->decode(batch);
2657+
2658+
if (ret == 1) {
2659+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2660+
2661+
return ret;
2662+
}
2663+
}
2664+
26412665
if (ret != 0) {
26422666
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
26432667
}

0 commit comments

Comments
 (0)