@@ -36,6 +36,7 @@ using json = nlohmann::json;
36
36
struct server_params
37
37
{
38
38
std::string hostname = " 127.0.0.1" ;
39
+ std::string api_key;
39
40
std::string public_path = " examples/server/public" ;
40
41
int32_t port = 8080 ;
41
42
int32_t read_timeout = 600 ;
@@ -1953,6 +1954,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
1953
1954
printf (" --host ip address to listen (default (default: %s)\n " , sparams.hostname .c_str ());
1954
1955
printf (" --port PORT port to listen (default (default: %d)\n " , sparams.port );
1955
1956
printf (" --path PUBLIC_PATH path from which to serve static files (default %s)\n " , sparams.public_path .c_str ());
1957
+ printf (" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n " );
1956
1958
printf (" -to N, --timeout N server read/write timeout in seconds (default: %d)\n " , sparams.read_timeout );
1957
1959
printf (" --embedding enable embedding vector output (default: %s)\n " , params.embedding ? " enabled" : " disabled" );
1958
1960
printf (" -np N, --parallel N number of slots for process requests (default: %d)\n " , params.n_parallel );
@@ -2002,6 +2004,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
2002
2004
}
2003
2005
sparams.public_path = argv[i];
2004
2006
}
2007
+ else if (arg == " --api-key" )
2008
+ {
2009
+ if (++i >= argc)
2010
+ {
2011
+ invalid_param = true ;
2012
+ break ;
2013
+ }
2014
+ sparams.api_key = argv[i];
2015
+ }
2005
2016
else if (arg == " --timeout" || arg == " -to" )
2006
2017
{
2007
2018
if (++i >= argc)
@@ -2669,6 +2680,32 @@ int main(int argc, char **argv)
2669
2680
2670
2681
httplib::Server svr;
2671
2682
2683
+ // Middleware for API key validation
2684
+ auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
2685
+ // If API key is not set, skip validation
2686
+ if (sparams.api_key .empty ()) {
2687
+ return true ;
2688
+ }
2689
+
2690
+ // Check for API key in the header
2691
+ auto auth_header = req.get_header_value (" Authorization" );
2692
+ std::string prefix = " Bearer " ;
2693
+ if (auth_header.substr (0 , prefix.size ()) == prefix) {
2694
+ std::string received_api_key = auth_header.substr (prefix.size ());
2695
+ if (received_api_key == sparams.api_key ) {
2696
+ return true ; // API key is valid
2697
+ }
2698
+ }
2699
+
2700
+ // API key is invalid or not provided
2701
+ res.set_content (" Unauthorized: Invalid API Key" , " text/plain" );
2702
+ res.status = 401 ; // Unauthorized
2703
+
2704
+ LOG_WARNING (" Unauthorized: Invalid API Key" , {});
2705
+
2706
+ return false ;
2707
+ };
2708
+
2672
2709
svr.set_default_headers ({{" Server" , " llama.cpp" },
2673
2710
{" Access-Control-Allow-Origin" , " *" },
2674
2711
{" Access-Control-Allow-Headers" , " content-type" }});
@@ -2711,8 +2748,11 @@ int main(int argc, char **argv)
2711
2748
res.set_content (data.dump (), " application/json" );
2712
2749
});
2713
2750
2714
- svr.Post (" /completion" , [&llama](const httplib::Request &req, httplib::Response &res)
2751
+ svr.Post (" /completion" , [&llama, &validate_api_key ](const httplib::Request &req, httplib::Response &res)
2715
2752
{
2753
+ if (!validate_api_key (req, res)) {
2754
+ return ;
2755
+ }
2716
2756
json data = json::parse (req.body );
2717
2757
const int task_id = llama.request_completion (data, false , false , -1 );
2718
2758
if (!json_value (data, " stream" , false )) {
@@ -2799,8 +2839,11 @@ int main(int argc, char **argv)
2799
2839
});
2800
2840
2801
2841
// TODO: add mount point without "/v1" prefix -- how?
2802
- svr.Post (" /v1/chat/completions" , [&llama](const httplib::Request &req, httplib::Response &res)
2842
+ svr.Post (" /v1/chat/completions" , [&llama, &validate_api_key ](const httplib::Request &req, httplib::Response &res)
2803
2843
{
2844
+ if (!validate_api_key (req, res)) {
2845
+ return ;
2846
+ }
2804
2847
json data = oaicompat_completion_params_parse (json::parse (req.body ));
2805
2848
2806
2849
const int task_id = llama.request_completion (data, false , false , -1 );
@@ -2869,8 +2912,11 @@ int main(int argc, char **argv)
2869
2912
}
2870
2913
});
2871
2914
2872
- svr.Post (" /infill" , [&llama](const httplib::Request &req, httplib::Response &res)
2915
+ svr.Post (" /infill" , [&llama, &validate_api_key ](const httplib::Request &req, httplib::Response &res)
2873
2916
{
2917
+ if (!validate_api_key (req, res)) {
2918
+ return ;
2919
+ }
2874
2920
json data = json::parse (req.body );
2875
2921
const int task_id = llama.request_completion (data, true , false , -1 );
2876
2922
if (!json_value (data, " stream" , false )) {
@@ -3005,11 +3051,15 @@ int main(int argc, char **argv)
3005
3051
3006
3052
svr.set_error_handler ([](const httplib::Request &, httplib::Response &res)
3007
3053
{
3054
+ if (res.status == 401 )
3055
+ {
3056
+ res.set_content (" Unauthorized" , " text/plain" );
3057
+ }
3008
3058
if (res.status == 400 )
3009
3059
{
3010
3060
res.set_content (" Invalid request" , " text/plain" );
3011
3061
}
3012
- else if (res.status != 500 )
3062
+ else if (res.status == 404 )
3013
3063
{
3014
3064
res.set_content (" File Not Found" , " text/plain" );
3015
3065
res.status = 404 ;
@@ -3032,11 +3082,15 @@ int main(int argc, char **argv)
3032
3082
// to make it ctrl+clickable:
3033
3083
LOG_TEE (" \n llama server listening at http://%s:%d\n\n " , sparams.hostname .c_str (), sparams.port );
3034
3084
3035
- LOG_INFO (" HTTP server listening" , {
3036
- {" hostname" , sparams.hostname },
3037
- {" port" , sparams.port },
3038
- });
3085
+ std::unordered_map<std::string, std::string> log_data;
3086
+ log_data[" hostname" ] = sparams.hostname ;
3087
+ log_data[" port" ] = std::to_string (sparams.port );
3088
+
3089
+ if (!sparams.api_key .empty ()) {
3090
+ log_data[" api_key" ] = " api_key: ****" + sparams.api_key .substr (sparams.api_key .length () - 4 );
3091
+ }
3039
3092
3093
+ LOG_INFO (" HTTP server listening" , log_data);
3040
3094
// run the HTTP server in a thread - see comment below
3041
3095
std::thread t ([&]()
3042
3096
{
0 commit comments