Skip to content

Commit 9ecf3e6

Browse files
authored
server : support audio input (#13714)
* server : support audio input * add audio support on webui
1 parent faaaff5 commit 9ecf3e6

File tree

12 files changed

+277
-174
lines changed

12 files changed

+277
-174
lines changed

tools/mtmd/mtmd-helper.cpp

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,7 @@ size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
1212
size_t n_tokens = 0;
1313
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
1414
auto chunk = mtmd_input_chunks_get(chunks, i);
15-
auto chunk_type = mtmd_input_chunk_get_type(chunk);
16-
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
17-
size_t n_tokens_text;
18-
mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
19-
n_tokens += n_tokens_text;
20-
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
21-
auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
22-
n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image);
23-
} else {
24-
GGML_ASSERT(false && "chunk type not supported");
25-
}
15+
n_tokens += mtmd_input_chunk_get_n_tokens(chunk);
2616
}
2717
return n_tokens;
2818
}
@@ -31,17 +21,7 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
3121
llama_pos n_pos = 0;
3222
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
3323
auto chunk = mtmd_input_chunks_get(chunks, i);
34-
auto chunk_type = mtmd_input_chunk_get_type(chunk);
35-
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
36-
size_t n_tokens_text;
37-
mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
38-
n_pos += n_tokens_text;
39-
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
40-
auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
41-
n_pos += mtmd_image_tokens_get_n_pos(tokens_image);
42-
} else {
43-
GGML_ASSERT(false && "chunk type not supported");
44-
}
24+
n_pos += mtmd_input_chunk_get_n_pos(chunk);
4525
}
4626
return n_pos;
4727
}

tools/mtmd/mtmd.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,10 @@ const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) {
751751
return bitmap->data.data();
752752
}
753753

754+
size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap) {
755+
return bitmap->data.size();
756+
}
757+
754758
bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap) {
755759
return bitmap->is_audio;
756760
}

tools/mtmd/mtmd.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,12 @@ MTMD_API bool mtmd_support_audio(mtmd_context * ctx);
119119
// the data is in float format (PCM F32)
120120
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, uint32_t ny, const unsigned char * data);
121121
MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples, const float * data);
122-
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
123-
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
124-
MTMD_API const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap);
125-
MTMD_API bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap);
126-
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
122+
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
123+
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
124+
MTMD_API const unsigned char * mtmd_bitmap_get_data (const mtmd_bitmap * bitmap);
125+
MTMD_API size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap);
126+
MTMD_API bool mtmd_bitmap_is_audio (const mtmd_bitmap * bitmap);
127+
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
127128
// bitmap ID is optional, but useful for KV cache tracking
128129
// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data()
129130
MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap);
@@ -322,6 +323,7 @@ struct bitmap {
322323
uint32_t nx() { return mtmd_bitmap_get_nx(ptr.get()); }
323324
uint32_t ny() { return mtmd_bitmap_get_ny(ptr.get()); }
324325
const unsigned char * data() { return mtmd_bitmap_get_data(ptr.get()); }
326+
size_t n_bytes() { return mtmd_bitmap_get_n_bytes(ptr.get()); }
325327
std::string id() { return mtmd_bitmap_get_id(ptr.get()); }
326328
void set_id(const char * id) { mtmd_bitmap_set_id(ptr.get(), id); }
327329
};

tools/server/public/index.html.gz

528 Bytes
Binary file not shown.

tools/server/server.cpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,6 +1891,7 @@ struct server_context {
18911891
float slot_prompt_similarity = 0.0f;
18921892

18931893
common_chat_templates_ptr chat_templates;
1894+
oaicompat_parser_options oai_parser_opt;
18941895

18951896
~server_context() {
18961897
mtmd_free(mctx);
@@ -2086,6 +2087,15 @@ struct server_context {
20862087
}
20872088

20882089
metrics.init();
2090+
2091+
oai_parser_opt = {
2092+
/* use_jinja */ params_base.use_jinja,
2093+
/* prefill_assistant */ params_base.prefill_assistant,
2094+
/* reasoning_format */ params_base.reasoning_format,
2095+
/* common_chat_templates */ chat_templates.get(),
2096+
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
2097+
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
2098+
};
20892099
}
20902100

20912101
server_slot * get_slot_by_id(int id) {
@@ -4092,7 +4102,10 @@ int main(int argc, char ** argv) {
40924102
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
40934103
{ "total_slots", ctx_server.params_base.n_parallel },
40944104
{ "model_path", ctx_server.params_base.model.path },
4095-
{ "modalities", json{{"vision", ctx_server.mctx != nullptr}} }, // TODO: add more in the future
4105+
{ "modalities", json{
4106+
{"vision", ctx_server.oai_parser_opt.allow_image},
4107+
{"audio", ctx_server.oai_parser_opt.allow_audio},
4108+
} },
40964109
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
40974110
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
40984111
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
@@ -4183,10 +4196,10 @@ int main(int argc, char ** argv) {
41834196
for (auto & file : files) {
41844197
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size()));
41854198
if (!bmp.ptr) {
4186-
throw std::runtime_error("Failed to load image");
4199+
throw std::runtime_error("Failed to load image or audio file");
41874200
}
41884201
// calculate bitmap hash (for KV caching)
4189-
std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
4202+
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
41904203
bmp.set_id(hash.c_str());
41914204
bitmaps.entries.push_back(std::move(bmp));
41924205
}
@@ -4418,7 +4431,7 @@ int main(int argc, char ** argv) {
44184431
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
44194432
};
44204433

4421-
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
4434+
const auto handle_chat_completions = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
44224435
LOG_DBG("request: %s\n", req.body.c_str());
44234436
if (ctx_server.params_base.embedding) {
44244437
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
@@ -4427,13 +4440,9 @@ int main(int argc, char ** argv) {
44274440

44284441
auto body = json::parse(req.body);
44294442
std::vector<raw_buffer> files;
4430-
json data = oaicompat_completion_params_parse(
4443+
json data = oaicompat_chat_params_parse(
44314444
body,
4432-
params.use_jinja,
4433-
params.prefill_assistant,
4434-
params.reasoning_format,
4435-
ctx_server.chat_templates.get(),
4436-
ctx_server.mctx,
4445+
ctx_server.oai_parser_opt,
44374446
files);
44384447

44394448
handle_completions_impl(
@@ -4446,16 +4455,12 @@ int main(int argc, char ** argv) {
44464455
};
44474456

44484457
// same with handle_chat_completions, but without inference part
4449-
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
4458+
const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
44504459
auto body = json::parse(req.body);
44514460
std::vector<raw_buffer> files; // dummy, unused
4452-
json data = oaicompat_completion_params_parse(
4461+
json data = oaicompat_chat_params_parse(
44534462
body,
4454-
params.use_jinja,
4455-
params.prefill_assistant,
4456-
params.reasoning_format,
4457-
ctx_server.chat_templates.get(),
4458-
ctx_server.mctx,
4463+
ctx_server.oai_parser_opt,
44594464
files);
44604465
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
44614466
};

tools/server/tests/unit/test_vision_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def create_server():
3030
("What is this:\n", "malformed", False, None),
3131
("What is this:\n", "https://google.com/404", False, None), # non-existent image
3232
("What is this:\n", "https://ggml.ai", False, None), # non-image data
33+
# TODO @ngxson : test with multiple images, no images and with audio
3334
]
3435
)
3536
def test_vision_chat_completion(prompt, image_url, success, re_content):

0 commit comments

Comments
 (0)