Skip to content

Control vectors in server #6289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2640,6 +2640,8 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n)
//

static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) {
auto start = ggml_time_ms();
printf("control vector load_one...\n");
int32_t n_tensors;

size_t n_bytes = 0;
Expand All @@ -2650,12 +2652,7 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr

// calculate size of ctx needed for tensors, ensure tensors are f32, and find max layer
{
struct ggml_init_params meta_params = {
/* .mem_size = */ ggml_tensor_overhead() * 128 + ggml_graph_overhead(),
/* .mem_buffer = */ nullptr,
/* .no_alloc = */ true,
};
ggml_context * meta_ctx = ggml_init(meta_params);
ggml_context * meta_ctx = nullptr;
struct gguf_init_params meta_gguf_params = {
/* .no_alloc = */ true,
/* .ctx = */ &meta_ctx,
Expand All @@ -2678,40 +2675,39 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr
uint32_t layer = std::stoi(name.substr(dotpos + 1));
if (layer == 0) {
fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
ggml_free(meta_ctx);
return result;
}
if (layer > max_direction_layer) {
max_direction_layer = layer;
}
} catch (...) {
fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
return result;
ggml_free(meta_ctx);
}
}

struct ggml_tensor * tensor_meta = ggml_get_tensor(meta_ctx, name.c_str());
if (tensor_meta->type != GGML_TYPE_F32 || ggml_n_dims(tensor_meta) != 1) {
fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
ggml_free(meta_ctx);
return result;
}
if (result.n_embd == -1) {
result.n_embd = ggml_nelements(tensor_meta);
} else if (ggml_nelements(tensor_meta) != result.n_embd) {
fprintf(stderr, "%s: direction tensor sizes mismatched in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
ggml_free(meta_ctx);
return result;
}
n_bytes += ggml_nbytes(tensor_meta);
}
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
ggml_free(meta_ctx);
}

if (n_tensors == 0) {
Expand All @@ -2720,13 +2716,7 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr
}

// load and scale tensors into final control vector context
struct ggml_init_params ggml_params = {
/* .mem_size = */ ggml_tensor_overhead() * n_tensors + n_bytes,
/* .mem_buffer = */ nullptr,
/* .no_alloc = */ false,
};
struct ggml_context * ctx = ggml_init(ggml_params);

struct ggml_context * ctx = nullptr;
struct gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ &ctx,
Expand Down Expand Up @@ -2759,10 +2749,17 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr
}
}

gguf_free(ctx_gguf);
ggml_free(ctx);

auto end = ggml_time_ms();
printf("control vector load_one took %ums\n", end - start);
return result;
}

llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos) {
auto start = ggml_time_ms();
printf("control vector load...\n");
llama_control_vector_data result = { -1, {} };

for (const auto & info : load_infos) {
Expand All @@ -2772,7 +2769,7 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
return result;
}
if (result.n_embd != -1 && (result.n_embd != cur.n_embd || result.data.size() != cur.data.size())) {
fprintf(stderr, "%s: control vector in %s does not match previous vector dimensions\n", __func__, info.fname.c_str());
printf("%s: control vector in %s does not match previous vector dimensions\n", __func__, info.fname.c_str());
return result;
}

Expand All @@ -2786,8 +2783,10 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
}

if (result.n_embd == -1) {
fprintf(stderr, "%s: no vectors passed\n", __func__);
printf("%s: no vectors passed\n", __func__);
}

