Skip to content

Commit 7b9e872

Browse files
committed
Routes for hot-reloading and reading current vector composition
1 parent bd9f6b9 commit 7b9e872

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

examples/server/server.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3183,6 +3183,74 @@ int main(int argc, char ** argv) {
31833183
res.status = 200; // HTTP OK
31843184
};
31853185

3186+
const auto handle_get_control_vectors = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3187+
json vectors = json::array();
3188+
3189+
for (const auto & vec : ctx_server.params.control_vectors) {
3190+
vectors.push_back(json {
3191+
{ "fname", vec.fname },
3192+
{ "strength", vec.strength }
3193+
});
3194+
}
3195+
json data = {
3196+
{ "vectors", vectors },
3197+
{ "layer_start", ctx_server.params.control_vector_layer_start },
3198+
{ "layer_end", ctx_server.params.control_vector_layer_end }
3199+
};
3200+
res.set_content(data.dump(), "application/json; charset=utf-8");
3201+
};
3202+
3203+
const auto handle_set_control_vectors = [&ctx_server, &res_error, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) {
3204+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3205+
3206+
json data = json::parse(req.body);
3207+
std::vector<llama_control_vector_load_info> vec_params;
3208+
3209+
if (data.contains("vectors") && data["vectors"].is_array()) {
3210+
for (const auto &item : data["vectors"]) {
3211+
auto v = item.get<llama_control_vector_load_info>();
3212+
std::cout << "Add vector: " << v.fname << " " << v.strength << "\n";
3213+
vec_params.push_back(v);
3214+
}
3215+
} else {
3216+
std::cerr << "No vectors passed\n";
3217+
res_error(res, format_error_response("No vectors passed", ERROR_TYPE_SERVER));
3218+
return;
3219+
}
3220+
const auto cvec = llama_control_vector_load(vec_params);
3221+
if (cvec.n_embd == -1) {
3222+
std::cerr << "Could not load control vector\n";
3223+
res_error(res, format_error_response("Could not load control vector", ERROR_TYPE_SERVER));
3224+
return;
3225+
}
3226+
3227+
if (ctx_server.params.control_vector_layer_start <= 0) {
3228+
ctx_server.params.control_vector_layer_start = 1;
3229+
}
3230+
if (ctx_server.params.control_vector_layer_end <= 0){
3231+
ctx_server.params.control_vector_layer_end = llama_n_layer(ctx_server.model);
3232+
}
3233+
int err = llama_control_vector_apply(ctx_server.ctx,
3234+
cvec.data.data(),
3235+
cvec.data.size(),
3236+
cvec.n_embd,
3237+
ctx_server.params.control_vector_layer_start,
3238+
ctx_server.params.control_vector_layer_end);
3239+
if (err) {
3240+
std::cerr << "Could not apply control vector\n";
3241+
res_error(res, format_error_response("Could not apply control vector", ERROR_TYPE_SERVER));
3242+
return;
3243+
}
3244+
ctx_server.params.control_vectors.clear();
3245+
for (auto v : vec_params) {
3246+
std::cout << "set vector param: " << v.fname << " " << v.strength << "\n";
3247+
ctx_server.params.control_vectors.push_back(v);
3248+
}
3249+
3250+
handle_get_control_vectors(req, res);
3251+
};
3252+
3253+
31863254
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
31873255
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
31883256
json data = {
@@ -3534,6 +3602,8 @@ int main(int argc, char ** argv) {
35343602
svr->Get ("/metrics", handle_metrics);
35353603
svr->Get ("/props", handle_props);
35363604
svr->Get ("/v1/models", handle_models);
3605+
svr->Get ("/control-vectors", handle_get_control_vectors);
3606+
svr->Post("/control-vectors", handle_set_control_vectors);
35373607
svr->Post("/completion", handle_completions); // legacy
35383608
svr->Post("/completions", handle_completions);
35393609
svr->Post("/v1/completions", handle_completions);

examples/server/utils.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,3 +615,8 @@ static json format_error_response(const std::string & message, const enum error_
615615
{"type", type_str},
616616
};
617617
}
618+
619+
void from_json(const json& j, llama_control_vector_load_info& l) {
620+
j.at("strength").get_to(l.strength);
621+
j.at("fname").get_to(l.fname);
622+
}

0 commit comments

Comments
 (0)