Skip to content

Commit 58c4767

Browse files
committed
shared_ptr for mtmd_image_tokens
1 parent 7ac0b7b commit 58c4767

File tree

3 files changed

+44
-55
lines changed

3 files changed

+44
-55
lines changed

examples/llava/gemma3-cli.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,18 +185,19 @@ static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector
185185
text.text = formatted_chat.prompt;
186186
text.add_special = add_bos;
187187
text.parse_special = true;
188-
mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps));
189-
if (chunks == nullptr) {
190-
LOG_ERR("Unable to tokenize prompt\n");
188+
mtmd_input_chunks chunks;
189+
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
190+
if (res != 0) {
191+
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
191192
return 1;
192193
}
193194

194-
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) {
195+
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
195196
LOG_ERR("Unable to eval prompt\n");
196197
return 1;
197198
}
198199

199-
ctx.n_past += mtmd_helper_get_n_tokens(chunks.get());
200+
ctx.n_past += mtmd_helper_get_n_tokens(chunks);
200201

201202
return 0;
202203
}

examples/llava/mtmd.cpp

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ static uint64_t hash_vector_float(const std::vector<float> & vec) {
106106
return seed;
107107
}
108108

109-
mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
110-
const mtmd_input_text & text,
111-
const std::vector<mtmd_bitmap> & bitmaps) {
112-
mtmd_input_chunks * output = new mtmd_input_chunks;
109+
int32_t mtmd_tokenize(mtmd_context * ctx,
110+
std::vector<mtmd_input_chunk> & output,
111+
const mtmd_input_text & text,
112+
const std::vector<mtmd_bitmap> & bitmaps) {
113113
auto vocab = llama_model_get_vocab(ctx->text_model);
114114

115115
std::string prompt_modified(text.text);
@@ -124,8 +124,8 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
124124
}
125125

126126
std::vector<std::string> parts = string_split_str(text.text, ctx->image_marker);
127-
output->clear();
128-
output->reserve(parts.size());
127+
output.clear();
128+
output.reserve(parts.size());
129129

130130
size_t i_img = 0;
131131

@@ -141,14 +141,14 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
141141
std::move(tokens),
142142
{},
143143
};
144-
output->emplace_back(std::move(chunk));
144+
output.emplace_back(std::move(chunk));
145145

146146
if (&parts.back() != &part) {
147147
// add image token to middle of 2 parts
148148

149149
if (i_img >= bitmaps.size()) {
150150
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
151-
return nullptr;
151+
return 1;
152152
}
153153

154154
// shim layer
@@ -163,10 +163,10 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
163163
bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32);
164164
if (!ok) {
165165
LOG_ERR("Unable to preprocess image\n");
166-
return nullptr;
166+
return 2;
167167
}
168168

169-
mtmd_image_tokens * image_tokens = new mtmd_image_tokens;
169+
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
170170
image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
171171
image_tokens->ny = 1; // TODO
172172
image_tokens->batch_f32 = std::move(batch_f32);
@@ -179,14 +179,14 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
179179
mtmd_input_chunk chunk{
180180
MTMD_INPUT_CHUNK_TYPE_IMAGE,
181181
{},
182-
image_tokens,
182+
std::move(image_tokens),
183183
};
184-
output->emplace_back(std::move(chunk));
184+
output.emplace_back(std::move(chunk));
185185
i_img++;
186186
}
187187
}
188188

189-
return output;
189+
return 0;
190190
}
191191

192192
void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
@@ -195,18 +195,6 @@ void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
195195
}
196196
}
197197

198-
void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images) {
199-
if (free_images) {
200-
for (auto & chunk : *chunks) {
201-
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
202-
mtmd_image_tokens_free(chunk.tokens_image);
203-
chunk.tokens_image = nullptr;
204-
}
205-
}
206-
}
207-
delete chunks;
208-
}
209-
210198
size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) {
211199
return image_tokens->n_tokens();
212200
}
@@ -238,9 +226,9 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
238226
return ctx->image_embd_v.data();
239227
}
240228

