@@ -449,6 +449,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
449
449
cvec (params.cvec),
450
450
loras (params.loras),
451
451
memory (params.memory),
452
+ mstate (params.mstate),
452
453
cross (params.cross),
453
454
cb_func (params.cb),
454
455
res (std::make_unique<llm_graph_result>()) {
@@ -1027,9 +1028,13 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1027
1028
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec () const {
1028
1029
const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
1029
1030
1031
+ const llama_kv_cache_unified_state_i * kv_state = static_cast <const llama_kv_cache_unified_state_i *>(mstate);
1032
+
1033
+ const llama_kv_cache_unified::compute_state * cstate = kv_state ? kv_state->get_cstate () : nullptr ;
1034
+
1030
1035
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
1031
1036
1032
- const auto n_kv = kv_self->get_n_kv ();
1037
+ const auto n_kv = kv_self->get_n_kv (cstate );
1033
1038
1034
1039
auto & cur = inp->pos_bucket ;
1035
1040
@@ -1233,12 +1238,16 @@ ggml_tensor * llm_graph_context::build_attn(
1233
1238
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified () const {
1234
1239
const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
1235
1240
1241
+ const llama_kv_cache_unified_state_i * kv_state = static_cast <const llama_kv_cache_unified_state_i *>(mstate);
1242
+
1243
+ const llama_kv_cache_unified::compute_state * cstate = kv_state ? kv_state->get_cstate () : nullptr ;
1244
+
1236
1245
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1237
1246
1238
1247
{
1239
1248
GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa for SWA" );
1240
1249
1241
- const auto n_kv = kv_self->get_n_kv ();
1250
+ const auto n_kv = kv_self->get_n_kv (cstate );
1242
1251
1243
1252
inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1244
1253
// cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1270,17 +1279,21 @@ ggml_tensor * llm_graph_context::build_attn(
1270
1279
1271
1280
const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
1272
1281
1282
+ const llama_kv_cache_unified_state_i * kv_state = static_cast <const llama_kv_cache_unified_state_i *>(mstate);
1283
+
1284
+ const llama_kv_cache_unified::compute_state * cstate = kv_state ? kv_state->get_cstate () : nullptr ;
1285
+
1273
1286
// store to KV cache
1274
1287
{
1275
- ggml_build_forward_expand (gf, kv_self->cpy_k (ctx0, k_cur, il));
1276
- ggml_build_forward_expand (gf, kv_self->cpy_v (ctx0, v_cur, il));
1288
+ ggml_build_forward_expand (gf, kv_self->cpy_k (cstate, ctx0, k_cur, il));
1289
+ ggml_build_forward_expand (gf, kv_self->cpy_v (cstate, ctx0, v_cur, il));
1277
1290
}
1278
1291
1279
1292
const auto & kq_mask = inp->get_kq_mask ();
1280
1293
1281
1294
ggml_tensor * q = q_cur;
1282
- ggml_tensor * k = kv_self->get_k (ctx0, il);
1283
- ggml_tensor * v = kv_self->get_v (ctx0, il);
1295
+ ggml_tensor * k = kv_self->get_k (cstate, ctx0, il);
1296
+ ggml_tensor * v = kv_self->get_v (cstate, ctx0, il);
1284
1297
1285
1298
ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1286
1299
cb (cur, " kqv_out" , il);
@@ -1303,10 +1316,15 @@ ggml_tensor * llm_graph_context::build_attn(
1303
1316
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
1304
1317
const llama_kv_cache_unified_iswa * kv_self = static_cast <const llama_kv_cache_unified_iswa *>(memory);
1305
1318
1319
+ const llama_kv_cache_unified_iswa_state_i * kv_state = static_cast <const llama_kv_cache_unified_iswa_state_i *>(mstate);
1320
+
1321
+ const llama_kv_cache_unified::compute_state * cstate_base = kv_state ? kv_state->get_cstate_base () : nullptr ;
1322
+ const llama_kv_cache_unified::compute_state * cstate_swa = kv_state ? kv_state->get_cstate_swa () : nullptr ;
1323
+
1306
1324
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1307
1325
1308
1326
{
1309
- const auto n_kv = kv_self->get_kv_base ()->get_n_kv ();
1327
+ const auto n_kv = kv_self->get_kv_base ()->get_n_kv (cstate_base );
1310
1328
1311
1329
inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1312
1330
// cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1318,7 +1336,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1318
1336
{
1319
1337
GGML_ASSERT (hparams.swa_type != LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified for non-SWA" );
1320
1338
1321
- const auto n_kv = kv_self->get_kv_swa ()->get_n_kv ();
1339
+ const auto n_kv = kv_self->get_kv_swa ()->get_n_kv (cstate_swa );
1322
1340
1323
1341
inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1324
1342
// cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@@ -1354,17 +1372,24 @@ ggml_tensor * llm_graph_context::build_attn(
1354
1372
1355
1373
const auto * kv = is_swa ? kv_self->get_kv_swa () : kv_self->get_kv_base ();
1356
1374
1375
+ const llama_kv_cache_unified_iswa_state_i * kv_state = static_cast <const llama_kv_cache_unified_iswa_state_i *>(mstate);
1376
+
1377
+ const llama_kv_cache_unified::compute_state * cstate_base = kv_state ? kv_state->get_cstate_base () : nullptr ;
1378
+ const llama_kv_cache_unified::compute_state * cstate_swa = kv_state ? kv_state->get_cstate_swa () : nullptr ;
1379
+
1380
+ const llama_kv_cache_unified::compute_state * cstate = is_swa ? cstate_swa : cstate_base;
1381
+
1357
1382
// store to KV cache
1358
1383
{
1359
- ggml_build_forward_expand (gf, kv->cpy_k (ctx0, k_cur, il));
1360
- ggml_build_forward_expand (gf, kv->cpy_v (ctx0, v_cur, il));
1384
+ ggml_build_forward_expand (gf, kv->cpy_k (cstate, ctx0, k_cur, il));
1385
+ ggml_build_forward_expand (gf, kv->cpy_v (cstate, ctx0, v_cur, il));
1361
1386
}
1362
1387
1363
1388
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
1364
1389
1365
1390
ggml_tensor * q = q_cur;
1366
- ggml_tensor * k = kv->get_k (ctx0, il);
1367
- ggml_tensor * v = kv->get_v (ctx0, il);
1391
+ ggml_tensor * k = kv->get_k (cstate, ctx0, il);
1392
+ ggml_tensor * v = kv->get_v (cstate, ctx0, il);
1368
1393
1369
1394
ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1370
1395
cb (cur, " kqv_out" , il);
0 commit comments