@@ -26,56 +26,52 @@ static std::vector<std::string> split_lines(const std::string & s, const std::st
26
26
return lines;
27
27
}
28
28
29
- static void batch_add_seq (common_batch & batch, const std::vector<int32_t > & tokens, llama_seq_id seq_id) {
29
+ static void batch_add_seq (llama_batch_ext * batch, const std::vector<int32_t > & tokens, llama_seq_id seq_id) {
30
30
size_t n_tokens = tokens.size ();
31
31
for (size_t i = 0 ; i < n_tokens; i++) {
32
- batch. add_text ( tokens[i], i, seq_id, true );
32
+ llama_batch_ext_add_text (batch, tokens[i], i, & seq_id, 1 , true );
33
33
}
34
34
}
35
35
36
- static void batch_decode (llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
36
+ static void batch_decode (llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm) {
37
37
const enum llama_pooling_type pooling_type = llama_pooling_type (ctx);
38
- const struct llama_model * model = llama_get_model (ctx);
38
+ const llama_model * model = llama_get_model (ctx);
39
39
40
40
// clear previous kv_cache values (irrelevant for embeddings)
41
41
llama_kv_self_clear (ctx);
42
42
43
+ const int n_tokens = llama_batch_ext_get_n_tokens (batch);
44
+
43
45
// run model
44
- LOG_INF (" %s: n_tokens = %d, n_seq = %d\n " , __func__, llama_batch_ext_get_n_tokens (batch. get ()) , n_seq);
46
+ LOG_INF (" %s: n_tokens = %d, n_seq = %d\n " , __func__, n_tokens , n_seq);
45
47
if (llama_model_has_encoder (model) && !llama_model_has_decoder (model)) {
46
48
// encoder-only model
47
- if (llama_encode_ext (ctx, batch. get () ) < 0 ) {
49
+ if (llama_encode_ext (ctx, batch) < 0 ) {
48
50
LOG_ERR (" %s : failed to encode\n " , __func__);
49
51
}
50
52
} else if (!llama_model_has_encoder (model) && llama_model_has_decoder (model)) {
51
53
// decoder-only model
52
- if (llama_decode_ext (ctx, batch. get () ) < 0 ) {
54
+ if (llama_decode_ext (ctx, batch) < 0 ) {
53
55
LOG_ERR (" %s : failed to decode\n " , __func__);
54
56
}
55
57
}
56
58
57
- for (int i = 0 ; i < llama_batch_ext_get_n_tokens (batch.get ()); i++) {
58
- if (!batch.tokens [i].logits ) {
59
- continue ;
60
- }
61
-
62
- const float * embd = nullptr ;
63
- int embd_pos = 0 ;
64
-
65
- if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
66
- // try to get token embeddings
67
- embd = llama_get_embeddings_ith (ctx, i);
68
- embd_pos = i;
59
+ if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
60
+ for (int i = 0 ; i < n_tokens; i++) {
61
+ const float * embd = llama_get_embeddings_ith (ctx, i);
69
62
GGML_ASSERT (embd != NULL && " failed to get token embeddings" );
70
- } else {
71
- // try to get sequence embeddings - supported only when pooling_type is not NONE
72
- embd = llama_get_embeddings_seq (ctx, batch.tokens [i].seq_id );
73
- embd_pos = batch.tokens [i].seq_id ;
74
- GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
63
+
64
+ float * out = output + i * n_embd;
65
+ common_embd_normalize (embd, out, n_embd, embd_norm);
75
66
}
67
+ } else {
68
+ for (int s = 0 ; s < n_seq; s++) {
69
+ const float * embd = llama_get_embeddings_seq (ctx, s);
70
+ GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
76
71
77
- float * out = output + embd_pos * n_embd;
78
- common_embd_normalize (embd, out, n_embd, embd_norm);
72
+ float * out = output + s * n_embd;
73
+ common_embd_normalize (embd, out, n_embd, embd_norm);
74
+ }
79
75
}
80
76
}
81
77
@@ -171,7 +167,7 @@ int main(int argc, char ** argv) {
171
167
172
168
// initialize batch
173
169
const int n_prompts = prompts.size ();
174
- struct common_batch batch = common_batch (n_batch, 1 );
170
+ llama_batch_ext * batch = llama_batch_ext_init (n_batch, 1 );
175
171
176
172
// count number of embeddings
177
173
int n_embd_count = 0 ;
@@ -198,12 +194,12 @@ int main(int argc, char ** argv) {
198
194
const uint64_t n_toks = inp.size ();
199
195
200
196
// encode if at capacity
201
- if (batch.get_n_tokens () + n_toks > n_batch) {
202
- float * out = emb + e * n_embd;
203
- batch_decode (ctx, batch, out, s, n_embd, params.embd_normalize );
204
- e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.get_n_tokens () : s;
197
+ if (llama_batch_ext_get_n_tokens (batch) + n_toks > n_batch) {
198
+ batch_decode (ctx, batch, emb + e * n_embd, s, n_embd, params.embd_normalize );
199
+ llama_batch_ext_clear (batch);
200
+
201
+ e += pooling_type == LLAMA_POOLING_TYPE_NONE ? llama_batch_ext_get_n_tokens (batch) : s;
205
202
s = 0 ;
206
- batch.clear ();
207
203
}
208
204
209
205
// add to batch
@@ -212,8 +208,7 @@ int main(int argc, char ** argv) {
212
208
}
213
209
214
210
// final batch
215
- float * out = emb + e * n_embd;
216
- batch_decode (ctx, batch, out, s, n_embd, params.embd_normalize );
211
+ batch_decode (ctx, batch, emb + e * n_embd, s, n_embd, params.embd_normalize );
217
212
218
213
if (params.embd_out .empty ()) {
219
214
LOG (" \n " );
@@ -318,6 +313,8 @@ int main(int argc, char ** argv) {
318
313
LOG (" \n " );
319
314
llama_perf_context_print (ctx);
320
315
316
+ llama_batch_ext_free (batch);
317
+
321
318
// clean up
322
319
llama_backend_free ();
323
320
0 commit comments