Skip to content

Commit aca7722

Browse files
committed
revert commit "context : remove logits_all flag (ggml-org#13284)"
1 parent 2b50607 commit aca7722

File tree

7 files changed

+34
-9
lines changed

7 files changed

+34
-9
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2099,6 +2099,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
20992099
params.cache_type_v = kv_cache_type_from_str(value);
21002100
}
21012101
).set_env("LLAMA_ARG_CACHE_TYPE_V"));
2102+
add_opt(common_arg(
2103+
{"--perplexity", "--all-logits"},
2104+
string_format("return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false"),
2105+
[](common_params & params) {
2106+
params.logits_all = true;
2107+
}
2108+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
21022109
add_opt(common_arg(
21032110
{"--hellaswag"},
21042111
"compute HellaSwag score over random tasks from datafile supplied with -f",

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11031103
cparams.n_threads = params.cpuparams.n_threads;
11041104
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
11051105
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
1106+
cparams.logits_all = params.logits_all;
11061107
cparams.embeddings = params.embedding;
11071108
cparams.rope_scaling_type = params.rope_scaling_type;
11081109
cparams.rope_freq_base = params.rope_freq_base;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ struct common_params {
320320
bool ctx_shift = true; // context shift on inifinite text generation
321321

322322
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
323+
bool logits_all = false; // return logits for all tokens in the batch
323324
bool use_mmap = true; // use mmap for faster loads
324325
bool use_mlock = false; // use mlock to keep model in memory
325326
bool verbose_prompt = false; // print prompt tokens before generation

include/llama.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -369,17 +369,19 @@ extern "C" {
369369
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
370370
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
371371

372-
// Abort callback
373-
// if it returns true, execution of llama_decode() will be aborted
374-
// currently works only with CPU execution
375-
ggml_abort_callback abort_callback;
376-
void * abort_callback_data;
377-
378372
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
373+
// TODO: move at the end of the struct
374+
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
379375
bool embeddings; // if true, extract embeddings (together with logits)
380376
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
381377
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
382378
bool no_perf; // whether to measure performance timings
379+
380+
// Abort callback
381+
// if it returns true, execution of llama_decode() will be aborted
382+
// currently works only with CPU execution
383+
ggml_abort_callback abort_callback;
384+
void * abort_callback_data;
383385
};
384386

385387
// model quantization parameters

src/llama-context.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ llama_context::llama_context(
116116
__func__, n_ctx_per_seq, hparams.n_ctx_train);
117117
}
118118

119+
logits_all = params.logits_all;
120+
119121
if (!hparams.vocab_only) {
120122
// GPU backends
121123
for (auto * dev : model.devices) {
@@ -902,7 +904,7 @@ int llama_context::decode(llama_batch & inp_batch) {
902904
for (uint32_t i = 0; i < n_tokens_all; ++i) {
903905
n_outputs_all += batch.logits[i] != 0;
904906
}
905-
} else if (embd_pooled) {
907+
} else if (logits_all || embd_pooled) {
906908
n_outputs_all = n_tokens_all;
907909
} else {
908910
// keep last output only
@@ -1865,12 +1867,13 @@ llama_context_params llama_context_default_params() {
18651867
/*.cb_eval_user_data =*/ nullptr,
18661868
/*.type_k =*/ GGML_TYPE_F16,
18671869
/*.type_v =*/ GGML_TYPE_F16,
1868-
/*.abort_callback =*/ nullptr,
1869-
/*.abort_callback_data =*/ nullptr,
1870+
/*.logits_all =*/ false,
18701871
/*.embeddings =*/ false,
18711872
/*.offload_kqv =*/ true,
18721873
/*.flash_attn =*/ false,
18731874
/*.no_perf =*/ true,
1875+
/*.abort_callback =*/ nullptr,
1876+
/*.abort_callback_data =*/ nullptr,
18741877
};
18751878

18761879
return result;

src/llama-context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ struct llama_context {
187187

188188
std::unique_ptr<llama_memory_i> memory;
189189

190+
// TODO: remove
191+
bool logits_all = false;
192+
190193
// decode output (2-dimensional array: [n_outputs][n_vocab])
191194
size_t logits_size = 0; // capacity (of floats) for logits
192195
float * logits = nullptr;

tools/main/main.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ int main(int argc, char ** argv) {
100100
console::init(params.simple_io, params.use_color);
101101
atexit([]() { console::cleanup(); });
102102

103+
if (params.logits_all) {
104+
LOG_ERR("************\n");
105+
LOG_ERR("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
106+
LOG_ERR("************\n\n");
107+
108+
return 0;
109+
}
110+
103111
if (params.embedding) {
104112
LOG_ERR("************\n");
105113
LOG_ERR("%s: please use the 'embedding' tool for embedding calculations\n", __func__);

0 commit comments

Comments
 (0)