Skip to content

Commit a76e73b

Browse files
ggerganovslaren
authored andcommitted
llama : support quantum K cache (ggml-org#4312)
* llama : support quantum K cache (wip) * metal : add F32 -> Q8_0 copy kernel * cuda : add F32 -> Q8_0 copy kernel ggml-ci * cuda : use mmv kernel for quantum cache ops * llama : pass KV cache type through API * llama : fix build ggml-ci * metal : add F32 -> Q4_0 copy kernel * metal : add F32 -> Q4_1 copy kernel * cuda : wip * cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels * llama-bench : support type_k/type_v * metal : use mm kernel only for quantum KV cache * cuda : add comment * llama : remove memory_f16 and kv_f16 flags --------- Co-authored-by: slaren <slarengh@gmail.com>
1 parent c751152 commit a76e73b

File tree

10 files changed

+575
-74
lines changed

10 files changed

+575
-74
lines changed

common/common.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
279279
break;
280280
}
281281
params.yarn_beta_slow = std::stof(argv[i]);
282-
} else if (arg == "--memory-f32") {
283-
params.memory_f16 = false;
284282
} else if (arg == "--top-p") {
285283
if (++i >= argc) {
286284
invalid_param = true;
@@ -499,6 +497,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
499497
params.infill = true;
500498
} else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
501499
params.dump_kv_cache = true;
500+
} else if (arg == "-nkvo" || arg == "--no-kv-offload") {
501+
params.no_kv_offload = true;
502+
} else if (arg == "-ctk" || arg == "--cache-type-k") {
503+
params.cache_type_k = argv[++i];
504+
} else if (arg == "-ctv" || arg == "--cache-type-v") {
505+
params.cache_type_v = argv[++i];
502506
} else if (arg == "--multiline-input") {
503507
params.multiline_input = true;
504508
} else if (arg == "--simple-io") {
@@ -799,8 +803,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
799803
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
800804
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
801805
printf(" --no-penalize-nl do not penalize newline token\n");
802-
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
803-
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
804806
printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
805807
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
806808
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
@@ -841,6 +843,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
841843
printf(" --verbose-prompt print prompt before generation\n");
842844
printf(" -dkvc, --dump-kv-cache\n");
843845
printf(" verbose print of the KV cache\n");
846+
printf(" -nkvo, --no-kv-offload\n");
847+
printf(" disable KV offload\n");
848+
printf(" -ctk TYPE, --cache-type-k TYPE\n");
849+
printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str());
850+
printf(" -ctv TYPE, --cache-type-v TYPE\n");
851+
printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str());
844852
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
845853
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
846854
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
@@ -905,6 +913,29 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
905913
return mparams;
906914
}
907915

916+
static ggml_type kv_cache_type_from_str(const std::string & s) {
917+
if (s == "f16") {
918+
return GGML_TYPE_F16;
919+
}
920+
if (s == "q8_0") {
921+
return GGML_TYPE_Q8_0;
922+
}
923+
if (s == "q4_0") {
924+
return GGML_TYPE_Q4_0;
925+
}
926+
if (s == "q4_1") {
927+
return GGML_TYPE_Q4_1;
928+
}
929+
if (s == "q5_0") {
930+
return GGML_TYPE_Q5_0;
931+
}
932+
if (s == "q5_1") {
933+
return GGML_TYPE_Q5_1;
934+
}
935+
936+
throw std::runtime_error("Invalid cache type: " + s);
937+
}
938+
908939
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
909940
auto cparams = llama_context_default_params();
910941

@@ -914,7 +945,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
914945
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
915946
cparams.mul_mat_q = params.mul_mat_q;
916947
cparams.seed = params.seed;
917-
cparams.f16_kv = params.memory_f16;
918948
cparams.logits_all = params.logits_all;
919949
cparams.embedding = params.embedding;
920950
cparams.rope_scaling_type = params.rope_scaling_type;
@@ -926,6 +956,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
926956
cparams.yarn_beta_slow = params.yarn_beta_slow;
927957
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
928958

959+
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
960+
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
961+
929962
return cparams;
930963
}
931964

