Skip to content

Commit 69982ea

Browse files
committed
speculative : refactor and add a simpler example
ggml-ci
1 parent ce2e59b commit 69982ea

File tree

10 files changed

+513
-1
lines changed

10 files changed

+513
-1
lines changed

common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ add_library(${TARGET} STATIC
6666
ngram-cache.h
6767
sampling.cpp
6868
sampling.h
69+
speculative.cpp
70+
speculative.h
6971
)
7072

7173
if (BUILD_SHARED_LIBS)

common/sampling.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,28 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
320320
return cur_p.data[cur_p.selected].id;
321321
}
322322

323+
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first) {
324+
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
325+
326+
std::vector<llama_token> result;
327+
result.reserve(idxs.size());
328+
329+
size_t i = 0;
330+
for (; i < draft.size(); i++) {
331+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
332+
333+
if (draft[i] != id) {
334+
break;
335+
}
336+
337+
result.push_back(id);
338+
}
339+
340+
result.push_back(common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first));
341+
342+
return result;
343+
}
344+
323345
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
324346
return llama_sampler_get_seed(gsmpl->chain);
325347
}

common/sampling.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
6060
//
6161
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
6262

63+
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first = false);
64+
6365
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
6466

6567
// helpers

common/speculative.cpp

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#include "speculative.h"
2+
3+
#include "log.h"
4+
#include "common.h"
5+
#include "sampling.h"
6+
7+
#include <vector>
8+
9+
struct seq_draft {
10+
};
11+
12+
struct common_speculative {
13+
struct common_speculative_params params;
14+
15+
llama_batch batch_dft;
16+
17+
struct common_sampler * smpl;
18+
19+
std::vector<int> i_batch_tgt;
20+
21+
std::vector<llama_token> tokens;
22+
};
23+
24+
struct common_speculative * common_speculative_init(struct common_speculative_params params) {
25+
auto * result = new common_speculative {
26+
/* .params = */ params,
27+
/* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1),
28+
/* .smpl = */ nullptr,
29+
/* .i_batch_tgt = */ {},
30+
/* .tokens = */ {},
31+
};
32+
33+
// TODO: optimize or pass from outside?
34+
#if 0
35+
{
36+
common_sampler_params sparams;
37+
sparams.no_perf = false;
38+
39+
sparams.top_k = 40;
40+
sparams.top_p = 0.9;
41+
42+
sparams.samplers = {
43+
COMMON_SAMPLER_TYPE_TOP_K,
44+
COMMON_SAMPLER_TYPE_TOP_P,
45+
COMMON_SAMPLER_TYPE_INFILL,
46+
};
47+
48+
result->smpl = common_sampler_init(params.model_dft, sparams);
49+
}
50+
#else
51+
{
52+
common_sampler_params sparams;
53+
sparams.no_perf = false;
54+
55+
sparams.top_k = 10;
56+
57+
sparams.samplers = {
58+
COMMON_SAMPLER_TYPE_TOP_K,
59+
};
60+
61+
result->smpl = common_sampler_init(params.model_dft, sparams);
62+
}
63+
#endif
64+
65+
result->batch_dft = llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1);
66+
67+
return result;
68+
}
69+
70+
void common_speculative_free(struct common_speculative * spec) {
71+
common_sampler_free(spec->smpl);
72+
73+
llama_batch_free(spec->batch_dft);
74+
75+
delete spec;
76+
}
77+
78+
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens) {
79+
llama_kv_cache_clear(spec->params.ctx_dft);
80+
81+
// TODO: error handling
82+
llama_decode(spec->params.ctx_dft, llama_batch_get_one(tokens, n_tokens));
83+
}
84+
85+
void common_speculative_add_draft(
86+
struct common_speculative * spec,
87+
struct llama_batch & batch_tgt,
88+
llama_token id_last,
89+
int n_past) {
90+
spec->tokens.clear();
91+
92+
spec->i_batch_tgt.clear();
93+
spec->i_batch_tgt.push_back(0);
94+
95+
common_sampler_reset(spec->smpl);
96+
97+
common_batch_clear(spec->batch_dft);
98+
common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true);
99+
100+
llama_decode(spec->params.ctx_dft, spec->batch_dft);
101+
102+
// sample n_draft tokens from the draft model
103+
for (int i = 0; i < spec->params.n_draft; ++i) {
104+
common_batch_clear(spec->batch_dft);
105+
106+
common_sampler_sample(spec->smpl, spec->params.ctx_dft, 0, true);
107+
108+
const auto * cur_p = common_sampler_get_candidates(spec->smpl);
109+
110+
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
111+
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
112+
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(spec->params.ctx_dft, cur_p->data[k].id).c_str());
113+
}
114+
115+
// add drafted token for each sequence
116+
const llama_token id = cur_p->data[0].id;
117+
118+
// only collect very high-confidence draft tokens
119+
if (cur_p->data[0].p < 0.75 && spec->tokens.size() >= 0) {
120+
break;
121+
}
122+
123+
common_sampler_accept(spec->smpl, id, true);
124+
125+
spec->tokens.push_back(id);
126+
127+
// add unique drafted tokens to the target batch
128+
spec->i_batch_tgt.push_back(batch_tgt.n_tokens);
129+
130+
common_batch_add(batch_tgt, id, n_past + i + 1, { 0 }, true);
131+
132+
if (batch_tgt.n_tokens > spec->params.n_draft) {
133+
break;
134+
}
135+
136+
common_batch_add(spec->batch_dft, id, n_past + i + 1, { 0 }, true);
137+
138+
// evaluate the drafted tokens on the draft model
139+
llama_decode(spec->params.ctx_dft, spec->batch_dft);
140+
}
141+
142+
// don't waste time on small batches
143+
// TODO: do not evaluate the draft model for tha many rounds
144+
if (batch_tgt.n_tokens < spec->params.n_min) {
145+
batch_tgt.n_tokens = 1;
146+
spec->tokens.resize(0);
147+
spec->i_batch_tgt.resize(1);
148+
}
149+
150+
// print current draft sequences
151+
LOG_DBG("draft %s\n", string_from(spec->params.ctx_dft, spec->tokens).c_str());
152+
}
153+
154+
std::vector<llama_token> common_speculative_sample(
155+
struct common_speculative * spec,
156+
struct common_sampler * smpl,
157+
struct llama_context * ctx_tgt) {
158+
return common_sampler_sample_n(smpl, ctx_tgt, spec->i_batch_tgt, spec->tokens);
159+
}

