Skip to content

Commit fa55281

Browse files
committed
separate vision ctx and llm ctx
1 parent ff77b15 commit fa55281

File tree

7 files changed

+139
-35
lines changed

7 files changed

+139
-35
lines changed

examples/vision/vision.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ int main(int argc, char ** argv) {
120120
return 1;
121121
}
122122

123+
llama_vision_context_params vparams = llama_vision_context_default_params();
124+
vparams.n_threads = llama_n_threads(ctx);
125+
llama_vision_context * vctx = llama_vision_init_from_model(model, vparams);
126+
if (!vctx) {
127+
LOG_ERR("model does not have vision encoder\n");
128+
return 1;
129+
}
130+
123131
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
124132

125133
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
@@ -136,12 +144,12 @@ int main(int argc, char ** argv) {
136144
}
137145
llama_vision_bitmap * img = load_image_from_file(img_path);
138146
LOG_INF("loaded image %s, size = %d x %d\n", img_path, img->nx, img->ny);
139-
img_tokens = llama_vision_tokenize(ctx, img);
147+
img_tokens = llama_vision_tokenize(vctx, img);
140148
if (!img_tokens) {
141149
LOG_ERR("failed to create image tokens\n");
142150
return 1;
143151
}
144-
if (llama_vision_encode(ctx, img_tokens)) {
152+
if (llama_vision_encode(vctx, img_tokens)) {
145153
LOG_ERR("failed to encode image\n");
146154
return 1;
147155
}
@@ -163,7 +171,7 @@ int main(int argc, char ** argv) {
163171
return 1;
164172
}
165173
} else {
166-
auto * img_embd = llama_vision_get_output_tensor(ctx);
174+
auto * img_embd = llama_vision_get_output_tensor(vctx);
167175
// std::vector<float> output_debug(ggml_nelements(img_embd));
168176
// ggml_backend_tensor_get(img_embd, output_debug.data(), 0, ggml_nbytes(img_embd));
169177
// for (int row = 0; row < 10; row++) {

include/llama.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ extern "C" {
229229
bool sorted;
230230
} llama_token_data_array;
231231

232+
struct llama_vision_context;
233+
232234
// Structure represents the basic input unit of vision model
233235
// This can be a processed image or slices of images under the hood
234236
struct llama_vision_tokens;
@@ -365,6 +367,10 @@ extern "C" {
365367
void * abort_callback_data;
366368
};
367369

370+
struct llama_vision_context_params {
371+
int32_t n_threads;
372+
};
373+
368374
// model quantization parameters
369375
typedef struct llama_model_quantize_params {
370376
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
@@ -402,6 +408,7 @@ extern "C" {
402408
// TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
403409
LLAMA_API struct llama_model_params llama_model_default_params(void);
404410
LLAMA_API struct llama_context_params llama_context_default_params(void);
411+
LLAMA_API struct llama_vision_context_params llama_vision_context_default_params(void);
405412
LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
406413
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
407414

@@ -1297,20 +1304,30 @@ extern "C" {
12971304
// Vision API
12981305
//
12991306

1307+
// Vision context
1308+
LLAMA_API struct llama_vision_context * llama_vision_init_from_model(
1309+
const struct llama_model * model,
1310+
struct llama_vision_context_params params);
1311+
LLAMA_API void llama_vision_free(struct llama_vision_context * ctx);
1312+
13001313
// Container for RGB bitmap
13011314
LLAMA_API struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny);
13021315
LLAMA_API void llama_vision_bitmap_free(struct llama_vision_bitmap * bmp);
13031316

13041317
// Create image tokens from the RGB bitmap
1305-
LLAMA_API struct llama_vision_tokens * llama_vision_tokenize(struct llama_context * ctx, llama_vision_bitmap * bmp);
1318+
LLAMA_API struct llama_vision_tokens * llama_vision_tokenize(
1319+
struct llama_vision_context * ctx,
1320+
struct llama_vision_bitmap * bmp);
13061321
LLAMA_API void llama_vision_tokens_free(struct llama_vision_tokens * img_tokens);
13071322

13081323
// User must reserve N number of tokens in tokenized text prompt for each image
13091324
// LLAMA_API int32_t llama_vision_get_n_tokens(const llama_vision_img_tokens * img_tokens);
13101325

13111326
// Encode patches into embeddings
1312-
LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, struct llama_vision_tokens * img_tokens);
1313-
LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(struct llama_context * ctx);
1327+
LLAMA_API int32_t llama_vision_encode(
1328+
struct llama_vision_context * ctx,
1329+
struct llama_vision_tokens * img_tokens);
1330+
LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(struct llama_vision_context * ctx);
13141331

13151332
//
13161333
// Model split

src/llama-arch.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,8 +1576,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
15761576
{LLM_TENSOR_V_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
15771577
{LLM_TENSOR_V_ENC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
15781578
{LLM_TENSOR_V_ENC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1579-
{LLM_TENSOR_V_PRE_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1580-
{LLM_TENSOR_V_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1579+
{LLM_TENSOR_V_PRE_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}},
1580+
{LLM_TENSOR_V_POST_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
15811581
{LLM_TENSOR_V_RESMPL_POS_EMBD_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_ADD}},
15821582
{LLM_TENSOR_V_RESMPL_ATTN_Q, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
15831583
{LLM_TENSOR_V_RESMPL_ATTN_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},

src/llama-context.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,6 @@ struct llama_context {
108108
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
109109
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
110110
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
111-
112-
// vision
113-
llama_vision_context vctx;
114111
};
115112

116113
// TODO: make these methods of llama_context

src/llama-vision.cpp

Lines changed: 98 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,7 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
982982
}
983983

984984
// alloc memory for graph
985-
bool ok = ggml_backend_sched_alloc_graph(ctx.sched, gf);
985+
bool ok = ggml_backend_sched_alloc_graph(ctx.sched.get(), gf);
986986
if (!ok) {
987987
LLAMA_LOG_ERROR("failed to alloc memory for graph\n");
988988
return -1;
@@ -1064,7 +1064,7 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
10641064
// compute
10651065
LLAMA_LOG_DEBUG("%s: compute start\n", __func__);
10661066
int64_t t_start = ggml_time_ms();
1067-
ggml_backend_sched_graph_compute(ctx.sched, gf);
1067+
ggml_backend_sched_graph_compute(ctx.sched.get(), gf);
10681068

10691069
// the last node is the embedding tensor
10701070
struct ggml_tensor * output_node = ggml_graph_node(gf, -1);
@@ -1091,6 +1091,92 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
10911091
////////////////////////////////////////////////////////////////////////////////////////
10921092
// public API
10931093

1094+
struct llama_vision_context_params llama_vision_context_default_params() {
1095+
return {
1096+
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
1097+
};
1098+
}
1099+
1100+
struct llama_vision_context * llama_vision_init_from_model(const struct llama_model * model, struct llama_vision_context_params params) {
1101+
if (!model->has_vision) {
1102+
return nullptr;
1103+
}
1104+
1105+
llama_vision_context * ctx = new llama_vision_context;
1106+
ctx->model = &model->vit;
1107+
1108+
// TODO: this looks ugly, mostly copied from llama.cpp, refactor it in the future
1109+
1110+
// init backends
1111+
{
1112+
// add CPU backend
1113+
ctx->backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
1114+
if (ctx->backend_cpu == nullptr) {
1115+
LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
1116+
llama_vision_free(ctx);
1117+
return nullptr;
1118+
}
1119+
ctx->backends.emplace_back(ctx->backend_cpu);
1120+
1121+
// create a list of the set_n_threads functions in the backends
1122+
for (auto & backend : ctx->backends) {
1123+
ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
1124+
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
1125+
if (reg) {
1126+
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
1127+
ggml_backend_set_n_threads_fn(backend.get(), params.n_threads);
1128+
}
1129+
}
1130+
}
1131+
1132+
// scheduler and compute buffers
1133+
{
1134+
// buffer types used for the compute buffer of each backend
1135+
std::vector<ggml_backend_buffer_type_t> backend_buft;
1136+
std::vector<ggml_backend_t> backend_ptrs;
1137+
for (auto & backend : ctx->backends) {
1138+
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
1139+
auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
1140+
if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model->devices.empty()) {
1141+
// use the host buffer of the first device CPU for faster transfer of the intermediate state
1142+
auto * dev = model->devices[0];
1143+
auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
1144+
if (host_buft) {
1145+
buft = host_buft;
1146+
}
1147+
}
1148+
backend_buft.push_back(buft);
1149+
backend_ptrs.push_back(backend.get());
1150+
}
1151+
1152+
const size_t max_nodes = model->max_nodes();
1153+
1154+
// buffer used to store the computation graph and the tensor meta data
1155+
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
1156+
1157+
// TODO: support pipeline_parallel
1158+
const bool pipeline_parallel = false;
1159+
1160+
ctx->sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
1161+
1162+
if (pipeline_parallel) {
1163+
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched.get()));
1164+
}
1165+
}
1166+
1167+
const size_t max_nodes = VISION_GRAPH_MAX_NODE; // TODO: make it dynamic
1168+
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
1169+
1170+
return ctx;
1171+
}
1172+
1173+
void llama_vision_free(struct llama_vision_context * ctx) {
1174+
if (ctx->ctx_ggml) {
1175+
ggml_free(ctx->ctx_ggml);
1176+
}
1177+
delete ctx;
1178+
}
1179+
10941180
struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny) {
10951181
llama_vision_bitmap * bmp = new llama_vision_bitmap;
10961182
bmp->nx = nx;
@@ -1105,16 +1191,15 @@ void llama_vision_bitmap_free(llama_vision_bitmap * bmp) {
11051191
}
11061192

11071193
struct llama_vision_tokens * llama_vision_tokenize(
1108-
struct llama_context * ctx,
1109-
llama_vision_bitmap * bmp) {
1110-
llama_vision_context & vctx = ctx->vctx;
1111-
switch (vctx.model->hparams.arch) {
1194+
struct llama_vision_context * ctx,
1195+
struct llama_vision_bitmap * bmp) {
1196+
switch (ctx->model->hparams.arch) {
11121197
case LLM_ARCH_VISION_LLAVA:
11131198
case LLM_ARCH_VISION_MOBILEVLM:
11141199
case LLM_ARCH_VISION_IDEFICS3:
1115-
return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp));
1200+
return new llama_vision_tokens(llama_vision_processor_llava(*ctx).tokenize(*bmp));
11161201
case LLM_ARCH_VISION_MINICPMV:
1117-
return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp));
1202+
return new llama_vision_tokens(llama_vision_processor_llava(*ctx).tokenize(*bmp));
11181203
default:
11191204
GGML_ASSERT(false && "unsupported arch");
11201205
}
@@ -1124,19 +1209,18 @@ void llama_vision_tokens_free(llama_vision_tokens * p) {
11241209
delete p;
11251210
}
11261211

1127-
int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_tokens * p) {
1212+
int32_t llama_vision_encode(struct llama_vision_context * ctx, struct llama_vision_tokens * p) {
11281213
if (p->buf.empty()) {
11291214
LLAMA_LOG_ERROR("%s: nothing to encode\n", __func__);
11301215
return -1;
11311216
}
11321217

1133-
llama_vision_context & vctx = ctx->vctx;
1134-
auto & hparams = vctx.model->hparams;
1218+
auto & hparams = ctx->model->hparams;
11351219
switch (hparams.mm_patch_merge_type) {
11361220
case MM_PATCH_MERGE_FLAT:
11371221
{
11381222
// flat / default llava-1.5 type embedding
1139-
int32_t encoded = llama_vision_encode_impl(vctx, *p);
1223+
int32_t encoded = llama_vision_encode_impl(*ctx, *p);
11401224
if (encoded != 0) {
11411225
LLAMA_LOG_ERROR("Unable to encode image\n");
11421226
return encoded;
@@ -1154,8 +1238,8 @@ int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_tokens * p)
11541238
return 0;
11551239
}
11561240

1157-
struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx) {
1158-
return ctx->vctx.output;
1241+
struct ggml_tensor * llama_vision_get_output_tensor(struct llama_vision_context * ctx) {
1242+
return ctx->output;
11591243
}
11601244

11611245
////////////////////////////////////////////////////////////////////////////////////////

src/llama-vision.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "ggml.h"
4+
#include "ggml-cpp.h"
45
#include "llama.h"
56
#include "llama-arch.h"
67

@@ -142,12 +143,14 @@ struct llama_vision_model {
142143
struct llama_vision_context {
143144
// memory buffers used to evaluate the model
144145
std::vector<uint8_t> buf_compute_meta;
145-
ggml_backend_sched_t sched = nullptr;
146-
struct ggml_context * ctx_ggml = nullptr;
146+
ggml_backend_sched_ptr sched;
147+
std::vector<ggml_backend_ptr> backends;
148+
ggml_backend_t backend_cpu;
147149

148150
const llama_vision_model * model;
149151

150152
// temporary output data, to be picked up by llama_decode()
153+
struct ggml_context * ctx_ggml = nullptr;
151154
struct ggml_tensor * output;
152155
};
153156

src/llama.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8460,7 +8460,9 @@ static int llama_prepare_sbatch(
84608460
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
84618461
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
84628462

8463-
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
8463+
GGML_ASSERT((batch.token && !batch.embd && !batch.embd_tensor)
8464+
|| (!batch.token && batch.embd && !batch.embd_tensor)
8465+
|| (!batch.token && !batch.embd && batch.embd_tensor)); // NOLINT
84648466
if (batch.token) {
84658467
for (uint32_t i = 0; i < n_tokens_all; ++i) {
84668468
if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
@@ -9893,13 +9895,6 @@ struct llama_context * llama_init_from_model(
98939895
}
98949896
}
98959897

9896-
if (model->has_vision) {
9897-
ctx->vctx.model = &model->vit;
9898-
ctx->vctx.sched = ctx->sched.get();
9899-
const size_t max_nodes = VISION_GRAPH_MAX_NODE; // TODO: make it dynamic
9900-
ctx->vctx.buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
9901-
}
9902-
99039898
return ctx;
99049899
}
99059900

0 commit comments

Comments
 (0)