Skip to content

Commit 9c21995

Browse files
committed
wip
1 parent a56f0a9 commit 9c21995

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/llama.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7773,6 +7773,7 @@ static int llama_decode_impl(
77737773
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1);
77747774

77757775
const llama_batch & batch = batch_allocr.batch;
7776+
77767777
const uint32_t n_tokens_all = batch.n_tokens;
77777778

77787779
const auto & model = lctx.model;
@@ -7800,9 +7801,6 @@ static int llama_decode_impl(
78007801
}
78017802
lctx.n_queued_tokens += n_tokens_all;
78027803

7803-
auto & kv_self = lctx.kv_self;
7804-
llama_kv_slot_restorer kv_slot_restorer(kv_self);
7805-
78067804
const int64_t n_embd = hparams.n_embd;
78077805
const int64_t n_vocab = vocab.n_tokens();
78087806

@@ -7828,16 +7826,19 @@ static int llama_decode_impl(
78287826
n_outputs = 1;
78297827
}
78307828

7831-
lctx.sbatch.from_batch(batch, n_embd,
7832-
/* simple_split */ !kv_self.recurrent,
7833-
/* logits_all */ n_outputs == n_tokens_all);
7834-
78357829
// reserve output buffer
78367830
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
78377831
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
78387832
return -2;
78397833
};
78407834

7835+
auto & kv_self = lctx.kv_self;
7836+
llama_kv_slot_restorer kv_slot_restorer(kv_self);
7837+
7838+
lctx.sbatch.from_batch(batch, n_embd,
7839+
/* simple_split */ !kv_self.recurrent,
7840+
/* logits_all */ n_outputs == n_tokens_all);
7841+
78417842
while (lctx.sbatch.n_tokens > 0) {
78427843
llama_ubatch ubatch;
78437844
if (kv_self.recurrent) {
@@ -8635,7 +8636,6 @@ struct llama_context * llama_init_from_model(
86358636
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
86368637
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
86378638

8638-
// this is necessary due to kv_self.n being padded later during inference
86398639
cparams.n_ctx = GGML_PAD(cparams.n_ctx, ctx->get_ctx_padding(cparams));
86408640

86418641
// with causal attention, the batch size is limited by the context size

0 commit comments

Comments
 (0)