Skip to content

Commit 9d05381

Browse files
committed
memory : extract state into llama_memory_state
ggml-ci
1 parent 47e570c commit 9d05381

8 files changed

+706
-373
lines changed

src/llama-context.cpp

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -274,13 +274,16 @@ llama_context::llama_context(
274274
// simulate full KV cache
275275
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
276276

277-
kv_self->set_full();
277+
const auto kv_state = kv_self->init_full();
278+
if (!kv_state) {
279+
throw std::runtime_error("failed to initialize KV cache");
280+
}
278281

279282
cross.v_embd.clear();
280283

281284
// reserve pp graph first so that buffers are only allocated once
282285
{
283-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens);
286+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
284287
if (!gf) {
285288
throw std::runtime_error("failed to allocate compute pp buffers");
286289
}
@@ -291,7 +294,7 @@ llama_context::llama_context(
291294

292295
// reserve with tg graph to get the number of splits and nodes
293296
{
294-
auto * gf = graph_reserve(1, 1, 1);
297+
auto * gf = graph_reserve(1, 1, 1, kv_state.get());
295298
if (!gf) {
296299
throw std::runtime_error("failed to allocate compute tg buffers");
297300
}
@@ -302,7 +305,7 @@ llama_context::llama_context(
302305

303306
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
304307
{
305-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens);
308+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
306309
if (!gf) {
307310
throw std::runtime_error("failed to allocate compute pp buffers");
308311
}
@@ -430,12 +433,15 @@ void llama_context::kv_self_update() {
430433

431434
if (kv_self->update(*this)) {
432435
// if the KV cache did any computation, we have to reserve a new worst-case graph
433-
kv_self->set_full();
436+
const auto kv_state = kv_self->init_full();
437+
if (!kv_state) {
438+
throw std::runtime_error("failed to initialize KV cache");
439+
}
434440

435441
const uint32_t n_seqs = cparams.n_seq_max;
436442
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
437443

438-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens);
444+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
439445
if (!gf) {
440446
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
441447
}
@@ -633,32 +639,32 @@ bool llama_context::apply_adapter_cvec(
633639
return cvec.apply(model, data, len, n_embd, il_start, il_end);
634640
}
635641

636-
llm_graph_result_ptr llama_context::process(const llama_ubatch & ubatch, llm_graph_type gtype, ggml_status * ret) {
642+
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
643+
if (mstate && !mstate->apply()) {
644+
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
645+
ret = GGML_STATUS_FAILED;
646+
return nullptr;
647+
}
648+
637649
auto * gf = graph_init();
638650
if (!gf) {
639651
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
640-
if (ret) {
641-
*ret = GGML_STATUS_FAILED;
642-
}
652+
ret = GGML_STATUS_FAILED;
643653
return nullptr;
644654
}
645655

646-
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype);
656+
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
647657
if (!res) {
648658
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
649-
if (ret) {
650-
*ret = GGML_STATUS_FAILED;
651-
}
659+
ret = GGML_STATUS_FAILED;
652660
return nullptr;
653661
}
654662

655663
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
656664

657665
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
658666
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
659-
if (ret) {
660-
*ret = GGML_STATUS_ALLOC_FAILED;
661-
}
667+
ret = GGML_STATUS_ALLOC_FAILED;
662668
return nullptr;
663669
}
664670

@@ -667,12 +673,12 @@ llm_graph_result_ptr llama_context::process(const llama_ubatch & ubatch, llm_gra
667673
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
668674
if (status != GGML_STATUS_SUCCESS) {
669675
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
670-
if (ret) {
671-
*ret = status;
672-
}
676+
ret = status;
673677
return nullptr;
674678
}
675679

680+
ret = GGML_STATUS_SUCCESS;
681+
676682
return res;
677683
}
678684

@@ -748,7 +754,7 @@ int llama_context::encode(llama_batch & inp_batch) {
748754
cparams.causal_attn = false;
749755

750756
ggml_status status;
751-
auto res = process(ubatch, LLM_GRAPH_TYPE_ENCODER, &status);
757+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
752758

753759
cparams.causal_attn = causal_attn_org;
754760

@@ -927,12 +933,12 @@ int llama_context::decode(llama_batch & inp_batch) {
927933
// handle any pending defrags/shifts
928934
kv_self_update();
929935

930-
auto decode_state = kv_self->init(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
931-
if (!decode_state) {
936+
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
937+
if (!kv_state) {
932938
return -2;
933939
}
934940

935-
switch (decode_state->get_status()) {
941+
switch (kv_state->get_status()) {
936942
case LLAMA_MEMORY_STATUS_SUCCESS:
937943
{
938944
} break;
@@ -955,8 +961,8 @@ int llama_context::decode(llama_batch & inp_batch) {
955961

956962
int64_t n_outputs_prev = 0;
957963

958-
while (const auto * ubatch_ptr = decode_state->next()) {
959-
const auto & ubatch = *ubatch_ptr;
964+
do {
965+
const auto & ubatch = kv_state->get_ubatch();
960966

961967
// count the outputs in this u_batch
962968
{
@@ -979,7 +985,7 @@ int llama_context::decode(llama_batch & inp_batch) {
979985
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
980986

981987
ggml_status status;
982-
auto res = process(ubatch, LLM_GRAPH_TYPE_DECODER, &status);
988+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
983989

984990
if (!res) {
985991
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1092,7 +1098,7 @@ int llama_context::decode(llama_batch & inp_batch) {
10921098
}
10931099

10941100
n_outputs_prev += n_outputs;
1095-
}
1101+
} while (kv_state->next());
10961102

10971103
// set to total number of outputs in the batch, for use in llama_get_logits_ith
10981104
n_outputs = n_outputs_all;
@@ -1101,7 +1107,7 @@ int llama_context::decode(llama_batch & inp_batch) {
11011107
{
11021108
bool sorted_output = true;
11031109

1104-
auto & out_ids = decode_state->out_ids();
1110+
auto & out_ids = kv_state->out_ids();
11051111

11061112
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
11071113

@@ -1261,7 +1267,7 @@ ggml_cgraph * llama_context::graph_init() {
12611267
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
12621268
}
12631269

1264-
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) {
1270+
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
12651271
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
12661272

12671273
if (n_tokens % n_seqs != 0) {
@@ -1281,7 +1287,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
12811287
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
12821288

12831289
auto * gf = graph_init();
1284-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1290+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
12851291

12861292
this->n_outputs = save_n_outputs;
12871293

@@ -1302,10 +1308,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13021308
}
13031309

13041310
llm_graph_result_ptr llama_context::graph_build(
1305-
ggml_context * ctx,
1306-
ggml_cgraph * gf,
1307-
const llama_ubatch & ubatch,
1308-
llm_graph_type gtype) {
1311+
ggml_context * ctx,
1312+
ggml_cgraph * gf,
1313+
const llama_ubatch & ubatch,
1314+
llm_graph_type gtype,
1315+
const llama_memory_state_i * mstate) {
13091316
return model.build_graph(
13101317
{
13111318
/*.ctx =*/ ctx,
@@ -1317,7 +1324,7 @@ llm_graph_result_ptr llama_context::graph_build(
13171324
/*.backend_cpu =*/ backend_cpu,
13181325
/*.cvec =*/ &cvec,
13191326
/*.loras =*/ &loras,
1320-
/*.memory =*/ memory.get(),
1327+
/*.mstate =*/ mstate,
13211328
/*.cross =*/ &cross,
13221329
/*.n_outputs =*/ n_outputs,
13231330
/*.cb =*/ graph_get_cb(),
@@ -2020,8 +2027,8 @@ void llama_context::opt_epoch_iter(
20202027

20212028
int64_t n_outputs_all = n_tokens_all;
20222029

2023-
auto decode_state = kv_self->init(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
2024-
if (!decode_state || decode_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2030+
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
2031+
if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
20252032
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
20262033
break;
20272034
}
@@ -2033,13 +2040,13 @@ void llama_context::opt_epoch_iter(
20332040
};
20342041

20352042
uint32_t pos_batch = 0;
2036-
while (const auto * ubatch_ptr = decode_state->next()) {
2037-
const auto & ubatch = *ubatch_ptr;
2043+
do {
2044+
const auto & ubatch = kv_state->get_ubatch();
20382045

20392046
n_outputs = ubatch.n_tokens;
20402047

20412048
auto * gf = graph_init();
2042-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
2049+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
20432050

20442051
struct ggml_context * ctx_compute_opt;
20452052
{
@@ -2073,7 +2080,7 @@ void llama_context::opt_epoch_iter(
20732080
ggml_free(ctx_compute_opt);
20742081

20752082
pos_batch += ubatch.n_tokens;
2076-
}
2083+
} while (kv_state->next());
20772084
}
20782085
}
20792086

src/llama-context.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ struct llama_kv_cache;
1818
class llama_io_read_i;
1919
class llama_io_write_i;
2020

21+
class llama_memory_i;
22+
class llama_memory_state_i;
23+
2124
struct llama_context {
2225
// init scheduler and compute buffers, reserve worst-case graphs
2326
llama_context(
@@ -90,12 +93,14 @@ struct llama_context {
9093
int32_t il_end);
9194

9295
// process a single ubatch with a specific graph type
96+
// if memory_state is provided, it will be applied first to the context's memory
9397
// ret contains the status of the graph computation
9498
// returns nullptr only if ret != GGML_STATUS_SUCCESS
95-
llm_graph_result_ptr process(
96-
const llama_ubatch & ubatch,
97-
llm_graph_type gtype,
98-
ggml_status * ret);
99+
llm_graph_result_ptr process_ubatch(
100+
const llama_ubatch & ubatch,
101+
llm_graph_type gtype,
102+
llama_memory_state_i * mstate,
103+
ggml_status & ret);
99104

100105
int encode(llama_batch & inp_batch);
101106
int decode(llama_batch & inp_batch);
@@ -192,14 +197,15 @@ struct llama_context {
192197
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
193198

194199
// reserve a graph with a dummy ubatch of the specified size
195-
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs);
200+
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
196201

197202
private:
198203
llm_graph_result_ptr graph_build(
199-
ggml_context * ctx,
200-
ggml_cgraph * gf,
201-
const llama_ubatch & ubatch,
202-
llm_graph_type gtype);
204+
ggml_context * ctx,
205+
ggml_cgraph * gf,
206+
const llama_ubatch & ubatch,
207+
llm_graph_type gtype,
208+
const llama_memory_state_i * mstate);
203209

204210
llm_graph_cb graph_get_cb() const;
205211

0 commit comments

Comments
 (0)