auto end = ggml_time_ms();
printf("control vector load time: %ums\n", end-start);
return result;
}
222 changes: 206 additions & 16 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ struct server_params {

std::vector<std::string> api_keys;

std::vector<llama_control_vector_load_option> control_vector_load_options;

#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
std::string ssl_key_file = "";
std::string ssl_cert_file = "";
Expand Down Expand Up @@ -2217,6 +2219,12 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" set an alias for the model, will be added as `model` field in completion response\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
printf(" --control-vector FNAME\n");
printf(" add a control vector\n");
printf(" --control-vector-scaled FNAME S\n");
printf(" add a control vector with user defined scaling S\n");
printf(" --control-vector-layer-range START END\n");
printf(" layer range to apply the control vector(s) to, start and end inclusive\n");
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
printf(" --path PUBLIC_PATH path from which to serve static files (default: disabled)\n");
Expand Down Expand Up @@ -2700,6 +2708,58 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
break;
}
params.kv_overrides.push_back(kvo);
} else if (arg == "--control-vector") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.control_vectors.push_back({ 1.0f, argv[i], });
} else if (arg == "--control-vector-scaled") {
if (++i >= argc) {
invalid_param = true;
break;
}
const char* fname = argv[i];
if (++i >= argc) {
invalid_param = true;
break;
}
params.control_vectors.push_back({ std::stof(argv[i]), fname, });
} else if (arg == "--control-vector-layer-range") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.control_vector_layer_start = std::stoi(argv[i]);
if (++i >= argc) {
invalid_param = true;
break;
}
params.control_vector_layer_end = std::stoi(argv[i]);
break;
} else if (arg == "--control-vector-option") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::string name = argv[i];

if (++i >= argc) {
invalid_param = true;
break;
}
std::string fname = argv[i];

size_t slen = fname.length();
bool is_dir = slen < 5 || strncmp(argv[i] + slen - 5, ".gguf", 5) != 0;

// Append path separator for dir names
if (is_dir && argv[i][slen - 1] != '/')
fname += '/';
if (is_dir && argv[i-1][slen - 1] != '/')
name += '/';
sparams.control_vector_load_options.push_back({ name, fname, is_dir });
break;
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
server_print_usage(argv[0], default_params, default_sparams);
Expand Down Expand Up @@ -3148,6 +3208,133 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};

const auto handle_control_vector_options = [&sparams](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json options = json::array();

for (const auto & opt : sparams.control_vector_load_options) {
options.push_back(opt.name);
}
res.set_content(options.dump(), "application/json; charset=utf-8");
};

const auto handle_get_control_vectors = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json vectors = json::array();

for (const auto & vec : ctx_server.params.control_vectors) {
vectors.push_back(json {
{ "fname", vec.fname },
{ "strength", vec.strength }
});
}
json data = {
{ "vectors", vectors },
{ "layer_start", ctx_server.params.control_vector_layer_start },
{ "layer_end", ctx_server.params.control_vector_layer_end }
};
res.set_content(data.dump(), "application/json; charset=utf-8");
};

const auto handle_set_control_vectors = [&ctx_server, &sparams, &res_error, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);

// vector parameters passed by user
std::vector<llama_control_vector_load_info> vec_params;
// names translated to real file names
std::vector<llama_control_vector_load_info> real_vec_params;

if (data.contains("vectors") && data["vectors"].is_array()) {
for (const auto &item : data["vectors"]) {
llama_control_vector_load_info v = item.get<llama_control_vector_load_info>();
std::string real_fname = "";
std::cout << "Check vec " << v.fname << "\n";
// check for path traversal attempt
if (v.fname.length() > 0 && v.fname[0] != '/' && v.fname[0] != '\\') {
if (v.fname.find("../") == -1 && v.fname.find("..\\") == -1 &&
v.fname.find("/..") == -1 && v.fname.find("\\..") == -1) {

// check if vector name matches allowed names
for (auto opt : sparams.control_vector_load_options) {
std::cout << "check option " << opt.name << " : " << opt.fname << " : " << opt.is_dir << "\n";
if (!opt.is_dir && opt.name == v.fname) {
std::cout << "file exact match\n";
real_fname = opt.fname;
break;
}
if (opt.is_dir && v.fname.rfind(opt.name, 0) == 0) {
std::cout << "file exact match\n";
real_fname = opt.fname + v.fname.substr(opt.name.length());
#if defined(_WIN32)
std::replace(real_fname.begin(), real_fname.end(), '/', '\\');
#endif
size_t len = real_fname.length();
if (len < 5 || real_fname.compare(len - 5, 5, ".gguf") != 0)
real_fname += ".gguf";
break;
}
}
}
}

if (real_fname.length() == 0) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res_error(res, format_error_response("Control vector not allowed", ERROR_TYPE_SERVER));
return;
}

std::cout << "Add vector: " << v.fname << " -> " << real_fname << " " << v.strength << "\n";
llama_control_vector_load_info real_info = { v.strength, real_fname };
vec_params.push_back(v);
real_vec_params.push_back(real_info);
}
} else {
std::cerr << "No vectors array passed\n";
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res_error(res, format_error_response("No vectors array passed. If you want reset to 0, send an empty array.", ERROR_TYPE_SERVER));
return;
}

