From 4c6ae70cadc0452fb2863ac90a97aa72dee13d72 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 26 May 2025 18:45:30 +0300 Subject: [PATCH 1/2] llama : validate seq id batch input ggml-ci --- src/llama-context.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ad77cae20eb50..6d022b20e2e21 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -693,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) { GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + // TODO: move the validation to the llama_batch_allocr if (batch.token) { for (int32_t i = 0; i < n_tokens; ++i) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); return -1; } + + if (batch.seq_id && batch.seq_id[i][0] > LLAMA_MAX_PARALLEL_SEQUENCES) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); + throw -1; + } } } @@ -887,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) { GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + // TODO: move the validation to the llama_batch_allocr if (batch.token) { for (int64_t i = 0; i < n_tokens_all; ++i) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]); - throw std::runtime_error("invalid token"); + return -1; + } + + if (batch.seq_id && batch.seq_id[i][0] > LLAMA_MAX_PARALLEL_SEQUENCES) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); + return -1; } } } From 6995c7e3ba4b60cd39bf75b8b35cae57dd3b2ae8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 26 May 2025 19:11:25 +0300 Subject: [PATCH 2/2] cont : fix the fix ggml-ci --- src/llama-context.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 6d022b20e2e21..e153351af3809 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -701,7 +701,7 @@ int llama_context::encode(llama_batch & inp_batch) { return -1; } - if (batch.seq_id && batch.seq_id[i][0] > LLAMA_MAX_PARALLEL_SEQUENCES) { + if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); throw -1; } @@ -901,8 +901,8 @@ int llama_context::decode(llama_batch & inp_batch) { return -1; } - if (batch.seq_id && batch.seq_id[i][0] > LLAMA_MAX_PARALLEL_SEQUENCES) { - LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); + if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); return -1; } }