From 71fc16bb6cd92b842f1fb7425e3db48e86ef3e07 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 15 Nov 2024 08:20:28 +0200 Subject: [PATCH 01/14] speculative : refactor and add a simpler example ggml-ci --- common/CMakeLists.txt | 2 + common/sampling.cpp | 22 ++ common/sampling.h | 2 + common/speculative.cpp | 159 ++++++++++ common/speculative.h | 33 ++ examples/CMakeLists.txt | 1 + examples/speculative-simple/CMakeLists.txt | 5 + examples/speculative-simple/README.md | 3 + .../speculative-simple/speculative-simple.cpp | 285 ++++++++++++++++++ examples/speculative/speculative.cpp | 2 +- 10 files changed, 513 insertions(+), 1 deletion(-) create mode 100644 common/speculative.cpp create mode 100644 common/speculative.h create mode 100644 examples/speculative-simple/CMakeLists.txt create mode 100644 examples/speculative-simple/README.md create mode 100644 examples/speculative-simple/speculative-simple.cpp diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 5ab1ffa1922aa..62a8a7db5652f 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -66,6 +66,8 @@ add_library(${TARGET} STATIC ngram-cache.h sampling.cpp sampling.h + speculative.cpp + speculative.h ) if (BUILD_SHARED_LIBS) diff --git a/common/sampling.cpp b/common/sampling.cpp index 7922fde47d369..fe1ef5bf9ae05 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -320,6 +320,28 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co return cur_p.data[cur_p.selected].id; } +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft, bool grammar_first) { + GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); + + std::vector result; + result.reserve(idxs.size()); + + size_t i = 0; + for (; i < draft.size(); i++) { + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + + if (draft[i] != id) { + break; + } + + result.push_back(id); + } + + result.push_back(common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first)); + + return result; +} + uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { return llama_sampler_get_seed(gsmpl->chain); } diff --git a/common/sampling.h b/common/sampling.h index d37f25ad37c4a..9e61690aa6fcb 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -60,6 +60,8 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam // llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft, bool grammar_first = false); + uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); // helpers diff --git a/common/speculative.cpp b/common/speculative.cpp new file mode 100644 index 0000000000000..d16cc3c8eb204 --- /dev/null +++ b/common/speculative.cpp @@ -0,0 +1,159 @@ +#include "speculative.h" + +#include "log.h" +#include "common.h" +#include "sampling.h" + +#include + +struct seq_draft { +}; + +struct common_speculative { + struct common_speculative_params params; + + llama_batch batch_dft; + + struct common_sampler * smpl; + + std::vector i_batch_tgt; + + std::vector tokens; +}; + +struct common_speculative * common_speculative_init(struct common_speculative_params params) { + auto * result = new common_speculative { + /* .params = */ params, + /* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1), + /* .smpl = */ nullptr, + /* .i_batch_tgt = */ {}, + /* .tokens = */ {}, + }; + + // TODO: optimize or pass from outside? +#if 0 + { + common_sampler_params sparams; + sparams.no_perf = false; + + sparams.top_k = 40; + sparams.top_p = 0.9; + + sparams.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + COMMON_SAMPLER_TYPE_TOP_P, + COMMON_SAMPLER_TYPE_INFILL, + }; + + result->smpl = common_sampler_init(params.model_dft, sparams); + } +#else + { + common_sampler_params sparams; + sparams.no_perf = false; + + sparams.top_k = 10; + + sparams.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + + result->smpl = common_sampler_init(params.model_dft, sparams); + } +#endif + + result->batch_dft = llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1); + + return result; +} + +void common_speculative_free(struct common_speculative * spec) { + common_sampler_free(spec->smpl); + + llama_batch_free(spec->batch_dft); + + delete spec; +} + +void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens) { + llama_kv_cache_clear(spec->params.ctx_dft); + + // TODO: error handling + llama_decode(spec->params.ctx_dft, llama_batch_get_one(tokens, n_tokens)); +} + +void common_speculative_add_draft( + struct common_speculative * spec, + struct llama_batch & batch_tgt, + llama_token id_last, + int n_past) { + spec->tokens.clear(); + + spec->i_batch_tgt.clear(); + spec->i_batch_tgt.push_back(0); + + common_sampler_reset(spec->smpl); + + common_batch_clear(spec->batch_dft); + common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true); + + llama_decode(spec->params.ctx_dft, spec->batch_dft); + + // sample n_draft tokens from the draft model + for (int i = 0; i < spec->params.n_draft; ++i) { + common_batch_clear(spec->batch_dft); + + common_sampler_sample(spec->smpl, spec->params.ctx_dft, 0, true); + + const auto * cur_p = common_sampler_get_candidates(spec->smpl); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(spec->params.ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < 0.75 && spec->tokens.size() >= 0) { + break; + } + + common_sampler_accept(spec->smpl, id, true); + + spec->tokens.push_back(id); + + // add unique drafted tokens to the target batch + spec->i_batch_tgt.push_back(batch_tgt.n_tokens); + + common_batch_add(batch_tgt, id, n_past + i + 1, { 0 }, true); + + if (batch_tgt.n_tokens > spec->params.n_draft) { + break; + } + + common_batch_add(spec->batch_dft, id, n_past + i + 1, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode(spec->params.ctx_dft, spec->batch_dft); + } + + // don't waste time on small batches + // TODO: do not evaluate the draft model for tha many rounds + if (batch_tgt.n_tokens < spec->params.n_min) { + batch_tgt.n_tokens = 1; + spec->tokens.resize(0); + spec->i_batch_tgt.resize(1); + } + + // print current draft sequences + LOG_DBG("draft %s\n", string_from(spec->params.ctx_dft, spec->tokens).c_str()); +} + +std::vector common_speculative_sample( + struct common_speculative * spec, + struct common_sampler * smpl, + struct llama_context * ctx_tgt) { + return common_sampler_sample_n(smpl, ctx_tgt, spec->i_batch_tgt, spec->tokens); +} diff --git a/common/speculative.h b/common/speculative.h new file mode 100644 index 0000000000000..0952e5e70e409 --- /dev/null +++ b/common/speculative.h @@ -0,0 +1,33 @@ +#pragma once + +#include "llama.h" + +#include + +struct common_speculative; + +struct common_speculative_params { + int n_draft = 16; + int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user? + + struct llama_model * model_dft = nullptr; + + struct llama_context * ctx_dft = nullptr; +}; + +struct common_speculative * common_speculative_init(struct common_speculative_params params); + +void common_speculative_free(struct common_speculative * spec); + +void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens); + +void common_speculative_add_draft( + struct common_speculative * spec, + struct llama_batch & batch_tgt, + llama_token id_last, + int n_past); + +std::vector common_speculative_sample( + struct common_speculative * spec, + struct common_sampler * smpl, + struct llama_context * ctx_tgt); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d63a96c1c2547..9bd099d4ef8a5 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -50,5 +50,6 @@ else() add_subdirectory(simple) add_subdirectory(simple-chat) add_subdirectory(speculative) + add_subdirectory(speculative-simple) add_subdirectory(tokenize) endif() diff --git a/examples/speculative-simple/CMakeLists.txt b/examples/speculative-simple/CMakeLists.txt new file mode 100644 index 0000000000000..7a3a141c27994 --- /dev/null +++ b/examples/speculative-simple/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-speculative-simple) +add_executable(${TARGET} speculative-simple.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/speculative-simple/README.md b/examples/speculative-simple/README.md new file mode 100644 index 0000000000000..6f3d6dc1505ad --- /dev/null +++ b/examples/speculative-simple/README.md @@ -0,0 +1,3 @@ +# llama.cpp/examples/speculative-simple + +Demonstration of basic greedy speculative decoding diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp new file mode 100644 index 0000000000000..31a09e61df8fe --- /dev/null +++ b/examples/speculative-simple/speculative-simple.cpp @@ -0,0 +1,285 @@ +#include "arg.h" +#include "common.h" +#include "sampling.h" +#include "speculative.h" +#include "log.h" +#include "llama.h" + +#include +#include +#include +#include +#include + +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 +#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 + +struct seq_draft { + std::vector i_batch_tgt; + + std::vector tokens; + + struct common_sampler * smpl = nullptr; +}; + +int main(int argc, char ** argv) { + common_params params; + + // needed to get candidate probs even for temp <= 0.0 + params.sparams.n_probs = 128; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { + return 1; + } + + if (params.n_predict < -1) { + LOG_ERR("%s: --n-predict must be >= -1\n", __func__); + return 1; + } + + common_init(); + + if (params.model_draft.empty()) { + LOG_ERR("%s: --model-draft is required\n", __func__); + return 1; + } + + // init llama.cpp + llama_backend_init(); + llama_numa_init(params.numa); + + llama_model * model_tgt = NULL; + llama_model * model_dft = NULL; + + llama_context * ctx_tgt = NULL; + llama_context * ctx_dft = NULL; + + // load the target model + common_init_result llama_init_tgt = common_init_from_params(params); + model_tgt = llama_init_tgt.model; + ctx_tgt = llama_init_tgt.context; + + // load the draft model + params.model = params.model_draft; + params.n_gpu_layers = params.n_gpu_layers_draft; + if (params.draft_cpuparams.n_threads > 0) { + params.cpuparams.n_threads = params.draft_cpuparams.n_threads; + } + + params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads; + common_init_result llama_init_dft = common_init_from_params(params); + model_dft = llama_init_dft.model; + ctx_dft = llama_init_dft.context; + + const bool vocab_type_tgt = llama_vocab_type(model_tgt); + LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt); + + const bool vocab_type_dft = llama_vocab_type(model_dft); + LOG_DBG("vocab_type dft: %d\n", vocab_type_dft); + + if (vocab_type_tgt != vocab_type_dft) { + LOG_ERR("%s: draft model vocab type must match target model to use speculation but ", __func__); + LOG_ERR("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); + return 1; + } + + if ( + llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || + llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || + llama_token_bos(model_tgt) != llama_token_bos(model_dft) || + llama_token_eos(model_tgt) != llama_token_eos(model_dft) + ) { + LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); + return 1; + } + + { + const int n_vocab_tgt = llama_n_vocab(model_tgt); + const int n_vocab_dft = llama_n_vocab(model_dft); + const int vocab_diff = n_vocab_tgt > n_vocab_dft + ? n_vocab_tgt - n_vocab_dft + : n_vocab_dft - n_vocab_tgt; + + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__); + LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", + n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + return 1; + } + + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { + const char * token_text_tgt = llama_token_get_text(model_tgt, i); + const char * token_text_dft = llama_token_get_text(model_dft, i); + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { + LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__); + LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i, + common_token_to_piece(ctx_tgt, i).c_str(), + common_token_to_piece(ctx_dft, i).c_str()); + return 1; + } + } + } + + + // Tokenize the prompt + std::vector inp; + inp = common_tokenize(ctx_tgt, params.prompt, true, true); + + const int max_context_size = llama_n_ctx(ctx_tgt); + const int max_tokens_list_size = max_context_size - 4; + + if ((int) inp.size() > max_tokens_list_size) { + LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); + return 1; + } + + LOG("\n\n"); + + for (auto id : inp) { + LOG("%s", common_token_to_piece(ctx_tgt, id).c_str()); + } + + const int n_input = inp.size(); + + const auto t_enc_start = ggml_time_us(); + + // eval the prompt + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), n_input - 1)); + + // note: keep the last token separate! + llama_token id_last = inp.back(); + + int n_past = inp.size() - 1; + + // how many tokens to draft each time + int n_draft = params.n_draft; + + int n_predict = 0; + int n_drafted = 0; + int n_accept = 0; + + // used to determine end of generation + bool has_eos = false; + + // target model sampling context + struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams); + + // init the speculator + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft; + params_spec.n_min = 5; + params_spec.model_dft = model_dft; + params_spec.ctx_dft = ctx_dft; + + struct common_speculative * spec = common_speculative_init(params_spec); + + // feed the prompt to the speculator + common_speculative_set_prompt(spec, inp.data(), n_input - 1); + + llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); + + const auto t_enc_end = ggml_time_us(); + + const auto t_dec_start = ggml_time_us(); + + while (true) { + // always have a token to evaluate from before + common_batch_clear(batch_tgt); + common_batch_add (batch_tgt, id_last, n_past, { 0 }, true); + + // optionally, append draft tokens to the target batch + common_speculative_add_draft(spec, batch_tgt, id_last, n_past); + + // evaluate the target model on the drafted tokens + { + //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); + + llama_decode(ctx_tgt, batch_tgt); + } + + // process the full target batch and return the accepted token based on the target sampler + const auto ids = common_speculative_sample(spec, smpl, ctx_tgt); + + n_past += ids.size(); + n_drafted += batch_tgt.n_tokens - 1; + n_accept += ids.size() - 1; + + // process the accepted tokens and update contexts + { + llama_token id; + std::string token_str; + + for (size_t i = 0; i < ids.size(); ++i) { + id = ids[i]; + + ++n_predict; + + if (llama_token_is_eog(model_tgt, id)) { + has_eos = true; + break; + } + + token_str = common_token_to_piece(ctx_tgt, id); + + if (params.use_color && i + 1 < ids.size()) { + LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str()); + } else { + LOG("%s", token_str.c_str()); + } + } + + if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { + break; + } + + LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); + + { + LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); + + llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past, -1); + } + + id_last = id; + } + } + + auto t_dec_end = ggml_time_us(); + + LOG("\n\n"); + + LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + + LOG_INF("\n"); + LOG_INF("n_draft = %d\n", n_draft); + LOG_INF("n_predict = %d\n", n_predict); + LOG_INF("n_drafted = %d\n", n_drafted); + LOG_INF("n_accept = %d\n", n_accept); + LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); + + LOG_INF("\n"); + LOG_INF("draft:\n\n"); + + llama_perf_context_print(ctx_dft); + + LOG_INF("\n"); + LOG_INF("target:\n\n"); + common_perf_print(ctx_tgt, smpl); + + common_sampler_free(smpl); + common_speculative_free(spec); + + llama_free(ctx_tgt); + llama_free_model(model_tgt); + + llama_free(ctx_dft); + llama_free_model(model_dft); + + llama_backend_free(); + + LOG("\n\n"); + + return 0; +} diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 6cafd8a837992..207b8ea345fea 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -12,7 +12,7 @@ #include #include -#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 struct seq_draft { From fe043ff1ff07fdef1899778e52dafbad26037d38 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 17 Nov 2024 18:55:27 +0200 Subject: [PATCH 02/14] speculative : clean-up and add comments and TODOs [no ci] --- common/sampling.h | 11 ++++ common/speculative.cpp | 7 +-- common/speculative.h | 13 ++++ .../speculative-simple/speculative-simple.cpp | 61 ++++++++++++++----- 4 files changed, 70 insertions(+), 22 deletions(-) diff --git a/common/sampling.h b/common/sampling.h index 9e61690aa6fcb..23cfae1ac3c57 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -60,6 +60,17 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam // llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); +// generalized version of common_sampler_sample +// +// will cross-reference the sampled tokens with a batch of draft tokens +// if the sampler disagrees at some point, we stop and return the sampled tokens up to now +// +// `common_sampler_sample_n(gsmpl, ctx, { idx }, {})` is equivalent to `common_sampler_sample(gsmpl, ctx, idx)` +// +// requires: idxs.size() == draft.size() + 1 +// +// returns at least 1 token, up to idxs.size() +// std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft, bool grammar_first = false); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); diff --git a/common/speculative.cpp b/common/speculative.cpp index d16cc3c8eb204..2726760ad5bde 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -4,11 +4,6 @@ #include "common.h" #include "sampling.h" -#include - -struct seq_draft { -}; - struct common_speculative { struct common_speculative_params params; @@ -140,7 +135,7 @@ void common_speculative_add_draft( } // don't waste time on small batches - // TODO: do not evaluate the draft model for tha many rounds + // TODO: do not evaluate the draft model for that many rounds if (batch_tgt.n_tokens < spec->params.n_min) { batch_tgt.n_tokens = 1; spec->tokens.resize(0); diff --git a/common/speculative.h b/common/speculative.h index 0952e5e70e409..a2df2667a4205 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -19,8 +19,21 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa void common_speculative_free(struct common_speculative * spec); +// TODO: remove void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens); +// sample up to n_draft tokens and add them to the batch using the draft model +// +// TODO: change to: +// +// void common_speculative_add_draft( +// struct common_speculative * spec, +// struct llama_batch & batch_tgt, +// llama_token * tokens, +// int32_t n_tokens); +// +// and update the internal logic to compute only the new tokens +// void common_speculative_add_draft( struct common_speculative * spec, struct llama_batch & batch_tgt, diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 31a09e61df8fe..aeccfd3699c89 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -120,7 +120,6 @@ int main(int argc, char ** argv) { } } - // Tokenize the prompt std::vector inp; inp = common_tokenize(ctx_tgt, params.prompt, true, true); @@ -139,18 +138,6 @@ int main(int argc, char ** argv) { LOG("%s", common_token_to_piece(ctx_tgt, id).c_str()); } - const int n_input = inp.size(); - - const auto t_enc_start = ggml_time_us(); - - // eval the prompt - llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), n_input - 1)); - - // note: keep the last token separate! - llama_token id_last = inp.back(); - - int n_past = inp.size() - 1; - // how many tokens to draft each time int n_draft = params.n_draft; @@ -161,9 +148,25 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; + // ================================================ + // everything until here is standard initialization + // the relevant stuff for speculative decoding starts here + + const int n_input = inp.size(); + + const auto t_enc_start = ggml_time_us(); + // target model sampling context struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams); + // eval the prompt + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), n_input - 1)); + + // note: keep the last token separate! + llama_token id_last = inp.back(); + + int n_past = inp.size() - 1; + // init the speculator struct common_speculative_params params_spec; params_spec.n_draft = n_draft; @@ -174,6 +177,13 @@ int main(int argc, char ** argv) { struct common_speculative * spec = common_speculative_init(params_spec); // feed the prompt to the speculator + // + // this has to be kept synchronized with the target context + // + // TODO: simplify this by moving the context management logic in the common_speculative instance + // for example, the common_speculative_add_draft can pass the entire context (or part of it) and the + // speculator will automatically compute any new tokens that are not present in its context + // common_speculative_set_prompt(spec, inp.data(), n_input - 1); llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); @@ -188,23 +198,41 @@ int main(int argc, char ** argv) { common_batch_add (batch_tgt, id_last, n_past, { 0 }, true); // optionally, append draft tokens to the target batch + // + // this is the most important part of the speculation. the more probable tokens that are provided here + // the better the performance will be. in theory, this computation can be performed asynchronously and even + // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens + // from a cache or lookup tables. + // common_speculative_add_draft(spec, batch_tgt, id_last, n_past); - // evaluate the target model on the drafted tokens + // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); llama_decode(ctx_tgt, batch_tgt); } - // process the full target batch and return the accepted token based on the target sampler + // sample from the full target batch and return the accepted tokens based on the target sampler + // + // for each token to be accepted, the sampler would have to sample that same token + // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the + // available logits from the batch and sample the next token until we run out of logits or the sampler + // disagrees with the draft + // const auto ids = common_speculative_sample(spec, smpl, ctx_tgt); + GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token + n_past += ids.size(); n_drafted += batch_tgt.n_tokens - 1; n_accept += ids.size() - 1; // process the accepted tokens and update contexts + // + // this is the standard token post-processing that we normally do + // in this case, we do it for a group of accepted tokens at once + // { llama_token id; std::string token_str; @@ -232,7 +260,7 @@ int main(int argc, char ** argv) { break; } - LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); + LOG_DBG("accepted %d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, id, token_str.c_str()); { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); @@ -241,6 +269,7 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_rm(ctx_dft, 0, n_past, -1); } + // remember the last accepted token for the next iteration id_last = id; } } From 0f878a657c5b01c144c769374556c198c7f484d5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 21 Nov 2024 21:27:14 +0200 Subject: [PATCH 03/14] speculative : manage context in common_speculative ggml-ci --- common/common.cpp | 72 +++++++++++- common/common.h | 14 +++ common/sampling.cpp | 22 ++++ common/sampling.h | 2 + common/speculative.cpp | 104 +++++++++++------- common/speculative.h | 27 +---- examples/server/server.cpp | 4 +- examples/server/utils.hpp | 57 ---------- .../speculative-simple/speculative-simple.cpp | 30 ++--- 9 files changed, 188 insertions(+), 144 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d314523db4c62..43fa8a1ef67be 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -536,12 +536,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat [](const unsigned char c) { return !std::isprint(c); }), detokenized.end()); - buf << "\n" << std::to_string(i) - << ":token '" << detokenized << "'" - << ":pos " << std::to_string(batch.pos[i]) - << ":n_seq_id " << std::to_string(batch.n_seq_id[i]) - << ":seq_id " << std::to_string(batch.seq_id[i][0]) - << ":logits " << std::to_string(batch.logits[i]); + buf << "\n" << std::to_string(i) + << ", token '" << detokenized << "'" + << ", pos " << std::to_string(batch.pos[i]) + << ", n_seq_id " << std::to_string(batch.n_seq_id[i]) + << ", seq_id " << std::to_string(batch.seq_id[i][0]) + << ", logits " << std::to_string(batch.logits[i]); } buf << " ]"; @@ -1490,6 +1490,66 @@ void common_batch_add( batch.n_tokens++; } +// +// Token utils +// + +size_t common_lcp(const llama_tokens & a, const llama_tokens & b) { + size_t i; + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} + + return i; +} + +size_t common_lcs(const llama_tokens & a, const llama_tokens & b) { + // check for empty sequences + if (a.empty() || b.empty()) { + return 0; + } + + // get the lengths of the input sequences + size_t a_len = a.size(); + size_t b_len = b.size(); + + // initialize the maximum length of the longest common subsequence (LCS) + size_t max_length = 0; + + // use two rows instead of a 2D matrix to optimize space + std::vector prev_row(b_len + 1, 0); + std::vector curr_row(b_len + 1, 0); + + // iterate through the elements of a + for (size_t i = 1; i <= a_len; i++) { + // iterate through the elements of b + for (size_t j = 1; j <= b_len; j++) { + // if elements at the current positions match + if (a[i - 1] == b[j - 1]) { + // if it's the first element of either sequences, set LCS length to 1 + if (i == 1 || j == 1) { + curr_row[j] = 1; + } else { + // increment LCS length by 1 compared to the previous element + curr_row[j] = prev_row[j - 1] + 1; + } + + // update max_length if necessary + if (curr_row[j] > max_length) { + max_length = curr_row[j]; + } + } else { + // reset LCS length if elements don't match + curr_row[j] = 0; + } + } + + // update the previous row for the next iteration + prev_row = curr_row; + } + + // return the maximum length of the LCS + return max_length; +} + // // Vocab utils // diff --git a/common/common.h b/common/common.h index 7977cc7a99a78..29d678c7bab8a 100644 --- a/common/common.h +++ b/common/common.h @@ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info { struct llama_lora_adapter * adapter; }; +using llama_tokens = std::vector; + // build info extern int LLAMA_BUILD_NUMBER; extern char const * LLAMA_COMMIT; @@ -461,7 +463,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f // clear LoRA adapters from context, then apply new list of adapters void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters); +// // Batch utils +// void common_batch_clear(struct llama_batch & batch); @@ -472,6 +476,16 @@ void common_batch_add( const std::vector & seq_ids, bool logits); +// +// Token utils +// + +// longest common prefix +size_t common_lcp(const llama_tokens & a, const llama_tokens & b); + +// longet common subsequence +size_t common_lcs(const llama_tokens & a, const llama_tokens & b); + // // Vocab utils // diff --git a/common/sampling.cpp b/common/sampling.cpp index fe1ef5bf9ae05..f90ac8b90ef0e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -342,6 +342,28 @@ std::vector common_sampler_sample_n(struct common_sampler * gsmpl, return result; } +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first) { + std::vector idxs; + idxs.reserve(batch.n_tokens); + + std::vector draft; + draft.reserve(batch.n_tokens); + + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + + if (idxs.size() > 0) { + GGML_ASSERT(batch.pos[idxs.back()] + 1 == batch.pos[i]); + draft.push_back(batch.token[i]); + } + idxs.push_back(i); + } + + return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first); +} + uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { return llama_sampler_get_seed(gsmpl->chain); } diff --git a/common/sampling.h b/common/sampling.h index 23cfae1ac3c57..ba496ac278feb 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -73,6 +73,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft, bool grammar_first = false); +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first = false); + uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); // helpers diff --git a/common/speculative.cpp b/common/speculative.cpp index 2726760ad5bde..6acf84a239693 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -11,9 +11,7 @@ struct common_speculative { struct common_sampler * smpl; - std::vector i_batch_tgt; - - std::vector tokens; + llama_tokens prompt_last; }; struct common_speculative * common_speculative_init(struct common_speculative_params params) { @@ -21,12 +19,10 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa /* .params = */ params, /* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1), /* .smpl = */ nullptr, - /* .i_batch_tgt = */ {}, - /* .tokens = */ {}, }; // TODO: optimize or pass from outside? -#if 0 +#if 1 { common_sampler_params sparams; sparams.no_perf = false; @@ -70,30 +66,79 @@ void common_speculative_free(struct common_speculative * spec) { delete spec; } -void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens) { - llama_kv_cache_clear(spec->params.ctx_dft); - - // TODO: error handling - llama_decode(spec->params.ctx_dft, llama_batch_get_one(tokens, n_tokens)); -} - void common_speculative_add_draft( struct common_speculative * spec, struct llama_batch & batch_tgt, + const llama_tokens & prompt, llama_token id_last, - int n_past) { - spec->tokens.clear(); + llama_token n_past_tgt) { - spec->i_batch_tgt.clear(); - spec->i_batch_tgt.push_back(0); + int reuse_i = 0; + int reuse_n = 0; - common_sampler_reset(spec->smpl); + const int n_ctx = llama_n_ctx(spec->params.ctx_dft) - spec->params.n_draft; + + const int i_start = std::max(0, (int) prompt.size() - n_ctx); + + for (int i = 0; i < (int) spec->prompt_last.size(); ++i) { + int cur = 0; + while (i_start + cur < (int) prompt.size() && + i + cur < (int) spec->prompt_last.size() && + prompt[i_start + cur] == spec->prompt_last[i + cur]) { + cur++; + } + + if ((cur >= spec->params.n_reuse || prompt.size() <= n_ctx) && cur > reuse_n) { + reuse_i = i; + reuse_n = cur; + } + } + + LOG_DBG("%s: reuse_i = %d, reuse_n = %d\n", __func__, reuse_i, reuse_n); + + if (reuse_n == 0) { + llama_kv_cache_clear(spec->params.ctx_dft); + + spec->prompt_last.clear(); + } else { + llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, 0, reuse_i); + llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, reuse_i + reuse_n, -1); + llama_kv_cache_seq_add(spec->params.ctx_dft, 0, reuse_i, -1, -reuse_i); + + spec->prompt_last.erase(spec->prompt_last.begin(), spec->prompt_last.begin() + reuse_i); + spec->prompt_last.erase(spec->prompt_last.begin() + reuse_n, spec->prompt_last.end()); + } + + common_batch_clear(spec->batch_dft); + + for (int i = i_start + reuse_n; i < (int) prompt.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt[i]); + common_batch_add(spec->batch_dft, prompt[i], i - i_start, { 0 }, false); + + spec->prompt_last.push_back(prompt[i]); + } + + const llama_pos n_past = prompt.size() - i_start; + + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + + if (spec->batch_dft.n_tokens > 0) { + LOG_DBG("%s: draft batch: %s\n", __func__, string_from(spec->params.ctx_dft, spec->batch_dft).c_str()); + + llama_decode(spec->params.ctx_dft, spec->batch_dft); + } common_batch_clear(spec->batch_dft); common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true); + spec->prompt_last.push_back(id_last); + + LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(spec->params.ctx_dft, spec->prompt_last).c_str()); + llama_decode(spec->params.ctx_dft, spec->batch_dft); + common_sampler_reset(spec->smpl); + // sample n_draft tokens from the draft model for (int i = 0; i < spec->params.n_draft; ++i) { common_batch_clear(spec->batch_dft); @@ -111,18 +156,13 @@ void common_speculative_add_draft( const llama_token id = cur_p->data[0].id; // only collect very high-confidence draft tokens - if (cur_p->data[0].p < 0.75 && spec->tokens.size() >= 0) { + if (cur_p->data[0].p < spec->params.p_min) { break; } common_sampler_accept(spec->smpl, id, true); - spec->tokens.push_back(id); - - // add unique drafted tokens to the target batch - spec->i_batch_tgt.push_back(batch_tgt.n_tokens); - - common_batch_add(batch_tgt, id, n_past + i + 1, { 0 }, true); + common_batch_add(batch_tgt, id, n_past_tgt + i, { 0 }, true); if (batch_tgt.n_tokens > spec->params.n_draft) { break; @@ -132,23 +172,13 @@ void common_speculative_add_draft( // evaluate the drafted tokens on the draft model llama_decode(spec->params.ctx_dft, spec->batch_dft); + + spec->prompt_last.push_back(id); } // don't waste time on small batches // TODO: do not evaluate the draft model for that many rounds if (batch_tgt.n_tokens < spec->params.n_min) { batch_tgt.n_tokens = 1; - spec->tokens.resize(0); - spec->i_batch_tgt.resize(1); } - - // print current draft sequences - LOG_DBG("draft %s\n", string_from(spec->params.ctx_dft, spec->tokens).c_str()); -} - -std::vector common_speculative_sample( - struct common_speculative * spec, - struct common_sampler * smpl, - struct llama_context * ctx_tgt) { - return common_sampler_sample_n(smpl, ctx_tgt, spec->i_batch_tgt, spec->tokens); } diff --git a/common/speculative.h b/common/speculative.h index a2df2667a4205..b3a87e64ce4ab 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -1,14 +1,16 @@ #pragma once #include "llama.h" - -#include +#include "common.h" struct common_speculative; struct common_speculative_params { int n_draft = 16; int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user? + int n_reuse = 256; + + float p_min = 0.9f; struct llama_model * model_dft = nullptr; @@ -19,28 +21,11 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa void common_speculative_free(struct common_speculative * spec); -// TODO: remove -void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens); - // sample up to n_draft tokens and add them to the batch using the draft model // -// TODO: change to: -// -// void common_speculative_add_draft( -// struct common_speculative * spec, -// struct llama_batch & batch_tgt, -// llama_token * tokens, -// int32_t n_tokens); -// -// and update the internal logic to compute only the new tokens -// void common_speculative_add_draft( struct common_speculative * spec, struct llama_batch & batch_tgt, + const llama_tokens & prompt, llama_token id_last, - int n_past); - -std::vector common_speculative_sample( - struct common_speculative * spec, - struct common_sampler * smpl, - struct llama_context * ctx_tgt); + llama_token n_past_tgt); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b8e003be9730e..b7b2cbe5a3f68 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -743,7 +743,7 @@ struct server_context { } // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens); + int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens); // fraction of the common subsequence length compared to the current slot's prompt length float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); @@ -1960,7 +1960,7 @@ struct server_context { if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens); + slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens); // reuse chunks from the cached prompt by shifting their KV cache in the new position if (params.n_cache_reuse > 0) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index c47ed3e47a76d..1665e9dc37db6 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -24,7 +24,6 @@ #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" using json = nlohmann::ordered_json; -using llama_tokens = std::vector; #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) @@ -439,62 +438,6 @@ static std::string gen_chatcmplid() { // other common utils // -static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) { - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} - - return i; -} - -static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) { - // check for empty sequences - if (a.empty() || b.empty()) { - return 0; - } - - // get the lengths of the input sequences - size_t a_len = a.size(); - size_t b_len = b.size(); - - // initialize the maximum length of the longest common subsequence (LCS) - size_t max_length = 0; - - // use two rows instead of a 2D matrix to optimize space - std::vector prev_row(b_len + 1, 0); - std::vector curr_row(b_len + 1, 0); - - // iterate through the elements of a - for (size_t i = 1; i <= a_len; i++) { - // iterate through the elements of b - for (size_t j = 1; j <= b_len; j++) { - // if elements at the current positions match - if (a[i - 1] == b[j - 1]) { - // if it's the first element of either sequences, set LCS length to 1 - if (i == 1 || j == 1) { - curr_row[j] = 1; - } else { - // increment LCS length by 1 compared to the previous element - curr_row[j] = prev_row[j - 1] + 1; - } - - // update max_length if necessary - if (curr_row[j] > max_length) { - max_length = curr_row[j]; - } - } else { - // reset LCS length if elements don't match - curr_row[j] = 0; - } - } - - // update the previous row for the next iteration - prev_row = curr_row; - } - - // return the maximum length of the LCS - return max_length; -} - static bool ends_with(const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index aeccfd3699c89..cb6c35ce1107a 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -14,14 +14,6 @@ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 -struct seq_draft { - std::vector i_batch_tgt; - - std::vector tokens; - - struct common_sampler * smpl = nullptr; -}; - int main(int argc, char ** argv) { common_params params; @@ -165,27 +157,21 @@ int main(int argc, char ** argv) { // note: keep the last token separate! llama_token id_last = inp.back(); + auto prompt_dft = std::vector(inp.begin(), inp.end() - 1); + int n_past = inp.size() - 1; // init the speculator struct common_speculative_params params_spec; params_spec.n_draft = n_draft; params_spec.n_min = 5; + params_spec.n_reuse = 256; + params_spec.p_min = 0.9f; params_spec.model_dft = model_dft; params_spec.ctx_dft = ctx_dft; struct common_speculative * spec = common_speculative_init(params_spec); - // feed the prompt to the speculator - // - // this has to be kept synchronized with the target context - // - // TODO: simplify this by moving the context management logic in the common_speculative instance - // for example, the common_speculative_add_draft can pass the entire context (or part of it) and the - // speculator will automatically compute any new tokens that are not present in its context - // - common_speculative_set_prompt(spec, inp.data(), n_input - 1); - llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); const auto t_enc_end = ggml_time_us(); @@ -204,7 +190,7 @@ int main(int argc, char ** argv) { // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens // from a cache or lookup tables. // - common_speculative_add_draft(spec, batch_tgt, id_last, n_past); + common_speculative_add_draft(spec, batch_tgt, prompt_dft, id_last, n_past + 1); // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { @@ -220,7 +206,7 @@ int main(int argc, char ** argv) { // available logits from the batch and sample the next token until we run out of logits or the sampler // disagrees with the draft // - const auto ids = common_speculative_sample(spec, smpl, ctx_tgt); + const auto ids = common_sampler_sample_n(smpl, ctx_tgt, batch_tgt); GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token @@ -266,9 +252,11 @@ int main(int argc, char ** argv) { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1); - llama_kv_cache_seq_rm(ctx_dft, 0, n_past, -1); } + prompt_dft.push_back(id_last); + prompt_dft.insert(prompt_dft.end(), ids.begin(), ids.end() - 1); + // remember the last accepted token for the next iteration id_last = id; } From e4c122b93c5d63c732b01e5cbc1e22b2eefeee7c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Nov 2024 11:05:49 +0200 Subject: [PATCH 04/14] speculative : simplify ggml-ci --- common/speculative.cpp | 166 ++++++++++++------ common/speculative.h | 12 +- .../speculative-simple/speculative-simple.cpp | 58 +----- 3 files changed, 126 insertions(+), 110 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 6acf84a239693..810fa93e4e740 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -4,21 +4,31 @@ #include "common.h" #include "sampling.h" +#include + +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 +#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 + struct common_speculative { struct common_speculative_params params; - llama_batch batch_dft; + llama_batch batch; + struct llama_context * ctx; struct common_sampler * smpl; - llama_tokens prompt_last; + llama_tokens prompt; }; -struct common_speculative * common_speculative_init(struct common_speculative_params params) { +struct common_speculative * common_speculative_init( + struct common_speculative_params params, + struct llama_context * ctx_dft) { auto * result = new common_speculative { - /* .params = */ params, - /* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1), - /* .smpl = */ nullptr, + /* .params = */ params, + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .ctx = */ ctx_dft, + /* .smpl = */ nullptr, + /* .prompt = */ {}, }; // TODO: optimize or pass from outside? @@ -36,7 +46,7 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa COMMON_SAMPLER_TYPE_INFILL, }; - result->smpl = common_sampler_init(params.model_dft, sparams); + result->smpl = common_sampler_init(llama_get_model(ctx_dft), sparams); } #else { @@ -49,46 +59,104 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa COMMON_SAMPLER_TYPE_TOP_K, }; - result->smpl = common_sampler_init(params.model_dft, sparams); + result->smpl = common_sampler_init(llama_get_model(ctx_dft), sparams); } #endif - result->batch_dft = llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1); - return result; } void common_speculative_free(struct common_speculative * spec) { common_sampler_free(spec->smpl); - llama_batch_free(spec->batch_dft); + llama_batch_free(spec->batch); delete spec; } +bool common_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft) { + const struct llama_model * model_tgt = llama_get_model(ctx_tgt); + const struct llama_model * model_dft = llama_get_model(ctx_dft); + + const bool vocab_type_tgt = llama_vocab_type(model_tgt); + LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); + + const bool vocab_type_dft = llama_vocab_type(model_dft); + LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); + + if (vocab_type_tgt != vocab_type_dft) { + LOG_ERR("%s: draft model vocab type must match target model to use speculation but " + "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); + return false; + } + + if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || + llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || + llama_token_bos(model_tgt) != llama_token_bos(model_dft) || + llama_token_eos(model_tgt) != llama_token_eos(model_dft) + ) { + LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); + return false; + } + + { + const int n_vocab_tgt = llama_n_vocab(model_tgt); + const int n_vocab_dft = llama_n_vocab(model_dft); + + const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); + + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " + "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", + __func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + return false; + } + + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { + const char * token_text_tgt = llama_token_get_text(model_tgt, i); + const char * token_text_dft = llama_token_get_text(model_dft, i); + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { + LOG_ERR("%s: draft model vocab must match target model to use speculation but " + "token %d content differs - target '%s', draft '%s'\n", __func__, i, + common_token_to_piece(ctx_tgt, i).c_str(), + common_token_to_piece(ctx_dft, i).c_str()); + return false; + } + } + } + + return true; +} + void common_speculative_add_draft( struct common_speculative * spec, struct llama_batch & batch_tgt, - const llama_tokens & prompt, + const llama_tokens & prompt_tgt, llama_token id_last, llama_token n_past_tgt) { + auto & batch = spec->batch; + auto & ctx = spec->ctx; + auto & smpl = spec->smpl; + auto & prompt = spec->prompt; int reuse_i = 0; int reuse_n = 0; - const int n_ctx = llama_n_ctx(spec->params.ctx_dft) - spec->params.n_draft; + const int n_ctx = llama_n_ctx(ctx) - spec->params.n_draft; - const int i_start = std::max(0, (int) prompt.size() - n_ctx); + const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); - for (int i = 0; i < (int) spec->prompt_last.size(); ++i) { + for (int i = 0; i < (int) prompt.size(); ++i) { int cur = 0; - while (i_start + cur < (int) prompt.size() && - i + cur < (int) spec->prompt_last.size() && - prompt[i_start + cur] == spec->prompt_last[i + cur]) { + while (i_start + cur < (int) prompt_tgt.size() && + i + cur < (int) prompt.size() && + prompt_tgt[i_start + cur] == prompt[i + cur]) { cur++; } - if ((cur >= spec->params.n_reuse || prompt.size() <= n_ctx) && cur > reuse_n) { + if ((cur >= spec->params.n_reuse || prompt_tgt.size() <= n_ctx) && cur > reuse_n) { reuse_i = i; reuse_n = cur; } @@ -97,59 +165,59 @@ void common_speculative_add_draft( LOG_DBG("%s: reuse_i = %d, reuse_n = %d\n", __func__, reuse_i, reuse_n); if (reuse_n == 0) { - llama_kv_cache_clear(spec->params.ctx_dft); + llama_kv_cache_clear(ctx); - spec->prompt_last.clear(); + prompt.clear(); } else { - llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, 0, reuse_i); - llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, reuse_i + reuse_n, -1); - llama_kv_cache_seq_add(spec->params.ctx_dft, 0, reuse_i, -1, -reuse_i); + llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i); + llama_kv_cache_seq_rm (ctx, 0, reuse_i + reuse_n, -1); + llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i); - spec->prompt_last.erase(spec->prompt_last.begin(), spec->prompt_last.begin() + reuse_i); - spec->prompt_last.erase(spec->prompt_last.begin() + reuse_n, spec->prompt_last.end()); + prompt.erase(prompt.begin(), prompt.begin() + reuse_i); + prompt.erase(prompt.begin() + reuse_n, prompt.end()); } - common_batch_clear(spec->batch_dft); + common_batch_clear(batch); - for (int i = i_start + reuse_n; i < (int) prompt.size(); ++i) { - //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt[i]); - common_batch_add(spec->batch_dft, prompt[i], i - i_start, { 0 }, false); + for (int i = i_start + reuse_n; i < (int) prompt_tgt.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); + common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); - spec->prompt_last.push_back(prompt[i]); + prompt.push_back(prompt_tgt[i]); } - const llama_pos n_past = prompt.size() - i_start; + const llama_pos n_past = prompt_tgt.size() - i_start; LOG_DBG("%s: n_past = %d\n", __func__, n_past); - if (spec->batch_dft.n_tokens > 0) { - LOG_DBG("%s: draft batch: %s\n", __func__, string_from(spec->params.ctx_dft, spec->batch_dft).c_str()); + if (batch.n_tokens > 0) { + LOG_DBG("%s: draft batch: %s\n", __func__, string_from(ctx, batch).c_str()); - llama_decode(spec->params.ctx_dft, spec->batch_dft); + llama_decode(ctx, batch); } - common_batch_clear(spec->batch_dft); - common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true); + common_batch_clear(batch); + common_batch_add (batch, id_last, n_past, { 0 }, true); - spec->prompt_last.push_back(id_last); + prompt.push_back(id_last); - LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(spec->params.ctx_dft, spec->prompt_last).c_str()); + LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(ctx, prompt).c_str()); - llama_decode(spec->params.ctx_dft, spec->batch_dft); + llama_decode(ctx, batch); - common_sampler_reset(spec->smpl); + common_sampler_reset(smpl); // sample n_draft tokens from the draft model for (int i = 0; i < spec->params.n_draft; ++i) { - common_batch_clear(spec->batch_dft); + common_batch_clear(batch); - common_sampler_sample(spec->smpl, spec->params.ctx_dft, 0, true); + common_sampler_sample(smpl, ctx, 0, true); - const auto * cur_p = common_sampler_get_candidates(spec->smpl); + const auto * cur_p = common_sampler_get_candidates(smpl); for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(spec->params.ctx_dft, cur_p->data[k].id).c_str()); + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); } // add drafted token for each sequence @@ -160,7 +228,7 @@ void common_speculative_add_draft( break; } - common_sampler_accept(spec->smpl, id, true); + common_sampler_accept(smpl, id, true); common_batch_add(batch_tgt, id, n_past_tgt + i, { 0 }, true); @@ -168,12 +236,12 @@ void common_speculative_add_draft( break; } - common_batch_add(spec->batch_dft, id, n_past + i + 1, { 0 }, true); + common_batch_add(batch, id, n_past + i + 1, { 0 }, true); // evaluate the drafted tokens on the draft model - llama_decode(spec->params.ctx_dft, spec->batch_dft); + llama_decode(ctx, batch); - spec->prompt_last.push_back(id); + prompt.push_back(id); } // don't waste time on small batches diff --git a/common/speculative.h b/common/speculative.h index b3a87e64ce4ab..b657b62296308 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -11,16 +11,18 @@ struct common_speculative_params { int n_reuse = 256; float p_min = 0.9f; - - struct llama_model * model_dft = nullptr; - - struct llama_context * ctx_dft = nullptr; }; -struct common_speculative * common_speculative_init(struct common_speculative_params params); +struct common_speculative * common_speculative_init( + struct common_speculative_params params, + struct llama_context * ctx_dft); void common_speculative_free(struct common_speculative * spec); +bool common_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft); + // sample up to n_draft tokens and add them to the batch using the draft model // void common_speculative_add_draft( diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index cb6c35ce1107a..cdfd5b8868591 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -5,21 +5,14 @@ #include "log.h" #include "llama.h" -#include #include #include #include #include -#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 -#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 - int main(int argc, char ** argv) { common_params params; - // needed to get candidate probs even for temp <= 0.0 - params.sparams.n_probs = 128; - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { return 1; } @@ -63,55 +56,10 @@ int main(int argc, char ** argv) { model_dft = llama_init_dft.model; ctx_dft = llama_init_dft.context; - const bool vocab_type_tgt = llama_vocab_type(model_tgt); - LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt); - - const bool vocab_type_dft = llama_vocab_type(model_dft); - LOG_DBG("vocab_type dft: %d\n", vocab_type_dft); - - if (vocab_type_tgt != vocab_type_dft) { - LOG_ERR("%s: draft model vocab type must match target model to use speculation but ", __func__); - LOG_ERR("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); - return 1; - } - - if ( - llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || - llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || - llama_token_bos(model_tgt) != llama_token_bos(model_dft) || - llama_token_eos(model_tgt) != llama_token_eos(model_dft) - ) { - LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); + if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { return 1; } - { - const int n_vocab_tgt = llama_n_vocab(model_tgt); - const int n_vocab_dft = llama_n_vocab(model_dft); - const int vocab_diff = n_vocab_tgt > n_vocab_dft - ? n_vocab_tgt - n_vocab_dft - : n_vocab_dft - n_vocab_tgt; - - if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { - LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__); - LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", - n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); - return 1; - } - - for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { - const char * token_text_tgt = llama_token_get_text(model_tgt, i); - const char * token_text_dft = llama_token_get_text(model_dft, i); - if (std::strcmp(token_text_tgt, token_text_dft) != 0) { - LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__); - LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i, - common_token_to_piece(ctx_tgt, i).c_str(), - common_token_to_piece(ctx_dft, i).c_str()); - return 1; - } - } - } - // Tokenize the prompt std::vector inp; inp = common_tokenize(ctx_tgt, params.prompt, true, true); @@ -167,10 +115,8 @@ int main(int argc, char ** argv) { params_spec.n_min = 5; params_spec.n_reuse = 256; params_spec.p_min = 0.9f; - params_spec.model_dft = model_dft; - params_spec.ctx_dft = ctx_dft; - struct common_speculative * spec = common_speculative_init(params_spec); + struct common_speculative * spec = common_speculative_init(params_spec, ctx_dft); llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); From 0d4d0c15599bfcc837bc078d973317e9e61887e0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Nov 2024 11:31:28 +0200 Subject: [PATCH 05/14] speculative : simplify (cont) ggml-ci --- common/sampling.cpp | 23 ++-------- common/sampling.h | 5 +- common/speculative.cpp | 37 ++++++--------- common/speculative.h | 12 ++--- .../speculative-simple/speculative-simple.cpp | 46 ++++++++++++------- 5 files changed, 56 insertions(+), 67 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index f90ac8b90ef0e..75e2e5d296092 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -320,7 +320,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co return cur_p.data[cur_p.selected].id; } -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft, bool grammar_first) { +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); std::vector result; @@ -342,23 +342,10 @@ std::vector common_sampler_sample_n(struct common_sampler * gsmpl, return result; } -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first) { - std::vector idxs; - idxs.reserve(batch.n_tokens); - - std::vector draft; - draft.reserve(batch.n_tokens); - - for (int i = 0; i < batch.n_tokens; i++) { - if (batch.logits[i] == 0) { - continue; - } - - if (idxs.size() > 0) { - GGML_ASSERT(batch.pos[idxs.back()] + 1 == batch.pos[i]); - draft.push_back(batch.token[i]); - } - idxs.push_back(i); +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { + std::vector idxs(draft.size() + 1); + for (size_t i = 0; i < idxs.size(); ++i) { + idxs[i] = i; } return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first); diff --git a/common/sampling.h b/common/sampling.h index ba496ac278feb..f9b193ac8db73 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -71,9 +71,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // // returns at least 1 token, up to idxs.size() // -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft, bool grammar_first = false); +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first = false); +// assume idxs == [ 0, 1, 2, ..., draft.size() ] +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); diff --git a/common/speculative.cpp b/common/speculative.cpp index 810fa93e4e740..eccba93e01d5c 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -10,24 +10,19 @@ #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 struct common_speculative { - struct common_speculative_params params; - - llama_batch batch; - struct llama_context * ctx; struct common_sampler * smpl; + llama_batch batch; llama_tokens prompt; }; struct common_speculative * common_speculative_init( - struct common_speculative_params params, struct llama_context * ctx_dft) { auto * result = new common_speculative { - /* .params = */ params, - /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), /* .ctx = */ ctx_dft, /* .smpl = */ nullptr, + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), /* .prompt = */ {}, }; @@ -130,12 +125,11 @@ bool common_speculative_are_compatible( return true; } -void common_speculative_add_draft( +llama_tokens common_speculative_gen_draft( struct common_speculative * spec, - struct llama_batch & batch_tgt, + struct common_speculative_params params, const llama_tokens & prompt_tgt, - llama_token id_last, - llama_token n_past_tgt) { + llama_token id_last) { auto & batch = spec->batch; auto & ctx = spec->ctx; auto & smpl = spec->smpl; @@ -144,7 +138,7 @@ void common_speculative_add_draft( int reuse_i = 0; int reuse_n = 0; - const int n_ctx = llama_n_ctx(ctx) - spec->params.n_draft; + const int n_ctx = llama_n_ctx(ctx) - params.n_draft; const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); @@ -156,7 +150,7 @@ void common_speculative_add_draft( cur++; } - if ((cur >= spec->params.n_reuse || prompt_tgt.size() <= n_ctx) && cur > reuse_n) { + if ((cur >= params.n_reuse || prompt_tgt.size() <= n_ctx) && cur > reuse_n) { reuse_i = i; reuse_n = cur; } @@ -207,8 +201,11 @@ void common_speculative_add_draft( common_sampler_reset(smpl); + llama_tokens result; + result.reserve(params.n_draft); + // sample n_draft tokens from the draft model - for (int i = 0; i < spec->params.n_draft; ++i) { + for (int i = 0; i < params.n_draft; ++i) { common_batch_clear(batch); common_sampler_sample(smpl, ctx, 0, true); @@ -224,15 +221,15 @@ void common_speculative_add_draft( const llama_token id = cur_p->data[0].id; // only collect very high-confidence draft tokens - if (cur_p->data[0].p < spec->params.p_min) { + if (cur_p->data[0].p < params.p_min) { break; } common_sampler_accept(smpl, id, true); - common_batch_add(batch_tgt, id, n_past_tgt + i, { 0 }, true); + result.push_back(id); - if (batch_tgt.n_tokens > spec->params.n_draft) { + if (result.size() >= params.n_draft) { break; } @@ -244,9 +241,5 @@ void common_speculative_add_draft( prompt.push_back(id); } - // don't waste time on small batches - // TODO: do not evaluate the draft model for that many rounds - if (batch_tgt.n_tokens < spec->params.n_min) { - batch_tgt.n_tokens = 1; - } + return result; } diff --git a/common/speculative.h b/common/speculative.h index b657b62296308..9fb669fde3095 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -7,15 +7,12 @@ struct common_speculative; struct common_speculative_params { int n_draft = 16; - int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user? int n_reuse = 256; float p_min = 0.9f; }; -struct common_speculative * common_speculative_init( - struct common_speculative_params params, - struct llama_context * ctx_dft); +struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); void common_speculative_free(struct common_speculative * spec); @@ -25,9 +22,8 @@ bool common_speculative_are_compatible( // sample up to n_draft tokens and add them to the batch using the draft model // -void common_speculative_add_draft( +llama_tokens common_speculative_gen_draft( struct common_speculative * spec, - struct llama_batch & batch_tgt, + struct common_speculative_params params, const llama_tokens & prompt, - llama_token id_last, - llama_token n_past_tgt); + llama_token id_last); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index cdfd5b8868591..d7e572cf84c2c 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -13,6 +13,9 @@ int main(int argc, char ** argv) { common_params params; + // minimum size of the draft to use + const int n_min = 5; + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { return 1; } @@ -92,31 +95,29 @@ int main(int argc, char ** argv) { // everything until here is standard initialization // the relevant stuff for speculative decoding starts here - const int n_input = inp.size(); - const auto t_enc_start = ggml_time_us(); // target model sampling context struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams); // eval the prompt - llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), n_input - 1)); + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); // note: keep the last token separate! llama_token id_last = inp.back(); - auto prompt_dft = std::vector(inp.begin(), inp.end() - 1); + // all tokens currently in the target context + auto prompt_tgt = std::vector(inp.begin(), inp.end() - 1); int n_past = inp.size() - 1; // init the speculator struct common_speculative_params params_spec; params_spec.n_draft = n_draft; - params_spec.n_min = 5; params_spec.n_reuse = 256; params_spec.p_min = 0.9f; - struct common_speculative * spec = common_speculative_init(params_spec, ctx_dft); + struct common_speculative * spec = common_speculative_init(ctx_dft); llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); @@ -125,21 +126,30 @@ int main(int argc, char ** argv) { const auto t_dec_start = ggml_time_us(); while (true) { - // always have a token to evaluate from before - common_batch_clear(batch_tgt); - common_batch_add (batch_tgt, id_last, n_past, { 0 }, true); - - // optionally, append draft tokens to the target batch + // optionally, generate draft tokens that can be appended to the target batch // // this is the most important part of the speculation. the more probable tokens that are provided here // the better the performance will be. in theory, this computation can be performed asynchronously and even // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens // from a cache or lookup tables. // - common_speculative_add_draft(spec, batch_tgt, prompt_dft, id_last, n_past + 1); + llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last); + + // always have a token to evaluate from before - id_last + common_batch_clear(batch_tgt); + common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { + // do not waste time on small drafts + if (draft.size() < n_min) { + draft.clear(); + } + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + } + //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); llama_decode(ctx_tgt, batch_tgt); @@ -152,11 +162,11 @@ int main(int argc, char ** argv) { // available logits from the batch and sample the next token until we run out of logits or the sampler // disagrees with the draft // - const auto ids = common_sampler_sample_n(smpl, ctx_tgt, batch_tgt); + const auto ids = common_sampler_sample_n(smpl, ctx_tgt, draft); GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token - n_past += ids.size(); + n_past += ids.size() - 1; n_drafted += batch_tgt.n_tokens - 1; n_accept += ids.size() - 1; @@ -192,7 +202,7 @@ int main(int argc, char ** argv) { break; } - LOG_DBG("accepted %d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, id, token_str.c_str()); + LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str()); { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); @@ -200,8 +210,8 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1); } - prompt_dft.push_back(id_last); - prompt_dft.insert(prompt_dft.end(), ids.begin(), ids.end() - 1); + prompt_tgt.push_back(id_last); + prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1); // remember the last accepted token for the next iteration id_last = id; @@ -210,6 +220,8 @@ int main(int argc, char ** argv) { auto t_dec_end = ggml_time_us(); + const int n_input = inp.size(); + LOG("\n\n"); LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); From f27ddc57d71a0c93ada329c91d17876913b54fed Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Nov 2024 12:27:09 +0200 Subject: [PATCH 06/14] speculative : add --draft-min CLI arg --- common/arg.cpp | 13 ++++++++++--- common/common.h | 1 + examples/speculative-simple/speculative-simple.cpp | 5 +---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 4115b2f7511d3..35500670fa995 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -609,7 +609,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, int value) { params.n_draft = value; } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP})); + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--draft-min"}, "N", + string_format("minimum number of draft tokens to use for speculative decoding (default: %d)", params.n_draft_min), + [](common_params & params, int value) { + params.n_draft_min = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-ps", "--p-split"}, "N", string_format("speculative decoding split probability (default: %.1f)", (double)params.p_split), @@ -1454,7 +1461,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); } } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-sm", "--split-mode"}, "{none,layer,row}", "how to split the model across multiple GPUs, one of:\n" @@ -1599,7 +1606,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.model_draft = value; } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-mu", "--model-url"}, "MODEL_URL", "model download url (default: unused)", diff --git a/common/common.h b/common/common.h index 29d678c7bab8a..42c17ed3cf122 100644 --- a/common/common.h +++ b/common/common.h @@ -162,6 +162,7 @@ struct common_params { int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_draft = 5; // number of tokens to draft during speculative decoding + int32_t n_draft_min = 0; // minimum number of draft tokens to use for speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index d7e572cf84c2c..6dee64834e96e 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -13,9 +13,6 @@ int main(int argc, char ** argv) { common_params params; - // minimum size of the draft to use - const int n_min = 5; - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { return 1; } @@ -142,7 +139,7 @@ int main(int argc, char ** argv) { // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { // do not waste time on small drafts - if (draft.size() < n_min) { + if (draft.size() < params.n_draft_min) { draft.clear(); } From ccc8f63f9f708b2a34f29c5c244b77f156bec927 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Nov 2024 13:48:39 +0200 Subject: [PATCH 07/14] speculative : minor fixup --- examples/speculative-simple/speculative-simple.cpp | 6 ++++-- tests/test-arg-parser.cpp | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 6dee64834e96e..fb63435ab5821 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -41,8 +41,9 @@ int main(int argc, char ** argv) { // load the target model common_init_result llama_init_tgt = common_init_from_params(params); + model_tgt = llama_init_tgt.model; - ctx_tgt = llama_init_tgt.context; + ctx_tgt = llama_init_tgt.context; // load the draft model params.model = params.model_draft; @@ -53,8 +54,9 @@ int main(int argc, char ** argv) { params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads; common_init_result llama_init_dft = common_init_from_params(params); + model_dft = llama_init_dft.model; - ctx_dft = llama_init_dft.context; + ctx_dft = llama_init_dft.context; if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { return 1; diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 3665238b5a2d8..93850b0371c07 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -70,7 +70,7 @@ int main(void) { // non-existence arg in specific example (--draft cannot be used outside llama-speculative) argv = {"binary_name", "--draft", "123"}; - assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SERVER)); + assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_EMBEDDING)); printf("test-arg-parser: test valid usage\n\n"); From 2e197a1f21c7c9ae1707a10efce4fc748892d005 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Nov 2024 16:11:25 +0200 Subject: [PATCH 08/14] make : build fixes --- Makefile | 1 + common/speculative.cpp | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 5c899438515e1..dd6d864ad513a 100644 --- a/Makefile +++ b/Makefile @@ -966,6 +966,7 @@ OBJ_COMMON = \ $(DIR_COMMON)/console.o \ $(DIR_COMMON)/ngram-cache.o \ $(DIR_COMMON)/sampling.o \ + $(DIR_COMMON)/speculative.o \ $(DIR_COMMON)/build-info.o \ $(DIR_COMMON)/json-schema-to-grammar.o diff --git a/common/speculative.cpp b/common/speculative.cpp index eccba93e01d5c..3adb9d67a7f45 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -150,7 +150,7 @@ llama_tokens common_speculative_gen_draft( cur++; } - if ((cur >= params.n_reuse || prompt_tgt.size() <= n_ctx) && cur > reuse_n) { + if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) { reuse_i = i; reuse_n = cur; } @@ -229,7 +229,7 @@ llama_tokens common_speculative_gen_draft( result.push_back(id); - if (result.size() >= params.n_draft) { + if (params.n_draft <= (int) result.size()) { break; } From be5f6110003ccf7be4fcdfb7f51527d8fc2fb50e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 24 Nov 2024 12:09:31 +0200 Subject: [PATCH 09/14] speculative : do not redraft previous drafts ggml-ci --- common/speculative.cpp | 21 ++++++++++++++----- .../speculative-simple/speculative-simple.cpp | 2 ++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 3adb9d67a7f45..4222234de8e08 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -27,7 +27,7 @@ struct common_speculative * common_speculative_init( }; // TODO: optimize or pass from outside? -#if 1 +#if 0 { common_sampler_params sparams; sparams.no_perf = false; @@ -156,13 +156,27 @@ llama_tokens common_speculative_gen_draft( } } - LOG_DBG("%s: reuse_i = %d, reuse_n = %d\n", __func__, reuse_i, reuse_n); + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); + + llama_tokens result; + result.reserve(params.n_draft); if (reuse_n == 0) { llama_kv_cache_clear(ctx); prompt.clear(); } else { + if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) { + for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) { + result.push_back(prompt[i]); + + if (result.size() >= params.n_draft) { + break; + } + } + return result; + } + llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i); llama_kv_cache_seq_rm (ctx, 0, reuse_i + reuse_n, -1); llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i); @@ -201,9 +215,6 @@ llama_tokens common_speculative_gen_draft( common_sampler_reset(smpl); - llama_tokens result; - result.reserve(params.n_draft); - // sample n_draft tokens from the draft model for (int i = 0; i < params.n_draft; ++i) { common_batch_clear(batch); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index fb63435ab5821..98a9b35d41f20 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -134,6 +134,8 @@ int main(int argc, char ** argv) { // llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last); + //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); + // always have a token to evaluate from before - id_last common_batch_clear(batch_tgt); common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); From d9fb3b2e0137bb86943ef4d811563ad8a586b4d3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 24 Nov 2024 12:50:17 +0200 Subject: [PATCH 10/14] speculative : fix the draft sampling ggml-ci --- common/sampling.cpp | 18 +++++++++++++----- common/sampling.h | 15 ++++++++++----- .../speculative-simple/speculative-simple.cpp | 4 +++- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 75e2e5d296092..52f4c9e226b08 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -320,7 +320,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co return cur_p.data[cur_p.selected].id; } -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); std::vector result; @@ -330,25 +330,33 @@ std::vector common_sampler_sample_n(struct common_sampler * gsmpl, for (; i < draft.size(); i++) { const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + common_sampler_accept(gsmpl, id, true); + + result.push_back(id); + if (draft[i] != id) { break; } + } + + if (i == draft.size()) { + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + + common_sampler_accept(gsmpl, id, true); result.push_back(id); } - result.push_back(common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first)); - return result; } -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { std::vector idxs(draft.size() + 1); for (size_t i = 0; i < idxs.size(); ++i) { idxs[i] = i; } - return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first); + return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); } uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { diff --git a/common/sampling.h b/common/sampling.h index f9b193ac8db73..883d905a608da 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -62,19 +62,24 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // generalized version of common_sampler_sample // -// will cross-reference the sampled tokens with a batch of draft tokens -// if the sampler disagrees at some point, we stop and return the sampled tokens up to now +// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match +// if the sampler disagrees at some point, we stop and return the accepted tokens up to now // -// `common_sampler_sample_n(gsmpl, ctx, { idx }, {})` is equivalent to `common_sampler_sample(gsmpl, ctx, idx)` +// common_sampler_sample_n(gsmpl, ctx, { idx }, {}); +// +// is equivalent to +// +// common_sampler_sample(gsmpl, ctx, idx); +// common_sampler_accept(gsmpl, token, true); // // requires: idxs.size() == draft.size() + 1 // // returns at least 1 token, up to idxs.size() // -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); // assume idxs == [ 0, 1, 2, ..., draft.size() ] -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 98a9b35d41f20..6699e1d85c263 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -163,7 +163,9 @@ int main(int argc, char ** argv) { // available logits from the batch and sample the next token until we run out of logits or the sampler // disagrees with the draft // - const auto ids = common_sampler_sample_n(smpl, ctx_tgt, draft); + const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft); + + //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str()); GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token From c8880e786cfa44c6879b0cdc398feeaffebbfafb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 24 Nov 2024 12:53:48 +0200 Subject: [PATCH 11/14] speculative : fix compile warning --- common/speculative.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 4222234de8e08..9d2c8504bb980 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -170,7 +170,7 @@ llama_tokens common_speculative_gen_draft( for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) { result.push_back(prompt[i]); - if (result.size() >= params.n_draft) { + if (params.n_draft <= (int) result.size()) { break; } } From 7f9cc2058c38fa78c9ea42cde671837048a68519 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 24 Nov 2024 14:55:16 +0200 Subject: [PATCH 12/14] common : refactor args ggml-ci --- common/arg.cpp | 445 +++++++++--------- common/common.cpp | 4 +- common/common.h | 28 +- common/sampling.cpp | 6 +- common/sampling.h | 2 +- common/speculative.cpp | 22 +- examples/batched/batched.cpp | 8 +- examples/infill/infill.cpp | 2 +- examples/llava/llava-cli.cpp | 2 +- examples/llava/minicpmv-cli.cpp | 2 +- examples/lookahead/lookahead.cpp | 2 +- examples/lookup/lookup-stats.cpp | 3 +- examples/lookup/lookup.cpp | 4 +- examples/main/main.cpp | 2 +- examples/parallel/parallel.cpp | 2 +- examples/retrieval/retrieval.cpp | 4 +- examples/save-load-state/save-load-state.cpp | 8 +- examples/server/server.cpp | 6 +- examples/speculative-simple/README.md | 9 + .../speculative-simple/speculative-simple.cpp | 33 +- examples/speculative/speculative.cpp | 28 +- tests/test-arg-parser.cpp | 2 +- 22 files changed, 330 insertions(+), 294 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 35500670fa995..32240f21f2469 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -233,10 +233,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context } } - postprocess_cpu_params(params.cpuparams, nullptr); + postprocess_cpu_params(params.cpuparams, nullptr); postprocess_cpu_params(params.cpuparams_batch, ¶ms.cpuparams); - postprocess_cpu_params(params.draft_cpuparams, ¶ms.cpuparams); - postprocess_cpu_params(params.draft_cpuparams_batch, ¶ms.cpuparams_batch); + + postprocess_cpu_params(params.speculative.cpuparams, ¶ms.cpuparams); + postprocess_cpu_params(params.speculative.cpuparams_batch, ¶ms.cpuparams_batch); if (params.prompt_cache_all && (params.interactive || params.interactive_first)) { throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); @@ -251,7 +252,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context for (auto & antiprompt : params.antiprompt) { string_process_escapes(antiprompt); } - for (auto & seq_breaker : params.sparams.dry_sequence_breakers) { + for (auto & seq_breaker : params.sampling.dry_sequence_breakers) { string_process_escapes(seq_breaker); } } @@ -329,7 +330,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::string sampler_type_chars; std::string sampler_type_names; - for (const auto & sampler : params.sparams.samplers) { + for (const auto & sampler : params.sampling.samplers) { sampler_type_chars += common_sampler_type_to_chr(sampler); sampler_type_names += common_sampler_type_to_str(sampler) + ";"; } @@ -407,26 +408,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } )); - add_opt(common_arg( - {"-td", "--threads-draft"}, "N", - "number of threads to use during generation (default: same as --threads)", - [](common_params & params, int value) { - params.draft_cpuparams.n_threads = value; - if (params.draft_cpuparams.n_threads <= 0) { - params.draft_cpuparams.n_threads = std::thread::hardware_concurrency(); - } - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(common_arg( - {"-tbd", "--threads-batch-draft"}, "N", - "number of threads to use during batch and prompt processing (default: same as --threads-draft)", - [](common_params & params, int value) { - params.draft_cpuparams_batch.n_threads = value; - if (params.draft_cpuparams_batch.n_threads <= 0) { - params.draft_cpuparams_batch.n_threads = std::thread::hardware_concurrency(); - } - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); add_opt(common_arg( {"-C", "--cpu-mask"}, "M", "CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")", @@ -515,115 +496,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.cpuparams_batch.poll = value; } )); - add_opt(common_arg( - {"-Cd", "--cpu-mask-draft"}, "M", - "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", - [](common_params & params, const std::string & mask) { - params.draft_cpuparams.mask_valid = true; - if (!parse_cpu_mask(mask, params.draft_cpuparams.cpumask)) { - throw std::invalid_argument("invalid cpumask"); - } - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(common_arg( - {"-Crd", "--cpu-range-draft"}, "lo-hi", - "Ranges of CPUs for affinity. Complements --cpu-mask-draft", - [](common_params & params, const std::string & range) { - params.draft_cpuparams.mask_valid = true; - if (!parse_cpu_range(range, params.draft_cpuparams.cpumask)) { - throw std::invalid_argument("invalid range"); - } - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(common_arg( - {"--cpu-strict-draft"}, "<0|1>", - "Use strict CPU placement for draft model (default: same as --cpu-strict)", - [](common_params & params, int value) { - params.draft_cpuparams.strict_cpu = value; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(common_arg( - {"--prio-draft"}, "N", - string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.draft_cpuparams.priority), - [](common_params & params, int prio) { - if (prio < 0 || prio > 3) { - throw std::invalid_argument("invalid value"); - } - params.draft_cpuparams.priority = (enum ggml_sched_priority) prio; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(common_arg( - {"--poll-draft"}, "<0|1>", - "Use polling to wait for draft model work (default: same as --poll])", - [](common_params & params, int value) { - params.draft_cpuparams.poll = value; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(common_arg( - {"-Cbd", "--cpu-mask-batch-draft"}, "M", - "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", - [](common_params & params, const std::string & mask) { - params.draft_cpuparams_batch.mask_valid = true; - if (!parse_cpu_mask(mask, params.draft_cpuparams_batch.cpumask)) { - throw std::invalid_argument("invalid cpumask"); - } - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(common_arg( - {"-Crbd", "--cpu-range-batch-draft"}, "lo-hi", - "Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)", - [](common_params & params, const std::string & range) { - params.draft_cpuparams_batch.mask_valid = true; - if (!parse_cpu_range(range, params.draft_cpuparams_batch.cpumask)) { - throw std::invalid_argument("invalid cpumask"); - } - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(common_arg( - {"--cpu-strict-batch-draft"}, "<0|1>", - "Use strict CPU placement for draft model (default: --cpu-strict-draft)", - [](common_params & params, int value) { - params.draft_cpuparams_batch.strict_cpu = value; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(common_arg( - {"--prio-batch-draft"}, "N", - string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.draft_cpuparams_batch.priority), - [](common_params & params, int prio) { - if (prio < 0 || prio > 3) { - throw std::invalid_argument("invalid value"); - } - params.draft_cpuparams_batch.priority = (enum ggml_sched_priority) prio; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(common_arg( - {"--poll-batch-draft"}, "<0|1>", - "Use polling to wait for draft model work (default: --poll-draft)", - [](common_params & params, int value) { - params.draft_cpuparams_batch.poll = value; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(common_arg( - {"--draft"}, "N", - string_format("number of tokens to draft for speculative decoding (default: %d)", params.n_draft), - [](common_params & params, int value) { - params.n_draft = value; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); - add_opt(common_arg( - {"--draft-min"}, "N", - string_format("minimum number of draft tokens to use for speculative decoding (default: %d)", params.n_draft_min), - [](common_params & params, int value) { - params.n_draft_min = value; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); - add_opt(common_arg( - {"-ps", "--p-split"}, "N", - string_format("speculative decoding split probability (default: %.1f)", (double)params.p_split), - [](common_params & params, const std::string & value) { - params.p_split = std::stof(value); - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); add_opt(common_arg( {"-lcs", "--lookup-cache-static"}, "FNAME", "path to static lookup cache to use for lookup decoding (not updated by generation)", @@ -708,7 +580,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"), [](common_params & params) { params.no_perf = true; - params.sparams.no_perf = true; + params.sampling.no_perf = true; } ).set_env("LLAMA_ARG_NO_PERF")); add_opt(common_arg( @@ -890,155 +762,155 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()), [](common_params & params, const std::string & value) { const auto sampler_names = string_split(value, ';'); - params.sparams.samplers = common_sampler_types_from_names(sampler_names, true); + params.sampling.samplers = common_sampler_types_from_names(sampler_names, true); } ).set_sparam()); add_opt(common_arg( {"-s", "--seed"}, "SEED", - string_format("RNG seed (default: %d, use random seed for %d)", params.sparams.seed, LLAMA_DEFAULT_SEED), + string_format("RNG seed (default: %d, use random seed for %d)", params.sampling.seed, LLAMA_DEFAULT_SEED), [](common_params & params, const std::string & value) { - params.sparams.seed = std::stoul(value); + params.sampling.seed = std::stoul(value); } ).set_sparam()); add_opt(common_arg( {"--sampling-seq"}, "SEQUENCE", string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()), [](common_params & params, const std::string & value) { - params.sparams.samplers = common_sampler_types_from_chars(value); + params.sampling.samplers = common_sampler_types_from_chars(value); } ).set_sparam()); add_opt(common_arg( {"--ignore-eos"}, "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)", [](common_params & params) { - params.sparams.ignore_eos = true; + params.sampling.ignore_eos = true; } ).set_sparam()); add_opt(common_arg( {"--penalize-nl"}, - string_format("penalize newline tokens (default: %s)", params.sparams.penalize_nl ? "true" : "false"), + string_format("penalize newline tokens (default: %s)", params.sampling.penalize_nl ? "true" : "false"), [](common_params & params) { - params.sparams.penalize_nl = true; + params.sampling.penalize_nl = true; } ).set_sparam()); add_opt(common_arg( {"--temp"}, "N", - string_format("temperature (default: %.1f)", (double)params.sparams.temp), + string_format("temperature (default: %.1f)", (double)params.sampling.temp), [](common_params & params, const std::string & value) { - params.sparams.temp = std::stof(value); - params.sparams.temp = std::max(params.sparams.temp, 0.0f); + params.sampling.temp = std::stof(value); + params.sampling.temp = std::max(params.sampling.temp, 0.0f); } ).set_sparam()); add_opt(common_arg( {"--top-k"}, "N", - string_format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k), + string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k), [](common_params & params, int value) { - params.sparams.top_k = value; + params.sampling.top_k = value; } ).set_sparam()); add_opt(common_arg( {"--top-p"}, "N", - string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sparams.top_p), + string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p), [](common_params & params, const std::string & value) { - params.sparams.top_p = std::stof(value); + params.sampling.top_p = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--min-p"}, "N", - string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sparams.min_p), + string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p), [](common_params & params, const std::string & value) { - params.sparams.min_p = std::stof(value); + params.sampling.min_p = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--xtc-probability"}, "N", - string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability), + string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), [](common_params & params, const std::string & value) { - params.sparams.xtc_probability = std::stof(value); + params.sampling.xtc_probability = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--xtc-threshold"}, "N", - string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sparams.xtc_threshold), + string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold), [](common_params & params, const std::string & value) { - params.sparams.xtc_threshold = std::stof(value); + params.sampling.xtc_threshold = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--typical"}, "N", - string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p), + string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sampling.typ_p), [](common_params & params, const std::string & value) { - params.sparams.typ_p = std::stof(value); + params.sampling.typ_p = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--repeat-last-n"}, "N", - string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sparams.penalty_last_n), + string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n), [](common_params & params, int value) { - params.sparams.penalty_last_n = value; - params.sparams.n_prev = std::max(params.sparams.n_prev, params.sparams.penalty_last_n); + params.sampling.penalty_last_n = value; + params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n); } ).set_sparam()); add_opt(common_arg( {"--repeat-penalty"}, "N", - string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sparams.penalty_repeat), + string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat), [](common_params & params, const std::string & value) { - params.sparams.penalty_repeat = std::stof(value); + params.sampling.penalty_repeat = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--presence-penalty"}, "N", - string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_present), + string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_present), [](common_params & params, const std::string & value) { - params.sparams.penalty_present = std::stof(value); + params.sampling.penalty_present = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--frequency-penalty"}, "N", - string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_freq), + string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_freq), [](common_params & params, const std::string & value) { - params.sparams.penalty_freq = std::stof(value); + params.sampling.penalty_freq = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--dry-multiplier"}, "N", - string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sparams.dry_multiplier), + string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sampling.dry_multiplier), [](common_params & params, const std::string & value) { - params.sparams.dry_multiplier = std::stof(value); + params.sampling.dry_multiplier = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--dry-base"}, "N", - string_format("set DRY sampling base value (default: %.2f)", (double)params.sparams.dry_base), + string_format("set DRY sampling base value (default: %.2f)", (double)params.sampling.dry_base), [](common_params & params, const std::string & value) { float potential_base = std::stof(value); if (potential_base >= 1.0f) { - params.sparams.dry_base = potential_base; + params.sampling.dry_base = potential_base; } } ).set_sparam()); add_opt(common_arg( {"--dry-allowed-length"}, "N", - string_format("set allowed length for DRY sampling (default: %d)", params.sparams.dry_allowed_length), + string_format("set allowed length for DRY sampling (default: %d)", params.sampling.dry_allowed_length), [](common_params & params, int value) { - params.sparams.dry_allowed_length = value; + params.sampling.dry_allowed_length = value; } ).set_sparam()); add_opt(common_arg( {"--dry-penalty-last-n"}, "N", - string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sparams.dry_penalty_last_n), + string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n), [](common_params & params, int value) { - params.sparams.dry_penalty_last_n = value; + params.sampling.dry_penalty_last_n = value; } ).set_sparam()); add_opt(common_arg( {"--dry-sequence-breaker"}, "STRING", string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n", - params.sparams.dry_sequence_breakers.empty() ? "none" : - std::accumulate(std::next(params.sparams.dry_sequence_breakers.begin()), - params.sparams.dry_sequence_breakers.end(), - std::string("'") + (params.sparams.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sparams.dry_sequence_breakers[0]) + "'", + params.sampling.dry_sequence_breakers.empty() ? "none" : + std::accumulate(std::next(params.sampling.dry_sequence_breakers.begin()), + params.sampling.dry_sequence_breakers.end(), + std::string("'") + (params.sampling.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sampling.dry_sequence_breakers[0]) + "'", [](const std::string& a, const std::string& b) { std::string formatted_b = (b == "\n") ? "\\n" : b; return a + ", '" + formatted_b + "'"; @@ -1047,51 +919,51 @@ common_params_context common_params_parser_init(common_params & params, llama_ex static bool defaults_cleared = false; if (!defaults_cleared) { - params.sparams.dry_sequence_breakers.clear(); + params.sampling.dry_sequence_breakers.clear(); defaults_cleared = true; } if (value == "none") { - params.sparams.dry_sequence_breakers.clear(); + params.sampling.dry_sequence_breakers.clear(); } else { - params.sparams.dry_sequence_breakers.emplace_back(value); + params.sampling.dry_sequence_breakers.emplace_back(value); } } ).set_sparam()); add_opt(common_arg( {"--dynatemp-range"}, "N", - string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range), + string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range), [](common_params & params, const std::string & value) { - params.sparams.dynatemp_range = std::stof(value); + params.sampling.dynatemp_range = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--dynatemp-exp"}, "N", - string_format("dynamic temperature exponent (default: %.1f)", (double)params.sparams.dynatemp_exponent), + string_format("dynamic temperature exponent (default: %.1f)", (double)params.sampling.dynatemp_exponent), [](common_params & params, const std::string & value) { - params.sparams.dynatemp_exponent = std::stof(value); + params.sampling.dynatemp_exponent = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--mirostat"}, "N", string_format("use Mirostat sampling.\nTop K, Nucleus and Locally Typical samplers are ignored if used.\n" - "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sparams.mirostat), + "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat), [](common_params & params, int value) { - params.sparams.mirostat = value; + params.sampling.mirostat = value; } ).set_sparam()); add_opt(common_arg( {"--mirostat-lr"}, "N", - string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sparams.mirostat_eta), + string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta), [](common_params & params, const std::string & value) { - params.sparams.mirostat_eta = std::stof(value); + params.sampling.mirostat_eta = std::stof(value); } ).set_sparam()); add_opt(common_arg( {"--mirostat-ent"}, "N", - string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sparams.mirostat_tau), + string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau), [](common_params & params, const std::string & value) { - params.sparams.mirostat_tau = std::stof(value); + params.sampling.mirostat_tau = std::stof(value); } ).set_sparam()); add_opt(common_arg( @@ -1107,7 +979,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex try { if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); - params.sparams.logit_bias.push_back({key, bias}); + params.sampling.logit_bias.push_back({key, bias}); } else { throw std::invalid_argument("invalid input format"); } @@ -1118,9 +990,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--grammar"}, "GRAMMAR", - string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sparams.grammar.c_str()), + string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()), [](common_params & params, const std::string & value) { - params.sparams.grammar = value; + params.sampling.grammar = value; } ).set_sparam()); add_opt(common_arg( @@ -1134,7 +1006,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::copy( std::istreambuf_iterator(file), std::istreambuf_iterator(), - std::back_inserter(params.sparams.grammar) + std::back_inserter(params.sampling.grammar) ); } ).set_sparam()); @@ -1142,7 +1014,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-j", "--json-schema"}, "SCHEMA", "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead", [](common_params & params, const std::string & value) { - params.sparams.grammar = json_schema_to_grammar(json::parse(value)); + params.sampling.grammar = json_schema_to_grammar(json::parse(value)); } ).set_sparam()); add_opt(common_arg( @@ -1451,17 +1323,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_env("LLAMA_ARG_N_GPU_LAYERS")); - add_opt(common_arg( - {"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N", - "number of layers to store in VRAM for the draft model", - [](common_params & params, int value) { - params.n_gpu_layers_draft = value; - if (!llama_supports_gpu_offload()) { - fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n"); - fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); - } - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-sm", "--split-mode"}, "{none,layer,row}", "how to split the model across multiple GPUs, one of:\n" @@ -1600,13 +1461,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.model = value; } ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL")); - add_opt(common_arg( - {"-md", "--model-draft"}, "FNAME", - "draft model for speculative decoding (default: unused)", - [](common_params & params, const std::string & value) { - params.model_draft = value; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-mu", "--model-url"}, "MODEL_URL", "model download url (default: unused)", @@ -2044,5 +1898,168 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_LOG_TIMESTAMPS")); + // speculative parameters + add_opt(common_arg( + {"-td", "--threads-draft"}, "N", + "number of threads to use during generation (default: same as --threads)", + [](common_params & params, int value) { + params.speculative.cpuparams.n_threads = value; + if (params.speculative.cpuparams.n_threads <= 0) { + params.speculative.cpuparams.n_threads = std::thread::hardware_concurrency(); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-tbd", "--threads-batch-draft"}, "N", + "number of threads to use during batch and prompt processing (default: same as --threads-draft)", + [](common_params & params, int value) { + params.speculative.cpuparams_batch.n_threads = value; + if (params.speculative.cpuparams_batch.n_threads <= 0) { + params.speculative.cpuparams_batch.n_threads = std::thread::hardware_concurrency(); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-Cd", "--cpu-mask-draft"}, "M", + "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", + [](common_params & params, const std::string & mask) { + params.speculative.cpuparams.mask_valid = true; + if (!parse_cpu_mask(mask, params.speculative.cpuparams.cpumask)) { + throw std::invalid_argument("invalid cpumask"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-Crd", "--cpu-range-draft"}, "lo-hi", + "Ranges of CPUs for affinity. Complements --cpu-mask-draft", + [](common_params & params, const std::string & range) { + params.speculative.cpuparams.mask_valid = true; + if (!parse_cpu_range(range, params.speculative.cpuparams.cpumask)) { + throw std::invalid_argument("invalid range"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--cpu-strict-draft"}, "<0|1>", + "Use strict CPU placement for draft model (default: same as --cpu-strict)", + [](common_params & params, int value) { + params.speculative.cpuparams.strict_cpu = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--prio-draft"}, "N", + string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams.priority), + [](common_params & params, int prio) { + if (prio < 0 || prio > 3) { + throw std::invalid_argument("invalid value"); + } + params.speculative.cpuparams.priority = (enum ggml_sched_priority) prio; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--poll-draft"}, "<0|1>", + "Use polling to wait for draft model work (default: same as --poll])", + [](common_params & params, int value) { + params.speculative.cpuparams.poll = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-Cbd", "--cpu-mask-batch-draft"}, "M", + "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", + [](common_params & params, const std::string & mask) { + params.speculative.cpuparams_batch.mask_valid = true; + if (!parse_cpu_mask(mask, params.speculative.cpuparams_batch.cpumask)) { + throw std::invalid_argument("invalid cpumask"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-Crbd", "--cpu-range-batch-draft"}, "lo-hi", + "Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)", + [](common_params & params, const std::string & range) { + params.speculative.cpuparams_batch.mask_valid = true; + if (!parse_cpu_range(range, params.speculative.cpuparams_batch.cpumask)) { + throw std::invalid_argument("invalid cpumask"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--cpu-strict-batch-draft"}, "<0|1>", + "Use strict CPU placement for draft model (default: --cpu-strict-draft)", + [](common_params & params, int value) { + params.speculative.cpuparams_batch.strict_cpu = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--prio-batch-draft"}, "N", + string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams_batch.priority), + [](common_params & params, int prio) { + if (prio < 0 || prio > 3) { + throw std::invalid_argument("invalid value"); + } + params.speculative.cpuparams_batch.priority = (enum ggml_sched_priority) prio; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--poll-batch-draft"}, "<0|1>", + "Use polling to wait for draft model work (default: --poll-draft)", + [](common_params & params, int value) { + params.speculative.cpuparams_batch.poll = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--draft-max", "--draft", "--draft-n"}, "N", + string_format("number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max), + [](common_params & params, int value) { + params.speculative.n_max = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--draft-min", "--draft-n-min"}, "N", + string_format("minimum number of draft tokens to use for speculative decoding (default: %d)", params.speculative.n_min), + [](common_params & params, int value) { + params.speculative.n_min = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--draft-p-split"}, "P", + string_format("speculative decoding split probability (default: %.1f)", (double)params.speculative.p_split), + [](common_params & params, const std::string & value) { + params.speculative.p_split = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--draft-p-min"}, "P", + string_format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min), + [](common_params & params, const std::string & value) { + params.speculative.p_min = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"-cd", "--ctx-size-draft"}, "N", + string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx), + [](common_params & params, int value) { + params.speculative.n_ctx = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N", + "number of layers to store in VRAM for the draft model", + [](common_params & params, int value) { + params.speculative.n_gpu_layers = value; + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n"); + fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"-md", "--model-draft"}, "FNAME", + "draft model for speculative decoding (default: unused)", + [](common_params & params, const std::string & value) { + params.speculative.model = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); + return ctx_arg; } diff --git a/common/common.cpp b/common/common.cpp index 43fa8a1ef67be..c398329d05bf5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -925,9 +925,9 @@ struct common_init_result common_init_from_params(common_params & params) { common_lora_adapters_apply(lctx, iparams.lora_adapters); } - if (params.sparams.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { + if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__); - params.sparams.ignore_eos = false; + params.sampling.ignore_eos = false; } if (params.warmup) { diff --git a/common/common.h b/common/common.h index 42c17ed3cf122..f354a5fbe1452 100644 --- a/common/common.h +++ b/common/common.h @@ -103,8 +103,8 @@ enum dimre_method { DIMRE_METHOD_MEAN, }; -// sampler parameters -struct common_sampler_params { +// sampling parameters +struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler int32_t n_prev = 64; // number of previous tokens to remember @@ -155,20 +155,30 @@ struct common_sampler_params { std::string print() const; }; +struct common_params_speculative { + int32_t n_ctx = 4096; // draft context size + int32_t n_max = 5; // maximum number of tokens to draft during speculative decoding + int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding + int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) + float p_split = 0.1f; // speculative decoding split probability + float p_min = 0.9f; // minimum speculative decoding probability (greedy) + + struct cpu_params cpuparams; + struct cpu_params cpuparams_batch; + + std::string model = ""; // draft model for speculative decoding // NOLINT +}; + struct common_params { int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 5; // number of tokens to draft during speculative decoding - int32_t n_draft_min = 0; // minimum number of draft tokens to use for speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode - float p_split = 0.1f; // speculative decoding split probability int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs int32_t grp_attn_n = 1; // group-attention factor @@ -185,8 +195,6 @@ struct common_params { struct cpu_params cpuparams; struct cpu_params cpuparams_batch; - struct cpu_params draft_cpuparams; - struct cpu_params draft_cpuparams_batch; ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; @@ -198,10 +206,10 @@ struct common_params { enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings - struct common_sampler_params sparams; + struct common_params_sampling sampling; + struct common_params_speculative speculative; std::string model = ""; // model path // NOLINT - std::string model_draft = ""; // draft model for speculative decoding // NOLINT std::string model_alias = "unknown"; // model alias // NOLINT std::string model_url = ""; // model url to download // NOLINT std::string hf_token = ""; // HF token // NOLINT diff --git a/common/sampling.cpp b/common/sampling.cpp index 52f4c9e226b08..0c4699a89c8b2 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -99,7 +99,7 @@ struct ring_buffer { }; struct common_sampler { - common_sampler_params params; + common_params_sampling params; struct llama_sampler * grmr; struct llama_sampler * chain; @@ -125,7 +125,7 @@ struct common_sampler { } }; -std::string common_sampler_params::print() const { +std::string common_params_sampling::print() const { char result[1024]; snprintf(result, sizeof(result), @@ -141,7 +141,7 @@ std::string common_sampler_params::print() const { return std::string(result); } -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) { +struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); lparams.no_perf = params.no_perf; diff --git a/common/sampling.h b/common/sampling.h index 883d905a608da..348911b18888b 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -36,7 +36,7 @@ struct common_sampler; // llama_sampler API overloads -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params); +struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); void common_sampler_free(struct common_sampler * gsmpl); diff --git a/common/speculative.cpp b/common/speculative.cpp index 9d2c8504bb980..316ea9e1eea1c 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -29,32 +29,32 @@ struct common_speculative * common_speculative_init( // TODO: optimize or pass from outside? #if 0 { - common_sampler_params sparams; - sparams.no_perf = false; + common_params_sampling params; + params.no_perf = false; - sparams.top_k = 40; - sparams.top_p = 0.9; + params.top_k = 40; + params.top_p = 0.9; - sparams.samplers = { + params.samplers = { COMMON_SAMPLER_TYPE_TOP_K, COMMON_SAMPLER_TYPE_TOP_P, COMMON_SAMPLER_TYPE_INFILL, }; - result->smpl = common_sampler_init(llama_get_model(ctx_dft), sparams); + result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); } #else { - common_sampler_params sparams; - sparams.no_perf = false; + common_params_sampling params; + params.no_perf = false; - sparams.top_k = 10; + params.top_k = 10; - sparams.samplers = { + params.samplers = { COMMON_SAMPLER_TYPE_TOP_K, }; - result->smpl = common_sampler_init(llama_get_model(ctx_dft), sparams); + result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); } #endif diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 3b554033e7ee4..ba219cd4b32ae 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -68,10 +68,10 @@ int main(int argc, char ** argv) { llama_sampler * smpl = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k)); - llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep)); - llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp)); - llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sparams.seed)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); if (ctx == NULL) { LOG_ERR("%s: error: failed to create the llama_context\n" , __func__); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 15b358dc4e854..ef700895720ff 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -73,7 +73,7 @@ int main(int argc, char ** argv) { common_init(); - auto & sparams = params.sparams; + auto & sparams = params.sampling; console::init(params.simple_io, params.use_color); atexit([]() { console::cleanup(); }); diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 1610985858fc9..2691c6e6b2dd2 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ LOG("\n"); - struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams); + struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling); if (!smpl) { LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__); exit(1); diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index cbecec343c640..e9cbb51ed90ab 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -237,7 +237,7 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm LOG_INF("\n"); - struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams); + struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling); return smpl; } diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 3c0ccfea2ccd7..8d0ef8b3d75e6 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -115,7 +115,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); // target model sampling context - struct common_sampler * smpl = common_sampler_init(model, params.sparams); + struct common_sampler * smpl = common_sampler_init(model, params.sampling); // verification n-grams std::vector ngrams_cur(G); diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index 7faebe7ba11fc..dff07c075c47f 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -21,7 +21,7 @@ int main(int argc, char ** argv){ common_init(); - const int n_draft = params.n_draft; + const int n_draft = params.speculative.n_max; // init llama.cpp llama_backend_init(); @@ -40,6 +40,7 @@ int main(int argc, char ** argv){ common_ngram_cache ngram_cache_context; common_ngram_cache ngram_cache_dynamic; common_ngram_cache ngram_cache_static; + int64_t t_draft_flat_us = 0; int64_t t_draft_us = 0; diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index a04728b1834cc..4d92bb2385358 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -22,7 +22,7 @@ int main(int argc, char ** argv){ common_init(); // max. number of additional tokens to draft if match is found - const int n_draft = params.n_draft; + const int n_draft = params.speculative.n_max; const bool dump_kv_cache = params.dump_kv_cache; @@ -102,7 +102,7 @@ int main(int argc, char ** argv){ bool has_eos = false; - struct common_sampler * smpl = common_sampler_init(model, params.sparams); + struct common_sampler * smpl = common_sampler_init(model, params.sampling); std::vector draft; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 7c4ce4be2abae..957451af7ce0a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -100,7 +100,7 @@ int main(int argc, char ** argv) { common_init(); - auto & sparams = params.sparams; + auto & sparams = params.sampling; // save choice to use color for later // (note for later: this is a slightly awkward choice) diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 43c8f3ed56ba9..fd2b1c0112838 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -160,7 +160,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.smpl = common_sampler_init(model, params.sparams); + client.smpl = common_sampler_init(model, params.sampling); } std::vector tokens_system; diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 1768aae510067..e78a8596d8cfe 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -282,8 +282,8 @@ int main(int argc, char ** argv) { return a.second > b.second; }); - LOG("Top %d similar chunks:\n", params.sparams.top_k); - for (int i = 0; i < std::min(params.sparams.top_k, (int) chunks.size()); i++) { + LOG("Top %d similar chunks:\n", params.sampling.top_k); + for (int i = 0; i < std::min(params.sampling.top_k, (int) chunks.size()); i++) { LOG("filename: %s\n", chunks[similarities[i].first].filename.c_str()); LOG("filepos: %lld\n", (long long int) chunks[similarities[i].first].filepos); LOG("similarity: %f\n", similarities[i].second); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 8c49a52a66124..2f0cf9baa32b7 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -9,7 +9,7 @@ int main(int argc, char ** argv) { common_params params; params.prompt = "The quick brown fox"; - params.sparams.seed = 1234; + params.sampling.seed = 1234; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; @@ -42,7 +42,7 @@ int main(int argc, char ** argv) { llama_sampler * smpl = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed)); // tokenize prompt auto tokens = common_tokenize(ctx, params.prompt, true); @@ -106,7 +106,7 @@ int main(int argc, char ** argv) { llama_sampler * smpl2 = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed)); + llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sampling.seed)); printf("\nsecond run: %s", params.prompt.c_str()); @@ -169,7 +169,7 @@ int main(int argc, char ** argv) { llama_sampler * smpl3 = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed)); + llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sampling.seed)); printf("\nsingle seq run: %s", params.prompt.c_str()); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b7b2cbe5a3f68..6c55d65c01330 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -175,7 +175,7 @@ struct server_slot { // sampling json json_schema; - struct common_sampler_params sparams; + struct common_params_sampling sparams; struct common_sampler * smpl = nullptr; llama_token sampled; @@ -687,7 +687,7 @@ struct server_context { SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); - slot.sparams = params.sparams; + slot.sparams = params.sampling; slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); @@ -788,7 +788,7 @@ struct server_context { bool launch_slot_with_task(server_slot & slot, const server_task & task) { slot_params default_params; // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) - auto default_sparams = params.sparams; + auto default_sparams = params.sampling; const auto & data = task.data; if (data.count("__oaicompat") != 0) { diff --git a/examples/speculative-simple/README.md b/examples/speculative-simple/README.md index 6f3d6dc1505ad..e3a6c6b4aa0bf 100644 --- a/examples/speculative-simple/README.md +++ b/examples/speculative-simple/README.md @@ -1,3 +1,12 @@ # llama.cpp/examples/speculative-simple Demonstration of basic greedy speculative decoding + +```bash +./bin/llama-speculative-simple \ + -m ../models/qwen2.5-32b-coder-instruct/ggml-model-q8_0.gguf \ + -md ../models/qwen2.5-1.5b-coder-instruct/ggml-model-q4_0.gguf \ + -f test.txt -c 0 -ngl 99 --color \ + --sampling-seq k --top-k 1 -fa --temp 0.0 \ + -ngld 99 --draft-max 16 --draft-min 5 --draft-p-min 0.9 +``` diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 6699e1d85c263..ed3e6a4661227 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -24,7 +24,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.model_draft.empty()) { + if (params.speculative.model.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; } @@ -46,13 +46,13 @@ int main(int argc, char ** argv) { ctx_tgt = llama_init_tgt.context; // load the draft model - params.model = params.model_draft; - params.n_gpu_layers = params.n_gpu_layers_draft; - if (params.draft_cpuparams.n_threads > 0) { - params.cpuparams.n_threads = params.draft_cpuparams.n_threads; + params.model = params.speculative.model; + params.n_gpu_layers = params.speculative.n_gpu_layers; + if (params.speculative.cpuparams.n_threads > 0) { + params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; } - params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads; + params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; common_init_result llama_init_dft = common_init_from_params(params); model_dft = llama_init_dft.model; @@ -66,11 +66,9 @@ int main(int argc, char ** argv) { std::vector inp; inp = common_tokenize(ctx_tgt, params.prompt, true, true); - const int max_context_size = llama_n_ctx(ctx_tgt); - const int max_tokens_list_size = max_context_size - 4; + if ((int) inp.size() > llama_n_ctx(ctx_tgt)) { + LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt)); - if ((int) inp.size() > max_tokens_list_size) { - LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); return 1; } @@ -81,7 +79,10 @@ int main(int argc, char ** argv) { } // how many tokens to draft each time - int n_draft = params.n_draft; + int n_draft = params.speculative.n_max; + int n_draft_min = params.speculative.n_min; + + float p_min = params.speculative.p_min; int n_predict = 0; int n_drafted = 0; @@ -97,7 +98,7 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // target model sampling context - struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams); + struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); // eval the prompt llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); @@ -112,9 +113,9 @@ int main(int argc, char ** argv) { // init the speculator struct common_speculative_params params_spec; - params_spec.n_draft = n_draft; - params_spec.n_reuse = 256; - params_spec.p_min = 0.9f; + params_spec.n_draft = n_draft; + params_spec.n_reuse = 256; + params_spec.p_min = p_min; struct common_speculative * spec = common_speculative_init(ctx_dft); @@ -143,7 +144,7 @@ int main(int argc, char ** argv) { // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { // do not waste time on small drafts - if (draft.size() < params.n_draft_min) { + if (draft.size() < n_draft_min) { draft.clear(); } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 207b8ea345fea..eb8bb2de54a26 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -33,7 +33,7 @@ int main(int argc, char ** argv) { common_params params; // needed to get candidate probs even for temp <= 0.0 - params.sparams.n_probs = 128; + params.sampling.n_probs = 128; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { return 1; @@ -46,7 +46,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.model_draft.empty()) { + if (params.speculative.model.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; } @@ -55,9 +55,9 @@ int main(int argc, char ** argv) { const int n_seq_dft = params.n_parallel; // probability threshold for splitting a draft branch (only for n_seq_dft > 1) - const float p_split = params.p_split; + const float p_draft_split = params.speculative.p_split; - std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed); + std::default_random_engine rng(params.sampling.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sampling.seed); std::uniform_real_distribution<> u_dist; // init llama.cpp @@ -76,13 +76,13 @@ int main(int argc, char ** argv) { ctx_tgt = llama_init_tgt.context; // load the draft model - params.model = params.model_draft; - params.n_gpu_layers = params.n_gpu_layers_draft; - if (params.draft_cpuparams.n_threads > 0) { - params.cpuparams.n_threads = params.draft_cpuparams.n_threads; + params.model = params.speculative.model; + params.n_gpu_layers = params.speculative.n_gpu_layers; + if (params.speculative.cpuparams.n_threads > 0) { + params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; } - params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads; + params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; common_init_result llama_init_dft = common_init_from_params(params); model_dft = llama_init_dft.model; ctx_dft = llama_init_dft.context; @@ -170,7 +170,7 @@ int main(int argc, char ** argv) { //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft)); // how many tokens to draft each time - int n_draft = params.n_draft; + int n_draft = params.speculative.n_max; int n_predict = 0; int n_drafted = 0; @@ -183,14 +183,14 @@ int main(int argc, char ** argv) { bool has_eos = false; // target model sampling context (reuse the llama_context's sampling instance) - struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams); + struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); // draft sequence data std::vector drafts(n_seq_dft); for (int s = 0; s < n_seq_dft; ++s) { // allocate llama_sampler for each draft sequence - drafts[s].smpl = common_sampler_init(model_dft, params.sparams); + drafts[s].smpl = common_sampler_init(model_dft, params.sampling); } llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); @@ -230,7 +230,7 @@ int main(int argc, char ** argv) { // for stochastic sampling, attempt to match the token with the drafted tokens { bool accept = false; - if (params.sparams.temp > 0) { + if (params.sampling.temp > 0) { // stochastic verification common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true); @@ -494,7 +494,7 @@ int main(int argc, char ** argv) { // attempt to split the branch if the probability is high enough for (int f = 1; f < 8; ++f) { - if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) { + if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) { LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur); llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 93850b0371c07..69604b87ceec4 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -96,7 +96,7 @@ int main(void) { // --draft cannot be used outside llama-speculative argv = {"binary_name", "--draft", "123"}; assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE)); - assert(params.n_draft == 123); + assert(params.speculative.n_max == 123); // skip this part on windows, because setenv is not supported #ifdef _WIN32 From 4eb126fff09449ab957c1646fa9e876efe15c13d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 24 Nov 2024 15:39:07 +0200 Subject: [PATCH 13/14] common : change defaults [no ci] --- common/common.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/common.h b/common/common.h index f354a5fbe1452..c9fb2b62a998f 100644 --- a/common/common.h +++ b/common/common.h @@ -157,8 +157,8 @@ struct common_params_sampling { struct common_params_speculative { int32_t n_ctx = 4096; // draft context size - int32_t n_max = 5; // maximum number of tokens to draft during speculative decoding - int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding + int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding + int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) float p_split = 0.1f; // speculative decoding split probability float p_min = 0.9f; // minimum speculative decoding probability (greedy) From 8f419181d1c20d8195148680df15b6f093cb1512 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 24 Nov 2024 19:19:12 +0200 Subject: [PATCH 14/14] common : final touches ggml-ci --- common/common.h | 2 +- common/speculative.cpp | 37 +++++++++++++------ common/speculative.h | 13 +++---- .../speculative-simple/speculative-simple.cpp | 17 +++++++-- 4 files changed, 45 insertions(+), 24 deletions(-) diff --git a/common/common.h b/common/common.h index c9fb2b62a998f..5c579b5abfe03 100644 --- a/common/common.h +++ b/common/common.h @@ -156,7 +156,7 @@ struct common_params_sampling { }; struct common_params_speculative { - int32_t n_ctx = 4096; // draft context size + int32_t n_ctx = 0; // draft context size int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) diff --git a/common/speculative.cpp b/common/speculative.cpp index 316ea9e1eea1c..fe315a2703e9c 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -142,6 +142,8 @@ llama_tokens common_speculative_gen_draft( const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); + // reuse as much as possible from the old draft context + // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt for (int i = 0; i < (int) prompt.size(); ++i) { int cur = 0; while (i_start + cur < (int) prompt_tgt.size() && @@ -166,6 +168,8 @@ llama_tokens common_speculative_gen_draft( prompt.clear(); } else { + // this happens when a previous draft has been discarded (for example, due to being too small), but the + // target model agreed with it. in this case, we simply pass back the previous results to save compute if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) { for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) { result.push_back(prompt[i]); @@ -174,42 +178,51 @@ llama_tokens common_speculative_gen_draft( break; } } + return result; } - llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i); - llama_kv_cache_seq_rm (ctx, 0, reuse_i + reuse_n, -1); - llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i); + if (reuse_i > 0) { + llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i); + llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i); + + prompt.erase(prompt.begin(), prompt.begin() + reuse_i); + } + + if (reuse_n < (int) prompt.size()) { + llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1); - prompt.erase(prompt.begin(), prompt.begin() + reuse_i); - prompt.erase(prompt.begin() + reuse_n, prompt.end()); + prompt.erase(prompt.begin() + reuse_n, prompt.end()); + } } + // prepare a batch to evaluate any new tokens in the prompt common_batch_clear(batch); - for (int i = i_start + reuse_n; i < (int) prompt_tgt.size(); ++i) { + for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); prompt.push_back(prompt_tgt[i]); } - const llama_pos n_past = prompt_tgt.size() - i_start; - - LOG_DBG("%s: n_past = %d\n", __func__, n_past); - + // we should rarely end-up here during normal decoding if (batch.n_tokens > 0) { - LOG_DBG("%s: draft batch: %s\n", __func__, string_from(ctx, batch).c_str()); + //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); llama_decode(ctx, batch); } + const llama_pos n_past = prompt.size(); + + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + common_batch_clear(batch); common_batch_add (batch, id_last, n_past, { 0 }, true); prompt.push_back(id_last); - LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(ctx, prompt).c_str()); + //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); llama_decode(ctx, batch); diff --git a/common/speculative.h b/common/speculative.h index 9fb669fde3095..50ec0344618aa 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -6,10 +6,10 @@ struct common_speculative; struct common_speculative_params { - int n_draft = 16; + int n_draft = 16; // max drafted tokens int n_reuse = 256; - float p_min = 0.9f; + float p_min = 0.9f; // min probabiliy required to accept a token in the draft }; struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); @@ -21,9 +21,8 @@ bool common_speculative_are_compatible( const struct llama_context * ctx_dft); // sample up to n_draft tokens and add them to the batch using the draft model -// llama_tokens common_speculative_gen_draft( - struct common_speculative * spec, - struct common_speculative_params params, - const llama_tokens & prompt, - llama_token id_last); + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt, + llama_token id_last); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index ed3e6a4661227..1bc7f428cf3b5 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -46,8 +46,11 @@ int main(int argc, char ** argv) { ctx_tgt = llama_init_tgt.context; // load the draft model - params.model = params.speculative.model; + params.model = params.speculative.model; + params.n_ctx = params.speculative.n_ctx; + params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch; params.n_gpu_layers = params.speculative.n_gpu_layers; + if (params.speculative.cpuparams.n_threads > 0) { params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; } @@ -66,8 +69,14 @@ int main(int argc, char ** argv) { std::vector inp; inp = common_tokenize(ctx_tgt, params.prompt, true, true); - if ((int) inp.size() > llama_n_ctx(ctx_tgt)) { - LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt)); + if (llama_n_ctx(ctx_tgt) < (int) inp.size()) { + LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt)); + + return 1; + } + + if (llama_n_batch(ctx_tgt) < (int) inp.size()) { + LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt)); return 1; } @@ -114,7 +123,7 @@ int main(int argc, char ** argv) { // init the speculator struct common_speculative_params params_spec; params_spec.n_draft = n_draft; - params_spec.n_reuse = 256; + params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft; params_spec.p_min = p_min; struct common_speculative * spec = common_speculative_init(ctx_dft);