common/speculative.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#pragma once
2+
3+
#include "llama.h"
4+
5+
#include <vector>
6+
7+
struct common_speculative;
8+
9+
struct common_speculative_params {
10+
int n_draft = 16;
11+
int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user?
12+
13+
struct llama_model * model_dft = nullptr;
14+
15+
struct llama_context * ctx_dft = nullptr;
16+
};
17+
18+
struct common_speculative * common_speculative_init(struct common_speculative_params params);
19+
20+
void common_speculative_free(struct common_speculative * spec);
21+
22+
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens);
23+
24+
void common_speculative_add_draft(
25+
struct common_speculative * spec,
26+
struct llama_batch & batch_tgt,
27+
llama_token id_last,
28+
int n_past);
29+
30+
std::vector<llama_token> common_speculative_sample(
31+
struct common_speculative * spec,
32+
struct common_sampler * smpl,
33+
struct llama_context * ctx_tgt);

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,6 @@ else()
5050
add_subdirectory(simple)
5151
add_subdirectory(simple-chat)
5252
add_subdirectory(speculative)
53+
add_subdirectory(speculative-simple)
5354
add_subdirectory(tokenize)
5455
endif()
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-speculative-simple)
2+
add_executable(${TARGET} speculative-simple.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/speculative-simple/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# llama.cpp/examples/speculative-simple
2+
3+
Demonstration of basic greedy speculative decoding

0 commit comments

Comments
 (0)