-
Notifications
You must be signed in to change notification settings - Fork 12k
server: feature Add Admin key parameter for slots/health/metrics #5837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ using json = nlohmann::json; | |
struct server_params { | ||
std::string hostname = "127.0.0.1"; | ||
std::vector<std::string> api_keys; | ||
std::vector<std::string> admin_keys; | ||
std::string public_path = "examples/server/public"; | ||
std::string chat_template = ""; | ||
int32_t port = 8080; | ||
|
@@ -2060,6 +2061,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, | |
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); | ||
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); | ||
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); | ||
printf(" --admin-key ADMIN_KEY optional admin key to enhance server security. If set, requests to admin endpoints must include this key.\n"); | ||
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); | ||
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); | ||
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); | ||
|
@@ -2128,6 +2130,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, | |
} | ||
sparams.public_path = argv[i]; | ||
} | ||
else if (arg == "--admin-key") | ||
{ | ||
if (++i >= argc) | ||
{ | ||
invalid_param = true; | ||
break; | ||
} | ||
sparams.admin_keys.emplace_back(argv[i]); | ||
} | ||
else if (arg == "--api-key") | ||
{ | ||
if (++i >= argc) | ||
|
@@ -2772,6 +2783,38 @@ int main(int argc, char **argv) | |
res.set_header("Access-Control-Allow-Headers", "*"); | ||
}); | ||
|
||
// Middleware for API key validation | ||
auto validate_key = [&sparams](const httplib::Request &req, httplib::Response &res, std::vector<std::string> &keys) -> bool { | ||
// If API key is not set, skip validation | ||
if (keys.empty()) { | ||
return true; | ||
} | ||
|
||
// Check for API key in the header | ||
auto auth_header = req.get_header_value("Authorization"); | ||
std::string prefix = "Bearer "; | ||
if (auth_header.substr(0, prefix.size()) == prefix) { | ||
std::string received_api_key = auth_header.substr(prefix.size()); | ||
if (std::find(keys.begin(), keys.end(), received_api_key) != keys.end()) { | ||
return true; // API key is valid | ||
} | ||
} | ||
|
||
// Check for API key in the params | ||
auto auth_param = req.get_param_value("key"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Query params are in clear on all http traffic scanners. I feel this hack is a security breach. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, but authorization headers are also passed in cleartext without https. Given the nature of the api keys as they are at the moment this is no more or less secure than the current solution. Long term, specifying keys on the command line and passing them directly is not the correct solution. This is just a quick way to get some level of security on those endpoints in the server as it exists now until a proper (jwt would make sense to me) solution is in place. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am speaking about logging. Authorization header is never logged. Query params are, example: But all http monitoring tool will do the same. Note: JWT is not a security protocol, just a transparent token often used in oauth2. |
||
if (std::find(keys.begin(), keys.end(), auth_param) != keys.end()) { | ||
return true; // API key is valid | ||
} | ||
|
||
// API key is invalid or not provided | ||
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8"); | ||
res.status = 401; // Unauthorized | ||
|
||
LOG_WARNING("Unauthorized: Invalid API Key", {}); | ||
|
||
return false; | ||
}; | ||
|
||
svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) { | ||
server_state current_state = state.load(); | ||
switch(current_state) { | ||
|
@@ -2797,7 +2840,7 @@ int main(int argc, char **argv) | |
{"slots_idle", n_idle_slots}, | ||
{"slots_processing", n_processing_slots}}; | ||
res.status = 200; // HTTP OK | ||
if (sparams.slots_endpoint && req.has_param("include_slots")) { | ||
if (sparams.slots_endpoint && req.has_param("include_slots") && validate_key(req, res, sparams.admin_keys)) { | ||
health["slots"] = result.result_json["slots"]; | ||
} | ||
|
||
|
@@ -2822,7 +2865,10 @@ int main(int argc, char **argv) | |
}); | ||
|
||
if (sparams.slots_endpoint) { | ||
svr.Get("/slots", [&](const httplib::Request&, httplib::Response& res) { | ||
svr.Get("/slots", [&](const httplib::Request& req, httplib::Response& res) { | ||
if (!validate_key(req, res, sparams.admin_keys)) { | ||
return; | ||
} | ||
// request slots data using task queue | ||
task_server task; | ||
task.id = llama.queue_tasks.get_new_id(); | ||
|
@@ -2842,7 +2888,10 @@ int main(int argc, char **argv) | |
} | ||
|
||
if (sparams.metrics_endpoint) { | ||
svr.Get("/metrics", [&](const httplib::Request&, httplib::Response& res) { | ||
svr.Get("/metrics", [&](const httplib::Request& req, httplib::Response& res) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. /metrics must not be protected. It does not contain data and it targets prometheus which does not support authentication. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the purpose of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you link to a security protocol this query param implements ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what you mean. This is just passing in the same api/admin key in through There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Kindly, I meant where have you seen we can pass secret as query param in a protocol ? I am wondering if it is a security issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any form of authentication via a URL Query Param is considered bad practice and is inherently unsafe, even when using TLS/SSL. There's always a "gotcha" somewhere that exposes the authentication method. This is why it's usually done with a bearer token or jwt, packaged as a header or body payload, and then passed via an encrypted tunnel. You can find this all on OWASP and within the RFC Specs. This has to do with underlying mechanics of how GET and POST requests work. For a real authy method, you would use POST and not expose the Auth tokens via Query Params. Edit: It took me a bit to find it. They updated it since I last read it. TLDR; Query parameters enable injection attacks. There's way more in-depth stuff that exploits query parameters. It's beginner security stuff. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The question was the opposite. We all know that this is not security here. That's why this PR is open... There is no need for beginners' explanations |
||
if (!validate_key(req, res, sparams.admin_keys)) { | ||
return; | ||
} | ||
// request slots data using task queue | ||
task_server task; | ||
task.id = llama.queue_tasks.get_new_id(); | ||
|
@@ -3000,32 +3049,6 @@ int main(int argc, char **argv) | |
llama.validate_model_chat_template(sparams); | ||
} | ||
|
||
// Middleware for API key validation | ||
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { | ||
// If API key is not set, skip validation | ||
if (sparams.api_keys.empty()) { | ||
return true; | ||
} | ||
|
||
// Check for API key in the header | ||
auto auth_header = req.get_header_value("Authorization"); | ||
std::string prefix = "Bearer "; | ||
if (auth_header.substr(0, prefix.size()) == prefix) { | ||
std::string received_api_key = auth_header.substr(prefix.size()); | ||
if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) { | ||
return true; // API key is valid | ||
} | ||
} | ||
|
||
// API key is invalid or not provided | ||
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8"); | ||
res.status = 401; // Unauthorized | ||
|
||
LOG_WARNING("Unauthorized: Invalid API Key", {}); | ||
|
||
return false; | ||
}; | ||
|
||
// this is only called if no index.html is found in the public --path | ||
svr.Get("/", [](const httplib::Request &, httplib::Response &res) | ||
{ | ||
|
@@ -3066,10 +3089,10 @@ int main(int argc, char **argv) | |
res.set_content(data.dump(), "application/json; charset=utf-8"); | ||
}); | ||
|
||
svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) | ||
svr.Post("/completion", [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res) | ||
{ | ||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); | ||
if (!validate_api_key(req, res)) { | ||
if (!validate_key(req, res, sparams.api_keys)) { | ||
return; | ||
} | ||
json data = json::parse(req.body); | ||
|
@@ -3163,10 +3186,10 @@ int main(int argc, char **argv) | |
res.set_content(models.dump(), "application/json; charset=utf-8"); | ||
}); | ||
|
||
const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res) | ||
const auto chat_completions = [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res) | ||
{ | ||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); | ||
if (!validate_api_key(req, res)) { | ||
if (!validate_key(req, res, sparams.api_keys)) { | ||
return; | ||
} | ||
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template); | ||
|
@@ -3246,10 +3269,10 @@ int main(int argc, char **argv) | |
svr.Post("/chat/completions", chat_completions); | ||
svr.Post("/v1/chat/completions", chat_completions); | ||
|
||
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) | ||
svr.Post("/infill", [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res) | ||
{ | ||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); | ||
if (!validate_api_key(req, res)) { | ||
if (!validate_key(req, res, sparams.api_keys)) { | ||
return; | ||
} | ||
json data = json::parse(req.body); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OWASP recommends "denying by default". It's a PITA, but the flag should be the inverse. If you want to toggle security measures, they should be on by default and the flag should be set to disable them for whatever purposes, e.g. testing, local usage, etc. Not sure how it should be handled in llama.cpp or local usage, but this matters in production.