Skip to content

Commit 3f55f78

Browse files
authored
llama : auto-batch preparation (#13845)
* llama : auto-batch ggml-ci * context : simplify if branching
1 parent 51fa76f commit 3f55f78

File tree

5 files changed

+65
-52
lines changed

5 files changed

+65
-52
lines changed

examples/parallel/parallel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ int main(int argc, char ** argv) {
392392
return 1;
393393
}
394394

395-
LOG_ERR("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
395+
LOG_WRN("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
396396

397397
n_cache_miss += 1;
398398

src/llama-context.cpp

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -424,28 +424,33 @@ const llama_kv_cache * llama_context::get_kv_self() const {
424424
return kv_self;
425425
}
426426

427-
void llama_context::kv_self_update() {
427+
bool llama_context::kv_self_update() {
428428
if (!memory) {
429-
return;
429+
return false;
430430
}
431431

432432
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
433433

434-
if (kv_self->update(*this)) {
435-
// if the KV cache did any computation, we have to reserve a new worst-case graph
436-
const auto kv_state = kv_self->init_full();
437-
if (!kv_state) {
438-
throw std::runtime_error("failed to initialize KV cache");
439-
}
434+
if (!kv_self->update(*this)) {
435+
// no updates have been performed
436+
return false;
437+
}
440438

441-
const uint32_t n_seqs = cparams.n_seq_max;
442-
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
439+
// if the KV cache did any computation, we have to reserve a new worst-case graph
440+
const auto kv_state = kv_self->init_full();
441+
if (!kv_state) {
442+
throw std::runtime_error("failed to initialize KV cache");
443+
}
443444

444-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
445-
if (!gf) {
446-
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
447-
}
445+
const uint32_t n_seqs = cparams.n_seq_max;
446+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
447+
448+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
449+
if (!gf) {
450+
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
448451
}
452+
453+
return true;
449454
}
450455

451456
enum llama_pooling_type llama_context::pooling_type() const {
@@ -933,24 +938,44 @@ int llama_context::decode(llama_batch & inp_batch) {
933938
// handle any pending defrags/shifts
934939
kv_self_update();
935940

936-
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
937-
if (!kv_state) {
938-
return -2;
939-
}
941+
llama_memory_state_ptr kv_state;
940942

941-
switch (kv_state->get_status()) {
942-
case LLAMA_MEMORY_STATUS_SUCCESS:
943-
{
944-
} break;
945-
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
946-
{
947-
// not a fatal error, we can re-try with a different batch
948-
return 1;
949-
}
950-
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
951-
{
952-
return -2;
953-
}
943+
bool did_defrag = false;
944+
945+
while (true) {
946+
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
947+
if (!kv_state) {
948+
return -2;
949+
}
950+
951+
switch (kv_state->get_status()) {
952+
case LLAMA_MEMORY_STATUS_SUCCESS:
953+
{
954+
} break;
955+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
956+
{
957+
if (!did_defrag) {
958+
did_defrag = true;
959+
960+
kv_self->defrag_sched(-1.0f);
961+
if (kv_self_update()) {
962+
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
963+
964+
continue;
965+
}
966+
}
967+
968+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
969+
970+
return 1;
971+
}
972+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
973+
{
974+
return -2;
975+
}
976+
}
977+
978+
break;
954979
}
955980

956981
// reserve output buffer
@@ -2646,22 +2671,8 @@ int32_t llama_encode(
26462671
int32_t llama_decode(
26472672
llama_context * ctx,
26482673
llama_batch batch) {
2649-
int ret = ctx->decode(batch);
2650-
2651-
// defrag and try again
2652-
// TODO: distinguish return code when we are sure that even after defrag there is no space available
2653-
if (ret == 1) {
2654-
llama_kv_self_defrag(ctx);
2655-
ret = ctx->decode(batch);
2656-
2657-
if (ret == 1) {
2658-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2659-
2660-
return ret;
2661-
}
2662-
}
2663-
2664-
if (ret != 0) {
2674+
const int ret = ctx->decode(batch);
2675+
if (ret != 0 && ret != 1) {
26652676
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
26662677
}
26672678

src/llama-context.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ struct llama_context {
5050
llama_kv_cache * get_kv_self();
5151
const llama_kv_cache * get_kv_self() const;
5252

53+
// return true of the KV cache was updated
5354
// TODO: remove
54-
void kv_self_update();
55+
bool kv_self_update();
5556

5657
enum llama_pooling_type pooling_type() const;
5758

src/llama-kv-cache.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,9 +1809,10 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
18091809
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
18101810
GGML_UNUSED(embd_pooled);
18111811

1812-
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
1812+
// TODO: if we fail with split_simple, we should attempt different splitting strategies
1813+
// but to do that properly, we first have to refactor the batches to be more flexible
18131814

1814-
// TODO: if we fail with split_simple, we should attempt split_equal
1815+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
18151816

18161817
std::vector<llama_ubatch> ubatches;
18171818

tools/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3431,7 +3431,7 @@ struct server_context {
34313431
// retry with half the batch size to try to find a free slot in the KV cache
34323432
n_batch /= 2;
34333433

3434-
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
3434+
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
34353435

34363436
continue; // continue loop of n_batch
34373437
}

0 commit comments

Comments
 (0)