Skip to content

server: feature Add Admin key parameter for slots/health/metrics #5837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ see https://github.com/ggerganov/llama.cpp/issues/1437
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
- `--port`: Set the port to listen. Default: `8080`.
- `--path`: path from which to serve static files (default examples/server/public)
- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests to `/completion`, `/infill` and `/chat/completions` must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
- `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s.
- `--admin-key`: Set an admin key for request authorization. With an admin key set, requests to `/metrics` and `/slots` must have the Authorization header set with the api key as Bearer token. Additionally, `/health` will not show slots without the key. May be used multiple times to enable multiple valid keys.
- `--embedding`: Enable embedding extraction, Default: disabled.
- `-np N`, `--parallel N`: Set the number of slots for process requests (default: 1)
- `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled)
Expand Down
93 changes: 58 additions & 35 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ using json = nlohmann::json;
struct server_params {
std::string hostname = "127.0.0.1";
std::vector<std::string> api_keys;
std::vector<std::string> admin_keys;
std::string public_path = "examples/server/public";
std::string chat_template = "";
int32_t port = 8080;
Expand Down Expand Up @@ -2060,6 +2061,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
printf(" --admin-key ADMIN_KEY optional admin key to enhance server security. If set, requests to admin endpoints must include this key.\n");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OWASP recommends "denying by default". It's a PITA, but the flag should be the inverse. If you want to toggle security measures, they should be on by default and the flag should be set to disable them for whatever purposes, e.g. testing, local usage, etc. Not sure how it should be handled in llama.cpp or local usage, but this matters in production.

printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
Expand Down Expand Up @@ -2128,6 +2130,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
sparams.public_path = argv[i];
}
else if (arg == "--admin-key")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
sparams.admin_keys.emplace_back(argv[i]);
}
else if (arg == "--api-key")
{
if (++i >= argc)
Expand Down Expand Up @@ -2772,6 +2783,38 @@ int main(int argc, char **argv)
res.set_header("Access-Control-Allow-Headers", "*");
});

// Middleware for API key validation
auto validate_key = [&sparams](const httplib::Request &req, httplib::Response &res, std::vector<std::string> &keys) -> bool {
// If API key is not set, skip validation
if (keys.empty()) {
return true;
}

// Check for API key in the header
auto auth_header = req.get_header_value("Authorization");
std::string prefix = "Bearer ";
if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size());
if (std::find(keys.begin(), keys.end(), received_api_key) != keys.end()) {
return true; // API key is valid
}
}

// Check for API key in the params
auto auth_param = req.get_param_value("key");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Query params are in clear on all http traffic scanners. I feel this hack is a security breach.
Monitoring endpoints must simply not be exposed to the world, or at least slots data simply can be disabled. They are here for debug purpose during development only.
@ggerganov I would prefer we invest on a real security protocol like Oauth2/openid than adding this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but authorization headers are also passed in cleartext without https. Given the nature of the api keys as they are at the moment this is no more or less secure than the current solution.

Long term, specifying keys on the command line and passing them directly is not the correct solution.

This is just a quick way to get some level of security on those endpoints in the server as it exists now until a proper (jwt would make sense to me) solution is in place.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am speaking about logging. Authorization header is never logged. Query params are, example:

https://github.com/ggerganov/llama.cpp/blob/9731134296af3a6839cd682e51d9c2109a871de5/examples/server/server.cpp#L2691-L2705

But all http monitoring tool will do the same.

Note: JWT is not a security protocol, just a transparent token often used in oauth2.

if (std::find(keys.begin(), keys.end(), auth_param) != keys.end()) {
return true; // API key is valid
}

// API key is invalid or not provided
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
res.status = 401; // Unauthorized

LOG_WARNING("Unauthorized: Invalid API Key", {});

return false;
};

svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) {
server_state current_state = state.load();
switch(current_state) {
Expand All @@ -2797,7 +2840,7 @@ int main(int argc, char **argv)
{"slots_idle", n_idle_slots},
{"slots_processing", n_processing_slots}};
res.status = 200; // HTTP OK
if (sparams.slots_endpoint && req.has_param("include_slots")) {
if (sparams.slots_endpoint && req.has_param("include_slots") && validate_key(req, res, sparams.admin_keys)) {
health["slots"] = result.result_json["slots"];
}

Expand All @@ -2822,7 +2865,10 @@ int main(int argc, char **argv)
});

