Skip to content

Commit a592c13

Browse files
committed
kv-cache : make the unified implementation more stateless
ggml-ci
1 parent a3ebf0a commit a592c13

6 files changed

+205
-117
lines changed

src/llama-context.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,15 @@ bool llama_context::apply_adapter_cvec(
633633
return cvec.apply(model, data, len, n_embd, il_start, il_end);
634634
}
635635

636-
llm_graph_result_ptr llama_context::process(const llama_ubatch & ubatch, llm_graph_type gtype, ggml_status * ret) {
636+
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status * ret) {
637+
if (mstate && !mstate->apply()) {
638+
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
639+
if (ret) {
640+
*ret = GGML_STATUS_FAILED;
641+
}
642+
return nullptr;
643+
}
644+
637645
auto * gf = graph_init();
638646
if (!gf) {
639647
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
@@ -748,7 +756,7 @@ int llama_context::encode(llama_batch & inp_batch) {
748756
cparams.causal_attn = false;
749757

750758
ggml_status status;
751-
auto res = process(ubatch, LLM_GRAPH_TYPE_ENCODER, &status);
759+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, &status);
752760

753761
cparams.causal_attn = causal_attn_org;
754762

@@ -927,12 +935,12 @@ int llama_context::decode(llama_batch & inp_batch) {
927935
// handle any pending defrags/shifts
928936
kv_self_update();
929937

930-
auto decode_state = kv_self->init(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
931-
if (!decode_state) {
938+
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
939+
if (!kv_state) {
932940
return -2;
933941
}
934942

935-
switch (decode_state->get_status()) {
943+
switch (kv_state->get_status()) {
936944
case LLAMA_MEMORY_STATUS_SUCCESS:
937945
{
938946
} break;
@@ -955,8 +963,8 @@ int llama_context::decode(llama_batch & inp_batch) {
955963

956964
int64_t n_outputs_prev = 0;
957965

958-
while (const auto * ubatch_ptr = decode_state->next()) {
959-
const auto & ubatch = *ubatch_ptr;
966+
do {
967+
const auto & ubatch = kv_state->get_ubatch();
960968

961969
// count the outputs in this u_batch
962970
{
@@ -979,7 +987,7 @@ int llama_context::decode(llama_batch & inp_batch) {
979987
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
980988

981989
ggml_status status;
982-
auto res = process(ubatch, LLM_GRAPH_TYPE_DECODER, &status);
990+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), &status);
983991

984992
if (!res) {
985993
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1092,7 +1100,7 @@ int llama_context::decode(llama_batch & inp_batch) {
10921100
}
10931101

10941102
n_outputs_prev += n_outputs;
1095-
}
1103+
} while (kv_state->next());
10961104

10971105
// set to total number of outputs in the batch, for use in llama_get_logits_ith
10981106
n_outputs = n_outputs_all;
@@ -1101,7 +1109,7 @@ int llama_context::decode(llama_batch & inp_batch) {
11011109
{
11021110
bool sorted_output = true;
11031111

1104-
auto & out_ids = decode_state->out_ids();
1112+
auto & out_ids = kv_state->out_ids();
11051113

11061114
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
11071115

@@ -2020,8 +2028,8 @@ void llama_context::opt_epoch_iter(
20202028

20212029
int64_t n_outputs_all = n_tokens_all;
20222030

2023-
auto decode_state = kv_self->init(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
2024-
if (!decode_state || decode_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2031+
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
2032+
if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
20252033
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
20262034
break;
20272035
}
@@ -2033,8 +2041,8 @@ void llama_context::opt_epoch_iter(
20332041
};
20342042

20352043
uint32_t pos_batch = 0;
2036-
while (const auto * ubatch_ptr = decode_state->next()) {
2037-
const auto & ubatch = *ubatch_ptr;
2044+
do {
2045+
const auto & ubatch = kv_state->get_ubatch();
20382046

20392047
n_outputs = ubatch.n_tokens;
20402048

@@ -2073,7 +2081,7 @@ void llama_context::opt_epoch_iter(
20732081
ggml_free(ctx_compute_opt);
20742082

20752083
pos_batch += ubatch.n_tokens;
2076-
}
2084+
} while (kv_state->next());
20772085
}
20782086
}
20792087

src/llama-context.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ struct llama_kv_cache;
1818
class llama_io_read_i;
1919
class llama_io_write_i;
2020

21+
class llama_memory_state_i;
22+
2123
struct llama_context {
2224
// init scheduler and compute buffers, reserve worst-case graphs
2325
llama_context(
@@ -90,12 +92,14 @@ struct llama_context {
9092
int32_t il_end);
9193

9294
// process a single ubatch with a specific graph type
95+
// if memory_state is provided, it will be applied first to the context's memory
9396
// ret contains the status of the graph computation
9497
// returns nullptr only if ret != GGML_STATUS_SUCCESS
95-
llm_graph_result_ptr process(
96-
const llama_ubatch & ubatch,
97-
llm_graph_type gtype,
98-
ggml_status * ret);
98+
llm_graph_result_ptr process_ubatch(
99+
const llama_ubatch & ubatch,
100+
llm_graph_type gtype,
101+
llama_memory_state_i * mstate,
102+
ggml_status * ret);
99103

100104
int encode(llama_batch & inp_batch);
101105
int decode(llama_batch & inp_batch);

src/llama-graph.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
10291029

10301030
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
10311031

1032-
const auto n_kv = kv_self->get_n();
1032+
const auto n_kv = kv_self->get_n_kv();
10331033

10341034
auto & cur = inp->pos_bucket;
10351035

@@ -1238,7 +1238,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12381238
{
12391239
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
12401240

1241-
const auto n_kv = kv_self->get_n();
1241+
const auto n_kv = kv_self->get_n_kv();
12421242

12431243
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
12441244
//cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1306,7 +1306,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13061306
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
13071307

13081308
{
1309-
const auto n_kv = kv_self->get_kv_base()->get_n();
1309+
const auto n_kv = kv_self->get_kv_base()->get_n_kv();
13101310

13111311
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
13121312
//cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1318,7 +1318,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13181318
{
13191319
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
13201320

1321-
const auto n_kv = kv_self->get_kv_swa()->get_n();
1321+
const auto n_kv = kv_self->get_kv_swa()->get_n_kv();
13221322

13231323
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
13241324
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);

0 commit comments

Comments
 (0)