@@ -2941,9 +2941,6 @@ struct llama_sbatch_seq {
2941
2941
llama_seq_id * seq_id;
2942
2942
size_t offset;
2943
2943
size_t length;
2944
-
2945
- // helper for smoother batch API transition -- can be deprecated in the future
2946
- llama_seq_id all_seq_id; // used if seq_id == NULL
2947
2944
};
2948
2945
2949
2946
// sequence-length-aware batch splitting
@@ -3038,30 +3035,18 @@ struct llama_sbatch {
3038
3035
} else {
3039
3036
ubatch.embd = nullptr;
3040
3037
}
3041
- // from here on, the else branches are deprecated;
3042
- // they are helpers for smoother batch API transition
3043
- if (batch->pos) {
3044
- if (ubatch.equal_seqs) {
3045
- for (size_t i = 0; i < length; ++i) {
3046
- ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
3047
- }
3048
- } else {
3049
- // simple split
3050
- ubatch.pos = batch->pos + seq.offset;
3051
- }
3052
- } else {
3038
+ if (ubatch.equal_seqs) {
3053
3039
for (size_t i = 0; i < length; ++i) {
3054
- llama_pos bi = ids[seq.offset + i];
3055
- ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
3040
+ ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
3056
3041
}
3042
+ } else {
3043
+ // simple split
3044
+ ubatch.pos = batch->pos + seq.offset;
3057
3045
}
3058
3046
if (ubatch.equal_seqs) {
3059
3047
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
3060
3048
if (seq.seq_id) {
3061
3049
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
3062
- } else {
3063
- GGML_ASSERT(seq.n_seq_id == 1);
3064
- ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
3065
3050
}
3066
3051
} else {
3067
3052
// simple split
@@ -3074,10 +3059,6 @@ struct llama_sbatch {
3074
3059
}
3075
3060
if (batch->seq_id) {
3076
3061
ubatch.seq_id = batch->seq_id + seq.offset;
3077
- } else {
3078
- for (size_t i = 0; i < length; ++i) {
3079
- ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
3080
- }
3081
3062
}
3082
3063
}
3083
3064
if (logits_all) {
@@ -3196,7 +3177,6 @@ struct llama_sbatch {
3196
3177
s.seq_id = nullptr;
3197
3178
s.offset = 0;
3198
3179
s.length = n_tokens;
3199
- s.all_seq_id = batch.all_seq_id;
3200
3180
return;
3201
3181
}
3202
3182
std::sort(ids.begin(), ids.end(),
@@ -3219,7 +3199,7 @@ struct llama_sbatch {
3219
3199
if (batch.pos) {
3220
3200
return batch.pos[a] < batch.pos[b];
3221
3201
}
3222
- // no pos, sort by id (assuming batch.all_pos_1 is positive)
3202
+ // no pos, sort by id
3223
3203
return a < b;
3224
3204
}
3225
3205
// shared prompts go first
@@ -3229,30 +3209,25 @@ struct llama_sbatch {
3229
3209
// init seq
3230
3210
llama_sbatch_seq * last_seq = nullptr;
3231
3211
3232
- if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
3233
- for (size_t i = 0; i < n_tokens; ++i) {
3234
- const size_t bi = ids[i];
3235
- const int32_t n_seqs = batch.n_seq_id[bi];
3236
- llama_seq_id * seq_ids = batch.seq_id[bi];
3237
- if (last_seq != nullptr) {
3238
- bool same = n_seqs == last_seq->n_seq_id;
3239
- for (int32_t j = 0; same && j < n_seqs; ++j) {
3240
- if (seq_ids[j] != last_seq->seq_id[j]) {
3241
- same = false;
3242
- }
3243
- }
3244
- if (same) {
3245
- last_seq->length += 1;
3246
- continue;
3212
+ for (size_t i = 0; i < n_tokens; ++i) {
3213
+ const size_t bi = ids[i];
3214
+ const int32_t n_seqs = batch.n_seq_id[bi];
3215
+ llama_seq_id * seq_ids = batch.seq_id[bi];
3216
+ if (last_seq != nullptr) {
3217
+ bool same = n_seqs == last_seq->n_seq_id;
3218
+ for (int32_t j = 0; same && j < n_seqs; ++j) {
3219
+ if (seq_ids[j] != last_seq->seq_id[j]) {
3220
+ same = false;
3247
3221
}
3248
3222
}
3249
- llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
3250
- seq.push_back(new_seq);
3251
- last_seq = &seq.back();
3223
+ if (same) {
3224
+ last_seq->length += 1;
3225
+ continue;
3226
+ }
3252
3227
}
3253
- } else {
3254
- llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
3228
+ llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
3255
3229
seq.push_back(new_seq);
3230
+ last_seq = &seq.back();
3256
3231
}
3257
3232
// keep shared prompts first at the end, then sort by length descending.
3258
3233
std::sort(seq.begin(), seq.end(),
@@ -21069,9 +21044,7 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
21069
21044
21070
21045
struct llama_batch llama_batch_get_one(
21071
21046
llama_token * tokens,
21072
- int32_t n_tokens,
21073
- llama_pos pos_0,
21074
- llama_seq_id seq_id) {
21047
+ int32_t n_tokens) {
21075
21048
return {
21076
21049
/*n_tokens =*/ n_tokens,
21077
21050
/*tokens =*/ tokens,
@@ -21080,9 +21053,6 @@ struct llama_batch llama_batch_get_one(
21080
21053
/*n_seq_id =*/ nullptr,
21081
21054
/*seq_id =*/ nullptr,
21082
21055
/*logits =*/ nullptr,
21083
- /*all_pos_0 =*/ pos_0,
21084
- /*all_pos_1 =*/ 1,
21085
- /*all_seq_id =*/ seq_id,
21086
21056
};
21087
21057
}
21088
21058
@@ -21095,9 +21065,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
21095
21065
/*n_seq_id =*/ nullptr,
21096
21066
/*seq_id =*/ nullptr,
21097
21067
/*logits =*/ nullptr,
21098
- /*all_pos_0 =*/ 0,
21099
- /*all_pos_1 =*/ 0,
21100
- /*all_seq_id =*/ 0,
21101
21068
};
21102
21069
21103
21070
if (embd) {
@@ -21133,10 +21100,58 @@ void llama_batch_free(struct llama_batch batch) {
21133
21100
if (batch.logits) free(batch.logits);
21134
21101
}
21135
21102
21103
+ // temporary allocate memory for the input batch if needed
21104
+ struct llama_batch_allocr {
21105
+ static const llama_seq_id default_seq_id = 0;
21106
+ std::array<llama_seq_id, 1> seq_id_0 = {default_seq_id};
21107
+ std::vector<llama_pos> pos;
21108
+ std::vector<int32_t> n_seq_id;
21109
+ std::vector<llama_seq_id *> seq_id;
21110
+ std::vector<int8_t> logits;
21111
+ // fulfill the batch returned by llama_batch_get_one
21112
+ struct llama_batch get_fulfilled_batch(struct llama_context * ctx, struct llama_batch in_batch) {
21113
+ struct llama_batch batch = in_batch;
21114
+ if (!batch.pos) {
21115
+ // determine the last position in KV cache
21116
+ llama_pos last_pos;
21117
+ for (const auto & cell : ctx->kv_self.cells) {
21118
+ if (cell.seq_id.find(default_seq_id) != cell.seq_id.end()) {
21119
+ last_pos = std::max(last_pos, cell.pos);
21120
+ }
21121
+ }
21122
+ pos.resize(batch.n_tokens);
21123
+ for (int32_t i = 1; i <= batch.n_tokens; i++) {
21124
+ pos[i] = i+last_pos;
21125
+ }
21126
+ batch.pos = pos.data();
21127
+ }
21128
+ if (!batch.n_seq_id) {
21129
+ n_seq_id.reserve(batch.n_tokens);
21130
+ for (int32_t i = 1; i <= batch.n_tokens; i++) {
21131
+ n_seq_id[i] = seq_id_0.size();
21132
+ }
21133
+ batch.n_seq_id = n_seq_id.data();
21134
+ }
21135
+ if (!batch.seq_id) {
21136
+ seq_id.reserve(batch.n_tokens);
21137
+ for (int32_t i = 1; i <= batch.n_tokens; i++) {
21138
+ seq_id[i] = seq_id_0.data();
21139
+ }
21140
+ batch.seq_id = seq_id.data();
21141
+ }
21142
+ if (!batch.logits) {
21143
+ logits.reserve(batch.n_tokens);
21144
+ logits[logits.size() - 1] = true;
21145
+ batch.logits = logits.data();
21146
+ }
21147
+ }
21148
+ };
21149
+
21136
21150
int32_t llama_encode(
21137
21151
struct llama_context * ctx,
21138
21152
struct llama_batch batch) {
21139
- const int ret = llama_encode_internal(*ctx, batch);
21153
+ llama_batch_allocr batch_allocr;
21154
+ const int ret = llama_encode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch));
21140
21155
if (ret < 0) {
21141
21156
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
21142
21157
}
@@ -21147,7 +21162,8 @@ int32_t llama_encode(
21147
21162
int32_t llama_decode(
21148
21163
struct llama_context * ctx,
21149
21164
struct llama_batch batch) {
21150
- const int ret = llama_decode_internal(*ctx, batch);
21165
+ llama_batch_allocr batch_allocr;
21166
+ const int ret = llama_decode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch));
21151
21167
if (ret < 0) {
21152
21168
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
21153
21169
}
0 commit comments