if (sparams.slots_endpoint) {
svr.Get("/slots", [&](const httplib::Request&, httplib::Response& res) {
svr.Get("/slots", [&](const httplib::Request& req, httplib::Response& res) {
if (!validate_key(req, res, sparams.admin_keys)) {
return;
}
// request slots data using task queue
task_server task;
task.id = llama.queue_tasks.get_new_id();
Expand All @@ -2842,7 +2888,10 @@ int main(int argc, char **argv)
}

if (sparams.metrics_endpoint) {
svr.Get("/metrics", [&](const httplib::Request&, httplib::Response& res) {
svr.Get("/metrics", [&](const httplib::Request& req, httplib::Response& res) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/metrics must not be protected. It does not contain data and it targets prometheus which does not support authentication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the purpose of the key query param.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you link to a security protocol this query param implements ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean. This is just passing in the same api/admin key in through ?key= instead of the authorization header for ease in configuration. It's the same as the authorization header otherwise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kindly, I meant where have you seen we can pass secret as query param in a protocol ? I am wondering if it is a security issue.
For example, in oauth2 implicit grant flow, this has been deprecated.

Copy link
Contributor

@teleprint-me teleprint-me Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any form of authentication via a URL Query Param is considered bad practice and is inherently unsafe, even when using TLS/SSL.

There's always a "gotcha" somewhere that exposes the authentication method. This is why it's usually done with a bearer token or jwt, packaged as a header or body payload, and then passed via an encrypted tunnel.

You can find this all on OWASP and within the RFC Specs. This has to do with underlying mechanics of how GET and POST requests work.

For a real authy method, you would use POST and not expose the Auth tokens via Query Params.

Edit: It took me a bit to find it. They updated it since I last read it.

TLDR; Query parameters enable injection attacks.

There's way more in-depth stuff that exploits query parameters. It's beginner security stuff.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The question was the opposite. We all know that this is not security here. That's why this PR is open... There is no need for beginners' explanations

if (!validate_key(req, res, sparams.admin_keys)) {
return;
}
// request slots data using task queue
task_server task;
task.id = llama.queue_tasks.get_new_id();
Expand Down Expand Up @@ -3000,32 +3049,6 @@ int main(int argc, char **argv)
llama.validate_model_chat_template(sparams);
}

// Middleware for API key validation
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
// If API key is not set, skip validation
if (sparams.api_keys.empty()) {
return true;
}

// Check for API key in the header
auto auth_header = req.get_header_value("Authorization");
std::string prefix = "Bearer ";
if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size());
if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
return true; // API key is valid
}
}

// API key is invalid or not provided
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
res.status = 401; // Unauthorized

LOG_WARNING("Unauthorized: Invalid API Key", {});

return false;
};

// this is only called if no index.html is found in the public --path
svr.Get("/", [](const httplib::Request &, httplib::Response &res)
{
Expand Down Expand Up @@ -3066,10 +3089,10 @@ int main(int argc, char **argv)
res.set_content(data.dump(), "application/json; charset=utf-8");
});

svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
svr.Post("/completion", [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
if (!validate_key(req, res, sparams.api_keys)) {
return;
}
json data = json::parse(req.body);
Expand Down Expand Up @@ -3163,10 +3186,10 @@ int main(int argc, char **argv)
res.set_content(models.dump(), "application/json; charset=utf-8");
});

const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
const auto chat_completions = [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
if (!validate_key(req, res, sparams.api_keys)) {
return;
}
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
Expand Down Expand Up @@ -3246,10 +3269,10 @@ int main(int argc, char **argv)
svr.Post("/chat/completions", chat_completions);
svr.Post("/v1/chat/completions", chat_completions);

svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
svr.Post("/infill", [&llama, &validate_key, &sparams](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
if (!validate_key(req, res, sparams.api_keys)) {
return;
}
json data = json::parse(req.body);
Expand Down