Skip to content

Commit 37cec43

Browse files
committed
kv-cache : extract the graph-specific state from the oject (unified)
ggml-ci
1 parent a592c13 commit 37cec43

6 files changed

+153
-71
lines changed

src/llama-context.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
651651
return nullptr;
652652
}
653653

654-
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype);
654+
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
655655
if (!res) {
656656
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
657657
if (ret) {
@@ -1289,7 +1289,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
12891289
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
12901290

12911291
auto * gf = graph_init();
1292-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1292+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, nullptr);
12931293

12941294
this->n_outputs = save_n_outputs;
12951295

@@ -1310,10 +1310,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13101310
}
13111311

13121312
llm_graph_result_ptr llama_context::graph_build(
1313-
ggml_context * ctx,
1314-
ggml_cgraph * gf,
1315-
const llama_ubatch & ubatch,
1316-
llm_graph_type gtype) {
1313+
ggml_context * ctx,
1314+
ggml_cgraph * gf,
1315+
const llama_ubatch & ubatch,
1316+
llm_graph_type gtype,
1317+
const llama_memory_state_i * mstate) {
13171318
return model.build_graph(
13181319
{
13191320
/*.ctx =*/ ctx,
@@ -1326,6 +1327,7 @@ llm_graph_result_ptr llama_context::graph_build(
13261327
/*.cvec =*/ &cvec,
13271328
/*.loras =*/ &loras,
13281329
/*.memory =*/ memory.get(),
1330+
/*.mstate =*/ mstate,
13291331
/*.cross =*/ &cross,
13301332
/*.n_outputs =*/ n_outputs,
13311333
/*.cb =*/ graph_get_cb(),
@@ -2047,7 +2049,7 @@ void llama_context::opt_epoch_iter(
20472049
n_outputs = ubatch.n_tokens;
20482050

20492051
auto * gf = graph_init();
2050-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
2052+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
20512053

20522054
struct ggml_context * ctx_compute_opt;
20532055
{

src/llama-context.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,11 @@ struct llama_context {
200200

201201
private:
202202
llm_graph_result_ptr graph_build(
203-
ggml_context * ctx,
204-
ggml_cgraph * gf,
205-
const llama_ubatch & ubatch,
206-
llm_graph_type gtype);
203+
ggml_context * ctx,
204+
ggml_cgraph * gf,
205+
const llama_ubatch & ubatch,
206+
llm_graph_type gtype,
207+
const llama_memory_state_i * mstate);
207208

208209
llm_graph_cb graph_get_cb() const;
209210

src/llama-graph.cpp

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
449449
cvec (params.cvec),
450450
loras (params.loras),
451451
memory (params.memory),
452+
mstate (params.mstate),
452453
cross (params.cross),
453454
cb_func (params.cb),
454455
res (std::make_unique<llm_graph_result>()) {
@@ -1027,9 +1028,13 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
10271028
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
10281029
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
10291030

1031+
const llama_kv_cache_unified_state_i * kv_state = static_cast<const llama_kv_cache_unified_state_i *>(mstate);
1032+
1033+
const llama_kv_cache_unified::compute_state * cstate = kv_state ? kv_state->get_cstate() : nullptr;
1034+
10301035
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
10311036

1032-
const auto n_kv = kv_self->get_n_kv();
1037+
const auto n_kv = kv_self->get_n_kv(cstate);
10331038

10341039
auto & cur = inp->pos_bucket;
10351040

@@ -1233,12 +1238,16 @@ ggml_tensor * llm_graph_context::build_attn(
12331238
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
12341239
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
12351240

1241+
const llama_kv_cache_unified_state_i * kv_state = static_cast<const llama_kv_cache_unified_state_i *>(mstate);
1242+
1243+
const llama_kv_cache_unified::compute_state * cstate = kv_state ? kv_state->get_cstate() : nullptr;
1244+
12361245
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
12371246

12381247
{
12391248
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
12401249

1241-
const auto n_kv = kv_self->get_n_kv();
1250+
const auto n_kv = kv_self->get_n_kv(cstate);
12421251

12431252
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
12441253
//cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1270,17 +1279,21 @@ ggml_tensor * llm_graph_context::build_attn(
12701279

12711280
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
12721281

1282+
const llama_kv_cache_unified_state_i * kv_state = static_cast<const llama_kv_cache_unified_state_i *>(mstate);
1283+
1284+
const llama_kv_cache_unified::compute_state * cstate = kv_state ? kv_state->get_cstate() : nullptr;
1285+
12731286
// store to KV cache
12741287
{
1275-
ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1276-
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1288+
ggml_build_forward_expand(gf, kv_self->cpy_k(cstate, ctx0, k_cur, il));
1289+
ggml_build_forward_expand(gf, kv_self->cpy_v(cstate, ctx0, v_cur, il));
12771290
}
12781291

12791292
const auto & kq_mask = inp->get_kq_mask();
12801293

12811294
ggml_tensor * q = q_cur;
1282-
ggml_tensor * k = kv_self->get_k(ctx0, il);
1283-
ggml_tensor * v = kv_self->get_v(ctx0, il);
1295+
ggml_tensor * k = kv_self->get_k(cstate, ctx0, il);
1296+
ggml_tensor * v = kv_self->get_v(cstate, ctx0, il);
12841297

12851298
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
12861299
cb(cur, "kqv_out", il);
@@ -1303,10 +1316,15 @@ ggml_tensor * llm_graph_context::build_attn(
13031316
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
13041317
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
13051318

1319+
const llama_kv_cache_unified_iswa_state_i * kv_state = static_cast<const llama_kv_cache_unified_iswa_state_i *>(mstate);
1320+
1321+
const llama_kv_cache_unified::compute_state * cstate_base = kv_state ? kv_state->get_cstate_base() : nullptr;
1322+
const llama_kv_cache_unified::compute_state * cstate_swa = kv_state ? kv_state->get_cstate_swa () : nullptr;
1323+
13061324
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
13071325

13081326
{
1309-
const auto n_kv = kv_self->get_kv_base()->get_n_kv();
1327+
const auto n_kv = kv_self->get_kv_base()->get_n_kv(cstate_base);
13101328

13111329
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
13121330
//cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1318,7 +1336,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13181336
{
13191337
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
13201338

1321-
const auto n_kv = kv_self->get_kv_swa()->get_n_kv();
1339+
const auto n_kv = kv_self->get_kv_swa()->get_n_kv(cstate_swa);
13221340

13231341
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
13241342
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@@ -1354,17 +1372,24 @@ ggml_tensor * llm_graph_context::build_attn(
13541372

13551373
const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
13561374

1375+
const llama_kv_cache_unified_iswa_state_i * kv_state = static_cast<const llama_kv_cache_unified_iswa_state_i *>(mstate);
1376+
1377+
const llama_kv_cache_unified::compute_state * cstate_base = kv_state ? kv_state->get_cstate_base() : nullptr;
1378+
const llama_kv_cache_unified::compute_state * cstate_swa = kv_state ? kv_state->get_cstate_swa () : nullptr;
1379+
1380+
const llama_kv_cache_unified::compute_state * cstate = is_swa ? cstate_swa : cstate_base;
1381+
13571382
// store to KV cache
13581383
{
1359-
ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1360-
ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
1384+
ggml_build_forward_expand(gf, kv->cpy_k(cstate, ctx0, k_cur, il));
1385+
ggml_build_forward_expand(gf, kv->cpy_v(cstate, ctx0, v_cur, il));
13611386
}
13621387

13631388
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
13641389

13651390
ggml_tensor * q = q_cur;
1366-
ggml_tensor * k = kv->get_k(ctx0, il);
1367-
ggml_tensor * v = kv->get_v(ctx0, il);
1391+
ggml_tensor * k = kv->get_k(cstate, ctx0, il);
1392+
ggml_tensor * v = kv->get_v(cstate, ctx0, il);
13681393

13691394
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
13701395
cb(cur, "kqv_out", il);

src/llama-graph.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ struct llama_ubatch;
1818
struct llama_cparams;
1919

2020
class llama_memory_i;
21+
class llama_memory_state_i;
22+
2123
class llama_kv_cache_unified;
2224
class llama_kv_cache_unified_iswa;
2325
class llama_kv_cache_recurrent;
@@ -383,10 +385,11 @@ struct llm_graph_params {
383385
ggml_backend_sched_t sched;
384386
ggml_backend_t backend_cpu;
385387

386-
const llama_adapter_cvec * cvec;
387-
const llama_adapter_loras * loras;
388-
const llama_memory_i * memory;
389-
const llama_cross * cross;
388+
const llama_adapter_cvec * cvec;
389+
const llama_adapter_loras * loras;
390+
const llama_memory_i * memory;
391+
const llama_memory_state_i * mstate;
392+
const llama_cross * cross;
390393

391394
int32_t n_outputs;
392395

@@ -435,10 +438,11 @@ struct llm_graph_context {
435438

436439
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
437440

438-
const llama_adapter_cvec * cvec;
439-
const llama_adapter_loras * loras;
440-
const llama_memory_i * memory;
441-
const llama_cross * cross;
441+
const llama_adapter_cvec * cvec;
442+
const llama_adapter_loras * loras;
443+
const llama_memory_i * memory;
444+
const llama_memory_state_i * mstate;
445+
const llama_cross * cross;
442446

443447
const llm_graph_cb & cb_func;
444448

0 commit comments

Comments
 (0)