@@ -954,7 +954,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
954
954
}
955
955
956
956
ggml_tensor * llm_graph_context::build_inp_s_copy () const {
957
- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
957
+ const auto * kv_state = get_state_recurrent ( );
958
958
959
959
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
960
960
@@ -971,7 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
971
971
}
972
972
973
973
ggml_tensor * llm_graph_context::build_inp_s_mask () const {
974
- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
974
+ const auto * kv_state = get_state_recurrent ( );
975
975
976
976
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
977
977
@@ -1025,7 +1025,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1025
1025
}
1026
1026
1027
1027
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec () const {
1028
- const auto * kv_state = static_cast < const llama_kv_cache_unified_state *>(mstate );
1028
+ const auto * kv_state = get_state_unified ( );
1029
1029
1030
1030
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
1031
1031
@@ -1056,6 +1056,30 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
1056
1056
return pos_bias;
1057
1057
}
1058
1058
1059
+ const llama_kv_cache_unified_state * llm_graph_context::get_state_unified () const {
1060
+ const auto * umstate = dynamic_cast <const llama_kv_cache_unified_state *>(mstate);
1061
+ if (!umstate) {
1062
+ const auto hmstate = dynamic_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1063
+ if (hmstate) {
1064
+ umstate = hmstate->get_state_attn ();
1065
+ }
1066
+ }
1067
+ GGML_ASSERT (umstate);
1068
+ return umstate;
1069
+ }
1070
+
1071
+ const llama_kv_cache_recurrent_state * llm_graph_context::get_state_recurrent () const {
1072
+ const auto * rmstate = dynamic_cast <const llama_kv_cache_recurrent_state *>(mstate);
1073
+ if (!rmstate) {
1074
+ const auto hmstate = dynamic_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1075
+ if (hmstate) {
1076
+ rmstate = hmstate->get_state_recurrent ();
1077
+ }
1078
+ }
1079
+ GGML_ASSERT (rmstate);
1080
+ return rmstate;
1081
+ }
1082
+
1059
1083
ggml_tensor * llm_graph_context::build_attn_mha (
1060
1084
ggml_cgraph * gf,
1061
1085
ggml_tensor * q,
@@ -1231,7 +1255,7 @@ ggml_tensor * llm_graph_context::build_attn(
1231
1255
}
1232
1256
1233
1257
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified () const {
1234
- const auto * kv_state = static_cast < const llama_kv_cache_unified_state *>(mstate );
1258
+ const auto * kv_state = get_state_unified ( );
1235
1259
1236
1260
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
1237
1261
@@ -1268,7 +1292,7 @@ ggml_tensor * llm_graph_context::build_attn(
1268
1292
ggml_build_forward_expand (gf, k_cur);
1269
1293
ggml_build_forward_expand (gf, v_cur);
1270
1294
1271
- const auto * kv_state = static_cast < const llama_kv_cache_unified_state *>(mstate );
1295
+ const auto * kv_state = get_state_unified ( );
1272
1296
1273
1297
// store to KV cache
1274
1298
{
@@ -1446,7 +1470,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
1446
1470
ggml_tensor * state_mask,
1447
1471
int32_t n_state,
1448
1472
int32_t n_seqs) const {
1449
- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
1473
+ const auto * kv_state = get_state_recurrent ( );
1450
1474
1451
1475
const auto n_kv = kv_state->get_n_kv ();
1452
1476
const auto kv_head = kv_state->get_head ();
@@ -1478,7 +1502,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1478
1502
ggml_tensor * state_mask,
1479
1503
const llama_ubatch & ubatch,
1480
1504
int il) const {
1481
- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
1505
+ const auto * kv_state = get_state_recurrent ( );
1482
1506
1483
1507
const auto token_shift_count = hparams.token_shift_count ;
1484
1508
@@ -1499,7 +1523,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1499
1523
ggml_tensor * token_shift,
1500
1524
const llama_ubatch & ubatch,
1501
1525
int il) const {
1502
- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
1526
+ const auto * kv_state = get_state_recurrent ( );
1503
1527
1504
1528
const auto token_shift_count = hparams.token_shift_count ;
1505
1529
const auto n_embd = hparams.n_embd ;
0 commit comments