-
Notifications
You must be signed in to change notification settings - Fork 12k
Support MiniCPM-2B-128k #6602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support MiniCPM-2B-128k #6602
Changes from all commits
e913ac9
9ecc666
8502a01
582b13c
4f61b30
4da6e3e
1cd0a03
f63f147
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -287,6 +287,7 @@ enum llm_kv { | |
LLM_KV_EXPERT_USED_COUNT, | ||
LLM_KV_POOLING_TYPE, | ||
LLM_KV_LOGIT_SCALE, | ||
LLM_KV_TIE_LM_HEAD, | ||
|
||
LLM_KV_ATTENTION_HEAD_COUNT, | ||
LLM_KV_ATTENTION_HEAD_COUNT_KV, | ||
|
@@ -364,6 +365,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { | |
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, | ||
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" }, | ||
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, | ||
{ LLM_KV_TIE_LM_HEAD, "%s.tie_lm_head" }, | ||
|
||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, | ||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, | ||
|
@@ -1845,6 +1847,8 @@ struct llama_hparams { | |
float f_logit_scale = 0.0f; | ||
|
||
bool causal_attn = true; | ||
bool need_kq_pos = false; | ||
bool tie_lm_head = true; | ||
bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models | ||
|
||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; | ||
|
@@ -3786,6 +3790,7 @@ static void llm_load_hparams( | |
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); | ||
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); | ||
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); | ||
ml.get_key(LLM_KV_TIE_LM_HEAD, hparams.tie_lm_head, false); | ||
|
||
GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); | ||
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); | ||
|
@@ -4870,20 +4875,12 @@ static bool llm_load_tensors( | |
case LLM_ARCH_MINICPM: | ||
{ | ||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); | ||
if (!hparams.tie_lm_head){ | ||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); | ||
} | ||
Comment on lines
+4878
to
+4880
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already handle tied tensors few lines below. Maybe you simply have to remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hello @ggerganov , we are currently encountering a problem when adapting our new model: the evaluation effect for models smaller than 4k is consistent with vllm, but the evaluation effect for models larger than 4k is not as good as vllm. What may be the reason for this? When evaluating, the example/server method is used, with the following startup and request parameters:
request data:
vllm param:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you use https://huggingface.co/openbmb/MiniCPM-2B-128k/blob/main/config.json#L34 Also, this model seems to use some rope scaling: https://huggingface.co/openbmb/MiniCPM-2B-128k/blob/main/config.json#L25 You need to apply the same thing when starting the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The model currently uses DynamicNTKScalingRotaryEmbedding. How should I pass parameters? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can run up to 64k without NTK scaling. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we don't support it. Anyway, isn't 64k context length enough? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. thanks for your quick reply. |
||
|
||
// output | ||
{ | ||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); | ||
if (model.arch != LLM_ARCH_MINICPM){ | ||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); | ||
// if output is NULL, init from the input tok embed | ||
if (model.output == NULL) { | ||
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); | ||
ml.n_created--; // artificial tensor | ||
ml.size_data += ggml_nbytes(model.output); | ||
} | ||
} | ||
} | ||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); | ||
|
||
for (int i = 0; i < n_layer; ++i) { | ||
ggml_context * ctx_layer = ctx_for_layer(i); | ||
|
@@ -10041,7 +10038,11 @@ struct llm_build_context { | |
cb(cur, "lmhead_scaling", -1); | ||
|
||
// lm_head | ||
cur = ggml_mul_mat(ctx0, model.tok_embd, cur); | ||
if (hparams.tie_lm_head){ | ||
cur = ggml_mul_mat(ctx0, model.tok_embd, cur); | ||
}else{ | ||
cur = ggml_mul_mat(ctx0, model.output, cur); | ||
} | ||
cb(cur, "result_output", -1); | ||
|
||
ggml_build_forward_expand(gf, cur); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name should be
tie_word_embeddings
in huggingface's config.