Skip to content

Commit 8ec0ff9

Browse files
committed
fix embeddings and retrieval
1 parent 50fb396 commit 8ec0ff9

File tree

3 files changed

+29
-36
lines changed

3 files changed

+29
-36
lines changed

examples/embedding/embedding.cpp

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,6 @@ static std::vector<std::string> split_lines(const std::string & s, const std::st
2626
return lines;
2727
}
2828

29-
static void batch_add_seq(llama_batch_ext * batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
30-
size_t n_tokens = tokens.size();
31-
for (size_t i = 0; i < n_tokens; i++) {
32-
llama_batch_ext_add_text(batch, tokens[i], i, &seq_id, 1, true);
33-
}
34-
}
35-
3629
static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm) {
3730
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
3831
const llama_model * model = llama_get_model(ctx);
@@ -167,7 +160,7 @@ int main(int argc, char ** argv) {
167160

168161
// initialize batch
169162
const int n_prompts = prompts.size();
170-
llama_batch_ext * batch = llama_batch_ext_init(ctx);
163+
llama_batch_ext_ptr batch(ctx);
171164

172165
// count number of embeddings
173166
int n_embd_count = 0;
@@ -194,21 +187,21 @@ int main(int argc, char ** argv) {
194187
const uint64_t n_toks = inp.size();
195188

196189
// encode if at capacity
197-
if (llama_batch_ext_get_n_tokens(batch) + n_toks > n_batch) {
198-
batch_decode(ctx, batch, emb + e * n_embd, s, n_embd, params.embd_normalize);
199-
llama_batch_ext_clear(batch);
190+
if (batch.n_tokens() + n_toks > n_batch) {
191+
batch_decode(ctx, batch.get(), emb + e * n_embd, s, n_embd, params.embd_normalize);
192+
batch.clear();
200193

201-
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? llama_batch_ext_get_n_tokens(batch) : s;
194+
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens() : s;
202195
s = 0;
203196
}
204197

205198
// add to batch
206-
batch_add_seq(batch, inp, s);
199+
batch.add_seq(inp, 0, s, true);
207200
s += 1;
208201
}
209202

210203
// final batch
211-
batch_decode(ctx, batch, emb + e * n_embd, s, n_embd, params.embd_normalize);
204+
batch_decode(ctx, batch.get(), emb + e * n_embd, s, n_embd, params.embd_normalize);
212205

213206
if (params.embd_out.empty()) {
214207
LOG("\n");
@@ -313,8 +306,6 @@ int main(int argc, char ** argv) {
313306
LOG("\n");
314307
llama_perf_context_print(ctx);
315308

316-
llama_batch_ext_free(batch);
317-
318309
// clean up
319310
llama_backend_free();
320311

examples/retrieval/retrieval.cpp

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,6 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
7474
return chunks;
7575
}
7676

77-
static void batch_add_seq(llama_batch_ext * batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
78-
const size_t n_tokens = tokens.size();
79-
for (size_t i = 0; i < n_tokens; i++) {
80-
llama_batch_ext_add_text(batch, tokens[i], i, &seq_id, 1, true);
81-
}
82-
}
83-
8477
static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm = 2) {
8578
const llama_model * model = llama_get_model(ctx);
8679

@@ -213,7 +206,7 @@ int main(int argc, char ** argv) {
213206

214207
// initialize batch
215208
const int n_chunks = chunks.size();
216-
llama_batch_ext * batch = llama_batch_ext_init(ctx);
209+
llama_batch_ext_ptr batch(ctx);
217210

218211
// allocate output
219212
const int n_embd = llama_model_n_embd(model);
@@ -230,21 +223,21 @@ int main(int argc, char ** argv) {
230223
const uint64_t n_toks = inp.size();
231224

232225
// encode if at capacity
233-
if (llama_batch_ext_get_n_tokens(batch) + n_toks > n_batch) {
234-
batch_decode(ctx, batch, emb + p * n_embd, s, n_embd);
235-
llama_batch_ext_clear(batch);
226+
if (batch.n_tokens() + n_toks > n_batch) {
227+
batch_decode(ctx, batch.get(), emb + p * n_embd, s, n_embd);
228+
batch.clear();
236229

237230
p += s;
238231
s = 0;
239232
}
240233

241234
// add to batch
242-
batch_add_seq(batch, inp, s);
235+
batch.add_seq(inp, 0, s, true);
243236
s += 1;
244237
}
245238

246239
// final batch
247-
batch_decode(ctx, batch, emb + p * n_embd, s, n_embd);
240+
batch_decode(ctx, batch.get(), emb + p * n_embd, s, n_embd);
248241

249242
// save embeddings to chunks
250243
for (int i = 0; i < n_chunks; i++) {
@@ -253,7 +246,7 @@ int main(int argc, char ** argv) {
253246
chunks[i].tokens.clear();
254247
}
255248

256-
llama_batch_ext * query_batch = llama_batch_ext_init(ctx);
249+
llama_batch_ext_ptr query_batch(ctx);
257250

258251
// start loop, receive query and return top k similar chunks based on cosine similarity
259252
std::string query;
@@ -262,12 +255,12 @@ int main(int argc, char ** argv) {
262255
std::getline(std::cin, query);
263256
std::vector<int32_t> query_tokens = common_tokenize(ctx, query, true);
264257

265-
batch_add_seq(query_batch, query_tokens, 0);
258+
batch.add_seq(query_tokens, 0, 0, true);
266259

267260
std::vector<float> query_emb(n_embd, 0);
268-
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
261+
batch_decode(ctx, query_batch.get(), query_emb.data(), 1, n_embd);
269262

270-
llama_batch_ext_clear(query_batch);
263+
query_batch.clear();
271264

272265
// compute cosine similarities
273266
{
@@ -296,9 +289,6 @@ int main(int argc, char ** argv) {
296289
LOG("\n");
297290
llama_perf_context_print(ctx);
298291

299-
llama_batch_ext_free(batch);
300-
llama_batch_ext_free(query_batch);
301-
302292
// clean up
303293
llama_backend_free();
304294
}

include/llama-cpp.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,18 @@ struct llama_batch_ext_ptr : std::unique_ptr<llama_batch_ext, llama_batch_ext_de
101101
return output_id;
102102
}
103103

104+
// Return output ID of the last token. Position starts from pos0
105+
int32_t add_seq(std::vector<llama_token> & tokens, llama_pos pos0, llama_seq_id seq_id, bool output_last) {
106+
int32_t output_id = -1;
107+
for (size_t i = 0; i < tokens.size(); i++) {
108+
output_id = llama_batch_ext_add_text(this->get(), tokens[i], pos0 + i, &seq_id, 1, false);
109+
}
110+
if (output_last) {
111+
llama_batch_ext_set_output_last(this->get());
112+
}
113+
return output_id;
114+
}
115+
104116
void clear() {
105117
llama_batch_ext_clear(this->get());
106118
}

0 commit comments

Comments
 (0)