Skip to content

Commit b43bfe8

Browse files
committed
llama: implement NTK-By-Parts (NTKv2)
1 parent d01bccd commit b43bfe8

File tree

7 files changed

+192
-53
lines changed

7 files changed

+192
-53
lines changed

examples/common.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
180180
break;
181181
}
182182
params.rope_freq_scale = std::stof(argv[i]);
183+
} else if (arg == "--rope-ntk-factor") {
184+
if (++i >= argc) {
185+
invalid_param = true;
186+
break;
187+
}
188+
params.rope_ntk_factor = std::stof(argv[i]);
189+
} else if (arg == "--rope-extrapolation-factor") {
190+
if (++i >= argc) {
191+
invalid_param = true;
192+
break;
193+
}
194+
params.rope_extrapolation_factor = std::stof(argv[i]);
183195
} else if (arg == "--memory-f32") {
184196
params.memory_f16 = false;
185197
} else if (arg == "--top-p") {
@@ -513,6 +525,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
513525
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
514526
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
515527
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
528+
fprintf(stderr, " --rope-ntk-factor N RoPE NTK mix factor (default: %.1f)\n", params.rope_ntk_factor);
529+
fprintf(stderr, " --rope-extrapolation-factor N\n");
530+
fprintf(stderr, " RoPE extrapolation mix factor (default: %.1f)\n", params.rope_extrapolation_factor);
516531
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
517532
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
518533
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
@@ -596,6 +611,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
596611
lparams.embedding = params.embedding;
597612
lparams.rope_freq_base = params.rope_freq_base;
598613
lparams.rope_freq_scale = params.rope_freq_scale;
614+
lparams.rope_ntk_factor = params.rope_ntk_factor;
615+
lparams.rope_extrapolation_factor = params.rope_extrapolation_factor;
599616

600617
return lparams;
601618
}

examples/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ struct gpt_params {
3535
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
3636
float rope_freq_base = 10000.0f; // RoPE base frequency
3737
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
38+
float rope_ntk_factor = 1.0f; // RoPE NTK mix factor
39+
float rope_extrapolation_factor = 1.0f; // RoPE extrapolation mix factor
3840

3941
// sampling parameters
4042
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

examples/server/server.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,9 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
610610
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
611611
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
612612
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
613+
fprintf(stderr, " --rope-ntk-factor N RoPE NTK mix factor (default: %.1f)\n", params.rope_ntk_factor);
614+
fprintf(stderr, " --rope-extrapolation-factor N\n");
615+
fprintf(stderr, " RoPE extrapolation mix factor (default: %.1f)\n", params.rope_extrapolation_factor);
613616
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
614617
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
615618
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -740,6 +743,22 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
740743
}
741744
params.rope_freq_scale = std::stof(argv[i]);
742745
}
746+
else if (arg == "--rope-ntk-factor")
747+
{
748+
if (++i >= argc) {
749+
invalid_param = true;
750+
break;
751+
}
752+
params.rope_ntk_factor = std::stof(argv[i]);
753+
}
754+
else if (arg == "--rope-extrapolation-factor")
755+
{
756+
if (++i >= argc) {
757+
invalid_param = true;
758+
break;
759+
}
760+
params.rope_extrapolation_factor = std::stof(argv[i]);
761+
}
743762
else if (arg == "--memory-f32" || arg == "--memory_f32")
744763
{
745764
params.memory_f16 = false;

0 commit comments

Comments
 (0)