Skip to content

Commit 5e8f776

Browse files
committed
llama : move llama_batch backward-compat function to common
1 parent 1788077 commit 5e8f776

File tree

10 files changed

+124
-98
lines changed

10 files changed

+124
-98
lines changed

common/common.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,6 +1856,83 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
18561856
return result;
18571857
}
18581858

1859+
//
1860+
// Compatibility with old API
1861+
//
1862+
1863+
struct llama_batch llama_batch_get_one(
1864+
llama_token * tokens,
1865+
int32_t n_tokens,
1866+
llama_pos pos_0,
1867+
llama_seq_id seq_id,
1868+
bool logits_all) {
1869+
// because old API does not call llama_batch_free,
1870+
// we assume that batches generated by llama_batch_get_one is a singleton
1871+
static std::vector<llama_pos> pos;
1872+
static std::vector<int32_t> n_seq_id;
1873+
static std::array <llama_seq_id, 1> seq_id_0;
1874+
static std::vector<llama_seq_id *> seq_ids;
1875+
static std::vector<int8_t> logits;
1876+
pos .resize(n_tokens);
1877+
n_seq_id.resize(n_tokens);
1878+
seq_ids .resize(n_tokens + 1);
1879+
logits .resize(n_tokens);
1880+
seq_id_0[0] = seq_id;
1881+
seq_ids [n_tokens] = nullptr;
1882+
llama_batch batch = {
1883+
/*n_tokens =*/ 0,
1884+
/*tokens =*/ tokens,
1885+
/*embd =*/ nullptr,
1886+
/*pos =*/ pos.data(),
1887+
/*n_seq_id =*/ n_seq_id.data(),
1888+
/*seq_id =*/ seq_ids.data(),
1889+
/*logits =*/ logits.data(),
1890+
};
1891+
for (int i = 0; i < n_tokens; i++) {
1892+
batch.seq_id[i] = seq_id_0.data();
1893+
bool logits = logits_all || i == n_tokens - 1;
1894+
llama_batch_add(batch, tokens[i], pos_0 + i, { seq_id }, logits);
1895+
}
1896+
return batch;
1897+
}
1898+
1899+
struct llama_batch llama_batch_get_one(
1900+
float * embd,
1901+
int32_t n_tokens,
1902+
llama_pos pos_0,
1903+
llama_seq_id seq_id,
1904+
bool logits_all) {
1905+
// because old API does not call llama_batch_free,
1906+
// we assume that batches generated by llama_batch_get_one is a singleton
1907+
static std::vector<llama_pos> pos;
1908+
static std::vector<int32_t> n_seq_id;
1909+
static std::array <llama_seq_id, 1> seq_id_0;
1910+
static std::vector<llama_seq_id *> seq_ids;
1911+
static std::vector<int8_t> logits;
1912+
pos .resize(n_tokens);
1913+
n_seq_id.resize(n_tokens);
1914+
seq_ids .resize(n_tokens + 1);
1915+
logits .resize(n_tokens);
1916+
seq_id_0[0] = seq_id;
1917+
seq_ids [n_tokens] = nullptr;
1918+
llama_batch batch = {
1919+
/*n_tokens =*/ n_tokens,
1920+
/*tokens =*/ nullptr,
1921+
/*embd =*/ embd,
1922+
/*pos =*/ pos.data(),
1923+
/*n_seq_id =*/ n_seq_id.data(),
1924+
/*seq_id =*/ seq_ids.data(),
1925+
/*logits =*/ logits.data(),
1926+
};
1927+
for (int i = 0; i < n_tokens; i++) {
1928+
batch.pos [i] = pos_0 + i;
1929+
batch.n_seq_id[i] = 1;
1930+
batch.seq_id [i] = seq_id_0.data();
1931+
batch.logits [i] = logits_all || i == n_tokens - 1;
1932+
}
1933+
return batch;
1934+
}
1935+
18591936
//
18601937
// YAML utils
18611938
//

common/common.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,26 @@ static const char * const LLM_KV_SPLIT_NO = "split.no";
542542
static const char * const LLM_KV_SPLIT_COUNT = "split.count";
543543
static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
544544

545+
//
546+
// Compatibility with old API
547+
//
548+
549+
// Return batch for single sequence of tokens starting at pos_0
550+
struct llama_batch llama_batch_get_one(
551+
llama_token * tokens,
552+
int32_t n_tokens,
553+
llama_pos pos_0,
554+
llama_seq_id seq_id,
555+
bool logits_all = false);
556+
557+
// Return batch for single sequence of embeddings starting at pos_0
558+
struct llama_batch llama_batch_get_one(
559+
float * embd,
560+
int32_t n_tokens,
561+
llama_pos pos_0,
562+
llama_seq_id seq_id,
563+
bool logits_all = false);
564+
545565
//
546566
// YAML utils
547567
//

examples/batched-bench/batched-bench.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ int main(int argc, char ** argv) {
7474
batch.n_seq_id + i,
7575
batch.seq_id + i,
7676
batch.logits + i,
77-
0, 0, 0, // unused
7877
};
7978

