From f91707bbe1bec01d742c427cde805eda80088f66 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 30 Nov 2023 14:36:02 -0500 Subject: [PATCH 1/3] llama : sanity checks for access to logits --- llama.cpp | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/llama.cpp b/llama.cpp index cb544228b9f02..5b9387fc6b5a5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1468,6 +1468,10 @@ struct llama_context { // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; +#ifndef NDEBUG + // guard against access to unset logits + std::vector logits_valid; +#endif bool logits_all = false; // input embedding (1-dimensional array: [n_embd]) @@ -5609,6 +5613,12 @@ static int llama_decode_internal( { auto & logits_out = lctx.logits; +#ifndef NDEBUG + auto & logits_valid = lctx.logits_valid; + logits_valid.clear(); + logits_valid.resize(n_vocab); +#endif + if (batch.logits) { logits_out.resize(n_vocab * n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { @@ -5616,13 +5626,22 @@ static int llama_decode_internal( continue; } memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); +#ifndef NDEBUG + logits_valid[i] = true; +#endif } } else if (lctx.logits_all) { logits_out.resize(n_vocab * n_tokens); memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); +#ifndef NDEBUG + std::fill(logits_valid.begin(), logits_valid.end(), true); +#endif } else { logits_out.resize(n_vocab); memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); +#ifndef NDEBUG + logits_valid[n_tokens - 1] = true; +#endif } } @@ -9465,6 +9484,7 @@ float * llama_get_logits(struct llama_context * ctx) { } float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { + assert(ctx->logits_valid.at(i)); return ctx->logits.data() + i*ctx->model.hparams.n_vocab; } From 5284e72aa615724adada7a31d2cbb5b8f8a4b78e Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 30 Nov 2023 17:36:21 -0500 Subject: [PATCH 2/3] n_vocab -> n_tokens Co-authored-by: Georgi Gerganov --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 5b9387fc6b5a5..0572e35c1a2a1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5616,7 +5616,7 @@ static int llama_decode_internal( #ifndef NDEBUG auto & logits_valid = lctx.logits_valid; logits_valid.clear(); - logits_valid.resize(n_vocab); + logits_valid.resize(n_tokens); #endif if (batch.logits) { From 9f8e60f0ee7a15b90b2f4ebee10e4af0774317e1 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 30 Nov 2023 22:35:41 -0500 Subject: [PATCH 3/3] zero logits vector before writing new data --- llama.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llama.cpp b/llama.cpp index 0572e35c1a2a1..feaf9cd1f5d42 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5617,6 +5617,8 @@ static int llama_decode_internal( auto & logits_valid = lctx.logits_valid; logits_valid.clear(); logits_valid.resize(n_tokens); + + logits_out.clear(); #endif if (batch.logits) {