Skip to content

Commit 1bce7e8

Browse files
committed
llama : use n_swa + n_ubatch cells for SWA cache + auto-batch
ggml-ci
1 parent 2252eef commit 1bce7e8

File tree

9 files changed

+111
-174
lines changed

9 files changed

+111
-174
lines changed

examples/parallel/parallel.cpp

Lines changed: 9 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ int main(int argc, char ** argv) {
164164
return 1;
165165
}
166166

167+
params.n_batch = params.n_ctx;
168+
167169
common_init();
168170

169171
// number of simultaneous "clients" to simulate
@@ -356,59 +358,23 @@ int main(int argc, char ** argv) {
356358
break;
357359
}
358360

359-
// process in chunks of params.n_batch
360-
int32_t n_batch = params.n_batch;
361-
362-
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
363-
// experiment: process in powers of 2
364-
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
365-
// n_batch /= 2;
366-
// i -= n_batch;
367-
// continue;
368-
//}
369-
370-
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
371-
372-
llama_batch batch_view = {
373-
n_tokens,
374-
batch.token + i,
375-
nullptr,
376-
batch.pos + i,
377-
batch.n_seq_id + i,
378-
batch.seq_id + i,
379-
batch.logits + i,
380-
};
381-
382-
const int ret = llama_decode(ctx, batch_view);
383-
if (ret != 0) {
384-
if (n_batch == 1 || ret < 0) {
385-
// if you get here, it means the KV cache is full - try increasing it via the context size
386-
LOG_ERR("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
387-
return 1;
388-
}
389-
390-
LOG_ERR("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
391-
392-
n_cache_miss += 1;
393-
394-
// retry with half the batch size to try to find a free slot in the KV cache
395-
n_batch /= 2;
396-
i -= n_batch;
397-
398-
continue;
361+
{
362+
if (const auto ret = llama_decode(ctx, batch) != 0) {
363+
LOG_ERR("%s : failed to decode the batch, n_tokens = %d, ret = %d\n", __func__, batch.n_tokens, ret);
364+
return 1;
399365
}
400366

401-
LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens);
367+
LOG_DBG("%s : decoded batch of %d tokens\n", __func__, batch.n_tokens);
402368

403369
for (auto & client : clients) {
404-
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
370+
if (client.seq_id == -1) {
405371
continue;
406372
}
407373

408374
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
409375
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
410376

411-
const llama_token id = common_sampler_sample(client.smpl, ctx, client.i_batch - i);
377+
const llama_token id = common_sampler_sample(client.smpl, ctx, client.i_batch);
412378

413379
common_sampler_accept(client.smpl, id, true);
414380

src/llama-context.cpp

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -420,9 +420,9 @@ const llama_kv_cache * llama_context::get_kv_self() const {
420420
return kv_self;
421421
}
422422

423-
void llama_context::kv_self_update() {
423+
bool llama_context::kv_self_update() {
424424
if (!memory) {
425-
return;
425+
return false;
426426
}
427427

428428
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
@@ -438,7 +438,11 @@ void llama_context::kv_self_update() {
438438
if (!gf) {
439439
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
440440
}
441+
442+
return true;
441443
}
444+
445+
return false;
442446
}
443447

444448
enum llama_pooling_type llama_context::pooling_type() const {
@@ -891,25 +895,53 @@ int llama_context::decode(llama_batch & inp_batch) {
891895
// handle any pending defrags/shifts
892896
kv_self_update();
893897

894-
auto decode_state = kv_self->init(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
895-
if (!decode_state) {
896-
return -2;
897-
}
898+
llama_memory_decode_state_ptr decode_state;
898899

899-
switch (decode_state->get_status()) {
900-
case LLAMA_MEMORY_STATUS_SUCCESS:
901-
{
902-
} break;
903-
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
904-
{
905-
// not a fatal error, we can re-try with a different batch
906-
return 1;
907-
}
908-
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
909-
{
910-
return -2;
911-
}
912-
}
900+
bool did_defrag = false;
901+
auto n_ubatch = cparams.n_ubatch;
902+
903+
do {
904+
decode_state = kv_self->init(batch, n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
905+
if (!decode_state) {
906+
return -2;
907+
}
908+
909+
switch (decode_state->get_status()) {
910+
case LLAMA_MEMORY_STATUS_SUCCESS:
911+
{
912+
} break;
913+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
914+
{
915+
if (!did_defrag) {
916+
did_defrag = true;
917+
918+
kv_self->defrag_sched(-1.0f);
919+
if (kv_self_update()) {
920+
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
921+
922+
continue;
923+
}
924+
}
925+
926+
if (n_ubatch > 1) {
927+
n_ubatch /= 2;
928+
929+
LLAMA_LOG_DEBUG("%s: failed to find free space in the KV cache, retrying with smaller ubatch size: n_ubatch = %d\n", __func__, n_ubatch);
930+
continue;
931+
}
932+
933+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
934+
935+
return 1;
936+
}
937+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
938+
{
939+
return -2;
940+
}
941+
}
942+
943+
break;
944+
} while(true);
913945

914946
// reserve output buffer
915947
if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -2588,22 +2620,8 @@ int32_t llama_encode(
25882620
int32_t llama_decode(
25892621
llama_context * ctx,
25902622
llama_batch batch) {
2591-
int ret = ctx->decode(batch);
2592-
2593-
// defrag and try again
2594-
// TODO: distinguish return code when we are sure that even after defrag there is no space available
2595-
if (ret == 1) {
2596-
llama_kv_self_defrag(ctx);
2597-
ret = ctx->decode(batch);
2598-
2599-
if (ret == 1) {
2600-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2601-
2602-
return ret;
2603-
}
2604-
}
2605-
2606-
if (ret != 0) {
2623+
const int ret = ctx->decode(batch);
2624+
if (ret != 0 && ret != 1) {
26072625
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
26082626
}
26092627

src/llama-context.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ struct llama_context {
4747
llama_kv_cache * get_kv_self();
4848
const llama_kv_cache * get_kv_self() const;
4949

50+
// return true of the KV cache was updated
5051
// TODO: remove
51-
void kv_self_update();
52+
bool kv_self_update();
5253

5354
enum llama_pooling_type pooling_type() const;
5455

src/llama-kv-cache.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,14 +1738,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
17381738
bool swa_full,
17391739
uint32_t kv_size,
17401740
uint32_t n_seq_max,
1741-
uint32_t n_batch,
1741+
uint32_t n_ubatch,
17421742
uint32_t n_pad) : hparams(model.hparams) {
17431743
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
17441744
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
17451745

17461746
const uint32_t size_base = kv_size;
17471747

1748-
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
1748+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
17491749

17501750
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
17511751
if (swa_full) {

src/llama-kv-cache.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
251251
bool swa_full,
252252
uint32_t kv_size,
253253
uint32_t n_seq_max,
254-
uint32_t n_batch,
254+
uint32_t n_ubatch,
255255
uint32_t n_pad);
256256

257257
~llama_kv_cache_unified_iswa() = default;

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13234,7 +13234,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1323413234
params.swa_full,
1323513235
cparams.n_ctx,
1323613236
cparams.n_seq_max,
13237-
cparams.n_batch,
13237+
cparams.n_ubatch,
1323813238
padding);
1323913239
} else {
1324013240
GGML_ASSERT(!hparams.is_swa_any());

tools/batched-bench/batched-bench.cpp

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ int main(int argc, char ** argv) {
2121
return 1;
2222
}
2323

24+
params.n_batch = params.n_ctx;
25+
2426
common_init();
2527

2628
int is_pp_shared = params.is_pp_shared;
@@ -61,48 +63,21 @@ int main(int argc, char ** argv) {
6163

6264
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
6365

64-
// decode in batches of ctx_params.n_batch tokens
65-
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
66-
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
67-
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
68-
69-
llama_batch batch_view = {
70-
n_tokens,
71-
batch.token + i,
72-
nullptr,
73-
batch.pos + i,
74-
batch.n_seq_id + i,
75-
batch.seq_id + i,
76-
batch.logits + i,
77-
};
78-
79-
const int ret = llama_decode(ctx, batch_view);
80-
if (ret != 0) {
81-
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
82-
return false;
83-
}
84-
85-
llama_synchronize(ctx);
86-
}
87-
88-
return true;
89-
};
90-
9166
// warm up
9267
{
9368
for (int i = 0; i < 16; ++i) {
9469
common_batch_add(batch, 0, i, { 0 }, false);
9570
}
9671

97-
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
98-
LOG_ERR("%s: llama_decode() failed\n", __func__);
72+
if (const auto ret = llama_decode(ctx, batch)) {
73+
LOG_ERR("%s: llama_decode() failed, ret = %d\n", __func__, ret);
9974
return 1;
10075
}
10176
}
10277

10378
if (!params.batched_bench_output_jsonl) {
10479
LOG("\n");
105-
LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
80+
LOG("%s: n_kv_max = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
10681
LOG("\n");
10782
LOG("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
10883
LOG("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------");
@@ -134,9 +109,11 @@ int main(int argc, char ** argv) {
134109

135110
llama_kv_self_clear(ctx);
136111

137-
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
138-
LOG_ERR("%s: llama_decode() failed\n", __func__);
139-
return 1;
112+
if (batch.n_tokens > 0) {
113+
if (const auto ret = llama_decode(ctx, batch) != 0) {
114+
LOG_ERR("%s: llama_decode() failed, ret = %d\n", __func__, ret);
115+
return 1;
116+
}
140117
}
141118

142119
if (is_pp_shared) {
@@ -156,8 +133,8 @@ int main(int argc, char ** argv) {
156133
common_batch_add(batch, 0, pp + i, { j }, true);
157134
}
158135

159-
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
160-
LOG_ERR("%s: llama_decode() failed\n", __func__);
136+
if (const auto ret = llama_decode(ctx, batch) != 0) {
137+
LOG_ERR("%s: llama_decode() failed, ret = %d\n", __func__, ret);
161138
return 1;
162139
}
163140
}

tools/perplexity/perplexity.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
856856
double acc = 0.0f;
857857

858858
const int n_ctx = llama_n_ctx(ctx);
859-
const int n_batch = params.n_batch;
859+
const int n_batch = n_ctx;
860860

861861
const int n_vocab = llama_vocab_n_tokens(vocab);
862862

@@ -1154,7 +1154,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11541154
LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__);
11551155

11561156
const int n_ctx = llama_n_ctx(ctx);
1157-
const int n_batch = params.n_batch;
1157+
const int n_batch = n_ctx;
11581158

11591159
const int n_vocab = llama_vocab_n_tokens(vocab);
11601160

@@ -1508,7 +1508,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15081508
LOG("\ntask\tacc_norm\n");
15091509

15101510
const int n_ctx = llama_n_ctx(ctx);
1511-
const int n_batch = params.n_batch;
1511+
const int n_batch = n_ctx;
15121512

15131513
const int n_vocab = llama_vocab_n_tokens(vocab);
15141514

@@ -1732,7 +1732,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
17321732
return;
17331733
}
17341734

1735-
const int n_batch = params.n_batch;
1735+
const int n_batch = params.n_ctx;
17361736
const int num_batches = (n_ctx + n_batch - 1)/n_batch;
17371737
const int nv = 2*((n_vocab + 1)/2) + 4;
17381738
const bool add_bos = llama_vocab_get_add_bos(vocab);
@@ -1982,12 +1982,13 @@ int main(int argc, char ** argv) {
19821982
common_init();
19831983

19841984
const int32_t n_ctx = params.n_ctx;
1985-
19861985
if (n_ctx <= 0) {
19871986
LOG_ERR("%s: perplexity tool requires '--ctx-size' > 0\n", __func__);
19881987
return 1;
19891988
}
19901989

1990+
params.n_batch = n_ctx;
1991+
19911992
const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
19921993

19931994
if (ppl) {

0 commit comments

Comments
 (0)