@@ -624,7 +624,6 @@ struct server_response {
624
624
}
625
625
}
626
626
};
627
-
628
627
struct server_context {
629
628
llama_model * model = nullptr ;
630
629
llama_context * ctx = nullptr ;
@@ -2700,6 +2699,35 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
2700
2699
break ;
2701
2700
}
2702
2701
params.kv_overrides .push_back (kvo);
2702
+ } else if (arg == " --control-vector" ) {
2703
+ if (++i >= argc) {
2704
+ invalid_param = true ;
2705
+ break ;
2706
+ }
2707
+ params.control_vectors .push_back ({ 1 .0f , argv[i], });
2708
+ } else if (arg == " --control-vector-scaled" ) {
2709
+ if (++i >= argc) {
2710
+ invalid_param = true ;
2711
+ break ;
2712
+ }
2713
+ const char * fname = argv[i];
2714
+ if (++i >= argc) {
2715
+ invalid_param = true ;
2716
+ break ;
2717
+ }
2718
+ params.control_vectors .push_back ({ std::stof (argv[i]), fname, });
2719
+ } else if (arg == " --control-vector-layer-range" ) {
2720
+ if (++i >= argc) {
2721
+ invalid_param = true ;
2722
+ break ;
2723
+ }
2724
+ params.control_vector_layer_start = std::stoi (argv[i]);
2725
+ if (++i >= argc) {
2726
+ invalid_param = true ;
2727
+ break ;
2728
+ }
2729
+ params.control_vector_layer_end = std::stoi (argv[i]);
2730
+ break ;
2703
2731
} else {
2704
2732
fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
2705
2733
server_print_usage (argv[0 ], default_params, default_sparams);
@@ -3148,6 +3176,81 @@ int main(int argc, char ** argv) {
3148
3176
res.status = 200 ; // HTTP OK
3149
3177
};
3150
3178
3179
+ const auto handle_get_control_vectors = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res) {
3180
+ json vectors = json::array ();
3181
+
3182
+ for (const auto & vec : params.control_vectors ) {
3183
+ vectors.push_back (json {
3184
+ { " fname" , vec.fname },
3185
+ { " strength" , vec.strength }
3186
+ });
3187
+ }
3188
+ json data = {
3189
+ { " vectors" , vectors },
3190
+ { " layer_start" , params.control_vector_layer_start },
3191
+ { " layer_end" , params.control_vector_layer_end }
3192
+ };
3193
+ res.set_content (data.dump (), " application/json; charset=utf-8" );
3194
+ };
3195
+
3196
+ const auto handle_set_control_vectors = [&ctx_server, &res_error, ¶ms, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) {
3197
+ res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3198
+
3199
+ json data = json::parse (req.body );
3200
+ std::vector<llama_control_vector_load_info> vec_params;
3201
+
3202
+ if (data.contains (" vectors" ) && data[" vectors" ].is_array ()) {
3203
+ for (const auto &item : data[" vectors" ]) {
3204
+ auto v = item.get <llama_control_vector_load_info>();
3205
+ // std::cout << "Add vector: " << v.fname << " " << v.strength << "\n";
3206
+ vec_params.push_back (v);
3207
+ }
3208
+ } else {
3209
+ std::cerr << " No vectors passed\n " ;
3210
+ res_error (res, format_error_response (" No vectors passed" , ERROR_TYPE_SERVER));
3211
+ return ;
3212
+ }
3213
+ for (auto v : params.control_vectors ) {
3214
+ // std::cout << "Subtract vector:" << v.fname << " " << v.strength << "\n";
3215
+ vec_params.push_back ({ -v.strength , v.fname });
3216
+ }
3217
+ const auto cvec = llama_control_vector_load (vec_params);
3218
+ if (cvec.n_embd == -1 ) {
3219
+ // std::cerr << "Could not load control vector\n";
3220
+ res_error (res, format_error_response (" Could not load control vector" , ERROR_TYPE_SERVER));
3221
+ return ;
3222
+ }
3223
+
3224
+ if (params.control_vector_layer_start <= 0 ) {
3225
+ params.control_vector_layer_start = 1 ;
3226
+ }
3227
+ if (params.control_vector_layer_end <= 0 ){
3228
+ params.control_vector_layer_end = llama_n_layer (ctx_server.model );
3229
+ }
3230
+ int err = llama_control_vector_apply (ctx_server.ctx ,
3231
+ cvec.data .data (),
3232
+ cvec.data .size (),
3233
+ cvec.n_embd ,
3234
+ params.control_vector_layer_start ,
3235
+ params.control_vector_layer_end );
3236
+ if (err) {
3237
+ std::cerr << " Could not apply control vector\n " ;
3238
+ res_error (res, format_error_response (" Could not apply control vector" , ERROR_TYPE_SERVER));
3239
+ return ;
3240
+ }
3241
+ auto s = params.control_vectors .size ();
3242
+ auto s2 = vec_params.size ();
3243
+ params.control_vectors .clear ();
3244
+ unsigned i = 0 ;
3245
+ for (auto v : vec_params) {
3246
+ if (i++ < s2 - s) {
3247
+ // std::cout << "set vector param: " << v.fname << " " << v.strength << "\n";
3248
+ params.control_vectors .push_back (v);
3249
+ }
3250
+ }
3251
+ handle_get_control_vectors (req, res);
3252
+ };
3253
+
3151
3254
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3152
3255
res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3153
3256
json data = {
@@ -3497,8 +3600,10 @@ int main(int argc, char ** argv) {
3497
3600
svr->Get (" /health" , handle_health);
3498
3601
svr->Get (" /slots" , handle_slots);
3499
3602
svr->Get (" /metrics" , handle_metrics);
3603
+ svr->Get (" /control-vectors" , handle_get_control_vectors);
3500
3604
svr->Get (" /props" , handle_props);
3501
3605
svr->Get (" /v1/models" , handle_models);
3606
+ svr->Post (" /control-vectors" , handle_set_control_vectors);
3502
3607
svr->Post (" /completion" , handle_completions); // legacy
3503
3608
svr->Post (" /completions" , handle_completions);
3504
3609
svr->Post (" /v1/completions" , handle_completions);
0 commit comments