Skip to content

Commit 0274e6b

Browse files
committed
Control vectors in server
1 parent a0e584d commit 0274e6b

File tree

2 files changed

+111
-1
lines changed

2 files changed

+111
-1
lines changed

examples/server/server.cpp

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,6 @@ struct server_response {
624624
}
625625
}
626626
};
627-
628627
struct server_context {
629628
llama_model * model = nullptr;
630629
llama_context * ctx = nullptr;
@@ -2700,6 +2699,35 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
27002699
break;
27012700
}
27022701
params.kv_overrides.push_back(kvo);
2702+
} else if (arg == "--control-vector") {
2703+
if (++i >= argc) {
2704+
invalid_param = true;
2705+
break;
2706+
}
2707+
params.control_vectors.push_back({ 1.0f, argv[i], });
2708+
} else if (arg == "--control-vector-scaled") {
2709+
if (++i >= argc) {
2710+
invalid_param = true;
2711+
break;
2712+
}
2713+
const char* fname = argv[i];
2714+
if (++i >= argc) {
2715+
invalid_param = true;
2716+
break;
2717+
}
2718+
params.control_vectors.push_back({ std::stof(argv[i]), fname, });
2719+
} else if (arg == "--control-vector-layer-range") {
2720+
if (++i >= argc) {
2721+
invalid_param = true;
2722+
break;
2723+
}
2724+
params.control_vector_layer_start = std::stoi(argv[i]);
2725+
if (++i >= argc) {
2726+
invalid_param = true;
2727+
break;
2728+
}
2729+
params.control_vector_layer_end = std::stoi(argv[i]);
2730+
break;
27032731
} else {
27042732
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
27052733
server_print_usage(argv[0], default_params, default_sparams);
@@ -3148,6 +3176,81 @@ int main(int argc, char ** argv) {
31483176
res.status = 200; // HTTP OK
31493177
};
31503178

3179+
const auto handle_get_control_vectors = [&ctx_server, &params](const httplib::Request & req, httplib::Response & res) {
3180+
json vectors = json::array();
3181+
3182+
for (const auto & vec : params.control_vectors) {
3183+
vectors.push_back(json {
3184+
{ "fname", vec.fname },
3185+
{ "strength", vec.strength }
3186+
});
3187+
}
3188+
json data = {
3189+
{ "vectors", vectors },
3190+
{ "layer_start", params.control_vector_layer_start },
3191+
{ "layer_end", params.control_vector_layer_end }
3192+
};
3193+
res.set_content(data.dump(), "application/json; charset=utf-8");
3194+
};
3195+
3196+
const auto handle_set_control_vectors = [&ctx_server, &res_error, &params, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) {
3197+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3198+
3199+
json data = json::parse(req.body);
3200+
std::vector<llama_control_vector_load_info> vec_params;
3201+
3202+
if (data.contains("vectors") && data["vectors"].is_array()) {
3203+
for (const auto &item : data["vectors"]) {
3204+
auto v = item.get<llama_control_vector_load_info>();
3205+
// std::cout << "Add vector: " << v.fname << " " << v.strength << "\n";
3206+
vec_params.push_back(v);
3207+
}
3208+
} else {
3209+
std::cerr << "No vectors passed\n";
3210+
res_error(res, format_error_response("No vectors passed", ERROR_TYPE_SERVER));
3211+
return;
3212+
}
3213+
for (auto v : params.control_vectors) {
3214+
// std::cout << "Subtract vector:" << v.fname << " " << v.strength << "\n";
3215+
vec_params.push_back({ -v.strength, v.fname });
3216+
}
3217+
const auto cvec = llama_control_vector_load(vec_params);
3218+
if (cvec.n_embd == -1) {
3219+
// std::cerr << "Could not load control vector\n";
3220+
res_error(res, format_error_response("Could not load control vector", ERROR_TYPE_SERVER));
3221+
return;
3222+
}
3223+
3224+
if (params.control_vector_layer_start <= 0) {
3225+
params.control_vector_layer_start = 1;
3226+
}
3227+
if (params.control_vector_layer_end <= 0){
3228+
params.control_vector_layer_end = llama_n_layer(ctx_server.model);
3229+
}
3230+
int err = llama_control_vector_apply(ctx_server.ctx,
3231+
cvec.data.data(),
3232+
cvec.data.size(),
3233+
cvec.n_embd,
3234+
params.control_vector_layer_start,
3235+
params.control_vector_layer_end);
3236+
if (err) {
3237+
std::cerr << "Could not apply control vector\n";
3238+
res_error(res, format_error_response("Could not apply control vector", ERROR_TYPE_SERVER));
3239+
return;
3240+
}
3241+
auto s = params.control_vectors.size();
3242+
auto s2 = vec_params.size();
3243+
params.control_vectors.clear();
3244+
unsigned i = 0;
3245+
for (auto v : vec_params) {
3246+
if (i++ < s2 - s) {
3247+
//std::cout << "set vector param: " << v.fname << " " << v.strength << "\n";
3248+
params.control_vectors.push_back(v);
3249+
}
3250+
}
3251+
handle_get_control_vectors(req, res);
3252+
};
3253+
31513254
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
31523255
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
31533256
json data = {
@@ -3497,8 +3600,10 @@ int main(int argc, char ** argv) {
34973600
svr->Get ("/health", handle_health);
34983601
svr->Get ("/slots", handle_slots);
34993602
svr->Get ("/metrics", handle_metrics);
3603+
svr->Get ("/control-vectors", handle_get_control_vectors);
35003604
svr->Get ("/props", handle_props);
35013605
svr->Get ("/v1/models", handle_models);
3606+
svr->Post("/control-vectors", handle_set_control_vectors);
35023607
svr->Post("/completion", handle_completions); // legacy
35033608
svr->Post("/completions", handle_completions);
35043609
svr->Post("/v1/completions", handle_completions);

examples/server/utils.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,3 +615,8 @@ static json format_error_response(const std::string & message, const enum error_
615615
{"type", type_str},
616616
};
617617
}
618+
619+
void from_json(const json& j, llama_control_vector_load_info& l) {
620+
j.at("strength").get_to(l.strength);
621+
j.at("fname").get_to(l.fname);
622+
}

0 commit comments

Comments
 (0)