Skip to content

Commit 2737945

Browse files
m18coppolaMichael Coppola
and
Michael Coppola
authored
server : support for multiple api keys (#4864)
* server: added support for multiple api keys, added loading api keys from file * minor: fix whitespace * added file error handling to --api-key-file, changed code to better reflect current style * server: update README.md for --api-key-file --------- Co-authored-by: Michael Coppola <info@michaeljcoppola.com>
1 parent eab6795 commit 2737945

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

examples/server/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ Command line options:
2323
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
2424
- `--port`: Set the port to listen. Default: `8080`.
2525
- `--path`: path from which to serve static files (default examples/server/public)
26-
- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token.
26+
- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
27+
- `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s.
2728
- `--embedding`: Enable embedding extraction, Default: disabled.
2829
- `-np N`, `--parallel N`: Set the number of slots for process requests (default: 1)
2930
- `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled)

examples/server/server.cpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ using json = nlohmann::json;
3939
struct server_params
4040
{
4141
std::string hostname = "127.0.0.1";
42-
std::string api_key;
42+
std::vector<std::string> api_keys;
4343
std::string public_path = "examples/server/public";
4444
int32_t port = 8080;
4545
int32_t read_timeout = 600;
@@ -2021,6 +2021,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
20212021
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
20222022
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
20232023
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
2024+
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
20242025
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
20252026
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
20262027
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
@@ -2081,7 +2082,28 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
20812082
invalid_param = true;
20822083
break;
20832084
}
2084-
sparams.api_key = argv[i];
2085+
sparams.api_keys.push_back(argv[i]);
2086+
}
2087+
else if (arg == "--api-key-file")
2088+
{
2089+
if (++i >= argc)
2090+
{
2091+
invalid_param = true;
2092+
break;
2093+
}
2094+
std::ifstream key_file(argv[i]);
2095+
if (!key_file) {
2096+
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
2097+
invalid_param = true;
2098+
break;
2099+
}
2100+
std::string key;
2101+
while (std::getline(key_file, key)) {
2102+
if (key.size() > 0) {
2103+
sparams.api_keys.push_back(key);
2104+
}
2105+
}
2106+
key_file.close();
20852107
}
20862108
else if (arg == "--timeout" || arg == "-to")
20872109
{
@@ -2881,8 +2903,10 @@ int main(int argc, char **argv)
28812903
log_data["hostname"] = sparams.hostname;
28822904
log_data["port"] = std::to_string(sparams.port);
28832905

2884-
if (!sparams.api_key.empty()) {
2885-
log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
2906+
if (sparams.api_keys.size() == 1) {
2907+
log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4);
2908+
} else if (sparams.api_keys.size() > 1) {
2909+
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
28862910
}
28872911

28882912
LOG_INFO("HTTP server listening", log_data);
@@ -2912,7 +2936,7 @@ int main(int argc, char **argv)
29122936
// Middleware for API key validation
29132937
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
29142938
// If API key is not set, skip validation
2915-
if (sparams.api_key.empty()) {
2939+
if (sparams.api_keys.empty()) {
29162940
return true;
29172941
}
29182942

@@ -2921,7 +2945,7 @@ int main(int argc, char **argv)
29212945
std::string prefix = "Bearer ";
29222946
if (auth_header.substr(0, prefix.size()) == prefix) {
29232947
std::string received_api_key = auth_header.substr(prefix.size());
2924-
if (received_api_key == sparams.api_key) {
2948+
if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
29252949
return true; // API key is valid
29262950
}
29272951
}

0 commit comments

Comments
 (0)