Skip to content

Commit b226c5b

Browse files
committed
refactor llama_batch_get_one
1 parent 1788077 commit b226c5b

File tree

2 files changed

+80
-70
lines changed

2 files changed

+80
-70
lines changed

include/llama.h

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,11 @@ extern "C" {
232232
// - token : the token ids of the input (used when embd is NULL)
233233
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
234234
// - pos : the positions of the respective token in the sequence
235+
// (if set to NULL, the token position will be tracked automatically by llama_decode)
235236
// - seq_id : the sequence to which the respective token belongs
237+
// (if set to NULL, the sequence ID will be assumed to be 0)
236238
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
239+
// (if set to NULL, only the logits for last token will be returned)
237240
//
238241
typedef struct llama_batch {
239242
int32_t n_tokens;
@@ -244,15 +247,6 @@ extern "C" {
244247
int32_t * n_seq_id;
245248
llama_seq_id ** seq_id;
246249
int8_t * logits; // TODO: rename this to "output"
247-
248-
// NOTE: helpers for smooth API transition - can be deprecated in the future
249-
// for future-proof code, use the above fields instead and ignore everything below
250-
//
251-
// pos[i] = all_pos_0 + i*all_pos_1
252-
//
253-
llama_pos all_pos_0; // used if pos == NULL
254-
llama_pos all_pos_1; // used if pos == NULL
255-
llama_seq_id all_seq_id; // used if seq_id == NULL
256250
} llama_batch;
257251

258252
enum llama_model_kv_override_type {
@@ -775,15 +769,15 @@ extern "C" {
775769
// Decoding
776770
//
777771

778-
// Return batch for single sequence of tokens starting at pos_0
772+
// Return batch for single sequence of tokens
773+
// The sequence ID will be fixed to 0
774+
// The position of the tokens will be tracked automatically by llama_decode
779775
//
780776
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
781777
//
782778
LLAMA_API struct llama_batch llama_batch_get_one(
783779
llama_token * tokens,
784-
int32_t n_tokens,
785-
llama_pos pos_0,
786-
llama_seq_id seq_id);
780+
int32_t n_tokens);
787781

788782
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
789783
// Each token can be assigned up to n_seq_max sequence ids

src/llama.cpp

Lines changed: 73 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2941,9 +2941,6 @@ struct llama_sbatch_seq {
29412941
llama_seq_id * seq_id;
29422942
size_t offset;
29432943
size_t length;
2944-
2945-
// helper for smoother batch API transition -- can be deprecated in the future
2946-
llama_seq_id all_seq_id; // used if seq_id == NULL
29472944
};
29482945

29492946
// sequence-length-aware batch splitting
@@ -3038,30 +3035,18 @@ struct llama_sbatch {
30383035
} else {
30393036
ubatch.embd = nullptr;
30403037
}
3041-
// from here on, the else branches are deprecated;
3042-
// they are helpers for smoother batch API transition
3043-
if (batch->pos) {
3044-
if (ubatch.equal_seqs) {
3045-
for (size_t i = 0; i < length; ++i) {
3046-
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
3047-
}
3048-
} else {
3049-
// simple split
3050-
ubatch.pos = batch->pos + seq.offset;
3051-
}
3052-
} else {
3038+
if (ubatch.equal_seqs) {
30533039
for (size_t i = 0; i < length; ++i) {
3054-
llama_pos bi = ids[seq.offset + i];
3055-
ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
3040+
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
30563041
}
3042+
} else {
3043+
// simple split
3044+
ubatch.pos = batch->pos + seq.offset;
30573045
}
30583046
if (ubatch.equal_seqs) {
30593047
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
30603048
if (seq.seq_id) {
30613049
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
3062-
} else {
3063-
GGML_ASSERT(seq.n_seq_id == 1);
3064-
ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
30653050
}
30663051
} else {
30673052
// simple split
@@ -3074,10 +3059,6 @@ struct llama_sbatch {
30743059
}
30753060
if (batch->seq_id) {
30763061
ubatch.seq_id = batch->seq_id + seq.offset;
3077-
} else {
3078-
for (size_t i = 0; i < length; ++i) {
3079-
ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
3080-
}
30813062
}
30823063
}
30833064
if (logits_all) {
@@ -3196,7 +3177,6 @@ struct llama_sbatch {
31963177
s.seq_id = nullptr;
31973178
s.offset = 0;
31983179
s.length = n_tokens;
3199-
s.all_seq_id = batch.all_seq_id;
32003180
return;
32013181
}
32023182
std::sort(ids.begin(), ids.end(),
@@ -3219,7 +3199,7 @@ struct llama_sbatch {
32193199
if (batch.pos) {
32203200
return batch.pos[a] < batch.pos[b];
32213201
}
3222-
// no pos, sort by id (assuming batch.all_pos_1 is positive)
3202+
// no pos, sort by id
32233203
return a < b;
32243204
}
32253205
// shared prompts go first
@@ -3229,30 +3209,25 @@ struct llama_sbatch {
32293209
// init seq
32303210
llama_sbatch_seq * last_seq = nullptr;
32313211

3232-
if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
3233-
for (size_t i = 0; i < n_tokens; ++i) {
3234-
const size_t bi = ids[i];
3235-
const int32_t n_seqs = batch.n_seq_id[bi];
3236-
llama_seq_id * seq_ids = batch.seq_id[bi];
3237-
if (last_seq != nullptr) {
3238-
bool same = n_seqs == last_seq->n_seq_id;
3239-
for (int32_t j = 0; same && j < n_seqs; ++j) {
3240-
if (seq_ids[j] != last_seq->seq_id[j]) {
3241-
same = false;
3242-
}
3243-
}
3244-
if (same) {
3245-
last_seq->length += 1;
3246-
continue;
3212+
for (size_t i = 0; i < n_tokens; ++i) {
3213+
const size_t bi = ids[i];
3214+
const int32_t n_seqs = batch.n_seq_id[bi];
3215+
llama_seq_id * seq_ids = batch.seq_id[bi];
3216+
if (last_seq != nullptr) {
3217+
bool same = n_seqs == last_seq->n_seq_id;
3218+
for (int32_t j = 0; same && j < n_seqs; ++j) {
3219+
if (seq_ids[j] != last_seq->seq_id[j]) {
3220+
same = false;
32473221
}
32483222
}
3249-
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
3250-
seq.push_back(new_seq);
3251-
last_seq = &seq.back();
3223+
if (same) {
3224+
last_seq->length += 1;
3225+
continue;
3226+
}
32523227
}
3253-
} else {
3254-
llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
3228+
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
32553229
seq.push_back(new_seq);
3230+
last_seq = &seq.back();
32563231
}
32573232
// keep shared prompts first at the end, then sort by length descending.
32583233
std::sort(seq.begin(), seq.end(),
@@ -21069,9 +21044,7 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
2106921044

2107021045
struct llama_batch llama_batch_get_one(
2107121046
llama_token * tokens,
21072-
int32_t n_tokens,
21073-
llama_pos pos_0,
21074-
llama_seq_id seq_id) {
21047+
int32_t n_tokens) {
2107521048
return {
2107621049
/*n_tokens =*/ n_tokens,
2107721050
/*tokens =*/ tokens,
@@ -21080,9 +21053,6 @@ struct llama_batch llama_batch_get_one(
2108021053
/*n_seq_id =*/ nullptr,
2108121054
/*seq_id =*/ nullptr,
2108221055
/*logits =*/ nullptr,
21083-
/*all_pos_0 =*/ pos_0,
21084-
/*all_pos_1 =*/ 1,
21085-
/*all_seq_id =*/ seq_id,
2108621056
};
2108721057
}
2108821058

@@ -21095,9 +21065,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
2109521065
/*n_seq_id =*/ nullptr,
2109621066
/*seq_id =*/ nullptr,
2109721067
/*logits =*/ nullptr,
21098-
/*all_pos_0 =*/ 0,
21099-
/*all_pos_1 =*/ 0,
21100-
/*all_seq_id =*/ 0,
2110121068
};
2110221069

2110321070
if (embd) {
@@ -21133,10 +21100,58 @@ void llama_batch_free(struct llama_batch batch) {
2113321100
if (batch.logits) free(batch.logits);
2113421101
}
2113521102

21103+
// temporary allocate memory for the input batch if needed
21104+
struct llama_batch_allocr {
21105+
static const llama_seq_id default_seq_id = 0;
21106+
std::array<llama_seq_id, 1> seq_id_0 = {default_seq_id};
21107+
std::vector<llama_pos> pos;
21108+
std::vector<int32_t> n_seq_id;
21109+
std::vector<llama_seq_id *> seq_id;
21110+
std::vector<int8_t> logits;
21111+
// fulfill the batch returned by llama_batch_get_one
21112+
struct llama_batch get_fulfilled_batch(struct llama_context * ctx, struct llama_batch in_batch) {
21113+
struct llama_batch batch = in_batch;
21114+
if (!batch.pos) {
21115+
// determine the last position in KV cache
21116+
llama_pos last_pos;
21117+
for (const auto & cell : ctx->kv_self.cells) {
21118+
if (cell.seq_id.find(default_seq_id) != cell.seq_id.end()) {
21119+
last_pos = std::max(last_pos, cell.pos);
21120+
}
21121+
}
21122+
pos.resize(batch.n_tokens);
21123+
for (int32_t i = 1; i <= batch.n_tokens; i++) {
21124+
pos[i] = i+last_pos;
21125+
}
21126+
batch.pos = pos.data();
21127+
}
21128+
if (!batch.n_seq_id) {
21129+
n_seq_id.reserve(batch.n_tokens);
21130+
for (int32_t i = 1; i <= batch.n_tokens; i++) {
21131+
n_seq_id[i] = seq_id_0.size();
21132+
}
21133+
batch.n_seq_id = n_seq_id.data();
21134+
}
21135+
if (!batch.seq_id) {
21136+
seq_id.reserve(batch.n_tokens);
21137+
for (int32_t i = 1; i <= batch.n_tokens; i++) {
21138+
seq_id[i] = seq_id_0.data();
21139+
}
21140+
batch.seq_id = seq_id.data();
21141+
}
21142+
if (!batch.logits) {
21143+
logits.reserve(batch.n_tokens);
21144+
logits[logits.size() - 1] = true;
21145+
batch.logits = logits.data();
21146+
}
21147+
}
21148+
};
21149+
2113621150
int32_t llama_encode(
2113721151
struct llama_context * ctx,
2113821152
struct llama_batch batch) {
21139-
const int ret = llama_encode_internal(*ctx, batch);
21153+
llama_batch_allocr batch_allocr;
21154+
const int ret = llama_encode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch));
2114021155
if (ret < 0) {
2114121156
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
2114221157
}
@@ -21147,7 +21162,8 @@ int32_t llama_encode(
2114721162
int32_t llama_decode(
2114821163
struct llama_context * ctx,
2114921164
struct llama_batch batch) {
21150-
const int ret = llama_decode_internal(*ctx, batch);
21165+
llama_batch_allocr batch_allocr;
21166+
const int ret = llama_decode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch));
2115121167
if (ret < 0) {
2115221168
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2115321169
}

0 commit comments

Comments
 (0)