Skip to content

Commit c1e7f48

Browse files
threadpool: add persistent threadpool for llama-bench
1 parent e771674 commit c1e7f48

File tree

1 file changed

+27
-10
lines changed

1 file changed

+27
-10
lines changed

examples/llama-bench/llama-bench.cpp

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,8 +1210,7 @@ struct sql_printer : public printer {
12101210
}
12111211
};
12121212

1213-
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
1214-
llama_set_n_threads(ctx, n_threads, n_threads);
1213+
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch) {
12151214

12161215
const llama_model * model = llama_get_model(ctx);
12171216
const int32_t n_vocab = llama_n_vocab(model);
@@ -1233,9 +1232,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
12331232
llama_synchronize(ctx);
12341233
}
12351234

1236-
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
1237-
llama_set_n_threads(ctx, n_threads, n_threads);
1238-
1235+
static void test_gen(llama_context * ctx, int n_gen, int n_past) {
12391236
const llama_model * model = llama_get_model(ctx);
12401237
const int32_t n_vocab = llama_n_vocab(model);
12411238

@@ -1332,13 +1329,31 @@ int main(int argc, char ** argv) {
13321329

13331330
llama_kv_cache_clear(ctx);
13341331

1332+
struct ggml_threadpool_params tpp;
1333+
tpp.n_threads = t.n_threads;
1334+
1335+
// TODO: expose these via cli opts
1336+
tpp.mask_specified = false;
1337+
tpp.strict_cpu = false;
1338+
tpp.prio = 1;
1339+
tpp.poll = false;
1340+
1341+
struct ggml_compute_threadpool * threadpool = ggml_create_threadpool(&tpp);
1342+
if (!threadpool) {
1343+
LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
1344+
exit(1);
1345+
}
1346+
1347+
llama_set_n_threads(ctx, t.n_threads, t.n_threads);
1348+
llama_attach_threadpool(ctx, threadpool);
1349+
13351350
// warmup run
13361351
if (t.n_prompt > 0) {
1337-
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
1338-
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
1352+
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch);
1353+
test_prompt(ctx, t.n_prompt, 0, t.n_batch);
13391354
}
13401355
if (t.n_gen > 0) {
1341-
test_gen(ctx, 1, 0, t.n_threads);
1356+
test_gen(ctx, 1, 0);
13421357
}
13431358

13441359
for (int i = 0; i < params.reps; i++) {
@@ -1347,10 +1362,10 @@ int main(int argc, char ** argv) {
13471362
uint64_t t_start = get_time_ns();
13481363

13491364
if (t.n_prompt > 0) {
1350-
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
1365+
test_prompt(ctx, t.n_prompt, 0, t.n_batch);
13511366
}
13521367
if (t.n_gen > 0) {
1353-
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
1368+
test_gen(ctx, t.n_gen, t.n_prompt);
13541369
}
13551370

13561371
uint64_t t_ns = get_time_ns() - t_start;
@@ -1362,6 +1377,8 @@ int main(int argc, char ** argv) {
13621377
llama_print_timings(ctx);
13631378

13641379
llama_free(ctx);
1380+
1381+
ggml_release_threadpool(threadpool);
13651382
}
13661383

13671384
llama_free_model(lmodel);

0 commit comments

Comments
 (0)