Skip to content

Commit 1bb30bf

Browse files
authored
llama : handle KV shift for recurrent models (#10402)
ggml-ci
1 parent 87a533b commit 1bb30bf

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/llama.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18211,13 +18211,13 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
1821118211
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
1821218212
bool need_reserve = false;
1821318213

18214-
// apply K-shift if needed
18215-
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
18214+
if (lctx.kv_self.has_shift) {
1821618215
if (!llama_kv_cache_can_shift(&lctx)) {
18217-
GGML_ABORT("Deepseek2 does not support K-shift");
18216+
GGML_ABORT("The current context does not support K-shift");
1821818217
}
1821918218

18220-
{
18219+
// apply K-shift if needed
18220+
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
1822118221
ggml_backend_sched_reset(lctx.sched.get());
1822218222

1822318223
ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
@@ -20463,7 +20463,7 @@ void llama_kv_cache_update(struct llama_context * ctx) {
2046320463
}
2046420464

2046520465
bool llama_kv_cache_can_shift(struct llama_context * ctx) {
20466-
return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
20466+
return !ctx->kv_self.recurrent && ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
2046720467
}
2046820468

2046920469
// deprecated

0 commit comments

Comments
 (0)