Skip to content

Commit b264edd

Browse files
committed
llama : fix Mamba pooled embeddings with multiple sequences
Until the pooled embeddings are refactored to allow splitting across ubatches for causal embeddings, recurrent models can only process a single sequence per ubatch when calculating pooled embeddings.
1 parent 652e9b0 commit b264edd

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

src/llama.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15154,6 +15154,8 @@ static int llama_decode_internal(
1515415154
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1515515155
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1515615156

15157+
lctx.embd_seq.clear();
15158+
1515715159
// count outputs
1515815160
if (batch_all.logits && !embd_pooled) {
1515915161
for (uint32_t i = 0; i < n_tokens_all; ++i) {
@@ -15177,8 +15179,19 @@ static int llama_decode_internal(
1517715179
};
1517815180

1517915181
while (lctx.sbatch.n_tokens > 0) {
15180-
// For now, only use equal splits for recurrent model architectures
15181-
llama_ubatch ubatch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
15182+
llama_ubatch ubatch;
15183+
if (kv_self.recurrent) {
15184+
if (embd_pooled) {
15185+
// Pooled embeddings cannot be split across ubatches (yet)
15186+
ubatch = lctx.sbatch.split_seq(n_ubatch);
15187+
} else {
15188+
// recurrent model architectures are easier to implement
15189+
// with equal-length sequences
15190+
ubatch = lctx.sbatch.split_equal(n_ubatch);
15191+
}
15192+
} else {
15193+
ubatch = lctx.sbatch.split_simple(n_ubatch);
15194+
}
1518215195
const uint32_t n_tokens = ubatch.n_tokens;
1518315196

1518415197
// count the outputs in this u_batch
@@ -15316,9 +15329,8 @@ static int llama_decode_internal(
1531615329
case LLAMA_POOLING_TYPE_CLS:
1531715330
case LLAMA_POOLING_TYPE_LAST:
1531815331
{
15319-
// extract sequence embeddings
15332+
// extract sequence embeddings (cleared before processing each batch)
1532015333
auto & embd_seq_out = lctx.embd_seq;
15321-
embd_seq_out.clear();
1532215334

1532315335
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1532415336
const llama_seq_id seq_id = ubatch.seq_id[s][0];

0 commit comments

Comments
 (0)