const auto cvec = llama_control_vector_load(real_vec_params);

if (cvec.n_embd == -1) {
std::cerr << "Could not load control vector\n";
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res_error(res, format_error_response("Could not load control vector", ERROR_TYPE_SERVER));
return;
}

if (ctx_server.params.control_vector_layer_start <= 0) {
ctx_server.params.control_vector_layer_start = 1;
}
if (ctx_server.params.control_vector_layer_end <= 0){
ctx_server.params.control_vector_layer_end = llama_n_layer(ctx_server.model);
}

int err = llama_control_vector_apply(ctx_server.ctx,
cvec.data.data(),
cvec.data.size(),
cvec.n_embd,
ctx_server.params.control_vector_layer_start,
ctx_server.params.control_vector_layer_end);
if (err) {
std::cerr << "Could not apply control vector\n";
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res_error(res, format_error_response("Could not apply control vector", ERROR_TYPE_SERVER));
return;
}

ctx_server.params.control_vectors.clear();

for (auto v : vec_params) {
std::cout << "set vector param: " << v.fname << " " << v.strength << "\n";
ctx_server.params.control_vectors.push_back(v);
}

handle_get_control_vectors(req, res);
};


const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = {
Expand Down Expand Up @@ -3494,22 +3681,25 @@ int main(int argc, char ** argv) {
json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));

// register API routes
svr->Get ("/health", handle_health);
svr->Get ("/slots", handle_slots);
svr->Get ("/metrics", handle_metrics);
svr->Get ("/props", handle_props);
svr->Get ("/v1/models", handle_models);
svr->Post("/completion", handle_completions); // legacy
svr->Post("/completions", handle_completions);
svr->Post("/v1/completions", handle_completions);
svr->Post("/chat/completions", handle_chat_completions);
svr->Post("/v1/chat/completions", handle_chat_completions);
svr->Post("/infill", handle_infill);
svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize);
svr->Get ("/health", handle_health);
svr->Get ("/slots", handle_slots);
svr->Get ("/metrics", handle_metrics);
svr->Get ("/props", handle_props);
svr->Get ("/v1/models", handle_models);
svr->Get ("/control-vectors", handle_get_control_vectors);
svr->Get ("/control-vector-options", handle_control_vector_options);
svr->Post("/control-vectors", handle_set_control_vectors);
svr->Post("/completion", handle_completions); // legacy
svr->Post("/completions", handle_completions);
svr->Post("/v1/completions", handle_completions);
svr->Post("/chat/completions", handle_chat_completions);
svr->Post("/v1/chat/completions", handle_chat_completions);
svr->Post("/infill", handle_infill);
svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize);

//
// Start the server
Expand Down
11 changes: 11 additions & 0 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,14 @@ static json format_error_response(const std::string & message, const enum error_
{"type", type_str},
};
}

static void from_json(const json& j, llama_control_vector_load_info& l) {
j.at("strength").get_to(l.strength);
j.at("fname").get_to(l.fname);
}

struct llama_control_vector_load_option {
std::string name;
std::string fname;
bool is_dir;
};
Loading