Skip to content

Commit 88ae895

Browse files
server : add optional API Key Authentication example (#4441)
* Add API key authentication for enhanced server-client security * server : to snake_case --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent ee4725a commit 88ae895

File tree

3 files changed

+70
-10
lines changed

3 files changed

+70
-10
lines changed

examples/server/public/completion.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ export async function* llama(prompt, params = {}, config = {}) {
3434
headers: {
3535
'Connection': 'keep-alive',
3636
'Content-Type': 'application/json',
37-
'Accept': 'text/event-stream'
37+
'Accept': 'text/event-stream',
38+
...(params.api_key ? {'Authorization': `Bearer ${params.api_key}`} : {})
3839
},
3940
signal: controller.signal,
4041
});

examples/server/public/index.html

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@
235235
grammar: '',
236236
n_probs: 0, // no completion_probabilities,
237237
image_data: [],
238-
cache_prompt: true
238+
cache_prompt: true,
239+
api_key: ''
239240
})
240241

241242
/* START: Support for storing prompt templates and parameters in browsers LocalStorage */
@@ -790,6 +791,10 @@
790791
<fieldset>
791792
${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
792793
</fieldset>
794+
<fieldset>
795+
<label for="api_key">API Key</label>
796+
<input type="text" name="api_key" value="${params.value.api_key}" placeholder="Enter API key" oninput=${updateParams} />
797+
</fieldset>
793798
</details>
794799
</form>
795800
`

examples/server/server.cpp

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ using json = nlohmann::json;
3636
struct server_params
3737
{
3838
std::string hostname = "127.0.0.1";
39+
std::string api_key;
3940
std::string public_path = "examples/server/public";
4041
int32_t port = 8080;
4142
int32_t read_timeout = 600;
@@ -1953,6 +1954,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
19531954
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
19541955
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
19551956
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
1957+
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
19561958
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
19571959
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
19581960
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
@@ -2002,6 +2004,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
20022004
}
20032005
sparams.public_path = argv[i];
20042006
}
2007+
else if (arg == "--api-key")
2008+
{
2009+
if (++i >= argc)
2010+
{
2011+
invalid_param = true;
2012+
break;
2013+
}
2014+
sparams.api_key = argv[i];
2015+
}
20052016
else if (arg == "--timeout" || arg == "-to")
20062017
{
20072018
if (++i >= argc)
@@ -2669,6 +2680,32 @@ int main(int argc, char **argv)
26692680

26702681
httplib::Server svr;
26712682

2683+
// Middleware for API key validation
2684+
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
2685+
// If API key is not set, skip validation
2686+
if (sparams.api_key.empty()) {
2687+
return true;
2688+
}
2689+
2690+
// Check for API key in the header
2691+
auto auth_header = req.get_header_value("Authorization");
2692+
std::string prefix = "Bearer ";
2693+
if (auth_header.substr(0, prefix.size()) == prefix) {
2694+
std::string received_api_key = auth_header.substr(prefix.size());
2695+
if (received_api_key == sparams.api_key) {
2696+
return true; // API key is valid
2697+
}
2698+
}
2699+
2700+
// API key is invalid or not provided
2701+
res.set_content("Unauthorized: Invalid API Key", "text/plain");
2702+
res.status = 401; // Unauthorized
2703+
2704+
LOG_WARNING("Unauthorized: Invalid API Key", {});
2705+
2706+
return false;
2707+
};
2708+
26722709
svr.set_default_headers({{"Server", "llama.cpp"},
26732710
{"Access-Control-Allow-Origin", "*"},
26742711
{"Access-Control-Allow-Headers", "content-type"}});
@@ -2711,8 +2748,11 @@ int main(int argc, char **argv)
27112748
res.set_content(data.dump(), "application/json");
27122749
});
27132750

2714-
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
2751+
svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
27152752
{
2753+
if (!validate_api_key(req, res)) {
2754+
return;
2755+
}
27162756
json data = json::parse(req.body);
27172757
const int task_id = llama.request_completion(data, false, false, -1);
27182758
if (!json_value(data, "stream", false)) {
@@ -2799,8 +2839,11 @@ int main(int argc, char **argv)
27992839
});
28002840

28012841
// TODO: add mount point without "/v1" prefix -- how?
2802-
svr.Post("/v1/chat/completions", [&llama](const httplib::Request &req, httplib::Response &res)
2842+
svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
28032843
{
2844+
if (!validate_api_key(req, res)) {
2845+
return;
2846+
}
28042847
json data = oaicompat_completion_params_parse(json::parse(req.body));
28052848

28062849
const int task_id = llama.request_completion(data, false, false, -1);
@@ -2869,8 +2912,11 @@ int main(int argc, char **argv)
28692912
}
28702913
});
28712914

2872-
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
2915+
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
28732916
{
2917+
if (!validate_api_key(req, res)) {
2918+
return;
2919+
}
28742920
json data = json::parse(req.body);
28752921
const int task_id = llama.request_completion(data, true, false, -1);
28762922
if (!json_value(data, "stream", false)) {
@@ -3005,11 +3051,15 @@ int main(int argc, char **argv)
30053051

30063052
svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
30073053
{
3054+
if (res.status == 401)
3055+
{
3056+
res.set_content("Unauthorized", "text/plain");
3057+
}
30083058
if (res.status == 400)
30093059
{
30103060
res.set_content("Invalid request", "text/plain");
30113061
}
3012-
else if (res.status != 500)
3062+
else if (res.status == 404)
30133063
{
30143064
res.set_content("File Not Found", "text/plain");
30153065
res.status = 404;
@@ -3032,11 +3082,15 @@ int main(int argc, char **argv)
30323082
// to make it ctrl+clickable:
30333083
LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
30343084

3035-
LOG_INFO("HTTP server listening", {
3036-
{"hostname", sparams.hostname},
3037-
{"port", sparams.port},
3038-
});
3085+
std::unordered_map<std::string, std::string> log_data;
3086+
log_data["hostname"] = sparams.hostname;
3087+
log_data["port"] = std::to_string(sparams.port);
3088+
3089+
if (!sparams.api_key.empty()) {
3090+
log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
3091+
}
30393092

3093+
LOG_INFO("HTTP server listening", log_data);
30403094
// run the HTTP server in a thread - see comment below
30413095
std::thread t([&]()
30423096
{

0 commit comments

Comments
 (0)