8079
const int ret = llama_decode(ctx, batch_view);

examples/imatrix/imatrix.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
508508
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
509509
}
510510

511-
// TODO: use batch.logits to save computations instead of relying on logits_all == true
512-
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
511+
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0, true))) {
513512
LOG_ERR("%s : failed to eval\n", __func__);
514513
return false;
515514
}

examples/llava/llava.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "llava.h"
33

44
#include "llama.h"
5+
#include "common.h"
56

67
#include <algorithm>
78
#include <cerrno>
@@ -409,7 +410,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
409410
if (n_eval > n_batch) {
410411
n_eval = n_batch;
411412
}
412-
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
413+
float * embd = image_embed->embed+i*n_embd;
414+
llama_batch batch = llama_batch_get_one(embd, n_eval, *n_past, 0);
413415
if (llama_decode(ctx_llama, batch)) {
414416
LOG_ERR("%s : failed to eval\n", __func__);
415417
return false;

examples/parallel/parallel.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,6 @@ int main(int argc, char ** argv) {
308308
batch.n_seq_id + i,
309309
batch.seq_id + i,
310310
batch.logits + i,
311-
0, 0, 0, // unused
312311
};
313312

314313
const int ret = llama_decode(ctx, batch_view);

examples/perplexity/perplexity.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
410410
const int batch_size = std::min(end - batch_start, n_batch);
411411

412412
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
413-
// TODO: use llama_batch.logits instead of relying on logits_all == true
414-
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
413+
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0, true))) {
415414
//LOG_ERR("%s : failed to eval\n", __func__);
416415
return {tokens, -1, logit_history, prob_history};
417416
}
@@ -699,7 +698,6 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
699698
batch.n_seq_id + i,
700699
batch.seq_id + i,
701700
batch.logits + i,
702-
0, 0, 0, // unused
703701
};
704702

705703
const int ret = llama_decode(ctx, batch_view);
@@ -1790,8 +1788,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
17901788
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
17911789
}
17921790

1793-
// TODO: use llama_batch.logits instead of relying on logits_all == true
1794-
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
1791+
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0, true))) {
17951792
LOG_ERR("%s : failed to eval\n", __func__);
17961793
return;
17971794
}

examples/server/server.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2283,7 +2283,6 @@ struct server_context {
22832283
batch.n_seq_id + i,
22842284
batch.seq_id + i,
22852285
batch.logits + i,
2286-
0, 0, 0, // unused
22872286
};
22882287

22892288
const int ret = llama_decode(ctx, batch_view);

include/llama.h

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -244,15 +244,6 @@ extern "C" {
244244
int32_t * n_seq_id;
245245
llama_seq_id ** seq_id;
246246
int8_t * logits; // TODO: rename this to "output"
247-
248-
// NOTE: helpers for smooth API transition - can be deprecated in the future
249-
// for future-proof code, use the above fields instead and ignore everything below
250-
//
251-
// pos[i] = all_pos_0 + i*all_pos_1
252-
//
253-
llama_pos all_pos_0; // used if pos == NULL
254-
llama_pos all_pos_1; // used if pos == NULL
255-
llama_seq_id all_seq_id; // used if seq_id == NULL
256247
} llama_batch;
257248

