Skip to content

Commit 4a2709f

Browse files
committed
feat: Support hybrid recurrent cache in llm_graph_context
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent b58351e commit 4a2709f

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

src/llama-graph.cpp

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
954954
}
955955

956956
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
957-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
957+
const auto * kv_state = get_state_recurrent();
958958

959959
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
960960

@@ -971,7 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
971971
}
972972

973973
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
974-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
974+
const auto * kv_state = get_state_recurrent();
975975

976976
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
977977

@@ -1025,7 +1025,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
10251025
}
10261026

10271027
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1028-
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1028+
const auto * kv_state = get_state_unified();
10291029

10301030
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
10311031

@@ -1056,6 +1056,30 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
10561056
return pos_bias;
10571057
}
10581058

1059+
const llama_kv_cache_unified_state * llm_graph_context::get_state_unified() const {
1060+
const auto * umstate = dynamic_cast<const llama_kv_cache_unified_state *>(mstate);
1061+
if (!umstate) {
1062+
const auto hmstate = dynamic_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1063+
if (hmstate) {
1064+
umstate = hmstate->get_state_attn();
1065+
}
1066+
}
1067+
GGML_ASSERT(umstate);
1068+
return umstate;
1069+
}
1070+
1071+
const llama_kv_cache_recurrent_state * llm_graph_context::get_state_recurrent() const {
1072+
const auto * rmstate = dynamic_cast<const llama_kv_cache_recurrent_state *>(mstate);
1073+
if (!rmstate) {
1074+
const auto hmstate = dynamic_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1075+
if (hmstate) {
1076+
rmstate = hmstate->get_state_recurrent();
1077+
}
1078+
}
1079+
GGML_ASSERT(rmstate);
1080+
return rmstate;
1081+
}
1082+
10591083
ggml_tensor * llm_graph_context::build_attn_mha(
10601084
ggml_cgraph * gf,
10611085
ggml_tensor * q,
@@ -1231,7 +1255,7 @@ ggml_tensor * llm_graph_context::build_attn(
12311255
}
12321256

12331257
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1234-
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1258+
const auto * kv_state = get_state_unified();
12351259

12361260
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
12371261

@@ -1268,7 +1292,7 @@ ggml_tensor * llm_graph_context::build_attn(
12681292
ggml_build_forward_expand(gf, k_cur);
12691293
ggml_build_forward_expand(gf, v_cur);
12701294

1271-
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1295+
const auto * kv_state = get_state_unified();
12721296

12731297
// store to KV cache
12741298
{
@@ -1446,7 +1470,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
14461470
ggml_tensor * state_mask,
14471471
int32_t n_state,
14481472
int32_t n_seqs) const {
1449-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1473+
const auto * kv_state = get_state_recurrent();
14501474

14511475
const auto n_kv = kv_state->get_n_kv();
14521476
const auto kv_head = kv_state->get_head();
@@ -1478,7 +1502,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
14781502
ggml_tensor * state_mask,
14791503
const llama_ubatch & ubatch,
14801504
int il) const {
1481-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1505+
const auto * kv_state = get_state_recurrent();
14821506

14831507
const auto token_shift_count = hparams.token_shift_count;
14841508

@@ -1499,7 +1523,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
14991523
ggml_tensor * token_shift,
15001524
const llama_ubatch & ubatch,
15011525
int il) const {
1502-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1526+
const auto * kv_state = get_state_recurrent();
15031527

15041528
const auto token_shift_count = hparams.token_shift_count;
15051529
const auto n_embd = hparams.n_embd;

src/llama-graph.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,8 @@ struct llm_graph_context {
531531
// attention
532532
//
533533

534+
const llama_kv_cache_unified_state * get_state_unified() const;
535+
534536
ggml_tensor * build_attn_mha(
535537
ggml_cgraph * gf,
536538
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
@@ -605,6 +607,8 @@ struct llm_graph_context {
605607
// recurrent
606608
//
607609

610+
const llama_kv_cache_recurrent_state * get_state_recurrent() const;
611+
608612
ggml_tensor * build_copy_mask_state(
609613
ggml_cgraph * gf,
610614
ggml_tensor * s,

0 commit comments

Comments
 (0)