From a3eb12f53f297af040c191e354942f6728477277 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 9 May 2025 11:17:41 +0200 Subject: [PATCH] server : PoC implementation of "interim" server --- common/arg.cpp | 44 +++--- common/arg.h | 3 + common/common.h | 2 - tools/server/server.cpp | 320 ++++++++++++++++++++++++++++------------ 4 files changed, 248 insertions(+), 121 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 9f87e9910b540..8660b219fae47 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -834,6 +834,26 @@ static std::string get_all_kv_cache_types() { // CLI argument parsing functions // +// handle model and download +void common_params_handle_models(enum llama_example cur_ex, common_params & params) { + auto res = common_params_handle_model(params.model, params.hf_token, ""); + if (params.no_mmproj) { + params.mmproj = {}; + } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { + // optionally, handle mmproj model when -hf is specified + params.mmproj = res.mmproj; + } + // only download mmproj if the current example is using it + for (auto & ex : mmproj_examples) { + if (cur_ex == ex) { + common_params_handle_model(params.mmproj, params.hf_token, ""); + break; + } + } + common_params_handle_model(params.speculative.model, params.hf_token, ""); + common_params_handle_model(params.vocoder.model, params.hf_token, ""); +} + static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) { std::string arg; const std::string arg_prefix = "--"; @@ -933,24 +953,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context } // handle model and download - { - auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH); - if (params.no_mmproj) { - params.mmproj = {}; - } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { - // optionally, handle mmproj model when -hf is specified - params.mmproj = res.mmproj; - } - // only download mmproj if the current example is using it - for (auto & ex : mmproj_examples) { - if (ctx_arg.ex == ex) { - common_params_handle_model(params.mmproj, params.hf_token, ""); - break; - } - } - common_params_handle_model(params.speculative.model, params.hf_token, ""); - common_params_handle_model(params.vocoder.model, params.hf_token, ""); - } + common_params_handle_models(ctx_arg.ex, params); if (params.escape) { string_process_escapes(params.prompt); @@ -2486,10 +2489,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-m", "--model"}, "FNAME", ex == LLAMA_EXAMPLE_EXPORT_LORA ? std::string("model path from which to load base model") - : string_format( - "model path (default: `models/$filename` with filename from `--hf-file` " - "or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH - ), + : "model path (default: `models/$filename` with filename from `--hf-file` or `--model-url` if set)", [](common_params & params, const std::string & value) { params.model.path = value; } diff --git a/common/arg.h b/common/arg.h index 70bea100fd4f2..11eb80f6978b7 100644 --- a/common/arg.h +++ b/common/arg.h @@ -80,6 +80,9 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); bool common_has_curl(); +// handle model and download +void common_params_handle_models(enum llama_example cur_ex, common_params & params); + struct common_remote_params { std::vector headers; long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout diff --git a/common/common.h b/common/common.h index 90702245463cb..394b97a64208d 100644 --- a/common/common.h +++ b/common/common.h @@ -23,8 +23,6 @@ fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ } while(0) -#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" - struct common_adapter_lora_info { std::string path; float scale; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 06788bbdc8545..52b87532e52c2 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3450,18 +3450,137 @@ inline void signal_handler(int signal) { shutdown_handler(signal); } -int main(int argc, char ** argv) { - // own arguments required by this example - common_params params; - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { - return 1; +// +// HTTP server +// + +static void res_error(httplib::Response & res, const json & error_data) { + json final_response {{"error", error_data}}; + res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); + res.status = json_value(error_data, "code", 500); +} + +static void res_ok(httplib::Response & res, const json & data) { + res.set_content(safe_json_to_str(data), MIMETYPE_JSON); + res.status = 200; +} + +struct server_http { + std::unique_ptr svr; + std::unique_ptr http_thread; + + server_http(const common_params & params) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); + svr.reset( + new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) + ); + } else { + LOG_INF("Running without SSL\n"); + svr.reset(new httplib::Server()); + } +#else + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_ERR("Server is built without SSL support\n"); + exit(1); + } + svr.reset(new httplib::Server()); +#endif + + svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { + std::string message; + try { + std::rethrow_exception(ep); + } catch (const std::exception & e) { + message = e.what(); + } catch (...) { + message = "Unknown Exception"; + } + + try { + json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); + LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); + res_error(res, formatted_error); + } catch (const std::exception & e) { + LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); + } + }); } - common_init(); + bool start(const common_params & params) { + if (http_thread) { + GGML_ABORT("HTTP server is already running"); + } + int n_threads_http = params.n_threads_http; + if (n_threads_http < 1) { + // +2 threads for monitoring endpoints + n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); + } + LOG_INF("%s: n_threads_http = %d\n", __func__, n_threads_http); + svr->new_task_queue = [&n_threads_http] { return new httplib::ThreadPool(n_threads_http); }; + + bool was_bound = false; + if (string_ends_with(std::string(params.hostname), ".sock")) { + LOG_INF("%s: setting address family to AF_UNIX\n", __func__); + svr->set_address_family(AF_UNIX); + // bind_to_port requires a second arg, any value other than 0 should + // simply get ignored + was_bound = svr->bind_to_port(params.hostname, 8080); + } else { + LOG_INF("%s: binding port with default address family\n", __func__); + // bind HTTP listen port + if (params.port <= 0) { + LOG_ERR("%s: port number %d is invalid\n", __func__, params.port); + return false; + } else { + was_bound = svr->bind_to_port(params.hostname, params.port); + } + } + + if (!was_bound) { + LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port); + return false; + } + + // run the HTTP server in a thread + http_thread.reset(new std::thread([&]() { svr->listen_after_bind(); })); + svr->wait_until_ready(); + + LOG_INF("%s: HTTP server started, hostname: %s, port: %d, n_threads_http = %d\n", __func__, params.hostname.c_str(), params.port, n_threads_http); + + return true; + } + + void stop() { + LOG_INF("%s: stopping HTTP server\n", __func__); + if (svr) { + svr->stop(); + if (http_thread) { + http_thread->join(); + http_thread.reset(); + } + svr.reset(); + } + } + + httplib::Server * operator ->() { + return svr.get(); + } + + operator bool() { + return svr != nullptr; + } +}; + +// main server +static int run_main_server(common_params & params) { // struct that contains llama context and inference server_context ctx_server; + server_http svr(params); + std::atomic state{SERVER_STATE_LOADING_MODEL}; llama_backend_init(); llama_numa_init(params.numa); @@ -3471,42 +3590,7 @@ int main(int argc, char ** argv) { LOG_INF("%s\n", common_params_get_system_info(params).c_str()); LOG_INF("\n"); - std::unique_ptr svr; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); - svr.reset( - new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) - ); - } else { - LOG_INF("Running without SSL\n"); - svr.reset(new httplib::Server()); - } -#else - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_ERR("Server is built without SSL support\n"); - return 1; - } - svr.reset(new httplib::Server()); -#endif - - std::atomic state{SERVER_STATE_LOADING_MODEL}; - - svr->set_default_headers({{"Server", "llama.cpp"}}); - svr->set_logger(log_server_request); - - auto res_error = [](httplib::Response & res, const json & error_data) { - json final_response {{"error", error_data}}; - res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); - res.status = json_value(error_data, "code", 500); - }; - - auto res_ok = [](httplib::Response & res, const json & data) { - res.set_content(safe_json_to_str(data), MIMETYPE_JSON); - res.status = 200; - }; - - svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { + svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { std::string message; try { std::rethrow_exception(ep); @@ -3525,7 +3609,7 @@ int main(int argc, char ** argv) { } }); - svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) { + svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { if (res.status == 404) { res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); } @@ -3555,7 +3639,7 @@ int main(int argc, char ** argv) { // Middlewares // - auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { + auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) { static const std::unordered_set public_endpoints = { "/health", "/models", @@ -3591,7 +3675,7 @@ int main(int argc, char ** argv) { return false; }; - auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { + auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) { server_state current_state = state.load(); if (current_state == SERVER_STATE_LOADING_MODEL) { auto tmp = string_split(req.path, '.'); @@ -3780,7 +3864,7 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { @@ -3812,7 +3896,7 @@ int main(int argc, char ** argv) { res_ok(res, result->to_json()); }; - const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { @@ -3845,7 +3929,7 @@ int main(int argc, char ** argv) { res_ok(res, result->to_json()); }; - const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { + const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { int task_id = ctx_server.queue_tasks.get_new_id(); { server_task task(SERVER_TASK_TYPE_SLOT_ERASE); @@ -3868,7 +3952,7 @@ int main(int argc, char ** argv) { res_ok(res, result->to_json()); }; - const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + const auto handle_slots_action = [¶ms, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { if (params.slot_save_path.empty()) { res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); return; @@ -3897,7 +3981,7 @@ int main(int argc, char ** argv) { } }; - const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) { // this endpoint is publicly available, please only return what is safe to be exposed json data = { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, @@ -3917,7 +4001,7 @@ int main(int argc, char ** argv) { res_ok(res, data); }; - const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_props_change = [&ctx_server](const httplib::Request & req, httplib::Response & res) { if (!ctx_server.params_base.endpoint_props) { res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); return; @@ -3930,7 +4014,7 @@ int main(int argc, char ** argv) { res_ok(res, {{ "success", true }}); }; - const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + const auto handle_api_show = [&ctx_server](const httplib::Request &, httplib::Response & res) { json data = { { "template", common_chat_templates_source(ctx_server.chat_templates.get()), @@ -3947,7 +4031,7 @@ int main(int argc, char ** argv) { // handle completion-like requests (completion, chat, infill) // we can optionally provide a custom format for partial results and final results - const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( + const auto handle_completions_impl = [&ctx_server]( server_task_type type, json & data, const std::function & is_connection_closed, @@ -4077,7 +4161,7 @@ int main(int argc, char ** argv) { OAICOMPAT_TYPE_COMPLETION); }; - const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { // check model compatibility std::string err; if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { @@ -4154,7 +4238,7 @@ int main(int argc, char ** argv) { OAICOMPAT_TYPE_NONE); // infill is not OAI compatible }; - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, ¶ms, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { LOG_DBG("request: %s\n", req.body.c_str()); if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); @@ -4173,13 +4257,13 @@ int main(int argc, char ** argv) { }; // same with handle_chat_completions, but without inference part - const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_apply_template = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res) { auto body = json::parse(req.body); json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); }; - const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { json models = { {"object", "list"}, {"data", { @@ -4196,7 +4280,7 @@ int main(int argc, char ** argv) { res_ok(res, models); }; - const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); json tokens_response = json::array(); @@ -4236,7 +4320,7 @@ int main(int argc, char ** argv) { res_ok(res, data); }; - const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); std::string content; @@ -4249,7 +4333,7 @@ int main(int argc, char ** argv) { res_ok(res, data); }; - const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { + const auto handle_embeddings_impl = [&ctx_server](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { const json body = json::parse(req.body); if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { @@ -4345,7 +4429,7 @@ int main(int argc, char ** argv) { handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); }; - const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) { if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED)); return; @@ -4543,12 +4627,6 @@ int main(int argc, char ** argv) { // // Start the server // - if (params.n_threads_http < 1) { - // +2 threads for monitoring endpoints - params.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); - } - log_data["n_threads_http"] = std::to_string(params.n_threads_http); - svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; // clean up function, to be called before exit auto clean_up = [&svr, &ctx_server]() { @@ -4558,44 +4636,18 @@ int main(int argc, char ** argv) { llama_backend_free(); }; - bool was_bound = false; - if (string_ends_with(std::string(params.hostname), ".sock")) { - LOG_INF("%s: setting address family to AF_UNIX\n", __func__); - svr->set_address_family(AF_UNIX); - // bind_to_port requires a second arg, any value other than 0 should - // simply get ignored - was_bound = svr->bind_to_port(params.hostname, 8080); - } else { - LOG_INF("%s: binding port with default address family\n", __func__); - // bind HTTP listen port - if (params.port == 0) { - int bound_port = svr->bind_to_any_port(params.hostname); - if ((was_bound = (bound_port >= 0))) { - params.port = bound_port; - } - } else { - was_bound = svr->bind_to_port(params.hostname, params.port); - } - } - - if (!was_bound) { - LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port); + if (!svr.start(params)) { clean_up(); + svr.stop(); return 1; } - // run the HTTP server in a thread - std::thread t([&]() { svr->listen_after_bind(); }); - svr->wait_until_ready(); - - LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, params.hostname.c_str(), params.port, params.n_threads_http); - // load the model LOG_INF("%s: loading model\n", __func__); if (!ctx_server.load_model(params)) { clean_up(); - t.join(); + svr.stop(); LOG_ERR("%s: exiting due to model loading error\n", __func__); return 1; } @@ -4643,7 +4695,81 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.start_loop(); clean_up(); - t.join(); + svr.stop(); return 0; } + +// interim server allowing to list and select models +static int run_interim_server(const common_params & params) { + server_http svr(params); + + std::mutex mutex; + std::condition_variable cv; + bool is_params_set = false; + common_params load_params; + + svr->Post("/load", [&](const httplib::Request & req, httplib::Response & res) { + std::unique_lock lock(mutex); + json data = json::parse(req.body); + std::string hf_repo = data["hf_repo"]; + + if (is_params_set) { + res_error(res, format_error_response("Model is being downloaded", ERROR_TYPE_UNAVAILABLE)); + return; + } + + if (hf_repo.empty()) { + res_error(res, format_error_response("Missing hf_repo parameter", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + // Load the model + load_params = params; // copy + load_params.model.hf_repo = hf_repo; + is_params_set = true; + cv.notify_all(); + }); + + if (!svr.start(params)) { + return 1; + } + LOG_INF("Interim server started\n"); + + // Wait for the model to be set + { + std::unique_lock lock(mutex); + cv.wait(lock, [&is_params_set] { return is_params_set; }); + // make sure model is downloaded + common_params_handle_models(LLAMA_EXAMPLE_SERVER, load_params); + } + + // Launch the main server with the new parameters + svr.stop(); + return run_main_server(load_params); +} + +int main(int argc, char ** argv) { + // own arguments required by this example + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { + return 1; + } + + common_init(); + + if (params.model.path.empty()) { + // no model path provided, start the interim server + LOG_INF("No model path provided, starting interim server\n"); + while (true) { + int res = run_interim_server(params); + if (res != 0) { + return res; + } + } + } else { + // model path provided, start the main server + return run_main_server(params); + } +}