258249
enum llama_model_kv_override_type {
@@ -775,16 +766,6 @@ extern "C" {
775766
// Decoding
776767
//
777768

778-
// Return batch for single sequence of tokens starting at pos_0
779-
//
780-
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
781-
//
782-
LLAMA_API struct llama_batch llama_batch_get_one(
783-
llama_token * tokens,
784-
int32_t n_tokens,
785-
llama_pos pos_0,
786-
llama_seq_id seq_id);
787-
788769
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
789770
// Each token can be assigned up to n_seq_max sequence ids
790771
// The batch has to be freed with llama_batch_free()

src/llama.cpp

Lines changed: 21 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2941,9 +2941,6 @@ struct llama_sbatch_seq {
29412941
llama_seq_id * seq_id;
29422942
size_t offset;
29432943
size_t length;
2944-
2945-
// helper for smoother batch API transition -- can be deprecated in the future
2946-
llama_seq_id all_seq_id; // used if seq_id == NULL
29472944
};
29482945

29492946
// sequence-length-aware batch splitting
@@ -3038,30 +3035,18 @@ struct llama_sbatch {
30383035
} else {
30393036
ubatch.embd = nullptr;
30403037
}
3041-
// from here on, the else branches are deprecated;
3042-
// they are helpers for smoother batch API transition
3043-
if (batch->pos) {
3044-
if (ubatch.equal_seqs) {
3045-
for (size_t i = 0; i < length; ++i) {
3046-
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
3047-
}
3048-
} else {
3049-
// simple split
3050-
ubatch.pos = batch->pos + seq.offset;
3051-
}
3052-
} else {
3038+
if (ubatch.equal_seqs) {
30533039
for (size_t i = 0; i < length; ++i) {
3054-
llama_pos bi = ids[seq.offset + i];
3055-
ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
3040+
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
30563041
}
3042+
} else {
3043+
// simple split
3044+
ubatch.pos = batch->pos + seq.offset;
30573045
}
30583046
if (ubatch.equal_seqs) {
30593047
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
30603048
if (seq.seq_id) {
30613049
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
3062-
} else {
3063-
GGML_ASSERT(seq.n_seq_id == 1);
3064-
ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
30653050
}
30663051
} else {
30673052
// simple split
@@ -3074,10 +3059,6 @@ struct llama_sbatch {
30743059
}
30753060
if (batch->seq_id) {
30763061
ubatch.seq_id = batch->seq_id + seq.offset;
3077-
} else {
3078-
for (size_t i = 0; i < length; ++i) {
3079-
ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
3080-
}
30813062
}
30823063
}
30833064
if (logits_all) {
@@ -3196,7 +3177,6 @@ struct llama_sbatch {
31963177
s.seq_id = nullptr;
31973178
s.offset = 0;
31983179
s.length = n_tokens;
3199-
s.all_seq_id = batch.all_seq_id;
32003180
return;
32013181
}
32023182
std::sort(ids.begin(), ids.end(),
@@ -3219,7 +3199,7 @@ struct llama_sbatch {
32193199
if (batch.pos) {
32203200
return batch.pos[a] < batch.pos[b];
32213201
}
3222-
// no pos, sort by id (assuming batch.all_pos_1 is positive)
3202+
// no pos, sort by id
32233203
return a < b;
32243204
}
32253205
// shared prompts go first
@@ -3229,30 +3209,25 @@ struct llama_sbatch {
32293209
// init seq
32303210
llama_sbatch_seq * last_seq = nullptr;
32313211

3232-
if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
3233-
for (size_t i = 0; i < n_tokens; ++i) {
3234-
const size_t bi = ids[i];
3235-
const int32_t n_seqs = batch.n_seq_id[bi];
3236-
llama_seq_id * seq_ids = batch.seq_id[bi];
3237-
if (last_seq != nullptr) {
3238-
bool same = n_seqs == last_seq->n_seq_id;
3239-
for (int32_t j = 0; same && j < n_seqs; ++j) {
3240-
if (seq_ids[j] != last_seq->seq_id[j]) {
3241-
same = false;
3242-
}
3243-
}
3244-
if (same) {
3245-
last_seq->length += 1;
3246-
continue;
3212+
for (size_t i = 0; i < n_tokens; ++i) {
3213+
const size_t bi = ids[i];
3214+
const int32_t n_seqs = batch.n_seq_id[bi];
3215+
llama_seq_id * seq_ids = batch.seq_id[bi];
3216+
if (last_seq != nullptr) {
3217+
bool same = n_seqs == last_seq->n_seq_id;
3218+
for (int32_t j = 0; same && j < n_seqs; ++j) {
3219+
if (seq_ids[j] != last_seq->seq_id[j]) {
3220+
same = false;
32473221
}
32483222
}
3249-
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
3250-
seq.push_back(new_seq);
3251-
last_seq = &seq.back();
3223+
if (same) {
3224+
last_seq->length += 1;
3225+
continue;
3226+
}
32523227
}
3253-
} else {
3254-
llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
3228+
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
32553229
seq.push_back(new_seq);
3230+
last_seq = &seq.back();
32563231
}
32573232
// keep shared prompts first at the end, then sort by length descending.
32583233
std::sort(seq.begin(), seq.end(),
@@ -21067,25 +21042,6 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
2106721042
ctx->cparams.causal_attn = causal_attn;
2106821043
}
2106921044

21070-
struct llama_batch llama_batch_get_one(
21071-
llama_token * tokens,
21072-
int32_t n_tokens,
21073-
llama_pos pos_0,
21074-
llama_seq_id seq_id) {
21075-
return {
21076-
/*n_tokens =*/ n_tokens,
21077-
/*tokens =*/ tokens,
21078-
/*embd =*/ nullptr,
21079-
/*pos =*/ nullptr,
21080-
/*n_seq_id =*/ nullptr,
21081-
/*seq_id =*/ nullptr,
21082-
/*logits =*/ nullptr,
21083-
/*all_pos_0 =*/ pos_0,
21084-
/*all_pos_1 =*/ 1,
21085-
/*all_seq_id =*/ seq_id,
21086-
};
21087-
}
21088-
2108921045
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
2109021046
llama_batch batch = {
2109121047
/*n_tokens =*/ 0,
@@ -21095,9 +21051,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
2109521051
/*n_seq_id =*/ nullptr,
2109621052
/*seq_id =*/ nullptr,
2109721053
/*logits =*/ nullptr,
21098-
/*all_pos_0 =*/ 0,
21099-
/*all_pos_1 =*/ 0,
21100-
/*all_seq_id =*/ 0,
2110121054
};
2110221055

2110321056
if (embd) {

0 commit comments

Comments
 (0)