@@ -1337,7 +1370,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
13371370
}
13381371
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
13391372
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
1340-
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
13411373
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
13421374
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
13431375
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);

common/common.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ struct gpt_params {
106106
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
107107

108108
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
109-
bool memory_f16 = true; // use f16 instead of f32 for memory kv
110109
bool random_prompt = false; // do not randomize prompt if none provided
111110
bool use_color = false; // use color to distinguish generations and inputs
112111
bool interactive = false; // interactive mode
@@ -132,9 +131,12 @@ struct gpt_params {
132131
bool infill = false; // use infill mode
133132
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
134133

134+
std::string cache_type_k = "f16"; // KV cache data type for the K
135+
std::string cache_type_v = "f16"; // KV cache data type for the V
136+
135137
// multimodal models (see examples/llava)
136138
std::string mmproj = ""; // path to multimodal projector
137-
std::string image = ""; // path to an image file
139+
std::string image = ""; // path to an image file
138140
};
139141

140142
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);

examples/llama-bench/llama-bench.cpp

Lines changed: 91 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ static std::vector<T> split(const std::string & str, char delim) {
5454
return values;
5555
}
5656

57+
template<typename T, typename F>
58+
static std::vector<std::string> transform_to_str(const std::vector<T> & values, F f) {
59+
std::vector<std::string> str_values;
60+
std::transform(values.begin(), values.end(), std::back_inserter(str_values), f);
61+
return str_values;
62+
}
63+
5764
template<typename T>
5865
static T avg(const std::vector<T> & v) {
5966
if (v.empty()) {
@@ -127,7 +134,8 @@ struct cmd_params {
127134
std::vector<int> n_prompt;
128135
std::vector<int> n_gen;
129136
std::vector<int> n_batch;
130-
std::vector<bool> f32_kv;
137+
std::vector<ggml_type> type_k;
138+
std::vector<ggml_type> type_v;
131139
std::vector<int> n_threads;
132140
std::vector<int> n_gpu_layers;
133141
std::vector<int> main_gpu;
@@ -143,7 +151,8 @@ static const cmd_params cmd_params_defaults = {
143151
/* n_prompt */ {512},
144152
/* n_gen */ {128},
145153
/* n_batch */ {512},
146-
/* f32_kv */ {false},
154+
/* type_k */ {GGML_TYPE_F16},
155+
/* type_v */ {GGML_TYPE_F16},
147156
/* n_threads */ {get_num_physical_cores()},
148157
/* n_gpu_layers */ {99},
149158
/* main_gpu */ {0},
@@ -163,7 +172,8 @@ static void print_usage(int /* argc */, char ** argv) {
163172
printf(" -p, --n-prompt <n> (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
164173
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
165174
printf(" -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
166-
printf(" --memory-f32 <0|1> (default: %s)\n", join(cmd_params_defaults.f32_kv, ",").c_str());
175+
printf(" -ctk <t>, --cache-type-k <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
176+
printf(" -ctv <t>, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
167177
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
168178
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
169179
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
@@ -174,9 +184,32 @@ static void print_usage(int /* argc */, char ** argv) {
174184
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
175185
printf("\n");
176186
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
187+
}
177188

189+
static ggml_type ggml_type_from_name(const std::string & s) {
190+
if (s == "f16") {
191+
return GGML_TYPE_F16;
192+
}
193+
if (s == "q8_0") {
194+
return GGML_TYPE_Q8_0;
195+
}
196+
if (s == "q4_0") {
197+
return GGML_TYPE_Q4_0;
198+
}
199+
if (s == "q4_1") {
200+
return GGML_TYPE_Q4_1;
201+
}
202+
if (s == "q5_0") {
203+
return GGML_TYPE_Q5_0;
204+
}
205+
if (s == "q5_1") {
206+
return GGML_TYPE_Q5_1;
207+
}
208+
209+
return GGML_TYPE_COUNT;
178210
}
179211

212+
180213
static cmd_params parse_cmd_params(int argc, char ** argv) {
181214
cmd_params params;
182215
std::string arg;
@@ -225,13 +258,38 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
225258
}
226259
auto p = split<int>(argv[i], split_delim);
227260
params.n_batch.insert(params.n_batch.end(), p.begin(), p.end());
228-
} else if (arg == "--memory-f32") {
261+
} else if (arg == "-ctk" || arg == "--cache-type-k") {
229262
if (++i >= argc) {
230263
invalid_param = true;
231264
break;
232265
}
233-
auto p = split<int>(argv[i], split_delim);
234-
params.f32_kv.insert(params.f32_kv.end(), p.begin(), p.end());
266+
auto p = split<std::string>(argv[i], split_delim);
267+
std::vector<ggml_type> types;
268+
for (const auto & t : p) {
269+
ggml_type gt = ggml_type_from_name(t);
270+
if (gt == GGML_TYPE_COUNT) {
271+
invalid_param = true;
272+
break;
273+
}
274+
types.push_back(gt);
275+
}
276+
params.type_k.insert(params.type_k.end(), types.begin(), types.end());
277+
} else if (arg == "-ctv" || arg == "--cache-type-v") {
278+
if (++i >= argc) {
279+
invalid_param = true;
280+
break;
281+
}
282+
auto p = split<std::string>(argv[i], split_delim);
283+
std::vector<ggml_type> types;
284+
for (const auto & t : p) {
285+
ggml_type gt = ggml_type_from_name(t);
286+
if (gt == GGML_TYPE_COUNT) {
287+
invalid_param = true;
288+
break;
289+
}
290+
types.push_back(gt);
291+
}
292+
params.type_v.insert(params.type_v.end(), types.begin(), types.end());
235293
} else if (arg == "-t" || arg == "--threads") {
236294
if (++i >= argc) {
237295
invalid_param = true;
@@ -322,7 +380,8 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
322380
if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; }
323381
if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; }
324382
if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; }
325-
if (params.f32_kv.empty()) { params.f32_kv = cmd_params_defaults.f32_kv; }
383+
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
384+
if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; }
326385
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
327386
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
328387
if (params.mul_mat_q.empty()) { params.mul_mat_q = cmd_params_defaults.mul_mat_q; }
@@ -337,7 +396,8 @@ struct cmd_params_instance {
337396
int n_prompt;
338397
int n_gen;
339398
int n_batch;
340-
bool f32_kv;
399+
ggml_type type_k;
400+
ggml_type type_v;
341401
int n_threads;
342402
int n_gpu_layers;
343403
int main_gpu;
@@ -366,7 +426,8 @@ struct cmd_params_instance {
366426

367427
cparams.n_ctx = n_prompt + n_gen;
368428
cparams.n_batch = n_batch;
369-
cparams.f16_kv = !f32_kv;
429+
cparams.type_k = type_k;
430+
cparams.type_v = type_v;
370431
cparams.mul_mat_q = mul_mat_q;
371432

372433
return cparams;
@@ -381,15 +442,17 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
381442
for (const auto & mg : params.main_gpu)
382443
for (const auto & ts : params.tensor_split)
383444
for (const auto & nb : params.n_batch)
384-
for (const auto & fk : params.f32_kv)
445+
for (const auto & tk : params.type_k)
446+
for (const auto & tv : params.type_v)
385447
for (const auto & mmq : params.mul_mat_q)
386448
for (const auto & nt : params.n_threads) {
387449
cmd_params_instance instance = {
388450
/* .model = */ m,
389451
/* .n_prompt = */ n_prompt,
390452
/* .n_gen = */ n_gen,
391453
/* .n_batch = */ nb,
392-
/* .f32_kv = */ fk,
454+
/* .type_k = */ tk,
455+
/* .type_v = */ tv,
393456
/* .n_threads = */ nt,
394457
/* .n_gpu_layers = */ nl,
395458
/* .main_gpu = */ mg,
@@ -411,7 +474,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
411474
for (const auto & mg : params.main_gpu)
412475
for (const auto & ts : params.tensor_split)
413476
for (const auto & nb : params.n_batch)
414-
for (const auto & fk : params.f32_kv)
477+
for (const auto & tk : params.type_k)
478+
for (const auto & tv : params.type_v)
415479
for (const auto & mmq : params.mul_mat_q)
416480
for (const auto & nt : params.n_threads) {
417481
for (const auto & n_prompt : params.n_prompt) {
@@ -423,7 +487,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
423487
/* .n_prompt = */ n_prompt,
424488
/* .n_gen = */ 0,
425489
/* .n_batch = */ nb,
426-
/* .f32_kv = */ fk,
490+
/* .type_k = */ tk,
491+
/* .type_v = */ tv,
427492
/* .n_threads = */ nt,
428493
/* .n_gpu_layers = */ nl,
429494
/* .main_gpu = */ mg,
@@ -442,7 +507,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
442507
/* .n_prompt = */ 0,
443508
/* .n_gen = */ n_gen,
444509
/* .n_batch = */ nb,
445-
/* .f32_kv = */ fk,
510+
/* .type_k = */ tk,
511+
/* .type_v = */ tv,
446512
/* .n_threads = */ nt,
447513
/* .n_gpu_layers = */ nl,
448514
/* .main_gpu = */ mg,
@@ -490,7 +556,8 @@ struct test {
490556
uint64_t model_n_params;
491557
int n_batch;
492558
int n_threads;
493-
bool f32_kv;
559+
ggml_type type_k;
560+
ggml_type type_v;
494561
int n_gpu_layers;
495562
int main_gpu;
496563
bool mul_mat_q;
@@ -509,7 +576,8 @@ struct test {
509576
model_n_params = llama_model_n_params(lmodel);
510577
n_batch = inst.n_batch;
511578
n_threads = inst.n_threads;
512-
f32_kv = inst.f32_kv;
579+
type_k = inst.type_k;
580+
type_v = inst.type_v;
513581
n_gpu_layers = inst.n_gpu_layers;
514582
main_gpu = inst.main_gpu;
515583
mul_mat_q = inst.mul_mat_q;
@@ -572,7 +640,7 @@ struct test {
572640
"cuda", "opencl", "metal", "gpu_blas", "blas",
573641
"cpu_info", "gpu_info",
574642
"model_filename", "model_type", "model_size", "model_n_params",
575-
"n_batch", "n_threads", "f16_kv",
643+
"n_batch", "n_threads", "type_k", "type_v",
576644
"n_gpu_layers", "main_gpu", "mul_mat_q", "tensor_split",
577645
"n_prompt", "n_gen", "test_time",
578646
"avg_ns", "stddev_ns",
@@ -622,7 +690,7 @@ struct test {
622690
std::to_string(cuda), std::to_string(opencl), std::to_string(metal), std::to_string(gpu_blas), std::to_string(blas),
623691
cpu_info, gpu_info,
624692
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
625-
std::to_string(n_batch), std::to_string(n_threads), std::to_string(!f32_kv),
693+
std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
626694
std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), tensor_split_str,
627695
std::to_string(n_prompt), std::to_string(n_gen), test_time,
628696
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -806,8 +874,11 @@ struct markdown_printer : public printer {
806874
if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
807875
fields.push_back("n_batch");
808876
}
809-
if (params.f32_kv.size() > 1 || params.f32_kv != cmd_params_defaults.f32_kv) {
810-
fields.push_back("f16_kv");
877+
if (params.type_k.size() > 1 || params.type_k != cmd_params_defaults.type_k) {
878+
fields.push_back("type_k");
879+
}
880+
if (params.type_v.size() > 1 || params.type_v != cmd_params_defaults.type_v) {
881+
fields.push_back("type_v");
811882
}
812883
if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) {
813884
fields.push_back("main_gpu");

examples/quantize-stats/quantize-stats.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ int main(int argc, char ** argv) {
322322
auto cparams = llama_context_default_params();
323323
cparams.n_ctx = 256;
324324
cparams.seed = 1;
325-
cparams.f16_kv = false;
326325

327326
ctx = llama_new_context_with_model(model, cparams);
328327

examples/server/server.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2109,10 +2109,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
21092109
}
21102110
params.yarn_beta_slow = std::stof(argv[i]);
21112111
}
2112-
else if (arg == "--memory-f32" || arg == "--memory_f32")
2113-
{
2114-
params.memory_f16 = false;
2115-
}
21162112
else if (arg == "--threads" || arg == "-t")
21172113
{
21182114
if (++i >= argc)

0 commit comments

Comments
 (0)