Skip to content

Commit 8dec38c

Browse files
committed
llama: implement NTK-By-Parts (NTKv2) RoPE scaling
1 parent 93356bd commit 8dec38c

File tree

7 files changed

+189
-50
lines changed

7 files changed

+189
-50
lines changed

examples/common.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
194194
break;
195195
}
196196
params.rope_freq_scale = std::stof(argv[i]);
197+
} else if (arg == "--rope-ntk-factor") {
198+
if (++i >= argc) {
199+
invalid_param = true;
200+
break;
201+
}
202+
params.rope_ntk_factor = std::stof(argv[i]);
203+
} else if (arg == "--rope-ext-factor") {
204+
if (++i >= argc) {
205+
invalid_param = true;
206+
break;
207+
}
208+
params.rope_ext_factor = std::stof(argv[i]);
197209
} else if (arg == "--memory-f32") {
198210
params.memory_f16 = false;
199211
} else if (arg == "--top-p") {
@@ -566,6 +578,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
566578
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
567579
fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
568580
fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
581+
fprintf(stdout, " --rope-ntk-factor N RoPE NTK mix factor (default: %.1f)\n", params.rope_ntk_factor);
582+
fprintf(stdout, " --rope-ext-factor N RoPE extrapolation mix factor (default: %.1f)\n", params.rope_ext_factor);
569583
fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
570584
fprintf(stdout, " --no-penalize-nl do not penalize newline token\n");
571585
fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
@@ -657,6 +671,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
657671
lparams.embedding = params.embedding;
658672
lparams.rope_freq_base = params.rope_freq_base;
659673
lparams.rope_freq_scale = params.rope_freq_scale;
674+
lparams.rope_ntk_factor = params.rope_ntk_factor;
675+
lparams.rope_ext_factor = params.rope_ext_factor;
660676

661677
return lparams;
662678
}

examples/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ struct gpt_params {
3232
float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; // rms norm epsilon
3333
float rope_freq_base = 10000.0f; // RoPE base frequency
3434
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
35+
float rope_ntk_factor = 0.0f; // RoPE NTK mix factor
36+
float rope_ext_factor = 0.0f; // RoPE extrapolation mix factor
3537

3638
// sampling parameters
3739
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

examples/server/server.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
612612
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
613613
fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
614614
fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
615+
fprintf(stdout, " --rope-ntk-factor N RoPE NTK mix factor (default: %.1f)\n", params.rope_ntk_factor);
616+
fprintf(stdout, " --rope-ext-factor N RoPE extrapolation mix factor (default: %.1f)\n", params.rope_ext_factor);
615617
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
616618
fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
617619
fprintf(stdout, " not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -764,6 +766,22 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
764766
}
765767
params.rope_freq_scale = std::stof(argv[i]);
766768
}
769+
else if (arg == "--rope-ntk-factor")
770+
{
771+
if (++i >= argc) {
772+
invalid_param = true;
773+
break;
774+
}
775+
params.rope_ntk_factor = std::stof(argv[i]);
776+
}
777+
else if (arg == "--rope-ext-factor")
778+
{
779+
if (++i >= argc) {
780+
invalid_param = true;
781+
break;
782+
}
783+
params.rope_ext_factor = std::stof(argv[i]);
784+
}
767785
else if (arg == "--memory-f32" || arg == "--memory_f32")
768786
{
769787
params.memory_f16 = false;

0 commit comments

Comments
 (0)