241-
size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) {
229+
size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
242230
size_t n_tokens = 0;
243-
for (auto & chunk : *chunks) {
231+
for (auto & chunk : chunks) {
244232
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
245233
n_tokens += chunk.tokens_text.size();
246234
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
@@ -289,16 +277,16 @@ struct decode_embd_batch {
289277

290278
int32_t mtmd_helper_eval(mtmd_context * ctx,
291279
llama_context * lctx,
292-
mtmd_input_chunks * chunks,
280+
mtmd_input_chunks & chunks,
293281
llama_pos pos0,
294282
llama_seq_id seq_id,
295283
int32_t n_batch) {
296284
int32_t ret;
297285
llama_pos n_past = pos0;
298286
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
299287

300-
for (auto & chunk : *chunks) {
301-
bool is_last = &chunk == &chunks->back();
288+
for (auto & chunk : chunks) {
289+
bool is_last = &chunk == &chunks.back();
302290
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
303291
// TODO @ngxson : may need to split into smaller batches
304292
text_batch.n_tokens = chunk.tokens_text.size();
@@ -327,7 +315,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
327315
if (ctx->print_timings) {
328316
LOG_INF("encoding image...\n");
329317
}
330-
ret = mtmd_encode(ctx, chunk.tokens_image);
318+
ret = mtmd_encode(ctx, chunk.tokens_image.get());
331319
if (ret != 0) {
332320
LOG_ERR("failed to encode image\n");
333321
llama_batch_free(text_batch);
@@ -337,7 +325,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
337325
LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
338326
}
339327

340-
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image);
328+
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
341329
float * embd = mtmd_get_output_embd(ctx);
342330
decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
343331
int64_t t1 = ggml_time_ms();
@@ -395,3 +383,7 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
395383
}
396384
return false;
397385
}
386+
387+
void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
388+
mtmd_image_tokens_free(val);
389+
}

examples/llava/mtmd.h

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,15 @@ struct mtmd_bitmap {
4141
std::vector<unsigned char> data;
4242
};
4343

44+
struct mtmd_image_tokens_deleter {
45+
void operator()(mtmd_image_tokens * val); // forward declaration
46+
};
47+
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
48+
4449
struct mtmd_input_chunk {
4550
mtmd_input_chunk_type type;
4651
std::vector<llama_token> tokens_text;
47-
mtmd_image_tokens * tokens_image = nullptr;
52+
mtmd_image_tokens_ptr tokens_image;
4853
};
4954

5055
using mtmd_input_chunks = std::vector<mtmd_input_chunk>;
@@ -84,15 +89,16 @@ MTMD_API void mtmd_free(mtmd_context * ctx);
8489
// 2. (image tokens)
8590
// 3. "<end_of_image>\ndescribe it in detail."
8691
// number of bitmaps must be equal to the number of image markers in the prompt
87-
// the returned value must be freed using mtmd_input_chunks_free()
8892
// this function is thread-safe (shared ctx)
89-
MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
93+
// return values:
94+
// 0 on success
95+
// 1 on number of images not matching the number of markers
96+
// 2 on image preprocessing error
97+
MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
98+
std::vector<mtmd_input_chunk> & output,
9099
const mtmd_input_text & text,
91100
const std::vector<mtmd_bitmap> & bitmaps);
92101

93-
// if free_images = true, free the image tokens ; otherwise, you must free them using mtmd_image_free()
94-
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images);
95-
96102
// access mtmd_image_tokens
97103
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
98104
MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens);
@@ -117,7 +123,7 @@ MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
117123
//
118124

119125
// helper to count the total number of tokens from a list of chunks, useful to keep track of n_past
120-
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
126+
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks);
121127

122128
// helper function that automatically:
123129
// 1. run llama_decode() on text chunks
@@ -126,7 +132,7 @@ MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
126132
// otherwise, returns 0 on success
127133
MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx,
128134
llama_context * lctx,
129-
mtmd_input_chunks * chunks,
135+
mtmd_input_chunks & chunks,
130136
llama_pos pos0,
131137
llama_seq_id seq_id,
132138
int32_t n_batch);
@@ -148,16 +154,6 @@ struct mtmd_context_deleter {
148154
};
149155
using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
150156

151-
struct mtmd_input_chunks_deleter {
152-
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val, true); }
153-
};
154-
using mtmd_input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;
155-
156-
struct mtmd_image_tokens_deleter {
157-
void operator()(mtmd_image_tokens * val) { mtmd_image_tokens_free(val); }
158-
};
159-
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
160-
161157
#else
162158

163159
static_assert(false && "C header is not yet supported by this library");

0 commit comments

Comments
 (0)