Skip to content

Support MiniCPM-2B-128k #6602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,8 @@ def set_gguf_parameters(self):
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_file_type(self.ftype)
if "tie_lm_head" in self.hparams:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name should be tie_word_embeddings in huggingface's config.

self.gguf_writer.add_tie_lm_head(self.hparams["tie_lm_head"])

def set_vocab(self):
self._set_vocab_llama_hf()
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class LLM:
EXPERT_USED_COUNT = "{arch}.expert_used_count"
POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale"
TIE_LM_HEAD = "{arch}.tie_lm_head"

class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
Expand Down Expand Up @@ -612,6 +613,7 @@ class MODEL_TENSOR(IntEnum):
],
MODEL_ARCH.MINICPM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
Expand Down Expand Up @@ -916,6 +918,7 @@ def get_type(val: Any) -> GGUFValueType:
KEY_FEED_FORWARD_LENGTH = Keys.LLM.FEED_FORWARD_LENGTH
KEY_USE_PARALLEL_RESIDUAL = Keys.LLM.USE_PARALLEL_RESIDUAL
KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT
KEY_TIE_LM_HEAD = Keys.LLM.TIE_LM_HEAD

# attention
KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,9 @@ def add_feed_forward_length(self, length: int) -> None:

def add_parallel_residual(self, use: bool) -> None:
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)

def add_tie_lm_head(self, tie_lm_head: bool) -> None:
self.add_bool(Keys.LLM.TIE_LM_HEAD.format(arch=self.arch), tie_lm_head)

def add_head_count(self, count: int) -> None:
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
Expand Down
27 changes: 14 additions & 13 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ enum llm_kv {
LLM_KV_EXPERT_USED_COUNT,
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_TIE_LM_HEAD,

LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
Expand Down Expand Up @@ -364,6 +365,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_TIE_LM_HEAD, "%s.tie_lm_head" },

{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
Expand Down Expand Up @@ -1845,6 +1847,8 @@ struct llama_hparams {
float f_logit_scale = 0.0f;

bool causal_attn = true;
bool need_kq_pos = false;
bool tie_lm_head = true;
bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models

enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
Expand Down Expand Up @@ -3786,6 +3790,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
ml.get_key(LLM_KV_TIE_LM_HEAD, hparams.tie_lm_head, false);

GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
Expand Down Expand Up @@ -4870,20 +4875,12 @@ static bool llm_load_tensors(
case LLM_ARCH_MINICPM:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
if (!hparams.tie_lm_head){
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
}
Comment on lines +4878 to +4880
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already handle tied tensors few lines below. Maybe you simply have to remove the if (model.arch != LLM_ARCH_MINICPM){ check and this model would work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @ggerganov , we are currently encountering a problem when adapting our new model: the evaluation effect for models smaller than 4k is consistent with vllm, but the evaluation effect for models larger than 4k is not as good as vllm. What may be the reason for this? When evaluating, the example/server method is used, with the following startup and request parameters:

./server -m MiniCPM-2B-128k/ggml-model-f16.gguf --chat-template chatml --rope-freq-base 4129032.258 --host 0.0.0.0 -c 12000

request data:

   data = {"stream": False,
            "n_predict": max_token,
            "temperature": 0.3,
            "stop": ["<|im_end|>", "</s>"],
            "repeat_last_n": 256,
            "repeat_penalty": 1.0,
            "top_k": 40,
            "top_p": 0.5,
            "min_p": 0.05,
            "tfs_z": 1,
            "typical_p": 1,
            "presence_penalty": 0,
            "frequency_penalty": 0,
            "mirostat": 0,
            "mirostat_tau": 5,
            "mirostat_eta": 0.1,
            "grammar": "", "n_probs": 0, "min_keep": 0, "image_data": [], "cache_prompt": True,
            "api_key": "",
            "prompt": f"<|im_start|>user{prompt}<|im_end|><|im_start|>assistant\n"
            }

vllm param:

params_dict = {
    "n": 1,
    "best_of": None,
    "presence_penalty": 0.0, 
    "frequency_penalty": 0.0,
    "repetition_penalty": 1.0,
    "temperature": 0.3, 
    "top_p": 0.5, 
    "top_k": -1,
    "use_beam_search": False,
    "length_penalty": 1.0,
    "early_stopping": False,
    "stop": None,
    "stop_token_ids": None,
    "ignore_eos": False,
    "logprobs": None,
    "prompt_logprobs": None,
    "skip_special_tokens": False,
    "stop": ["<|im_end|>", "</s>"]
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you use --rope-freq-base 4129032.258 when the config specifies 1e6:

https://huggingface.co/openbmb/MiniCPM-2B-128k/blob/main/config.json#L34

Also, this model seems to use some rope scaling:

https://huggingface.co/openbmb/MiniCPM-2B-128k/blob/main/config.json#L25

You need to apply the same thing when starting the server

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model currently uses DynamicNTKScalingRotaryEmbedding. How should I pass parameters?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can run up to 64k without NTK scaling.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@foldl @zkh2016 hi, do we support Dynamic NTK scaling in llama.cpp? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. thanks for your quick reply.


// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
if (model.arch != LLM_ARCH_MINICPM){
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
// if output is NULL, init from the input tok embed
if (model.output == NULL) {
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
ml.n_created--; // artificial tensor
ml.size_data += ggml_nbytes(model.output);
}
}
}
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});

for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
Expand Down Expand Up @@ -10041,7 +10038,11 @@ struct llm_build_context {
cb(cur, "lmhead_scaling", -1);

// lm_head
cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
if (hparams.tie_lm_head){
cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
}else{
cur = ggml_mul_mat(ctx0, model.output, cur);
}
cb(cur, "result_output", -1);

ggml_build_forward_expand(gf, cur);
Expand Down
Loading